]> git.parisson.com Git - django-social-auth.git/commitdiff
Improve and fix partial pipeline arguments management. Closes #251
authorMatías Aguirre <matiasaguirre@gmail.com>
Tue, 14 Feb 2012 23:59:42 +0000 (21:59 -0200)
committerMatías Aguirre <matiasaguirre@gmail.com>
Tue, 14 Feb 2012 23:59:42 +0000 (21:59 -0200)
social_auth/backends/__init__.py
social_auth/backends/pipeline/misc.py
social_auth/views.py

index b52869cccfc0a07ebae66d8929abcc3a8b24db10..af1922b0e3d9f7479fa46d9d2581308ba7537aeb 100644 (file)
@@ -95,22 +95,18 @@ class SocialAuthBackend(ModelBackend):
             return None
 
         response = kwargs.get('response')
+        pipeline = PIPELINE
+        kwargs = kwargs.copy()
+        kwargs['backend'] = self
 
         if 'pipeline_index' in kwargs:
-            details = kwargs.pop('details')
-            uid = kwargs.pop('uid')
-            is_new = kwargs.pop('is_new')
-            pipeline = PIPELINE[kwargs['pipeline_index']:]
+            pipeline = pipeline[kwargs['pipeline_index']:]
         else:
-            details = self.get_user_details(response)
-            uid = self.get_user_id(details, response)
-            is_new = False
-            pipeline = PIPELINE
-
-        out = self.pipeline(pipeline, backend=self, uid=uid,
-                            details=details, is_new=is_new,
-                            *args, **kwargs)
+            kwargs['details'] = self.get_user_details(response)
+            kwargs['uid'] = self.get_user_id(kwargs['details'], response)
+            kwargs['is_new'] = False
 
+        out = self.pipeline(pipeline, *args, **kwargs)
         if not isinstance(out, dict):
             return out
 
@@ -311,9 +307,32 @@ class BaseAuth(object):
         """Completes loging process, must return user instance"""
         raise NotImplementedError('Implement in subclass')
 
+    def to_session_dict(self, next_idx, *args, **kwargs):
+        """Returns dict to store on session for partial pipeline."""
+        return {
+            'next': next_idx,
+            'backend': self.AUTH_BACKEND.name,
+            'args': args,
+            'kwargs': kwargs
+        }
+
+    def from_session_dict(self, entry, *args, **kwargs):
+        """Takes session saved entry to continue pipeline and merges with
+        any new extra argument needed. Returns tuple with next pipeline
+        index entry, arguments and keyword arguments to continue the
+        process."""
+        session_kwargs = entry['kwargs']
+        session_kwargs.update(kwargs)
+        return ( entry['next'],
+                 list(entry['args']) + list(args),
+                 session_kwargs )
+
     def continue_pipeline(self, *args, **kwargs):
-        """Continue previos halted pipeline"""
-        kwargs.update({ self.AUTH_BACKEND.name: True })
+        """Continue previous halted pipeline"""
+        kwargs.update({
+            'auth': self,
+            self.AUTH_BACKEND.name: True
+        })
         return authenticate(*args, **kwargs)
 
     def request_token_extra_arguments(self):
@@ -377,6 +396,17 @@ class OpenIdAuth(BaseAuth):
         return setting('OPENID_TRUST_ROOT') or \
                self.request.build_absolute_uri('/')
 
+    def continue_pipeline(self, *args, **kwargs):
+        """Continue previous halted pipeline"""
+        response = self.consumer().complete(dict(self.data.items()),
+                                            self.request.build_absolute_uri())
+        kwargs.update({
+            'auth': self,
+            'response': response,
+            self.AUTH_BACKEND.name: True
+        })
+        return authenticate(*args, **kwargs)
+
     def auth_complete(self, *args, **kwargs):
         """Complete auth process"""
         response = self.consumer().complete(dict(self.data.items()),
@@ -384,7 +414,11 @@ class OpenIdAuth(BaseAuth):
         if not response:
             raise ValueError('This is an OpenID relying party endpoint')
         elif response.status == SUCCESS:
-            kwargs.update({'response': response, self.AUTH_BACKEND.name: True})
+            kwargs.update({
+                'auth': self,
+                'response': response,
+                self.AUTH_BACKEND.name: True
+            })
             return authenticate(*args, **kwargs)
         elif response.status == FAILURE:
             raise ValueError('OpenID authentication failed: %s' % \
@@ -492,7 +526,11 @@ class ConsumerBasedOAuth(BaseOAuth):
         if data is not None:
             data['access_token'] = access_token.to_string()
 
-        kwargs.update({'response': data, self.AUTH_BACKEND.name: True})
+        kwargs.update({
+            'auth': self,
+            'response': data,
+            self.AUTH_BACKEND.name: True
+        })
         return authenticate(*args, **kwargs)
 
     def unauthorized_token(self):
@@ -612,7 +650,11 @@ class BaseOAuth2(BaseOAuth):
             raise ValueError('OAuth2 authentication failed: %s' % error)
         else:
             response.update(self.user_data(response['access_token']) or {})
-            kwargs.update({'response': response, self.AUTH_BACKEND.name: True})
+            kwargs.update({
+                'auth': self,
+                'response': response,
+                self.AUTH_BACKEND.name: True
+            })
             return authenticate(*args, **kwargs)
 
     def get_scope(self):
index 3b30debf93672b249fc1d6526149cf73f2bed2d0..22b0e19740a866948a0b5cb53dd91f213ad74d85 100644 (file)
@@ -5,8 +5,7 @@ from social_auth.utils import setting
 PIPELINE_ENTRY = 'social_auth.backends.pipeline.misc.save_status_to_session'
 
 
-def save_status_to_session(request, backend, details, response, uid,
-                           *args, **kwargs):
+def save_status_to_session(request, auth, *args, **kwargs):
     """Saves current social-auth status to session."""
     next_entry = setting('SOCIAL_AUTH_PIPELINE_RESUME_ENTRY')
 
@@ -18,13 +17,8 @@ def save_status_to_session(request, backend, details, response, uid,
     except ValueError:
         idx = None
 
+    data = auth.to_session_dict(idx, *args, **kwargs)
+
     name = setting('SOCIAL_AUTH_PARTIAL_PIPELINE_KEY', 'partial_pipeline')
-    request.session[name] = {
-        'backend': backend.name,
-        'uid': uid,
-        'details': details,
-        'response': response,
-        'is_new': kwargs.get('is_new', True),
-        'next_index': idx
-    }
+    request.session[name] = data
     request.session.modified = True
index d3e273467a0777ef4d3ffccaf99f5c2ef45569a5..45140ee95621e928252eff691c26600c40ac6a47 100644 (file)
@@ -192,13 +192,9 @@ def auth_complete(request, backend, user=None, *args, **kwargs):
     if request.session.get(name):
         data = request.session.pop(name)
         request.session.modified = True
-        return backend.continue_pipeline(pipeline_index=data['next_index'],
-                                         user=user,
-                                         request=request,
-                                         uid=data['uid'],
-                                         details=data['details'],
-                                         is_new=data['is_new'],
-                                         response=data['response'],
-                                         *args, **kwargs)
+        idx, args, kwargs = backend.from_session_dict(data, user=user,
+                                                      request=request,
+                                                      *args, **kwargs)
+        return backend.continue_pipeline(pipeline_index=idx, *args, **kwargs)
     else:
         return backend.auth_complete(user=user, request=request, *args, **kwargs)