]> git.parisson.com Git - timeside.git/commitdiff
Decoder: Add ArrayDecoder : a fake decoder that take a numpy array as input
authorThomas Fillon <thomas@parisson.com>
Tue, 19 Nov 2013 14:54:07 +0000 (15:54 +0100)
committerThomas Fillon <thomas@parisson.com>
Tue, 19 Nov 2013 14:54:07 +0000 (15:54 +0100)
tests/test_array_decoding.py [new file with mode: 0644]
timeside/decoder/core.py

diff --git a/tests/test_array_decoding.py b/tests/test_array_decoding.py
new file mode 100644 (file)
index 0000000..64d21b1
--- /dev/null
@@ -0,0 +1,156 @@
+#! /usr/bin/env python
+from __future__ import division
+
+from timeside.decoder.core import ArrayDecoder
+from unit_timeside import *
+
+
+import numpy as np
+
+
+class TestDecoding(TestCase):
+
+    "Test decoding for ArrayDecoder"
+
+    def setUp(self):
+        self.samplerate, self.channels, self.blocksize = None, None, None
+        self.start = 0
+        self.duration = None
+        self.array_duration = 8
+        self.expected_duration = self.array_duration
+        self.expected_is_segment = False
+
+
+    def test1DArray(self):
+        "Test 1D Array decoding"
+        self.source_samplerate = 44100
+        self.source = np.random.randn(
+            self.array_duration * self.source_samplerate,)
+        self.source_channels = 1
+
+
+    def test2DArrayMono(self):
+        "Test 2D Array mono decoding"
+        self.source_samplerate = 32000
+        self.source = np.random.randn(
+            self.array_duration * self.source_samplerate, 1)
+        self.source_channels = 1
+
+    def test2DArrayStereo(self):
+        "Test 2D Array stereo decoding"
+        self.source_samplerate = 22050
+        self.source = np.random.randn(
+            self.array_duration * self.source_samplerate, 2)
+        self.source_channels = 2
+
+    def test2DArrayMultiChannel(self):
+        "Test 2D Array multi-channel decoding"
+        self.source_samplerate = 16000
+        self.source = np.random.randn(
+            self.array_duration * self.source_samplerate, 5)
+        self.source_channels = 5
+
+    def tearDown(self):
+        decoder = ArrayDecoder(samples=self.source,
+                               samplerate=self.source_samplerate,
+                               start=self.start,
+                               duration=self.duration)
+
+        decoder.setup(samplerate=self.samplerate, channels=self.channels,
+                      blocksize=self.blocksize)
+
+
+        # Check input
+        self.assertEqual(self.source_samplerate, decoder.input_samplerate)
+        self.assertEqual(self.expected_is_segment, decoder.is_segment)
+        self.assertEqual(self.expected_duration, decoder.input_duration)
+        self.assertEqual(self.source_channels, decoder.input_channels)
+        # Check output
+        self.assertEqual(self.source_samplerate, decoder.samplerate())
+        self.assertEqual(self.source_channels, decoder.channels())
+
+        # Check Idecoder interface
+        self.assertIsInstance(decoder.mediainfo(), dict)
+        self.assertIsInstance(decoder.format(), str)
+        self.assertIsInstance(decoder.encoding(), str)
+        self.assertIsInstance(decoder.resolution(), int)
+        self.assertIsNone(decoder.metadata())
+
+
+        totalframes = 0
+
+        while True:
+            frames, eod = decoder.process()
+            totalframes += frames.shape[0]
+            if eod:
+                break
+            self.assertEqual(frames.shape[0], decoder.blocksize())
+            self.assertEqual(frames.shape[1], decoder.channels())
+
+        if self.channels:
+            # when specified, check that the channels are the ones requested
+            self.assertEqual(self.channels, decoder.output_channels)
+        else:
+            # otherwise check that the channels are preserved, if not specified
+            self.assertEqual(decoder.input_channels, decoder.output_channels)
+            # and if we know the expected channels, check the output match
+            if self.source_channels:
+                self.assertEqual(
+                    self.source_channels, decoder.output_channels)
+        # do the same with the sampling rate
+        if self.samplerate:
+            self.assertEqual(self.samplerate, decoder.output_samplerate)
+        else:
+            self.assertEqual(
+                decoder.input_samplerate, decoder.output_samplerate)
+
+
+        self.assertEqual(totalframes, self.expected_duration * decoder.output_samplerate)
+
+
+class TestDecodingSegment(TestDecoding):
+
+    def setUp(self):
+        super(TestDecodingSegment, self).setUp()
+        self.start = 1
+        self.duration = 3
+        self.expected_is_segment = True
+        self.expected_duration = self.duration
+
+
+
+
+class TestDecodingSegmentDefaultStart(TestDecodingSegment):
+
+    def setUp(self):
+        super(TestDecodingSegmentDefaultStart, self).setUp()
+        self.start = 0
+        self.duration = 1
+        self.expected_duration = self.duration
+
+
+class TestDecodingSegmentDefaultDuration(TestDecodingSegment):
+
+    def setUp(self):
+        super(TestDecodingSegmentDefaultDuration, self).setUp()
+        self.start = 1
+        self.duration = None
+        self.expected_duration = self.array_duration - self.start
+
+
+class TestDecodingShortBlock(TestDecoding):
+
+    def setUp(self):
+        super(TestDecodingShortBlock, self).setUp()
+        self.blocksize = 256
+
+
+class TestDecodingLongBlock(TestDecoding):
+
+    def setUp(self):
+        super(TestDecodingLongBlock, self).setUp()
+        self.blocksize = 1024 * 8 * 2
+
+
+if __name__ == '__main__':
+    unittest.main(testRunner=TestRunner())
index d4cff1cf34357e6aaa22e20d103b40cf83e04f31..616da054c087cee973dc0909817a5202a95956bc 100644 (file)
@@ -1,10 +1,10 @@
 #!/usr/bin/python
 # -*- coding: utf-8 -*-
 
-# Copyright (c) 2007-2011 Parisson
+# Copyright (c) 2007-2013 Parisson
 # Copyright (c) 2007 Olivier Guilyardi <olivier@samalyse.com>
-# Copyright (c) 2007-2011 Guillaume Pellerin <pellerin@parisson.com>
-# Copyright (c) 2010-2011 Paul Brossier <piem@piem.org>
+# Copyright (c) 2007-2013 Guillaume Pellerin <pellerin@parisson.com>
+# Copyright (c) 2010-2013 Paul Brossier <piem@piem.org>
 #
 # This file is part of TimeSide.
 
 # You should have received a copy of the GNU General Public License
 # along with TimeSide.  If not, see <http://www.gnu.org/licenses/>.
 
-# Authors: Paul Brossier <piem@piem.org>
+# Authors:
+# Paul Brossier <piem@piem.org>
 # Guillaume Pellerin <yomguy@parisson.com>
+# Thomas Fillon <thomas@parisson.com>
+
+from __future__ import division
 
 from timeside.core import Processor, implements, interfacedoc
 from timeside.api import IDecoder
@@ -30,7 +34,7 @@ from timeside.tools import *
 
 import Queue
 from gst import _gst as gst
-from numpy import int64, uint64
+import numpy as np
 
 
 GST_APPSINK_MAX_BUFFERS = 10
@@ -105,8 +109,8 @@ class FileDecoder(Processor):
             uri_info = uri_discoverer.discover_uri(self.uri)
         except  GError as e:
             raise IOError(e)
-        self.uri_duration = (uri_info.get_duration() / float(gst.SECOND) -
-                            self.uri_start)
+        self.uri_duration = (uri_info.get_duration() / gst.SECOND
+                                - self.uri_start)
 
     def setup(self, channels=None, samplerate=None, blocksize=None):
 
@@ -137,8 +141,8 @@ class FileDecoder(Processor):
                             ! audioresample
                             ! appsink name=sink sync=False async=True
                             '''.format(uri = self.uri,
-                                       uri_start = uint64(round(self.uri_start * gst.SECOND)),
-                                       uri_duration = int64(round(self.uri_duration * gst.SECOND)))
+                                       uri_start = np.uint64(round(self.uri_start * gst.SECOND)),
+                                       uri_duration = np.int64(round(self.uri_duration * gst.SECOND)))
                                        # convert uri_start and uri_duration to nanoseconds
         else:
             # Create the pipe with standard Gstreamer uridecodbin
@@ -278,14 +282,13 @@ class FileDecoder(Processor):
             pass
 
     def _on_new_buffer_cb(self, sink):
-        from numpy import concatenate
         buf = sink.emit('pull-buffer')
         new_array = gst_buffer_to_numpy_array(buf, self.output_channels)
         #print 'processing new buffer', new_array.shape
         if self.last_buffer is None:
             self.last_buffer = new_array
         else:
-            self.last_buffer = concatenate((self.last_buffer, new_array), axis=0)
+            self.last_buffer = np.concatenate((self.last_buffer, new_array), axis=0)
         while self.last_buffer.shape[0] >= self.output_blocksize:
             new_block = self.last_buffer[:self.output_blocksize]
             self.last_buffer = self.last_buffer[self.output_blocksize:]
@@ -317,7 +320,7 @@ class FileDecoder(Processor):
         if self.input_samplerate == self.output_samplerate:
             return self.input_totalframes
         else:
-            ratio = float(self.output_samplerate) / self.input_samplerate
+            ratio = self.output_samplerate / self.input_samplerate
             return int(self.input_totalframes * ratio)
 
     @interfacedoc
@@ -352,9 +355,174 @@ class FileDecoder(Processor):
     @interfacedoc
     def resolution(self):
         # TODO check: width or depth?
-        return self.audiowidth
+        return self.input_width
 
     @interfacedoc
     def metadata(self):
         # TODO check
         return self.tags
+
+
+class ArrayDecoder(Processor):
+    """ Decoder taking Numpy array as input"""
+    implements(IDecoder)
+
+    mimetype = ''
+    output_blocksize = 8*1024
+    output_samplerate = None
+    output_channels = None
+
+    # IProcessor methods
+
+    @staticmethod
+    @interfacedoc
+    def id():
+        return "array_dec"
+
+    def __init__(self, samples, samplerate=44100, start=0, duration=None):
+        '''
+            Construct a new ArrayDecoder from an numpy array
+
+            Parameters
+            ----------
+            samples : numpy array of dimension 1 (mono) or 2 (multichannel)
+                    if shape = (n) or (n,1) : n samples, mono
+                    if shape = (n,m) : n samples with m channels
+            start : float
+                start time of the segment in seconds
+            duration : float
+                duration of the segment in seconds
+        '''
+        super(ArrayDecoder, self).__init__()
+
+        # Check array dimension
+        if samples.ndim > 2:
+            raise TypeError('Wrong number of dimensions for argument samples')
+        if samples.ndim == 1:
+            samples = samples[:, np.newaxis]  # reshape to 2D array
+
+        self.samples = samples  # Create a 2 dimensions array
+        self.input_samplerate = samplerate
+        self.input_channels = self.samples.shape[1]
+
+        self.uri = '_'.join(['raw_audio_array',
+                            'x'.join([str(dim) for dim in samples.shape]),
+                             samples.dtype.type.__name__])
+
+        self.uri_start = float(start)
+        if duration:
+            self.uri_duration = float(duration)
+        else:
+            self.uri_duration = duration
+
+        if start == 0 and duration is None:
+            self.is_segment = False
+        else:
+            self.is_segment = True
+
+        self.frames = self.get_frames()
+
+    def setup(self, channels=None, samplerate=None, blocksize=None):
+
+        # the output data format we want
+        if blocksize:
+            self.output_blocksize = blocksize
+        if samplerate:
+            self.output_samplerate = int(samplerate)
+        if channels:
+            self.output_channels = int(channels)
+
+        if self.uri_duration is None:
+            self.uri_duration = (len(self.samples) / self.input_samplerate
+                                 - self.uri_start)
+
+        if self.is_segment:
+            start_index = self.uri_start * self.input_samplerate
+            stop_index = start_index + int(np.ceil(self.uri_duration
+                                           * self.input_samplerate))
+            stop_index = min(stop_index, len(self.samples))
+            self.samples = self.samples[start_index:stop_index]
+
+        if not self.output_samplerate:
+            self.output_samplerate = self.input_samplerate
+
+        if not self.output_channels:
+            self.output_channels = self.input_channels
+
+        self.input_totalframes = len(self.samples)
+        self.input_duration = self.input_totalframes / self.input_samplerate
+
+        self.input_width = self.samples.itemsize * 8
+
+    def get_frames(self):
+        "Define an iterator that will return frames at the given blocksize"
+        nb_frames = self.input_totalframes // self.output_blocksize
+
+        if self.input_totalframes % self.output_blocksize == 0:
+            nb_frames -= 1  # Last frame must send eod=True
+
+        for index in xrange(0,
+                            nb_frames * self.output_blocksize,
+                            self.output_blocksize):
+            yield (self.samples[index:index+self.output_blocksize], False)
+
+        yield (self.samples[nb_frames * self.output_blocksize:], True)
+
+    @interfacedoc
+    def process(self, frames=None, eod=False):
+
+        return self.frames.next()
+
+    @interfacedoc
+    def channels(self):
+        return self.output_channels
+
+    @interfacedoc
+    def samplerate(self):
+        return self.output_samplerate
+
+    @interfacedoc
+    def blocksize(self):
+        return self.output_blocksize
+
+    @interfacedoc
+    def totalframes(self):
+        if self.input_samplerate == self.output_samplerate:
+            return self.input_totalframes
+        else:
+            ratio = self.output_samplerate / self.input_samplerate
+            return int(self.input_totalframes * ratio)
+
+    @interfacedoc
+    def release(self):
+        pass
+
+    @interfacedoc
+    def mediainfo(self):
+        return dict(uri=self.uri,
+                    duration=self.uri_duration,
+                    start=self.uri_start,
+                    is_segment=self.is_segment,
+                    samplerate=self.input_samplerate)
+
+    def __del__(self):
+        self.release()
+
+    ## IDecoder methods
+    @interfacedoc
+    def format(self):
+        import re
+        base_type = re.search('^[a-z]*', self.samples.dtype.name).group(0)
+        return 'audio/x-raw-'+base_type
+
+    @interfacedoc
+    def encoding(self):
+        return self.format().split('/')[-1]
+
+    @interfacedoc
+    def resolution(self):
+        return self.input_width
+
+    @interfacedoc
+    def metadata(self):
+        return None