]> git.parisson.com Git - timeside.git/commitdiff
Fix issue #63
authorThomas Fillon <thomas@parisson.com>
Mon, 15 Dec 2014 14:40:15 +0000 (15:40 +0100)
committerThomas Fillon <thomas@parisson.com>
Mon, 15 Dec 2014 14:41:13 +0000 (15:41 +0100)
timeside/analyzer/core.py

index 01595fda350a2efb228452bbe6f6495e0b1e1cb8..8e9483d28aad3a43604d1817e44030e63c02c290 100644 (file)
@@ -661,7 +661,7 @@ class AnalyzerResult(MetadataObject):
             result[subgroup_name].from_hdf5(h5subgroup)
         return result
 
-    def _render_plot(self, ax):
+    def _render_plot(self, ax, size=(1024,256)):
         return NotImplemented
 
     def render(self):
@@ -669,10 +669,6 @@ class AnalyzerResult(MetadataObject):
 
            Return the figure, use fig.show() to display if neeeded
         '''
-        # TODO : this may crash if the data array is too large
-        # possible workaround downsampled the data
-        #  and plot center, min, max values
-        # see http://stackoverflow.com/a/8881973
 
         fig, ax = plt.subplots()
         self.data_object._render_plot(ax)
@@ -689,7 +685,7 @@ class AnalyzerResult(MetadataObject):
 
         ax = fig.add_axes([0, 0, 1, 1], frame_on=False)
 
-        self.data_object._render_plot(ax)
+        self.data_object._render_plot(ax, size)
 
         ax.autoscale(axis='x', tight=True)
 
@@ -804,7 +800,7 @@ class EventObject(DataObject):
     def duration(self):
         return np.zeros(len(self.data))
 
-    def _render_plot(self, ax):
+    def _render_plot(self, ax, size=(1024,256)):
         ax.stem(self.time, self.data)
 
 
@@ -830,9 +826,39 @@ class FrameValueObject(ValueObject, FramewiseObject):
                                   ('y_value', None),
                                   ('frame_metadata', None)])
 
-    def _render_plot(self, ax):
+    def _render_plot(self, ax, size=(1024, 256)):
         if not self.y_value.size:
-            ax.plot(self.time, self.data)
+            # This was crashing if the data array is too large
+            # workaround consists in downsampling the data
+            #  and plot center, min, max values
+            # see http://stackoverflow.com/a/8881973
+            #  TODO: mean may not be appropriate for waveform ... (mean~=0)
+            nb_frames = self.data.shape[0]
+            chunksize = size[0]
+
+            numchunks = nb_frames // chunksize
+
+            if self.data.ndim <= 1:
+                ychunks = self.data[:chunksize*numchunks].reshape((-1,
+                                                                   chunksize))
+            else:
+                # Take only first channel
+                ychunks = self.data[:chunksize*numchunks, 0].reshape((-1, chunksize))
+
+            xchunks = self.time[:chunksize*numchunks].reshape((-1, chunksize))
+
+            # Calculate the max, min, and means of chunksize-element chunks...
+            max_env = ychunks.max(axis=1)
+            min_env = ychunks.min(axis=1)
+            ycenters = ychunks.mean(axis=1)
+            xcenters = xchunks.mean(axis=1)
+
+            # Now plot the bounds and the mean...
+            ax.fill_between(xcenters, min_env, max_env, color='gray',
+                            edgecolor='none', alpha=0.5)
+            ax.plot(xcenters, ycenters)
+
+            #ax.plot(self.time, self.data)
         else:
             ax.imshow(20 * np.log10(self.data.T),
                       origin='lower',
@@ -847,7 +873,7 @@ class FrameLabelObject(LabelObject, FramewiseObject):
                                   ('label_metadata', None),
                                   ('frame_metadata', None)])
 
-    def _render_plot(self, ax):
+    def _render_plot(self, ax, size=(1024,256)):
         pass
 
 
@@ -872,7 +898,7 @@ class SegmentValueObject(ValueObject, SegmentObject):
                                   ('time', None),
                                   ('duration', None)])
 
-    def _render_plot(self, ax):
+    def _render_plot(self, ax, size=(1024,256)):
         for time, value in (self.time, self.data):
             ax.axvline(time, ymin=0, ymax=value, color='r')
             # TODO : check value shape !!!
@@ -885,7 +911,7 @@ class SegmentLabelObject(LabelObject, SegmentObject):
                                   ('time', None),
                                   ('duration', None)])
 
-    def _render_plot(self, ax):
+    def _render_plot(self, ax, size=(1024,256)):
         import itertools
         colors = itertools.cycle(['b', 'g', 'r', 'c', 'm', 'y', 'k'])
         ax_color = {}