ts-lyceum-back/auth.py
2023-11-30 21:01:57 +03:00

111 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timedelta
from fastapi import HTTPException, Depends, Request, status
from fastapi.openapi.models import OAuthFlows
from fastapi.security import OAuth2, OAuth2PasswordRequestForm
from fastapi.security.utils import get_authorization_scheme_param
from jose import jwt, JWTError
from typing import Optional
from passlib.context import CryptContext
from schemas.user_schemas import UserInfo, UserDatabase
from db.models.user import User
from tortoise.exceptions import DoesNotExist
from schemas.token_schemas import TokenData
SECRET_KEY = "test" # TODO: Ну типа да
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # TODO: Конфиг
pass_context = CryptContext(schemes=["bcrypt"])
class OAuth2PasswordBearerCookie(OAuth2):
def __init__(self, token_url: str, scheme_name: str = None, scopes: dict = None, auto_error: bool = True):
if not scopes:
scopes = {}
flows = OAuthFlows(password={"tokenUrl": token_url, "scopes": scopes})
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.cookies.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=401,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None
return param
security = OAuth2PasswordBearerCookie(token_url="/login")
async def get_current_user(token: str = Depends(security)):
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
try:
user = await UserInfo.from_queryset_single(
User.get(username=token_data.username)
)
except DoesNotExist:
raise credentials_exception
return user
async def get_user(username: str):
return await UserDatabase.from_queryset_single(User.get(username=username))
def verify_password(plain_password, hashed_password):
return pass_context.verify(plain_password, hashed_password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def validate_user(user: OAuth2PasswordRequestForm = Depends()):
try:
db_user = await get_user(user.username)
except DoesNotExist:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
if not verify_password(user.password, db_user.password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
)
return db_user