]> git.parisson.com Git - telemeta.git/commitdiff
#67: improve OAI-PMH validation
authorolivier <>
Thu, 2 Apr 2009 16:31:26 +0000 (16:31 +0000)
committerolivier <>
Thu, 2 Apr 2009 16:31:26 +0000 (16:31 +0000)
telemeta/interop/oai.py

index e43ecd22b3a3d92b19bea9b629453af38cfef735..56a0eb4741125d70ee26b16ae0931f84c23f845d 100644 (file)
@@ -16,6 +16,87 @@ class IDataSource(object):
            or None if the record doesn't exist"""
         pass
 
+class ArgumentValidator(object):
+    """OAI-PMH request argument validator"""
+
+    def __init__(self, request, response):
+        self.response = response
+        self.opt_args = []
+        self.required_args = ['verb']
+        self.request = request
+        self.format = None
+
+    def optional(self, *args):
+        """Add optional arguments"""
+        self.opt_args.extend(args)
+
+    def require(self, *args):
+        """Add required arguments"""
+        self.required_args.extend(args)
+
+    def accept_format(self, format):
+        """Indicate which metadata format is supported"""
+        self.format = format
+
+    def has_verb(self):
+        """Check if the request includes a valid Verb, return True if it does, False otherwise, 
+           setting an error into the response"""
+
+        valid = ['GetRecord', 'Identify', 'ListIdentifiers', 'ListMetadataFormats', 'ListRecords', 'ListSets']
+
+        result = False
+        if self.request.has_key('verb'):
+            try:
+                valid.index(self.request['verb'])
+                result = True
+            except ValueError:
+                pass
+
+        if not result:
+            self.response.error('badVerb')
+
+        return result
+
+    def validate(self):
+        """Perform validation, return True if successfull, False otherwise, setting appropriate
+           errors into the response"""
+        all_args    = []
+        all_args[:] = self.opt_args[:]
+        all_args.extend(self.required_args)
+        for k in self.request:
+            try:
+                all_args.index(k)
+            except ValueError:
+                self.response.error('badArgument', 'Invalid argument: %s' % k)
+                return False
+
+        return self.pre_validate()
+
+    def pre_validate(self):
+        """Same as validate(), but doesn't not check for unknown arguments"""
+
+        for k in self.required_args:
+            if not self.request.has_key(k):
+                self.response.error('badArgument', 'Missing required argument: %s' % k)
+                return False
+
+        for k in self.request:
+            if k == 'metadataPrefix':
+                if self.format:
+                    if self.format != self.request[k]:
+                        self.response.error('cannotDisseminateFormat')
+                        return False
+                else:
+                    raise Exception('Can\'t validate metadataPrefix argument: supported format isn\'t defined')
+            elif (k == 'from') or (k == 'until'):
+                try:
+                    datetime.strptime(self.request[k], '%Y-%m-%dT%H-%M-%SZ')
+                except ValueError:
+                    self.response.error('badArgument', "Invalid ISO8601 time format in '%s' argument" % k)
+                    return False
+
+        return True         
+
 class DataProvider(object):
     """OAI-PMH Data Provider"""
 
@@ -29,39 +110,28 @@ class DataProvider(object):
             'granularity':      'YYYY-MM-DDThh:mm:ssZ'
         }
 
-    def require_argument(self, response, args, required):
-        """Return True if the required argument is present in args, False otherwise, setting
-           an error into the response"""
-        if not args.has_key(required):
-            response.error("badArgument", msg="Missing required argument '%s'" % required)
-            return False
-        return True
-
-    def validate_format(self, response, args):
-        """Return True if the metadataPrefix argument is present in args and a supported format,
-           False otherwise, setting an error into the response"""
-        arg = args.get('metadataPrefix')
-        if not self.require_argument(response, args, 'metadataPrefix'):
-            return False
-        if arg != 'oai_dc':
-            response.error('cannotDisseminateFormat')
-            return False
-
-        return True            
+    def parse_time(self, str):
+        """Parse an ISO8601 date string into a datetime object"""
+        return datetime.strptime(str, '%Y-%m-%dT%H-%M-%SZ')
 
     def handle(self, args, datasource):
         """Handle a request and return the response as a DOM document"""
+
         response = Response(self.identity, datasource)
-        if self.require_argument(response, args, 'verb'):
-            verb = args.get('verb')
+
+        validator = ArgumentValidator(args, response)
+        validator.accept_format('oai_dc')
+
+        if validator.has_verb():
+
+            verb = args['verb']
             response.set_verb(verb)
+
             if verb == 'Identify':
-                response.identify()
+                validator.validate() and response.identify()
             elif verb == 'GetRecord':
-                if self.require_argument(response, args, 'identifier') and self.validate_format(response, args):
-                    response.get_record(args['identifier'])
-            else:
-                response.error('badVerb')
+                validator.require('identifier', 'metadataPrefix')
+                validator.validate() and response.get_record(args['identifier'])
 
         doc = libxml2.parseDoc(response.doc.toxml(encoding="utf-8"))
         response.free()
@@ -130,7 +200,6 @@ class Response(object):
             date_time = datetime.now()
         return date_time.strftime('%Y-%m-%dT%H-%M-%SZ')
 
-
     def error(self, code, msg = None):
         """Add error tag using code. If msg is not provided, use a default error message."""