]> git.parisson.com Git - django-social-auth.git/commitdiff
Refactored backend loading to avoid a race condition. Fixes #204
authorStephen McDonald <steve@jupo.org>
Sat, 24 Dec 2011 20:26:23 +0000 (07:26 +1100)
committerStephen McDonald <steve@jupo.org>
Sat, 24 Dec 2011 20:26:23 +0000 (07:26 +1100)
social_auth/backends/__init__.py

index 20ec356cd3a490ff0fd65404a7648a7549b45066..83713ca4eaf353d4da4f713958132a54a632acc3 100644 (file)
@@ -287,6 +287,7 @@ class BaseAuth(object):
 
         @AUTH_BACKEND   Authorization backend related with this service
     """
+
     AUTH_BACKEND = None
 
     def __init__(self, request, redirect):
@@ -607,49 +608,49 @@ class BaseOAuth2(BaseOAuth):
                setting(self.SETTINGS_SECRET_NAME)
 
 
-# import sources from where check for auth backends
-SOCIAL_AUTH_IMPORT_SOURCES = (
-    'social_auth.backends',
-    'social_auth.backends.contrib',
-) + setting('SOCIAL_AUTH_IMPORT_BACKENDS', ())
-
-def get_backends():
-    enabled = setting('SOCIAL_AUTH_ENABLED_BACKENDS')
-    if enabled:
-        enabled = defaultdict(lambda: False, ((bak, True) for bak in enabled))
-    else:
-        enabled = defaultdict(lambda: True)
-
-    backends = {}
-    for mod_name in SOCIAL_AUTH_IMPORT_SOURCES:
-        try:
-            mod = import_module(mod_name)
-        except ImportError:
-            logger.exception('Error importing %s', mod_name)
-            continue
-
-        for directory, subdir, files in walk(mod.__path__[0]):
-            for name in filter(lambda name: name.endswith('.py'), files):
-                try:
-                    name = basename(name).replace('.py', '')
-                    sub = import_module(mod_name + '.' + name)
-
-                    # register only enabled backends
-                    new = ((key, val) for key, val in sub.BACKENDS.items()
-                                if val.enabled() and enabled[key])
-                    backends.update(new)
-                except (ImportError, AttributeError):
-                    pass
-
-    if enabled[OpenIdAuth.AUTH_BACKEND.name]:
-        backends[OpenIdAuth.AUTH_BACKEND.name] = OpenIdAuth
-    return backends
-
-
-# load backends from defined modules
-BACKENDS = get_backends()
+# Backend loading was previously performed via the
+# SOCIAL_AUTH_IMPORT_BACKENDS setting - as it's no longer used,
+# provide a deprecation warning.
+if setting('SOCIAL_AUTH_IMPORT_BACKENDS'):
+    from warnings import warn
+    warn("SOCIAL_AUTH_IMPORT_SOURCES is deprecated")
 
+# Cache for discovered backends.
+BACKENDS = {}
 
 def get_backend(name, *args, **kwargs):
-    """Return auth backend instance *if* it's registered, None in other case"""
-    return BACKENDS.get(name, lambda *args, **kwargs: None)(*args, **kwargs)
+    """Returns a backend by name. Backends are stored in the BACKENDS
+    cache dict. If not found, each of the modules referenced in
+    AUTHENTICATION_BACKENDS is imported and checked for a BACKENDS
+    definition. If the named backend is found in the module's BACKENDS
+    definition, it's then stored in the cache for future access.
+
+    Previously all backends were attempted to be loaded at
+    import time of this module, which meant that backends that subclass
+    bases found in this module would not have the chance to be loaded
+    by the time they were added to this module's BACKENDS dict. See:
+    https://github.com/omab/django-social-auth/issues/204
+
+    This new approach ensures that backends are allowed to subclass from
+    bases in this module and still be picked up.
+    """
+    try:
+        # Cached backend which has previously been discovered.
+        return BACKENDS[name](*args, **kwargs)
+    except KeyError:
+        pass
+    # Look for a BACKENDS definition on each of the modules for
+    # AUTHENTICATION_BACKENDS.
+    for auth_backend in settings.AUTHENTICATION_BACKENDS:
+        module = import_module(auth_backend.rsplit(".", 1)[0])
+        backends = getattr(module, "BACKENDS", {})
+        try:
+            backend = backends[name]
+        except KeyError:
+            pass
+        else:
+            # If the backend is enabled, add it to the cache and
+            # return it.
+            if backend.enabled():
+                BACKENDS[name] = backend
+                return backend(*args, **kwargs)