expose reset data to admin (#1602)

This commit is contained in:
emrgnt-cmplxty
2024-11-17 20:45:35 -08:00
committed by GitHub
parent 8a2723a657
commit 2a3f06a541
8 changed files with 88 additions and 8 deletions
+2
View File
@@ -60,6 +60,7 @@ from shared.api.models.management.responses import (
WrappedGetPromptsResponse,
WrappedLogResponse,
WrappedPromptMessageResponse,
WrappedResetDataResult,
WrappedServerStatsResponse,
WrappedUserCollectionResponse,
WrappedUserOverviewResponse,
@@ -86,6 +87,7 @@ __all__ = [
"WrappedUserResponse",
"WrappedVerificationResult",
"WrappedGenericMessageResponse",
"WrappedResetDataResult",
# Ingestion Responses
"IngestionResponse",
"WrappedIngestionResponse",
+3 -3
View File
@@ -495,7 +495,7 @@ class UserHandler(Handler):
pass
@abstractmethod
async def get_user_verification_data(
async def get_user_validation_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
@@ -1393,10 +1393,10 @@ class DatabaseProvider(Provider):
user_ids, offset, limit
)
async def get_user_verification_data(
async def get_user_validation_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
return await self.user_handler.get_user_verification_data(user_id)
return await self.user_handler.get_user_validation_data(user_id)
# Vector handler methods
async def upsert(self, entry: VectorEntry) -> None:
+31 -1
View File
@@ -9,6 +9,7 @@ from core.base import R2RException
from core.base.api.models import (
GenericMessageResponse,
WrappedGenericMessageResponse,
WrappedResetDataResult,
WrappedTokenResponse,
WrappedUserResponse,
WrappedVerificationResult,
@@ -280,7 +281,36 @@ class AuthRouter(BaseRouter):
raise R2RException(
status_code=400, message="Invalid user ID format"
)
result = await self.service.get_user_verification_data(user_uuid)
result = await self.service.get_user_verification_code(user_uuid)
return result
@self.router.get("/user/{user_id}/reset_token")
@self.base_endpoint
async def get_user_reset_token(
user_id: str = Path(..., description="User ID"),
auth_user=Depends(self.service.providers.auth.auth_wrapper),
) -> WrappedResetDataResult:
"""
Get only the verification code for a specific user.
Only accessible by superusers.
"""
if not auth_user.is_superuser:
raise R2RException(
status_code=403,
message="Only superusers can access verification codes",
)
try:
user_uuid = UUID(user_id)
except ValueError:
raise R2RException(
status_code=400, message="Invalid user ID format"
)
result = await self.service.get_user_reset_token(user_uuid)
if not result["reset_token"]:
raise R2RException(
status_code=404, message="No reset token found"
)
return result
# Add to AuthRouter class (auth_router.py)
+22 -2
View File
@@ -185,7 +185,7 @@ class AuthService(Service):
)
@telemetry_event("GetUserVerificationCode")
async def get_user_verification_data(
async def get_user_verification_code(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
@@ -193,7 +193,7 @@ class AuthService(Service):
This method should be called after superuser authorization has been verified.
"""
verification_data = (
await self.providers.database.get_user_verification_data(user_id)
await self.providers.database.get_user_validation_data(user_id)
)
return {
"verification_code": verification_data["verification_data"][
@@ -204,6 +204,26 @@ class AuthService(Service):
],
}
@telemetry_event("GetUserVerificationCode")
async def get_user_reset_token(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
Get only the verification code data for a specific user.
This method should be called after superuser authorization has been verified.
"""
verification_data = (
await self.providers.database.get_user_validation_data(user_id)
)
return {
"reset_token": verification_data["verification_data"][
"reset_token"
],
"expiry": verification_data["verification_data"][
"reset_token_expiry"
],
}
@telemetry_event("SendResetEmail")
async def send_reset_email(self, email: str) -> dict:
"""
+2 -1
View File
@@ -289,7 +289,8 @@ class KGTriplesExtractionPipe(AsyncPipe[dict]):
# sort the extractions accroding to chunk_order field in metadata in ascending order
extractions = sorted(
extractions, key=lambda x: x.metadata.get("chunk_order", float('inf'))
extractions,
key=lambda x: x.metadata.get("chunk_order", float("inf")),
)
# group these extractions into groups of extraction_merge_count
+1 -1
View File
@@ -580,7 +580,7 @@ class PostgresUserHandler(UserHandler):
)
return result is not None
async def get_user_verification_data(
async def get_user_validation_data(
self, user_id: UUID, *args, **kwargs
) -> dict:
"""
+14
View File
@@ -213,6 +213,20 @@ class AuthMixins:
"GET", f"user/{user_id}/verification_data"
)
async def get_user_reset_token(self, user_id: Union[str, UUID]) -> dict:
"""
Retrieves only the verification code for a specific user. Requires superuser access.
Args:
user_id (Union[str, UUID]): The ID of the user to get verification code for.
Returns:
dict: Contains verification code and its expiry date
"""
return await self._make_request( # type: ignore
"GET", f"user/{user_id}/reset_token"
)
async def send_reset_email(self, email: str) -> dict:
"""
Generates a new verification code and sends a reset email to the user.
@@ -145,6 +145,18 @@ class VerificationResult(BaseModel):
message: Optional[str] = None
class VerificationResult(BaseModel):
verification_code: str
expiry: datetime
message: Optional[str] = None
class ResetDataResult(BaseModel):
reset_token: str
expiry: datetime
message: Optional[str] = None
class AddUserResponse(BaseModel):
result: bool
@@ -178,6 +190,7 @@ WrappedDocumentChunkResponse = PaginatedResultsWrapper[
]
WrappedDeleteResponse = ResultsWrapper[None]
WrappedVerificationResult = ResultsWrapper[VerificationResult]
WrappedResetDataResult = ResultsWrapper[ResetDataResult]
WrappedConversationsOverviewResponse = PaginatedResultsWrapper[
list[ConversationOverviewResponse]
]