]> git.parisson.com Git - timeside.git/commitdiff
Decoder: Code refactoring on core, put some functions in decoder.utils + improve...
authorThomas Fillon <thomas@parisson.com>
Fri, 22 Nov 2013 13:39:14 +0000 (14:39 +0100)
committerThomas Fillon <thomas@parisson.com>
Fri, 22 Nov 2013 13:39:14 +0000 (14:39 +0100)
tests/test_decoder_utils.py [new file with mode: 0644]
timeside/decoder/core.py
timeside/decoder/utils.py

diff --git a/tests/test_decoder_utils.py b/tests/test_decoder_utils.py
new file mode 100644 (file)
index 0000000..44a03d7
--- /dev/null
@@ -0,0 +1,103 @@
+#! /usr/bin/env python
+
+# author: Thomas Fillon <thomas@parisson.com>
+
+from __future__ import division
+
+from numpy import arange, sin
+from unit_timeside import *
+from timeside.decoder.utils import get_uri, get_media_uri_info, path2uri
+import os.path
+
+
+class TestGetUri(TestCase):
+    "Test get_uri function"
+    def testFileName(self):
+        "Retrieve the uri from a filename"
+        self.source = os.path.join(os.path.dirname(__file__),
+                                   "samples/sweep.wav")
+
+        self.uri = path2uri(os.path.abspath(self.source))
+
+    def testUri(self):
+        "Retrieve the uri from an uri"
+        self.uri = 'file://already/an/uri/file.wav'
+        self.source = self.uri
+
+    def tearDown(self):
+        self.assertEqual(self.uri, get_uri(self.source))
+
+
+class TestGetUriWrongUri(TestCase):
+    def testMissingFile(self):
+        "Missing file raise IOerror"
+        self.source = os.path.join(os.path.dirname(__file__),
+                                   "a_missing_file_blahblah.wav")
+    def testNotValidUri(self):
+        "Not valid uri raise IOerror"
+        self.source = os.path.join("://not/a/valid/uri/parisson.com")
+
+    def testNotSupportedUriProtocol(self):
+        "Not supported uri protocol raise IOerror"
+        self.source = os.path.join("mailto://john.doe@parisson.com")
+
+    def tearDown(self):
+        self.assertRaises(IOError, get_uri, self.source)
+
+
+class TestGetMediaInfo(TestCase):
+    "Test get_media_uri_info function on an uri"
+
+    def setUp(self):
+        self.test_exact_duration = True
+        self.source_duration = 8
+        self.expected_channels = 2
+        self.expected_samplerate = 44100
+        self.expected_depth = 16
+
+    def testWav(self):
+        "Test wav decoding"
+        self.source = os.path.join(os.path.dirname(__file__),
+                                   "samples/sweep.wav")
+
+
+    def testWavMono(self):
+        "Test mono wav decoding"
+        self.source = os.path.join(os.path.dirname(__file__),
+                                   "samples/sweep_mono.wav")
+
+        self.expected_channels = 1
+
+    def testWav32k(self):
+        "Test 32kHz wav decoding"
+        self.source = os.path.join(os.path.dirname(__file__),
+                                   "samples/sweep_32000.wav")
+        self.expected_samplerate = 32000
+
+    def testFlac(self):
+        "Test flac decoding"
+        self.source = os.path.join(os.path.dirname(__file__),
+                                   "samples/sweep.flac")
+
+    def testOgg(self):
+        "Test ogg decoding"
+        self.source = os.path.join(os.path.dirname(__file__),
+                                   "samples/sweep.ogg")
+        self.expected_depth = 0  # ?
+
+    def testMp3(self):
+        "Test mp3 decoding"
+        self.source = os.path.join(os.path.dirname(__file__),
+                                   "samples/sweep.mp3")
+        self.expected_depth = 32
+
+    def tearDown(self):
+        uri = get_uri(self.source)
+        uri_info = get_media_uri_info(uri)
+        self.assertEqual(self.source_duration, uri_info['duration'])
+        self.assertEqual(self.expected_channels, uri_info['streams'][0]['channels'])
+        self.assertEqual(self.expected_samplerate, uri_info['streams'][0]['samplerate'])
+        self.assertEqual(self.expected_depth, uri_info['streams'][0]['depth'])
+
+if __name__ == '__main__':
+    unittest.main(testRunner=TestRunner())
index ef61fc3783fb16ec5edb44cdebb40c3d1af2459f..82c9284449dc90b37612bd7a796347448bd69000 100644 (file)
@@ -32,6 +32,8 @@ from timeside.core import Processor, implements, interfacedoc
 from timeside.api import IDecoder
 from timeside.tools import *
 
+from utils import get_uri, get_media_uri_info
+
 import Queue
 from gst import _gst as gst
 import numpy as np
@@ -39,7 +41,6 @@ import numpy as np
 
 GST_APPSINK_MAX_BUFFERS = 10
 QUEUE_SIZE = 10
-GST_DISCOVER_TIMEOUT = 5000000000L
 
 
 class FileDecoder(Processor):
@@ -78,18 +79,7 @@ class FileDecoder(Processor):
 
         super(FileDecoder, self).__init__()
 
-        # is this a file?
-        import os.path
-        if os.path.exists(uri):
-            # get the absolute path
-            uri = os.path.abspath(uri)
-            # and make a uri of it
-            from urllib import quote
-            self.uri = 'file://' + quote(uri)
-        elif '://' in uri:
-            self.uri = uri
-        else:
-            raise IOError('File not found!')
+        self.uri = get_uri(uri)
 
         self.uri_start = float(start)
         if duration:
@@ -104,16 +94,8 @@ class FileDecoder(Processor):
 
     def set_uri_default_duration(self):
         # Set the duration from the length of the file
-        from gst.pbutils import Discoverer
-        from glib import GError
-        #import gobject
-        uri_discoverer = Discoverer(GST_DISCOVER_TIMEOUT)
-        try:
-            uri_info = uri_discoverer.discover_uri(self.uri)
-        except  GError as e:
-            raise IOError(e)
-        self.uri_duration = (uri_info.get_duration() / gst.SECOND
-                                - self.uri_start)
+        uri_total_duration = get_media_uri_info(self.uri)['duration']
+        self.uri_duration = uri_total_duration - self.uri_start
 
     def setup(self, channels=None, samplerate=None, blocksize=None):
 
@@ -254,6 +236,7 @@ class FileDecoder(Processor):
             if not self.output_channels:
                 self.output_channels = self.input_channels
             self.input_duration = length / 1.e9
+
             self.input_totalframes = int(self.input_duration * self.input_samplerate)
             if "x-raw-float" in caps.to_string():
                 self.input_width = caps[0]["width"]
@@ -454,7 +437,6 @@ class ArrayDecoder(Processor):
 
         self.input_totalframes = len(self.samples)
         self.input_duration = self.input_totalframes / self.input_samplerate
-
         self.input_width = self.samples.itemsize * 8
 
     def get_frames(self):
index b806e1ed25a945fab23eafa2d6c9336ffb149d0d..a3bba41602c5d72de6d340797af0d4689089f092 100644 (file)
@@ -1,3 +1,29 @@
+# -*- coding: utf-8 -*-
+
+# Copyright (c) 2007-2013 Parisson
+# Copyright (c) 2007-2013 Guillaume Pellerin <pellerin@parisson.com>
+# Copyright (c) 2010-2013 Paul Brossier <piem@piem.org>
+#
+# This file is part of TimeSide.
+
+# TimeSide is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+
+# TimeSide is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+
+# You should have received a copy of the GNU General Public License
+# along with TimeSide.  If not, see <http://www.gnu.org/licenses/>.
+
+# Authors:
+# Paul Brossier <piem@piem.org>
+# Guillaume Pellerin <yomguy@parisson.com>
+# Thomas Fillon <thomas@parisson.com>
+
 import numpy
 
 
@@ -35,3 +61,86 @@ class Noise(object):
         self.seekpoint += will_read
         return numpy.random.random(will_read)*2 - 1
 
+
+def path2uri(path):
+    """
+    Return a valid uri (file scheme) from absolute path name of a file
+
+    >>> path2uri('/home/john/my_file.wav')
+    'file:///home/john/my_file.wav'
+
+    >>> path2uri('C:\Windows\my_file.wav')
+    'file:///C%3A%5CWindows%5Cmy_file.wav'
+    """
+    import urlparse, urllib
+
+    return urlparse.urljoin('file:', urllib.pathname2url(path))
+
+
+def get_uri(source):
+    """
+    Check a media source as a valid file or uri and return the proper uri
+    """
+
+    import gst
+    # Is this an valid URI source
+    if gst.uri_is_valid(source):
+        uri_protocol = gst.uri_get_protocol(source)
+        if gst.uri_protocol_is_supported(gst.URI_SRC, uri_protocol):
+            return source
+        else:
+            raise IOError('Invalid URI source for Gstreamer')
+
+    # is this a file?
+    import os.path
+    if os.path.exists(source):
+        # get the absolute path
+        pathname = os.path.abspath(source)
+        # and make a uri of it
+        uri = path2uri(pathname)
+
+        return get_uri(uri)
+    else:
+        raise IOError('File not found!')
+
+    return uri
+
+def get_media_uri_info(uri):
+        from gst.pbutils import Discoverer
+        from gst import SECOND as GST_SECOND
+        from glib import GError
+        #import gobject
+        GST_DISCOVER_TIMEOUT = 5000000000L
+        uri_discoverer = Discoverer(GST_DISCOVER_TIMEOUT)
+        try:
+            uri_info = uri_discoverer.discover_uri(uri)
+        except  GError as e:
+            raise IOError(e)
+        info = dict()
+
+        # Duration in seconds
+        info['duration'] = uri_info.get_duration() / GST_SECOND
+
+        audio_streams = uri_info.get_audio_streams()
+        info['streams'] = []
+        for stream in audio_streams:
+            stream_info = {'bitrate': stream.get_bitrate (),
+                           'channels': stream.get_channels (),
+                           'depth': stream.get_depth (),
+                           'max_bitrate': stream.get_max_bitrate(),
+                           'samplerate': stream.get_sample_rate()
+                           }
+            info['streams'].append(stream_info)
+
+        return info
+
+
+
+if __name__ == "__main__":
+    # Run doctest from __main__ and unittest from tests
+    from tests.unit_timeside import runTestModule
+    # load corresponding tests
+    from tests import test_decoder_utils
+
+    runTestModule('__main__', test_decoder_utils)
+