73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
from typing import Annotated
|
|
|
|
from fastapi import Depends, HTTPException
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
from jose import JWTError, jwt
|
|
from pydantic import ValidationError
|
|
from setech.utils import get_logger
|
|
from starlette import status
|
|
|
|
from service.api.models.auth import TokenPayload
|
|
from service.config import settings
|
|
from service.constants import security
|
|
from service.constants.types import PaginationParams
|
|
from service.database.models import AnonymousUser, User
|
|
|
|
__all__ = ["LoggedInUser", "QueryParams", "CurrentRequestUser", "RequestUser"]
|
|
_l = get_logger("api")
|
|
|
|
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.root_path}/login/access-token", auto_error=False)
|
|
|
|
TokenDep = Annotated[str | None, Depends(reusable_oauth2)]
|
|
RequestUser = User | AnonymousUser
|
|
|
|
|
|
async def get_current_user(token: TokenDep) -> User:
|
|
credentials_exception = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
if token is None:
|
|
raise credentials_exception
|
|
try:
|
|
payload = jwt.decode(token, settings.secret_key, algorithms=[security.ALGORITHM])
|
|
token_data = TokenPayload(**payload)
|
|
except (JWTError, ValidationError, AttributeError):
|
|
raise credentials_exception
|
|
user = await User.filter(username=token_data.sub).first()
|
|
if user is None:
|
|
raise credentials_exception
|
|
return user
|
|
|
|
|
|
LoggedInUser = Annotated[User, Depends(get_current_user)]
|
|
|
|
|
|
async def get_request_user(token: TokenDep) -> RequestUser:
|
|
if not token:
|
|
return AnonymousUser()
|
|
try:
|
|
user = await get_current_user(token)
|
|
except HTTPException:
|
|
return AnonymousUser()
|
|
return user
|
|
|
|
|
|
CurrentRequestUser = Annotated[RequestUser, Depends(get_request_user)]
|
|
|
|
|
|
def query_params(q: str | None = None, page: int = 1, limit: int = 10, order: str | None = None) -> PaginationParams:
|
|
page -= 1
|
|
if page < 0:
|
|
page = 0
|
|
if limit < 0:
|
|
limit = 1
|
|
if limit > 250:
|
|
limit = 250
|
|
_l.info(f"Filtering by: {q=}, {page=}, {limit=} | Ordering by: {order}")
|
|
return PaginationParams(q=q, offset=page * limit, limit=limit, order=order)
|
|
|
|
|
|
QueryParams = Annotated[PaginationParams, Depends(query_params)]
|