|
@@ -40,7 +40,6 @@ from sentry.services.hybrid_cloud.organization import (
|
|
RpcOrganizationMemberSummary,
|
|
RpcOrganizationMemberSummary,
|
|
organization_service,
|
|
organization_service,
|
|
)
|
|
)
|
|
-from sentry.services.hybrid_cloud.user import RpcUser
|
|
|
|
from sentry.services.hybrid_cloud.user.service import user_service
|
|
from sentry.services.hybrid_cloud.user.service import user_service
|
|
from sentry.silo import unguarded_write
|
|
from sentry.silo import unguarded_write
|
|
from sentry.silo.base import SiloMode
|
|
from sentry.silo.base import SiloMode
|
|
@@ -122,15 +121,17 @@ class DatabaseBackedAuthService(AuthService):
|
|
def get_org_auth_config(
|
|
def get_org_auth_config(
|
|
self, *, organization_ids: List[int]
|
|
self, *, organization_ids: List[int]
|
|
) -> List[RpcOrganizationAuthConfig]:
|
|
) -> List[RpcOrganizationAuthConfig]:
|
|
- aps: Mapping[int, AuthProvider] = {
|
|
|
|
|
|
+ aps = {
|
|
ap.organization_id: ap
|
|
ap.organization_id: ap
|
|
for ap in AuthProvider.objects.filter(organization_id__in=organization_ids)
|
|
for ap in AuthProvider.objects.filter(organization_id__in=organization_ids)
|
|
}
|
|
}
|
|
- qs: Mapping[int, int] = {
|
|
|
|
|
|
+ qs = {
|
|
row["organization_id"]: row["id__count"]
|
|
row["organization_id"]: row["id__count"]
|
|
- for row in ApiKey.objects.filter(organization_id__in=organization_ids)
|
|
|
|
- .values("organization_id")
|
|
|
|
- .annotate(Count("id"))
|
|
|
|
|
|
+ for row in (
|
|
|
|
+ ApiKey.objects.filter(organization_id__in=organization_ids)
|
|
|
|
+ .values("organization_id")
|
|
|
|
+ .annotate(Count("id"))
|
|
|
|
+ )
|
|
}
|
|
}
|
|
return [
|
|
return [
|
|
RpcOrganizationAuthConfig(
|
|
RpcOrganizationAuthConfig(
|
|
@@ -141,29 +142,21 @@ class DatabaseBackedAuthService(AuthService):
|
|
for oid in organization_ids
|
|
for oid in organization_ids
|
|
]
|
|
]
|
|
|
|
|
|
- def _load_auth_user(self, user: User) -> RpcUser | None:
|
|
|
|
- rpc_user: RpcUser | None = None
|
|
|
|
- if user is not None:
|
|
|
|
- return user_service.get_user(user_id=user.id)
|
|
|
|
- return rpc_user
|
|
|
|
-
|
|
|
|
def authenticate_with(
|
|
def authenticate_with(
|
|
self, *, request: AuthenticationRequest, authenticator_types: List[RpcAuthenticatorType]
|
|
self, *, request: AuthenticationRequest, authenticator_types: List[RpcAuthenticatorType]
|
|
) -> AuthenticationContext:
|
|
) -> AuthenticationContext:
|
|
fake_request = FakeAuthenticationRequest(request)
|
|
fake_request = FakeAuthenticationRequest(request)
|
|
- user: User | None = None
|
|
|
|
- token: Any = None
|
|
|
|
|
|
|
|
for authenticator_type in authenticator_types:
|
|
for authenticator_type in authenticator_types:
|
|
- t = authenticator_type.as_authenticator().authenticate(fake_request)
|
|
|
|
|
|
+ t = authenticator_type.as_authenticator().authenticate(fake_request) # type: ignore[arg-type]
|
|
if t is not None:
|
|
if t is not None:
|
|
user, token = t
|
|
user, token = t
|
|
- break
|
|
|
|
|
|
+ return AuthenticationContext(
|
|
|
|
+ auth=AuthenticatedToken.from_token(token),
|
|
|
|
+ user=user_service.get_user(user_id=user.id),
|
|
|
|
+ )
|
|
|
|
|
|
- return AuthenticationContext(
|
|
|
|
- auth=AuthenticatedToken.from_token(token) if token else None,
|
|
|
|
- user=self._load_auth_user(user),
|
|
|
|
- )
|
|
|
|
|
|
+ return AuthenticationContext(auth=None, user=None)
|
|
|
|
|
|
def token_has_org_access(self, *, token: AuthenticatedToken, organization_id: int) -> bool:
|
|
def token_has_org_access(self, *, token: AuthenticatedToken, organization_id: int) -> bool:
|
|
return SentryAppInstallationToken.objects.has_organization_access(token, organization_id)
|
|
return SentryAppInstallationToken.objects.has_organization_access(token, organization_id)
|
|
@@ -171,17 +164,17 @@ class DatabaseBackedAuthService(AuthService):
|
|
def authenticate(self, *, request: AuthenticationRequest) -> MiddlewareAuthenticationResponse:
|
|
def authenticate(self, *, request: AuthenticationRequest) -> MiddlewareAuthenticationResponse:
|
|
fake_request = FakeAuthenticationRequest(request)
|
|
fake_request = FakeAuthenticationRequest(request)
|
|
handler = RequestAuthenticationMiddleware(placeholder_get_response)
|
|
handler = RequestAuthenticationMiddleware(placeholder_get_response)
|
|
- expired_user: User | None = None
|
|
|
|
|
|
+ expired_user = None
|
|
try:
|
|
try:
|
|
# Hahaha. Yes. You're reading this right. I'm calling, the middleware, from the service method, that is
|
|
# Hahaha. Yes. You're reading this right. I'm calling, the middleware, from the service method, that is
|
|
# called, from slightly different, middleware.
|
|
# called, from slightly different, middleware.
|
|
- handler.process_request(fake_request)
|
|
|
|
|
|
+ handler.process_request(fake_request) # type: ignore[arg-type]
|
|
except AuthUserPasswordExpired as e:
|
|
except AuthUserPasswordExpired as e:
|
|
expired_user = e.user
|
|
expired_user = e.user
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise Exception("Unexpected error processing handler") from e
|
|
raise Exception("Unexpected error processing handler") from e
|
|
|
|
|
|
- auth: AuthenticatedToken | None = None
|
|
|
|
|
|
+ auth = None
|
|
if fake_request.auth is not None:
|
|
if fake_request.auth is not None:
|
|
auth = AuthenticatedToken.from_token(fake_request.auth)
|
|
auth = AuthenticatedToken.from_token(fake_request.auth)
|
|
|
|
|
|
@@ -192,11 +185,11 @@ class DatabaseBackedAuthService(AuthService):
|
|
)
|
|
)
|
|
|
|
|
|
if expired_user is not None:
|
|
if expired_user is not None:
|
|
- result.user = self._load_auth_user(expired_user)
|
|
|
|
|
|
+ result.user = user_service.get_user(user_id=expired_user.id)
|
|
result.expired = True
|
|
result.expired = True
|
|
elif fake_request.user is not None and not fake_request.user.is_anonymous:
|
|
elif fake_request.user is not None and not fake_request.user.is_anonymous:
|
|
with transaction.atomic(using=router.db_for_read(User)):
|
|
with transaction.atomic(using=router.db_for_read(User)):
|
|
- result.user = self._load_auth_user(fake_request.user)
|
|
|
|
|
|
+ result.user = user_service.get_user(user_id=fake_request.user.id)
|
|
transaction.set_rollback(True, using=router.db_for_read(User))
|
|
transaction.set_rollback(True, using=router.db_for_read(User))
|
|
if SiloMode.single_process_silo_mode():
|
|
if SiloMode.single_process_silo_mode():
|
|
connections.close_all()
|
|
connections.close_all()
|
|
@@ -214,16 +207,15 @@ class DatabaseBackedAuthService(AuthService):
|
|
sso_state = _query_sso_state(
|
|
sso_state = _query_sso_state(
|
|
organization_id=organization_id, is_super_user=is_superuser, member=org_member
|
|
organization_id=organization_id, is_super_user=is_superuser, member=org_member
|
|
)
|
|
)
|
|
- permissions: List[str] = list()
|
|
|
|
- # "permissions" is a bit of a misnomer -- these are all admin level permissions, and the intent is that if you
|
|
|
|
- # have them, you can only use them when you are acting, as a superuser. This is intentional.
|
|
|
|
|
|
+
|
|
if is_superuser:
|
|
if is_superuser:
|
|
- permissions.extend(get_permissions_for_user(user_id))
|
|
|
|
|
|
+ # "permissions" is a bit of a misnomer -- these are all admin level permissions, and the intent is that if you
|
|
|
|
+ # have them, you can only use them when you are acting, as a superuser. This is intentional.
|
|
|
|
+ permissions = list(get_permissions_for_user(user_id))
|
|
|
|
+ else:
|
|
|
|
+ permissions = []
|
|
|
|
|
|
- return RpcAuthState(
|
|
|
|
- sso_state=sso_state,
|
|
|
|
- permissions=permissions,
|
|
|
|
- )
|
|
|
|
|
|
+ return RpcAuthState(sso_state=sso_state, permissions=permissions)
|
|
|
|
|
|
def get_org_ids_with_scim(
|
|
def get_org_ids_with_scim(
|
|
self,
|
|
self,
|
|
@@ -262,7 +254,7 @@ class DatabaseBackedAuthService(AuthService):
|
|
def disable_provider(self, *, provider_id: int) -> None:
|
|
def disable_provider(self, *, provider_id: int) -> None:
|
|
with outbox_context(transaction.atomic(router.db_for_write(AuthProvider))):
|
|
with outbox_context(transaction.atomic(router.db_for_write(AuthProvider))):
|
|
try:
|
|
try:
|
|
- auth_provider: AuthProvider = AuthProvider.objects.get(id=provider_id)
|
|
|
|
|
|
+ auth_provider = AuthProvider.objects.get(id=provider_id)
|
|
except AuthProvider.DoesNotExist:
|
|
except AuthProvider.DoesNotExist:
|
|
return
|
|
return
|
|
|
|
|