]> git.parisson.com Git - timeside.git/commitdiff
Add a validate method for parameters + unittest
authorThomas Fillon <thomas@parisson.com>
Mon, 26 May 2014 10:17:15 +0000 (12:17 +0200)
committerThomas Fillon <thomas@parisson.com>
Mon, 26 May 2014 10:17:15 +0000 (12:17 +0200)
tests/test_tools_parameters.py
timeside/tools/parameters.py

index 3b994ad92e39796236066454f27b900c224c4bf5..145de7bcdc17610394d604c7d20338a237217367 100644 (file)
@@ -5,6 +5,8 @@ from unit_timeside import unittest, TestRunner
 from timeside.tools.parameters import HasParam, HasTraits
 from timeside.tools.parameters import Unicode, Int, Float, Range
 
+import simplejson as json
+
 
 class TestHasParam(unittest.TestCase):
 
@@ -17,24 +19,31 @@ class TestHasParam(unittest.TestCase):
                 param3 = Float()
                 param4 = Range(low=0, high=10, value=3)
 
+        self.param_dict = {"param1": "", "param2": 0, "param3": 0.0,
+                           "param4": 3}
         self.has_param_cls = ParamClass()
 
     def test_get_parameters(self):
         "get_parameters method"
         param_json = self.has_param_cls.get_parameters()
-        self.assertEqual(param_json,
-                         ('{"param4": 3, "param3": 0.0, '
-                          '"param2": 0, "param1": ""}'))
+        self.assertEqual(json.loads(param_json),
+                         self.param_dict)
 
     def test_set_parameters(self):
         "set_parameters method"
-        new_param_json = ('{"param1": "plop", "param2": 7, '
-                          '"param3": 0.5, "param4": 8}')
+        new_param_dict = {"param1": "plop", "param2": 7,
+                          "param3": 0.5, "param4": 8}
+        new_param_json = json.dumps(new_param_dict)
+        # Set from dict
+        self.has_param_cls.set_parameters(new_param_dict)
+        param_json = self.has_param_cls.get_parameters()
+        param_dict = json.loads(param_json)
+        self.assertEqual(param_dict, new_param_dict)
+        # set from JSON
         self.has_param_cls.set_parameters(new_param_json)
         param_json = self.has_param_cls.get_parameters()
-        self.assertEqual(param_json,
-                        ('{"param4": 8, "param3": 0.5, '
-                         '"param2": 7, "param1": "plop"}'))
+        param_dict = json.loads(param_json)
+        self.assertEqual(param_dict, new_param_dict)
 
     def test_param_view(self):
         "param_view method"
@@ -76,6 +85,27 @@ class TestHasParam(unittest.TestCase):
         self.assertRaises(AttributeError, _parameters.__getattribute__, name)
         self.assertNotIn(name, _parameters.trait_names())
 
+    def test_validate_True(self):
+        "Validate parameters format against Traits specification : pass"
+        # Validate from dict
+        self.assertEqual(self.param_dict,
+                         self.has_param_cls.validate_parameters(self.param_dict))
+        # Validate from JSON
+        param_json = json.dumps(self.param_dict)
+        self.assertEqual(self.param_dict,
+                         self.has_param_cls.validate_parameters(param_json))
+
+    def test_validate_False(self):
+        "Validate parameters format against Traits specification : reject"
+        bad_param = {"param1": "", "param2": 0, "param3": 0.0,
+                     "param4": 3.3}  # Param4 is a Float (it should be a int)
+        # Validate from dict
+        self.assertRaises(ValueError, self.has_param_cls.validate_parameters, bad_param)
+        # Validate from JSON
+        bad_param_json = json.dumps(bad_param)
+        self.assertRaises(ValueError, self.has_param_cls.validate_parameters,
+                          bad_param_json)
+
 
 if __name__ == '__main__':
     unittest.main(testRunner=TestRunner())
index ef8efcd5bf50f51fc36adc9af56af0a67476cef0..75f43ae5f210b853d680c1899948276cc4711071 100644 (file)
@@ -23,6 +23,8 @@
 
 
 from traits.api import HasTraits, Unicode, Int, Float, Range
+from traits.api import TraitError
+
 import simplejson as json
 
 
@@ -81,9 +83,32 @@ class HasParam(object):
         param_dict = self._parameters.get(list_traits)
         return json.dumps(param_dict)
 
-    def set_parameters(self, param_str):
-        param_dict = json.loads(param_str)
-        self._parameters.set(**param_dict)
+    def set_parameters(self, parameters):
+        if isinstance(parameters, basestring):
+            self.set_parameters(json.loads(parameters))
+        else:
+            self._parameters.set(**parameters)
+
+    def validate_parameters(self, parameters):
+        """Validate parameters format against Traits specification
+        Input can be either a dictionary or a JSON string
+        Returns the validated parameters or raises a ValueError"""
+
+        if isinstance(parameters, basestring):
+            return self.validate_parameters(json.loads(parameters))
+        # Check key against traits name
+        traits_name = self._parameters.editable_traits()
+        for name in parameters:
+            if name not in traits_name:
+                raise KeyError(name)
+
+        try:
+            valid_params = {name: self._parameters.validate_trait(name, value)
+                            for name, value in parameters.items()}
+        except TraitError as e:
+            raise ValueError(str(e))
+
+        return valid_params
 
     def param_view(self):
         list_traits = self._parameters.editable_traits()