]> git.parisson.com Git - django-social-auth.git/commitdiff
Ensure that redirect_uri has state argument on other backends. Refs #386
authorMatías Aguirre <matiasaguirre@gmail.com>
Wed, 4 Jul 2012 19:03:16 +0000 (16:03 -0300)
committerMatías Aguirre <matiasaguirre@gmail.com>
Wed, 4 Jul 2012 19:03:16 +0000 (16:03 -0300)
social_auth/backends/__init__.py
social_auth/backends/contrib/github.py
social_auth/backends/contrib/instagram.py
social_auth/backends/contrib/yandex.py
social_auth/backends/facebook.py
social_auth/utils.py

index 93117c27fb55ca550b796c20cd62101b59265027..2ff137e7e8edb2c5320a3f38c41f78bfe72ed23b 100644 (file)
@@ -11,7 +11,6 @@ enabled.
 """
 from urllib2 import Request, urlopen, HTTPError
 from urllib import urlencode
-from urlparse import urlsplit
 
 from openid.consumer.consumer import Consumer, SUCCESS, CANCEL, FAILURE
 from openid.consumer.discover import DiscoveryFailure
@@ -28,7 +27,7 @@ from django.utils.crypto import constant_time_compare, get_random_string
 from django.middleware.csrf import CSRF_KEY_LENGTH
 
 from social_auth.utils import setting, log, model_to_ctype, ctype_to_model, \
-                              clean_partial_pipeline
+                              clean_partial_pipeline, url_add_parameters
 from social_auth.store import DjangoOpenIDStore
 from social_auth.backends.exceptions import StopPipeline, AuthException, \
                                             AuthFailed, AuthCanceled, \
@@ -514,13 +513,9 @@ class OpenIdAuth(BaseAuth):
 
     def openid_request(self, extra_params=None):
         """Return openid request"""
-        openid_url = self.openid_url()
-        if extra_params:
-            query = urlsplit(openid_url).query
-            openid_url += (query and '&' or '?') + urlencode(extra_params)
-
         try:
-            return self.consumer().begin(openid_url)
+            return self.consumer().begin(url_add_parameters(self.openid_url(),
+                                                            extra_params))
         except DiscoveryFailure, err:
             raise AuthException(self, 'OpenID discovery error: %s' % err)
 
@@ -659,8 +654,6 @@ class BaseOAuth2(BaseOAuth):
     Attributes:
         AUTHORIZATION_URL       Authorization service url
         ACCESS_TOKEN_URL        Token URL
-        FORCE_STATE_CHECK       Ensure state argument check (check issue #386
-                                for further details)
     """
     AUTHORIZATION_URL = None
     ACCESS_TOKEN_URL = None
@@ -668,21 +661,29 @@ class BaseOAuth2(BaseOAuth):
     RESPONSE_TYPE = 'code'
     SCOPE_VAR_NAME = None
     DEFAULT_SCOPE = None
-    FORCE_STATE_CHECK = True
 
-    def csrf_token(self):
+    def state_token(self):
         """Generate csrf token to include as state parameter."""
         return get_random_string(CSRF_KEY_LENGTH)
 
+    def get_redirect_uri(self, state):
+        """Build redirect_uri with redirect_state parameter."""
+        return url_add_parameters(self.redirect_uri, {'redirect_state': state})
+
     def auth_url(self):
         """Return redirect url"""
         client_id, client_secret = self.get_key_and_secret()
-        args = {'client_id': client_id, 'redirect_uri': self.redirect_uri}
-
-        if self.FORCE_STATE_CHECK:
-            state = self.csrf_token()
-            args['state'] = state
-            self.request.session[self.AUTH_BACKEND.name + '_state'] = state
+        state = self.state_token()
+        # Store state in session for further request validation. The state
+        # value is passed as state parameter (as specified in OAuth2 spec), but
+        # also added to redirect_uri, that way we can still verify the request
+        # if the provider doesn't implement the state parameter.
+        self.request.session[self.AUTH_BACKEND.name + '_state'] = state
+        args = {
+            'client_id': client_id,
+            'state': state,
+            'redirect_uri': self.get_redirect_uri(state)
+        }
 
         scope = self.get_scope()
         if scope:
@@ -693,28 +694,35 @@ class BaseOAuth2(BaseOAuth):
         args.update(self.auth_extra_arguments())
         return self.AUTHORIZATION_URL + '?' + urlencode(args)
 
+    def validate_state(self):
+        """Validate state value. Raises exception on error, returns state
+        value if valid."""
+        state = self.request.session.get(self.AUTH_BACKEND.name + '_state')
+        request_state = self.data.get('state') or \
+                        self.data.get('redirect_state')
+        if not request_state:
+            raise AuthMissingParameter(self, 'state')
+        elif not state:
+            raise AuthStateMissing(self, 'state')
+        elif not constant_time_compare(request_state, state):
+            raise AuthStateForbidden(self)
+        return state
+
     def auth_complete(self, *args, **kwargs):
         """Completes loging process, must return user instance"""
         if self.data.get('error'):
             error = self.data.get('error_description') or self.data['error']
             raise AuthFailed(self, error)
 
-        if self.FORCE_STATE_CHECK:
-            request_state = self.data.get('state')
-            state = self.request.session.get(self.AUTH_BACKEND.name + '_state')
-            if not request_state:
-                raise AuthMissingParameter(self, 'state')
-            elif not state:
-                raise AuthStateMissing(self, 'state')
-            elif not constant_time_compare(request_state, state):
-                raise AuthStateForbidden(self)
-
+        state = self.validate_state()
         client_id, client_secret = self.get_key_and_secret()
-        params = {'grant_type': 'authorization_code',  # request auth code
-                  'code': self.data.get('code', ''),  # server response code
-                  'client_id': client_id,
-                  'client_secret': client_secret,
-                  'redirect_uri': self.redirect_uri}
+        params = {
+            'grant_type': 'authorization_code',  # request auth code
+            'code': self.data.get('code', ''),  # server response code
+            'client_id': client_id,
+            'client_secret': client_secret,
+            'redirect_uri': self.get_redirect_uri(state)
+        }
         headers = {'Content-Type': 'application/x-www-form-urlencoded',
                     'Accept': 'application/json'}
         request = Request(self.ACCESS_TOKEN_URL, data=urlencode(params),
index af9371a89514d93cb6b3d9cb46111f6d8f8acf48..0580b2f3217ba0d7f618846830430d6de3da3ffc 100644 (file)
@@ -53,8 +53,6 @@ class GithubAuth(BaseOAuth2):
     SCOPE_SEPARATOR = ','
     # Look at http://developer.github.com/v3/oauth/
     SCOPE_VAR_NAME = 'GITHUB_EXTENDED_PERMISSIONS'
-    # Github doesn't return the state paramenter if specified :(
-    FORCE_STATE_CHECK = False
 
     def user_data(self, access_token, *args, **kwargs):
         """Loads user data from service"""
index 3cf1689a946edb830a56c6d804ac165f398e543f..faa2e82cb3264f59128724cce2168cda0ce53f56 100644 (file)
@@ -37,7 +37,6 @@ class InstagramAuth(BaseOAuth2):
     AUTH_BACKEND = InstagramBackend
     SETTINGS_KEY_NAME = 'INSTAGRAM_CLIENT_ID'
     SETTINGS_SECRET_NAME = 'INSTAGRAM_CLIENT_SECRET'
-    FORCE_STATE_CHECK = False
 
     def user_data(self, access_token, *args, **kwargs):
         """Loads user data from service"""
index b13f99252fcbe95a7e3fddb1d4dd280a4998ac2d..ea6e4578282b6edfc8a119a8bf993d33822b2022 100644 (file)
@@ -12,8 +12,8 @@ from urllib import urlencode
 from urllib2 import urlopen
 from urlparse import urlparse, urlsplit
 
-from social_auth.backends import OpenIDBackend, OpenIdAuth, USERNAME,\
-    OAuthBackend, BaseOAuth2
+from social_auth.backends import OpenIDBackend, OpenIdAuth, USERNAME, \
+                                 OAuthBackend, BaseOAuth2
 
 from social_auth.utils import setting, log
 
index ba9a84438a096a092373316cdb9f1bdebc2fb5ce..303eda41638a0a9558c8b5b565c2aba3e92c2e6c 100644 (file)
@@ -91,9 +91,10 @@ class FacebookAuth(BaseOAuth2):
         expires = None
 
         if 'code' in self.data:
+            state = self.validate_state()
             url = ACCESS_TOKEN + urlencode({
                 'client_id': setting('FACEBOOK_APP_ID'),
-                'redirect_uri': self.redirect_uri,
+                'redirect_uri': self.get_redirect_uri(state),
                 'client_secret': setting('FACEBOOK_API_SECRET'),
                 'code': self.data['code']
             })
index 8c383253a498d66e84cff445ae6772ad2b120b20..9c4b37d7244199ef276c7510737ec3f930aebed8 100644 (file)
@@ -1,4 +1,5 @@
 import urlparse
+import urllib
 import logging
 from collections import defaultdict
 from datetime import timedelta, tzinfo
@@ -174,6 +175,16 @@ def log_exceptions_to_messages(request, backend, err):
         error(request, unicode(err), extra_tags='social-auth %s' % name)
 
 
+def url_add_parameters(url, params):
+    """Adds parameters to URL, parameter will be repeated if already present"""
+    if params:
+        fragments = list(urlparse.urlparse(url))
+        fragments[4] = urllib.urlencode(urlparse.parse_qsl(fragments[4]) +
+                                        params.items())
+        url = urlparse.urlunparse(fragments)
+    return url
+
+
 if __name__ == '__main__':
     import doctest
     doctest.testmod()