]> git.parisson.com Git - django-social-auth.git/commitdiff
Save model instances into session in a easy format to retrieve it later. Refs #251
authorMatías Aguirre <matiasaguirre@gmail.com>
Wed, 15 Feb 2012 14:57:12 +0000 (12:57 -0200)
committerMatías Aguirre <matiasaguirre@gmail.com>
Wed, 15 Feb 2012 14:57:12 +0000 (12:57 -0200)
social_auth/backends/__init__.py
social_auth/utils.py

index af1922b0e3d9f7479fa46d9d2581308ba7537aeb..b4a59a4944a5b2a3d1a20ca43e84cc9ccd57f84f 100644 (file)
@@ -26,7 +26,7 @@ from django.contrib.auth.backends import ModelBackend
 from django.utils import simplejson
 from django.utils.importlib import import_module
 
-from social_auth.utils import setting, log
+from social_auth.utils import setting, log, model_to_ctype, ctype_to_model
 from social_auth.store import DjangoOpenIDStore
 from social_auth.backends.exceptions import StopPipeline
 
@@ -312,8 +312,9 @@ class BaseAuth(object):
         return {
             'next': next_idx,
             'backend': self.AUTH_BACKEND.name,
-            'args': args,
-            'kwargs': kwargs
+            'args': tuple(map(model_to_ctype, args)),
+            'kwargs': dict((key, model_to_ctype(val))
+                                for key, val in kwargs.iteritems())
         }
 
     def from_session_dict(self, entry, *args, **kwargs):
@@ -321,11 +322,12 @@ class BaseAuth(object):
         any new extra argument needed. Returns tuple with next pipeline
         index entry, arguments and keyword arguments to continue the
         process."""
-        session_kwargs = entry['kwargs']
-        session_kwargs.update(kwargs)
-        return ( entry['next'],
-                 list(entry['args']) + list(args),
-                 session_kwargs )
+        args = args[:] + tuple(map(ctype_to_model, entry['args']))
+
+        kwargs = kwargs.copy()
+        kwargs.update((key, ctype_to_model(val))
+                            for key, val in entry['kwargs'].iteritems())
+        return (entry['next'], args, kwargs)
 
     def continue_pipeline(self, *args, **kwargs):
         """Continue previous halted pipeline"""
index b5de003079cc8451d1ead326fae838add887e1ae..9f288ed89c96e36564bbc2e72e46d956524cd749 100644 (file)
@@ -3,6 +3,8 @@ import logging
 from collections import defaultdict
 
 from django.conf import settings
+from django.db.models import Model
+from django.contrib.contenttypes.models import ContentType
 
 
 def sanitize_log_data(secret, data=None, leave_characters=4):
@@ -102,6 +104,26 @@ def log(level, *args, **kwargs):
       'warn': logger.warn }[level](*args, **kwargs)
 
 
+def model_to_ctype(val):
+    """Converts values that are instance of Model to a dictionary
+    with enough information to retrieve the instance back later."""
+    if isinstance(val, Model):
+        val = {
+            'pk': val.pk,
+            'ctype': ContentType.objects.get_for_model(val).pk
+        }
+    return val
+
+
+def ctype_to_model(val):
+    """Converts back the instance saved by model_to_ctype function."""
+    if isinstance(val, dict) and 'pk' in val and 'ctype' in val:
+        ctype = ContentType.objects.get_for_id(val['ctype'])
+        ModelClass = ctype.model_class()
+        val = ModelClass.objects.get(pk=val['pk'])
+    return val
+
+
 if __name__ == '__main__':
     import doctest
     doctest.testmod()