]> git.parisson.com Git - timeside.git/commitdiff
Analyzers: Code refactoring for Analyzer Result support in analyzer/core.py
authorThomas Fillon <thomas@parisson.com>
Mon, 2 Dec 2013 17:43:49 +0000 (18:43 +0100)
committerThomas Fillon <thomas@parisson.com>
Mon, 2 Dec 2013 17:46:52 +0000 (18:46 +0100)
doc/source/tutorial/AnalyzerResult.rst
tests/test_AnalyzerResult.py
timeside/analyzer/core.py

index 80d63ed97802dafa11e35c18dcaba0fd0491902a..c6af44e11ed8ed8d30ca297c840ffe80baebefdc 100644 (file)
@@ -5,9 +5,9 @@
  Analyzer Result examples
 ==========================
 
-Example of use of the Aanalyzer Result structure
+Example of use of the Analyzer Result structure
 
-Usage : analyzer_result_factory(data_mode=None, time_mode=None)
+Usage : AnalyzerResult.factory(data_mode=None, time_mode=None)
 
 Four different *time_mode* can be specified :
 
@@ -22,15 +22,15 @@ Two different *data_mode* can be specified :
 - 'label' : Data are returned as label indexes (specified by the label_metadata key)
 
 
-See : :func:`timeside.analyzer.core.analyzer_result_factory`, :class:`timeside.analyzer.core.AnalyzerResult`
+See : :func:`timeside.analyzer.core.AnalyzerResult`, :class:`timeside.analyzer.core.AnalyzerResult`
 
 Default
 =======
 
 Create a new analyzer result without arguments
 
-   >>> from timeside.analyzer.core import analyzer_result_factory
-   >>> res = analyzer_result_factory()
+   >>> from timeside.analyzer.core import AnalyzerResult
+   >>> res = AnalyzerResult()
 
 This default result has all the metadata and dataObject attribute
 
@@ -60,7 +60,7 @@ Four different time_mode can be specified :
 Framewise
 ---------
 
->>> res = analyzer_result_factory(time_mode='framewise')
+>>> res = AnalyzerResult(time_mode='framewise')
 >>> res.keys()
 ['id_metadata', 'data_object', 'audio_metadata', 'frame_metadata', 'parameters']
 
@@ -70,7 +70,7 @@ Global
 No frame metadata information is needed for these modes.
 The 'frame_metadata' key/attribute is deleted.
 
->>> res = analyzer_result_factory(time_mode='global')
+>>> res = AnalyzerResult(time_mode='global')
 >>> res.keys()
 ['id_metadata', 'data_object', 'audio_metadata', 'parameters']
 >>> res.data_object
@@ -79,7 +79,7 @@ DataObject(value=array([], dtype=float64))
 Segment
 -------
 
->>> res = analyzer_result_factory(time_mode='segment')
+>>> res = AnalyzerResult(time_mode='segment')
 >>> res.keys()
 ['id_metadata', 'data_object', 'audio_metadata', 'parameters']
 >>> res.data_object
@@ -88,7 +88,7 @@ DataObject(value=array([], dtype=float64), time=array([], dtype=float64), durati
 Event
 -----
 
->>> res = analyzer_result_factory(time_mode='event')
+>>> res = AnalyzerResult(time_mode='event')
 >>> res.keys()
 ['id_metadata', 'data_object', 'audio_metadata', 'parameters']
 >>> res.data_object
@@ -105,7 +105,7 @@ Value
 -----
 The label_metadata key is deleted.
 
->>> res = analyzer_result_factory(data_mode='value')
+>>> res = AnalyzerResult.factory(data_mode='value')
 >>> res.keys()
 ['id_metadata', 'data_object', 'audio_metadata', 'frame_metadata', 'parameters']
 
@@ -116,7 +116,7 @@ DataObject(value=array([], dtype=float64))
 
 Label
 -----
->>> res = analyzer_result_factory(data_mode='label')
+>>> res = AnalyzerResult.factory(data_mode='label')
 >>> res.keys()
 ['id_metadata', 'data_object', 'audio_metadata', 'frame_metadata', 'label_metadata', 'parameters']
 
index c718ba6bf107b32032ea960690ebe91673928da3..45cc2c7cc0d60ed733276f09f01cb54d6c2e18de 100755 (executable)
@@ -14,7 +14,7 @@ class TestAnalyzerResult(unittest.TestCase):
     """ test AnalyzerResult """
 
     def setUp(self):
-        self.result = analyzer_result_factory(data_mode='value', time_mode='framewise')
+        self.result = AnalyzerResult.factory(data_mode='value', time_mode='framewise')
 
         from datetime import datetime
         self.result.id_metadata = dict(date=datetime.now().replace(microsecond=0).isoformat(' '),
@@ -225,5 +225,15 @@ class TestAnalyzerResultJson(TestAnalyzerResult):
         #for i in range(len(d_json)):
         self.assertEqual(d_json, results)
 
+
+class TestAnalyzerResultAsDict(TestAnalyzerResult):
+    """ test AnalyzerResult as Dictionnary"""
+
+    def tearDown(self):
+
+        self.assertIsInstance(self.result.as_dict(), dict)
+        self.assertItemsEqual(self.result.keys() + ['data_mode', 'time_mode'],
+                              self.result.as_dict().keys())
+
 if __name__ == '__main__':
     unittest.main(testRunner=TestRunner())
\ No newline at end of file
index bb88eba76fcb802709d3b28564093d778af56ff3..c1bf83410d5eb3853c89c326db300cb43da19b5a 100644 (file)
@@ -521,8 +521,19 @@ class AnalyzerResult(MetadataObject):
         self.label_metadata = LabelMetadata()
         self.parameters = AnalyzerParameters()
 
-        self._data_mode = data_mode
-        self._time_mode = time_mode
+    @staticmethod
+    def factory(data_mode='value', time_mode='framewise'):
+        """
+        Factory function for Analyzer result
+        """
+        for result_cls in AnalyzerResult.__subclasses__():
+            if (hasattr(result_cls, '_time_mode') and
+                hasattr(result_cls, '_data_mode') and
+                (result_cls._data_mode, result_cls._time_mode) == (data_mode,
+                                                                   time_mode)):
+                return result_cls()
+        print data_mode, time_mode
+        raise ValueError('Wrong arguments')
 
     def __setattr__(self, name, value):
         if name in ['_data_mode', '_time_mode']:
@@ -575,8 +586,8 @@ class AnalyzerResult(MetadataObject):
 
         data_mode_child = root.find('data_mode')
         time_mode_child = root.find('time_mode')
-        result = analyzer_result_factory(data_mode=data_mode_child.text,
-                                         time_mode=time_mode_child.text)
+        result = AnalyzerResult.factory(data_mode=data_mode_child.text,
+                                        time_mode=time_mode_child.text)
         for child in root:
             key = child.tag
             if key not in ['data_mode', 'time_mode']:
@@ -596,10 +607,15 @@ class AnalyzerResult(MetadataObject):
             subgroup = group.create_group(key)
             self.__getattribute__(key).to_hdf5(subgroup)
 
-    def from_hdf5(self, h5group):
+    @staticmethod
+    def from_hdf5(h5group):
         # Read Sub-Group
+        result = AnalyzerResult.factory(
+                                data_mode=h5group.attrs['data_mode'],
+                                time_mode=h5group.attrs['time_mode'])
         for subgroup_name, h5subgroup in h5group.items():
-            self.__getattribute__(subgroup_name).from_hdf5(h5subgroup)
+            result[subgroup_name].from_hdf5(h5subgroup)
+        return result
 
     @property
     def data_mode(self):
@@ -634,7 +650,8 @@ class AnalyzerResult(MetadataObject):
         return self.id_metadata.unit
 
 
-class ValueObject(AnalyzerResult):
+class ValueObject(object):
+    _data_mode = 'value'
 
     def __init__(self):
         super(ValueObject, self).__init__()
@@ -656,7 +673,8 @@ class ValueObject(AnalyzerResult):
                     )
 
 
-class LabelObject(AnalyzerResult):
+class LabelObject(object):
+    _data_mode = 'label'
 
     def __init__(self):
         super(LabelObject, self).__init__()
@@ -667,7 +685,8 @@ class LabelObject(AnalyzerResult):
         return self.data_object.label
 
 
-class GlobalObject(AnalyzerResult):
+class GlobalObject(object):
+    _time_mode = 'global'
 
     def __init__(self):
         super(GlobalObject, self).__init__()
@@ -684,7 +703,8 @@ class GlobalObject(AnalyzerResult):
         return self.audio_metadata.duration
 
 
-class FramewiseObject(AnalyzerResult):
+class FramewiseObject(object):
+    _time_mode = 'framewise'
 
     def __init__(self):
         super(FramewiseObject, self).__init__()
@@ -704,7 +724,8 @@ class FramewiseObject(AnalyzerResult):
                 * numpy.ones(len(self)))
 
 
-class EventObject(AnalyzerResult):
+class EventObject(object):
+    _time_mode = 'event'
 
     def __init__(self):
         super(EventObject, self).__init__()
@@ -721,6 +742,7 @@ class EventObject(AnalyzerResult):
 
 
 class SegmentObject(EventObject):
+    _time_mode = 'segment'
 
     def __init__(self):
         super(EventObject, self).__init__()
@@ -731,86 +753,49 @@ class SegmentObject(EventObject):
         return self.data_object.duration
 
 
-class GlobalValueResult(ValueObject, GlobalObject):
+class GlobalValueResult(ValueObject, GlobalObject, AnalyzerResult):
     pass
 
 
-class GlobalLabelResult(LabelObject, GlobalObject):
+class GlobalLabelResult(LabelObject, GlobalObject, AnalyzerResult):
     pass
 
 
-class FrameValueResult(ValueObject, FramewiseObject):
+class FrameValueResult(ValueObject, FramewiseObject, AnalyzerResult):
     pass
 
 
-class FrameLabelResult(LabelObject, FramewiseObject):
+class FrameLabelResult(LabelObject, FramewiseObject, AnalyzerResult):
     pass
 
 
-class EventValueResult(ValueObject, EventObject):
+class EventValueResult(ValueObject, EventObject, AnalyzerResult):
     pass
 
 
-class EventLabelResult(LabelObject, EventObject):
+class EventLabelResult(LabelObject, EventObject, AnalyzerResult):
     pass
 
 
-class SegmentValueResult(ValueObject, SegmentObject):
+class SegmentValueResult(ValueObject, SegmentObject, AnalyzerResult):
     pass
 
 
-class SegmentLabelResult(LabelObject, SegmentObject):
+class SegmentLabelResult(LabelObject, SegmentObject, AnalyzerResult):
     pass
 
 
-def analyzer_result_factory(data_mode='value', time_mode='framewise'):
-    '''
-    Analyzer result Factory function
-    '''
-    if (data_mode, time_mode) == ('value', 'framewise'):
-        result = FrameValueResult()
-
-    elif (data_mode, time_mode) == ('label', 'framewise'):
-        result = FrameLabelResult()
-
-    elif (data_mode, time_mode) == ('value', 'global'):
-        result = GlobalValueResult()
-
-    elif (data_mode, time_mode) == ('label', 'global'):
-        result = GlobalLabelResult()
-
-    elif (data_mode, time_mode) == ('value', 'event'):
-        result = EventValueResult()
-
-    elif (data_mode, time_mode) == ('label', 'event'):
-        result = EventLabelResult()
-
-    elif (data_mode, time_mode) == ('value', 'segment'):
-        result = SegmentValueResult()
-
-    elif (data_mode, time_mode) == ('label', 'segment'):
-        result = SegmentLabelResult()
-
-    else:
-        raise ValueError('Wrong arguments')
-
-    result._time_mode = time_mode
-    result._data_mode = data_mode
-
-    return result
-
-
 class AnalyzerResultContainer(dict):
 
     '''
     >>> import timeside
-    >>> wavFile = 'tests/samples/sweep.wav'
+    >>> wavFile = 'http://github.com/yomguy/timeside-samples/raw/master/samples/sweep.mp3'
     >>> d = timeside.decoder.FileDecoder(wavFile, start=1)
 
     >>> a = timeside.analyzer.Analyzer()
     >>> (d|a).run() #doctest: +ELLIPSIS
     >>> a.new_result() #doctest: +ELLIPSIS
-    FrameValueResult(id_metadata=IdMetadata(id='analyzer', name='Generic analyzer', unit='', description='', date='...', version='...', author='TimeSide', uuid='...'), data_object=DataObject(value=array([], dtype=float64)), audio_metadata=AudioMetadata(uri='file:///...', start=1.0, duration=7.0, is_segment=True, channels=None, channelsManagement=''), frame_metadata=FrameMetadata(samplerate=44100, blocksize=8192, stepsize=8192), parameters={})
+    FrameValueResult(id_metadata=IdMetadata(id='analyzer', name='Generic analyzer', unit='', description='', date='...', version='...', author='TimeSide', uuid='...'), data_object=DataObject(value=array([], dtype=float64)), audio_metadata=AudioMetadata(uri='http://...', start=1.0, duration=7..., is_segment=True, channels=None, channelsManagement=''), frame_metadata=FrameMetadata(samplerate=44100, blocksize=8192, stepsize=8192), parameters={})
     >>> resContainer = timeside.analyzer.core.AnalyzerResultContainer()
 
     '''
@@ -845,7 +830,8 @@ class AnalyzerResultContainer(dict):
 
         return ET.tostring(root, encoding="utf-8", method="xml")
 
-    def from_xml(self, xml_string):
+    @staticmethod
+    def from_xml(xml_string):
         import xml.etree.ElementTree as ET
 
         results = AnalyzerResultContainer()
@@ -872,7 +858,8 @@ class AnalyzerResultContainer(dict):
         return json.dumps([res.as_dict() for res in self.values()],
                           default=NumpyArrayEncoder)
 
-    def from_json(self, json_str):
+    @staticmethod
+    def from_json(json_str):
         import simplejson as json
 
         # Define Specialize JSON decoder for numpy array
@@ -888,8 +875,8 @@ class AnalyzerResultContainer(dict):
         results = AnalyzerResultContainer()
         for res_json in results_json:
 
-            res = analyzer_result_factory(data_mode=res_json['data_mode'],
-                                        time_mode=res_json['time_mode'])
+            res = AnalyzerResult.factory(data_mode=res_json['data_mode'],
+                                         time_mode=res_json['time_mode'])
             for key in res_json.keys():
                 if key not in ['data_mode', 'time_mode']:
                     res[key] = res_json[key]
@@ -911,7 +898,8 @@ class AnalyzerResultContainer(dict):
 
         return yaml.dump([res.as_dict() for res in self.values()])
 
-    def from_yaml(self, yaml_str):
+    @staticmethod
+    def from_yaml(yaml_str):
         import yaml
 
         # Define Specialize Yaml encoder for numpy array
@@ -924,8 +912,8 @@ class AnalyzerResultContainer(dict):
         results_yaml = yaml.load(yaml_str)
         results = AnalyzerResultContainer()
         for res_yaml in results_yaml:
-            res = analyzer_result_factory(data_mode=res_yaml['data_mode'],
-                                        time_mode=res_yaml['time_mode'])
+            res = AnalyzerResult.factory(data_mode=res_yaml['data_mode'],
+                                         time_mode=res_yaml['time_mode'])
             for key in res_yaml.keys():
                 if key not in ['data_mode', 'time_mode']:
                     res[key] = res_yaml[key]
@@ -935,7 +923,8 @@ class AnalyzerResultContainer(dict):
     def to_numpy(self, output_file):
         numpy.save(output_file, self)
 
-    def from_numpy(self, input_file):
+    @staticmethod
+    def from_numpy(input_file):
         return numpy.load(input_file)
 
     def to_hdf5(self, output_file):
@@ -944,27 +933,24 @@ class AnalyzerResultContainer(dict):
             for res in self.values():
                 res.to_hdf5(h5_file)
 
-    def from_hdf5(self, input_file):
+    @staticmethod
+    def from_hdf5(input_file):
         import h5py
         # TODO : enable import for yaafe hdf5 format
 
         # Open HDF5 file for reading and get results
         h5_file = h5py.File(input_file, 'r')
-        data_list = AnalyzerResultContainer()
+        results = AnalyzerResultContainer()
         try:
             for group in h5_file.values():
-
-                result = analyzer_result_factory(data_mode=group.attrs['data_mode'],
-                                        time_mode=group.attrs['time_mode'])
-                result.from_hdf5(group)
-
-                data_list.add(result)
+                result = AnalyzerResult.from_hdf5(group)
+                results.add(result)
         except TypeError:
             print('TypeError for HDF5 serialization')
         finally:
             h5_file.close()  # Close the HDF5 file
 
-        return data_list
+        return results
 
 
 class Analyzer(Processor):
@@ -1024,8 +1010,8 @@ class Analyzer(Processor):
 
         from datetime import datetime
 
-        result = analyzer_result_factory(data_mode=data_mode,
-                                         time_mode=time_mode)
+        result = AnalyzerResult.factory(data_mode=data_mode,
+                                time_mode=time_mode)
 
         # Automatically write known metadata
         result.id_metadata.date = datetime.now().replace(