From 5284150608695da7bf74146f5f9d0ed91b71655d Mon Sep 17 00:00:00 2001 From: Thomas Fillon Date: Mon, 26 May 2014 12:17:15 +0200 Subject: [PATCH] Add a validate method for parameters + unittest --- tests/test_tools_parameters.py | 46 ++++++++++++++++++++++++++++------ timeside/tools/parameters.py | 31 ++++++++++++++++++++--- 2 files changed, 66 insertions(+), 11 deletions(-) diff --git a/tests/test_tools_parameters.py b/tests/test_tools_parameters.py index 3b994ad..145de7b 100644 --- a/tests/test_tools_parameters.py +++ b/tests/test_tools_parameters.py @@ -5,6 +5,8 @@ from unit_timeside import unittest, TestRunner from timeside.tools.parameters import HasParam, HasTraits from timeside.tools.parameters import Unicode, Int, Float, Range +import simplejson as json + class TestHasParam(unittest.TestCase): @@ -17,24 +19,31 @@ class TestHasParam(unittest.TestCase): param3 = Float() param4 = Range(low=0, high=10, value=3) + self.param_dict = {"param1": "", "param2": 0, "param3": 0.0, + "param4": 3} self.has_param_cls = ParamClass() def test_get_parameters(self): "get_parameters method" param_json = self.has_param_cls.get_parameters() - self.assertEqual(param_json, - ('{"param4": 3, "param3": 0.0, ' - '"param2": 0, "param1": ""}')) + self.assertEqual(json.loads(param_json), + self.param_dict) def test_set_parameters(self): "set_parameters method" - new_param_json = ('{"param1": "plop", "param2": 7, ' - '"param3": 0.5, "param4": 8}') + new_param_dict = {"param1": "plop", "param2": 7, + "param3": 0.5, "param4": 8} + new_param_json = json.dumps(new_param_dict) + # Set from dict + self.has_param_cls.set_parameters(new_param_dict) + param_json = self.has_param_cls.get_parameters() + param_dict = json.loads(param_json) + self.assertEqual(param_dict, new_param_dict) + # set from JSON self.has_param_cls.set_parameters(new_param_json) param_json = self.has_param_cls.get_parameters() - self.assertEqual(param_json, - ('{"param4": 8, "param3": 0.5, ' - '"param2": 7, "param1": "plop"}')) + param_dict = json.loads(param_json) + self.assertEqual(param_dict, new_param_dict) def test_param_view(self): "param_view method" @@ -76,6 +85,27 @@ class TestHasParam(unittest.TestCase): self.assertRaises(AttributeError, _parameters.__getattribute__, name) self.assertNotIn(name, _parameters.trait_names()) + def test_validate_True(self): + "Validate parameters format against Traits specification : pass" + # Validate from dict + self.assertEqual(self.param_dict, + self.has_param_cls.validate_parameters(self.param_dict)) + # Validate from JSON + param_json = json.dumps(self.param_dict) + self.assertEqual(self.param_dict, + self.has_param_cls.validate_parameters(param_json)) + + def test_validate_False(self): + "Validate parameters format against Traits specification : reject" + bad_param = {"param1": "", "param2": 0, "param3": 0.0, + "param4": 3.3} # Param4 is a Float (it should be a int) + # Validate from dict + self.assertRaises(ValueError, self.has_param_cls.validate_parameters, bad_param) + # Validate from JSON + bad_param_json = json.dumps(bad_param) + self.assertRaises(ValueError, self.has_param_cls.validate_parameters, + bad_param_json) + if __name__ == '__main__': unittest.main(testRunner=TestRunner()) diff --git a/timeside/tools/parameters.py b/timeside/tools/parameters.py index ef8efcd..75f43ae 100644 --- a/timeside/tools/parameters.py +++ b/timeside/tools/parameters.py @@ -23,6 +23,8 @@ from traits.api import HasTraits, Unicode, Int, Float, Range +from traits.api import TraitError + import simplejson as json @@ -81,9 +83,32 @@ class HasParam(object): param_dict = self._parameters.get(list_traits) return json.dumps(param_dict) - def set_parameters(self, param_str): - param_dict = json.loads(param_str) - self._parameters.set(**param_dict) + def set_parameters(self, parameters): + if isinstance(parameters, basestring): + self.set_parameters(json.loads(parameters)) + else: + self._parameters.set(**parameters) + + def validate_parameters(self, parameters): + """Validate parameters format against Traits specification + Input can be either a dictionary or a JSON string + Returns the validated parameters or raises a ValueError""" + + if isinstance(parameters, basestring): + return self.validate_parameters(json.loads(parameters)) + # Check key against traits name + traits_name = self._parameters.editable_traits() + for name in parameters: + if name not in traits_name: + raise KeyError(name) + + try: + valid_params = {name: self._parameters.validate_trait(name, value) + for name, value in parameters.items()} + except TraitError as e: + raise ValueError(str(e)) + + return valid_params def param_view(self): list_traits = self._parameters.editable_traits() -- 2.39.5