From: Matías Aguirre Date: Wed, 4 Jul 2012 19:03:16 +0000 (-0300) Subject: Ensure that redirect_uri has state argument on other backends. Refs #386 X-Git-Url: https://git.parisson.com/?a=commitdiff_plain;h=2977bc35989256fc6bd6f4368c050be9a2ebb0ac;p=django-social-auth.git Ensure that redirect_uri has state argument on other backends. Refs #386 --- diff --git a/social_auth/backends/__init__.py b/social_auth/backends/__init__.py index 93117c2..2ff137e 100644 --- a/social_auth/backends/__init__.py +++ b/social_auth/backends/__init__.py @@ -11,7 +11,6 @@ enabled. """ from urllib2 import Request, urlopen, HTTPError from urllib import urlencode -from urlparse import urlsplit from openid.consumer.consumer import Consumer, SUCCESS, CANCEL, FAILURE from openid.consumer.discover import DiscoveryFailure @@ -28,7 +27,7 @@ from django.utils.crypto import constant_time_compare, get_random_string from django.middleware.csrf import CSRF_KEY_LENGTH from social_auth.utils import setting, log, model_to_ctype, ctype_to_model, \ - clean_partial_pipeline + clean_partial_pipeline, url_add_parameters from social_auth.store import DjangoOpenIDStore from social_auth.backends.exceptions import StopPipeline, AuthException, \ AuthFailed, AuthCanceled, \ @@ -514,13 +513,9 @@ class OpenIdAuth(BaseAuth): def openid_request(self, extra_params=None): """Return openid request""" - openid_url = self.openid_url() - if extra_params: - query = urlsplit(openid_url).query - openid_url += (query and '&' or '?') + urlencode(extra_params) - try: - return self.consumer().begin(openid_url) + return self.consumer().begin(url_add_parameters(self.openid_url(), + extra_params)) except DiscoveryFailure, err: raise AuthException(self, 'OpenID discovery error: %s' % err) @@ -659,8 +654,6 @@ class BaseOAuth2(BaseOAuth): Attributes: AUTHORIZATION_URL Authorization service url ACCESS_TOKEN_URL Token URL - FORCE_STATE_CHECK Ensure state argument check (check issue #386 - for further details) """ AUTHORIZATION_URL = None ACCESS_TOKEN_URL = None @@ -668,21 +661,29 @@ class BaseOAuth2(BaseOAuth): RESPONSE_TYPE = 'code' SCOPE_VAR_NAME = None DEFAULT_SCOPE = None - FORCE_STATE_CHECK = True - def csrf_token(self): + def state_token(self): """Generate csrf token to include as state parameter.""" return get_random_string(CSRF_KEY_LENGTH) + def get_redirect_uri(self, state): + """Build redirect_uri with redirect_state parameter.""" + return url_add_parameters(self.redirect_uri, {'redirect_state': state}) + def auth_url(self): """Return redirect url""" client_id, client_secret = self.get_key_and_secret() - args = {'client_id': client_id, 'redirect_uri': self.redirect_uri} - - if self.FORCE_STATE_CHECK: - state = self.csrf_token() - args['state'] = state - self.request.session[self.AUTH_BACKEND.name + '_state'] = state + state = self.state_token() + # Store state in session for further request validation. The state + # value is passed as state parameter (as specified in OAuth2 spec), but + # also added to redirect_uri, that way we can still verify the request + # if the provider doesn't implement the state parameter. + self.request.session[self.AUTH_BACKEND.name + '_state'] = state + args = { + 'client_id': client_id, + 'state': state, + 'redirect_uri': self.get_redirect_uri(state) + } scope = self.get_scope() if scope: @@ -693,28 +694,35 @@ class BaseOAuth2(BaseOAuth): args.update(self.auth_extra_arguments()) return self.AUTHORIZATION_URL + '?' + urlencode(args) + def validate_state(self): + """Validate state value. Raises exception on error, returns state + value if valid.""" + state = self.request.session.get(self.AUTH_BACKEND.name + '_state') + request_state = self.data.get('state') or \ + self.data.get('redirect_state') + if not request_state: + raise AuthMissingParameter(self, 'state') + elif not state: + raise AuthStateMissing(self, 'state') + elif not constant_time_compare(request_state, state): + raise AuthStateForbidden(self) + return state + def auth_complete(self, *args, **kwargs): """Completes loging process, must return user instance""" if self.data.get('error'): error = self.data.get('error_description') or self.data['error'] raise AuthFailed(self, error) - if self.FORCE_STATE_CHECK: - request_state = self.data.get('state') - state = self.request.session.get(self.AUTH_BACKEND.name + '_state') - if not request_state: - raise AuthMissingParameter(self, 'state') - elif not state: - raise AuthStateMissing(self, 'state') - elif not constant_time_compare(request_state, state): - raise AuthStateForbidden(self) - + state = self.validate_state() client_id, client_secret = self.get_key_and_secret() - params = {'grant_type': 'authorization_code', # request auth code - 'code': self.data.get('code', ''), # server response code - 'client_id': client_id, - 'client_secret': client_secret, - 'redirect_uri': self.redirect_uri} + params = { + 'grant_type': 'authorization_code', # request auth code + 'code': self.data.get('code', ''), # server response code + 'client_id': client_id, + 'client_secret': client_secret, + 'redirect_uri': self.get_redirect_uri(state) + } headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'} request = Request(self.ACCESS_TOKEN_URL, data=urlencode(params), diff --git a/social_auth/backends/contrib/github.py b/social_auth/backends/contrib/github.py index af9371a..0580b2f 100644 --- a/social_auth/backends/contrib/github.py +++ b/social_auth/backends/contrib/github.py @@ -53,8 +53,6 @@ class GithubAuth(BaseOAuth2): SCOPE_SEPARATOR = ',' # Look at http://developer.github.com/v3/oauth/ SCOPE_VAR_NAME = 'GITHUB_EXTENDED_PERMISSIONS' - # Github doesn't return the state paramenter if specified :( - FORCE_STATE_CHECK = False def user_data(self, access_token, *args, **kwargs): """Loads user data from service""" diff --git a/social_auth/backends/contrib/instagram.py b/social_auth/backends/contrib/instagram.py index 3cf1689..faa2e82 100644 --- a/social_auth/backends/contrib/instagram.py +++ b/social_auth/backends/contrib/instagram.py @@ -37,7 +37,6 @@ class InstagramAuth(BaseOAuth2): AUTH_BACKEND = InstagramBackend SETTINGS_KEY_NAME = 'INSTAGRAM_CLIENT_ID' SETTINGS_SECRET_NAME = 'INSTAGRAM_CLIENT_SECRET' - FORCE_STATE_CHECK = False def user_data(self, access_token, *args, **kwargs): """Loads user data from service""" diff --git a/social_auth/backends/contrib/yandex.py b/social_auth/backends/contrib/yandex.py index b13f992..ea6e457 100644 --- a/social_auth/backends/contrib/yandex.py +++ b/social_auth/backends/contrib/yandex.py @@ -12,8 +12,8 @@ from urllib import urlencode from urllib2 import urlopen from urlparse import urlparse, urlsplit -from social_auth.backends import OpenIDBackend, OpenIdAuth, USERNAME,\ - OAuthBackend, BaseOAuth2 +from social_auth.backends import OpenIDBackend, OpenIdAuth, USERNAME, \ + OAuthBackend, BaseOAuth2 from social_auth.utils import setting, log diff --git a/social_auth/backends/facebook.py b/social_auth/backends/facebook.py index ba9a844..303eda4 100644 --- a/social_auth/backends/facebook.py +++ b/social_auth/backends/facebook.py @@ -91,9 +91,10 @@ class FacebookAuth(BaseOAuth2): expires = None if 'code' in self.data: + state = self.validate_state() url = ACCESS_TOKEN + urlencode({ 'client_id': setting('FACEBOOK_APP_ID'), - 'redirect_uri': self.redirect_uri, + 'redirect_uri': self.get_redirect_uri(state), 'client_secret': setting('FACEBOOK_API_SECRET'), 'code': self.data['code'] }) diff --git a/social_auth/utils.py b/social_auth/utils.py index 8c38325..9c4b37d 100644 --- a/social_auth/utils.py +++ b/social_auth/utils.py @@ -1,4 +1,5 @@ import urlparse +import urllib import logging from collections import defaultdict from datetime import timedelta, tzinfo @@ -174,6 +175,16 @@ def log_exceptions_to_messages(request, backend, err): error(request, unicode(err), extra_tags='social-auth %s' % name) +def url_add_parameters(url, params): + """Adds parameters to URL, parameter will be repeated if already present""" + if params: + fragments = list(urlparse.urlparse(url)) + fragments[4] = urllib.urlencode(urlparse.parse_qsl(fragments[4]) + + params.items()) + url = urlparse.urlunparse(fragments) + return url + + if __name__ == '__main__': import doctest doctest.testmod()