expose reset data to admin (#1602)
This commit is contained in:
@@ -60,6 +60,7 @@ from shared.api.models.management.responses import (
|
||||
WrappedGetPromptsResponse,
|
||||
WrappedLogResponse,
|
||||
WrappedPromptMessageResponse,
|
||||
WrappedResetDataResult,
|
||||
WrappedServerStatsResponse,
|
||||
WrappedUserCollectionResponse,
|
||||
WrappedUserOverviewResponse,
|
||||
@@ -86,6 +87,7 @@ __all__ = [
|
||||
"WrappedUserResponse",
|
||||
"WrappedVerificationResult",
|
||||
"WrappedGenericMessageResponse",
|
||||
"WrappedResetDataResult",
|
||||
# Ingestion Responses
|
||||
"IngestionResponse",
|
||||
"WrappedIngestionResponse",
|
||||
|
||||
@@ -495,7 +495,7 @@ class UserHandler(Handler):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_verification_data(
|
||||
async def get_user_validation_data(
|
||||
self, user_id: UUID, *args, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -1393,10 +1393,10 @@ class DatabaseProvider(Provider):
|
||||
user_ids, offset, limit
|
||||
)
|
||||
|
||||
async def get_user_verification_data(
|
||||
async def get_user_validation_data(
|
||||
self, user_id: UUID, *args, **kwargs
|
||||
) -> dict:
|
||||
return await self.user_handler.get_user_verification_data(user_id)
|
||||
return await self.user_handler.get_user_validation_data(user_id)
|
||||
|
||||
# Vector handler methods
|
||||
async def upsert(self, entry: VectorEntry) -> None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.base import R2RException
|
||||
from core.base.api.models import (
|
||||
GenericMessageResponse,
|
||||
WrappedGenericMessageResponse,
|
||||
WrappedResetDataResult,
|
||||
WrappedTokenResponse,
|
||||
WrappedUserResponse,
|
||||
WrappedVerificationResult,
|
||||
@@ -280,7 +281,36 @@ class AuthRouter(BaseRouter):
|
||||
raise R2RException(
|
||||
status_code=400, message="Invalid user ID format"
|
||||
)
|
||||
result = await self.service.get_user_verification_data(user_uuid)
|
||||
result = await self.service.get_user_verification_code(user_uuid)
|
||||
return result
|
||||
|
||||
@self.router.get("/user/{user_id}/reset_token")
|
||||
@self.base_endpoint
|
||||
async def get_user_reset_token(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
auth_user=Depends(self.service.providers.auth.auth_wrapper),
|
||||
) -> WrappedResetDataResult:
|
||||
"""
|
||||
Get only the verification code for a specific user.
|
||||
Only accessible by superusers.
|
||||
"""
|
||||
if not auth_user.is_superuser:
|
||||
raise R2RException(
|
||||
status_code=403,
|
||||
message="Only superusers can access verification codes",
|
||||
)
|
||||
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
except ValueError:
|
||||
raise R2RException(
|
||||
status_code=400, message="Invalid user ID format"
|
||||
)
|
||||
result = await self.service.get_user_reset_token(user_uuid)
|
||||
if not result["reset_token"]:
|
||||
raise R2RException(
|
||||
status_code=404, message="No reset token found"
|
||||
)
|
||||
return result
|
||||
|
||||
# Add to AuthRouter class (auth_router.py)
|
||||
|
||||
@@ -185,7 +185,7 @@ class AuthService(Service):
|
||||
)
|
||||
|
||||
@telemetry_event("GetUserVerificationCode")
|
||||
async def get_user_verification_data(
|
||||
async def get_user_verification_code(
|
||||
self, user_id: UUID, *args, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -193,7 +193,7 @@ class AuthService(Service):
|
||||
This method should be called after superuser authorization has been verified.
|
||||
"""
|
||||
verification_data = (
|
||||
await self.providers.database.get_user_verification_data(user_id)
|
||||
await self.providers.database.get_user_validation_data(user_id)
|
||||
)
|
||||
return {
|
||||
"verification_code": verification_data["verification_data"][
|
||||
@@ -204,6 +204,26 @@ class AuthService(Service):
|
||||
],
|
||||
}
|
||||
|
||||
@telemetry_event("GetUserVerificationCode")
|
||||
async def get_user_reset_token(
|
||||
self, user_id: UUID, *args, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Get only the verification code data for a specific user.
|
||||
This method should be called after superuser authorization has been verified.
|
||||
"""
|
||||
verification_data = (
|
||||
await self.providers.database.get_user_validation_data(user_id)
|
||||
)
|
||||
return {
|
||||
"reset_token": verification_data["verification_data"][
|
||||
"reset_token"
|
||||
],
|
||||
"expiry": verification_data["verification_data"][
|
||||
"reset_token_expiry"
|
||||
],
|
||||
}
|
||||
|
||||
@telemetry_event("SendResetEmail")
|
||||
async def send_reset_email(self, email: str) -> dict:
|
||||
"""
|
||||
|
||||
@@ -289,7 +289,8 @@ class KGTriplesExtractionPipe(AsyncPipe[dict]):
|
||||
|
||||
# sort the extractions accroding to chunk_order field in metadata in ascending order
|
||||
extractions = sorted(
|
||||
extractions, key=lambda x: x.metadata.get("chunk_order", float('inf'))
|
||||
extractions,
|
||||
key=lambda x: x.metadata.get("chunk_order", float("inf")),
|
||||
)
|
||||
|
||||
# group these extractions into groups of extraction_merge_count
|
||||
|
||||
@@ -580,7 +580,7 @@ class PostgresUserHandler(UserHandler):
|
||||
)
|
||||
return result is not None
|
||||
|
||||
async def get_user_verification_data(
|
||||
async def get_user_validation_data(
|
||||
self, user_id: UUID, *args, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
|
||||
@@ -213,6 +213,20 @@ class AuthMixins:
|
||||
"GET", f"user/{user_id}/verification_data"
|
||||
)
|
||||
|
||||
async def get_user_reset_token(self, user_id: Union[str, UUID]) -> dict:
|
||||
"""
|
||||
Retrieves only the verification code for a specific user. Requires superuser access.
|
||||
|
||||
Args:
|
||||
user_id (Union[str, UUID]): The ID of the user to get verification code for.
|
||||
|
||||
Returns:
|
||||
dict: Contains verification code and its expiry date
|
||||
"""
|
||||
return await self._make_request( # type: ignore
|
||||
"GET", f"user/{user_id}/reset_token"
|
||||
)
|
||||
|
||||
async def send_reset_email(self, email: str) -> dict:
|
||||
"""
|
||||
Generates a new verification code and sends a reset email to the user.
|
||||
|
||||
@@ -145,6 +145,18 @@ class VerificationResult(BaseModel):
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class VerificationResult(BaseModel):
|
||||
verification_code: str
|
||||
expiry: datetime
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class ResetDataResult(BaseModel):
|
||||
reset_token: str
|
||||
expiry: datetime
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class AddUserResponse(BaseModel):
|
||||
result: bool
|
||||
|
||||
@@ -178,6 +190,7 @@ WrappedDocumentChunkResponse = PaginatedResultsWrapper[
|
||||
]
|
||||
WrappedDeleteResponse = ResultsWrapper[None]
|
||||
WrappedVerificationResult = ResultsWrapper[VerificationResult]
|
||||
WrappedResetDataResult = ResultsWrapper[ResetDataResult]
|
||||
WrappedConversationsOverviewResponse = PaginatedResultsWrapper[
|
||||
list[ConversationOverviewResponse]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user