|
@@ -1,6 +1,5 @@
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
-import hashlib
|
|
|
|
from collections.abc import Callable, Iterable
|
|
from collections.abc import Callable, Iterable
|
|
from typing import Any, ClassVar
|
|
from typing import Any, ClassVar
|
|
|
|
|
|
@@ -299,55 +298,6 @@ class ClientIdSecretAuthentication(QuietBasicAuthentication):
|
|
class UserAuthTokenAuthentication(StandardAuthentication):
|
|
class UserAuthTokenAuthentication(StandardAuthentication):
|
|
token_name = b"bearer"
|
|
token_name = b"bearer"
|
|
|
|
|
|
- def _find_or_update_token_by_hash(self, token_str: str) -> ApiToken | ApiTokenReplica:
|
|
|
|
- """
|
|
|
|
- Find token by hash or update token's hash value if only found via plaintext.
|
|
|
|
-
|
|
|
|
- 1. Hash provided plaintext token.
|
|
|
|
- 2. Perform lookup based on hashed value.
|
|
|
|
- 3. If found, return the token.
|
|
|
|
- 4. If not found, search for the token based on its plaintext value.
|
|
|
|
- 5. If found, update the token's hashed value and return the token.
|
|
|
|
- 6. If not found via hash or plaintext value, raise AuthenticationFailed
|
|
|
|
-
|
|
|
|
- Returns `ApiTokenReplica` if running in REGION silo or
|
|
|
|
- `ApiToken` if running in CONTROL silo.
|
|
|
|
- """
|
|
|
|
-
|
|
|
|
- hashed_token = hashlib.sha256(token_str.encode()).hexdigest()
|
|
|
|
-
|
|
|
|
- if SiloMode.get_current_mode() == SiloMode.REGION:
|
|
|
|
- try:
|
|
|
|
- # Try to find the token by its hashed value first
|
|
|
|
- return ApiTokenReplica.objects.get(hashed_token=hashed_token)
|
|
|
|
- except ApiTokenReplica.DoesNotExist:
|
|
|
|
- try:
|
|
|
|
- # If we can't find it by hash, use the plaintext string
|
|
|
|
- return ApiTokenReplica.objects.get(token=token_str)
|
|
|
|
- except ApiTokenReplica.DoesNotExist:
|
|
|
|
- # If the token does not exist by plaintext either, it is not a valid token
|
|
|
|
- raise AuthenticationFailed("Invalid token")
|
|
|
|
- else:
|
|
|
|
- try:
|
|
|
|
- # Try to find the token by its hashed value first
|
|
|
|
- return ApiToken.objects.select_related("user", "application").get(
|
|
|
|
- hashed_token=hashed_token
|
|
|
|
- )
|
|
|
|
- except ApiToken.DoesNotExist:
|
|
|
|
- try:
|
|
|
|
- # If we can't find it by hash, use the plaintext string
|
|
|
|
- api_token = ApiToken.objects.select_related("user", "application").get(
|
|
|
|
- token=token_str
|
|
|
|
- )
|
|
|
|
- except ApiToken.DoesNotExist:
|
|
|
|
- # If the token does not exist by plaintext either, it is not a valid token
|
|
|
|
- raise AuthenticationFailed("Invalid token")
|
|
|
|
- else:
|
|
|
|
- # Update it with the hashed value if found by plaintext
|
|
|
|
- api_token.hashed_token = hashed_token
|
|
|
|
- api_token.save(update_fields=["hashed_token"])
|
|
|
|
- return api_token
|
|
|
|
-
|
|
|
|
def accepts_auth(self, auth: list[bytes]) -> bool:
|
|
def accepts_auth(self, auth: list[bytes]) -> bool:
|
|
if not super().accepts_auth(auth):
|
|
if not super().accepts_auth(auth):
|
|
return False
|
|
return False
|
|
@@ -370,16 +320,26 @@ class UserAuthTokenAuthentication(StandardAuthentication):
|
|
application_is_inactive = False
|
|
application_is_inactive = False
|
|
|
|
|
|
if not token:
|
|
if not token:
|
|
- token = self._find_or_update_token_by_hash(token_str)
|
|
|
|
- if isinstance(token, ApiTokenReplica): # we're running as a REGION silo
|
|
|
|
- user = user_service.get_user(user_id=token.user_id)
|
|
|
|
- application_is_inactive = not token.application_is_active
|
|
|
|
- else: # the token returned is an ApiToken from the CONTROL silo
|
|
|
|
- user = token.user
|
|
|
|
|
|
+ if SiloMode.get_current_mode() == SiloMode.REGION:
|
|
|
|
+ try:
|
|
|
|
+ atr = token = ApiTokenReplica.objects.get(token=token_str)
|
|
|
|
+ except ApiTokenReplica.DoesNotExist:
|
|
|
|
+ raise AuthenticationFailed("Invalid token")
|
|
|
|
+ user = user_service.get_user(user_id=atr.user_id)
|
|
|
|
+ application_is_inactive = not atr.application_is_active
|
|
|
|
+ else:
|
|
|
|
+ try:
|
|
|
|
+ at = token = (
|
|
|
|
+ ApiToken.objects.filter(token=token_str)
|
|
|
|
+ .select_related("user", "application")
|
|
|
|
+ .get()
|
|
|
|
+ )
|
|
|
|
+ except ApiToken.DoesNotExist:
|
|
|
|
+ raise AuthenticationFailed("Invalid token")
|
|
|
|
+ user = at.user
|
|
application_is_inactive = (
|
|
application_is_inactive = (
|
|
- token.application is not None and not token.application.is_active
|
|
|
|
|
|
+ at.application is not None and not at.application.is_active
|
|
)
|
|
)
|
|
-
|
|
|
|
elif isinstance(token, SystemToken):
|
|
elif isinstance(token, SystemToken):
|
|
user = token.user
|
|
user = token.user
|
|
|
|
|
|
@@ -429,9 +389,9 @@ class OrgAuthTokenAuthentication(StandardAuthentication):
|
|
raise AuthenticationFailed("Invalid org token")
|
|
raise AuthenticationFailed("Invalid org token")
|
|
else:
|
|
else:
|
|
try:
|
|
try:
|
|
- token = OrgAuthToken.objects.get(
|
|
|
|
|
|
+ token = OrgAuthToken.objects.filter(
|
|
token_hashed=token_hashed, date_deactivated__isnull=True
|
|
token_hashed=token_hashed, date_deactivated__isnull=True
|
|
- )
|
|
|
|
|
|
+ ).get()
|
|
except OrgAuthToken.DoesNotExist:
|
|
except OrgAuthToken.DoesNotExist:
|
|
raise AuthenticationFailed("Invalid org token")
|
|
raise AuthenticationFailed("Invalid org token")
|
|
|
|
|