"""Authentication routes — register, login, refresh, profile, logout.""" from datetime import datetime, timezone from typing import Optional from uuid import UUID import redis.asyncio as redis from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.auth.jwt import ( ACCESS_TOKEN_EXPIRE_MINUTES, REFRESH_TOKEN_EXPIRE_DAYS, create_access_token, create_refresh_token, verify_token, ) from app.config import settings from app.database import get_session from app.models.models import User from app.schemas.schemas import TokenResponse, UserCreate, UserLogin, UserResponse from passlib.context import CryptContext # --------------------------------------------------------------------------- # Router # --------------------------------------------------------------------------- router = APIRouter(prefix="/auth", tags=["Authentication"]) # --------------------------------------------------------------------------- # Password hashing # --------------------------------------------------------------------------- _pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def _hash_password(password: str) -> str: return _pwd_context.hash(password) def _verify_password(plain: str, hashed: str) -> bool: return _pwd_context.verify(plain, hashed) # --------------------------------------------------------------------------- # Redis helpers (lazy singleton client) # --------------------------------------------------------------------------- _redis_client: Optional[redis.Redis] = None def _get_redis() -> redis.Redis: """Return a singleton Redis client built from *settings.REDIS_URL*.""" global _redis_client if _redis_client is None: _redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True) return _redis_client async def _blacklist_token(token: str, ttl: int) -> None: """Store *token* in the blacklist with *ttl* seconds expiry.""" r = _get_redis() await r.setex(f"token:blacklist:{token}", ttl, "1") async def _is_token_blacklisted(token: str) -> bool: """Return True if *token* has been blacklisted.""" r = _get_redis() return await r.exists(f"token:blacklist:{token}") > 0 async def _enforce_not_blacklisted(token: str) -> None: if await _is_token_blacklisted(token): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has been revoked", headers={"WWW-Authenticate": "Bearer"}, ) # --------------------------------------------------------------------------- # OAuth2 scheme # --------------------------------------------------------------------------- from fastapi.security import OAuth2PasswordBearer _oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") # --------------------------------------------------------------------------- # Dependency: load current user from JWT (with blacklist check) # --------------------------------------------------------------------------- async def get_current_user_dep( token: str = Depends(_oauth2_scheme), session: AsyncSession = Depends(get_session), ) -> User: """Full dependency: verify token string, check blacklist, load user.""" await _enforce_not_blacklisted(token) payload = verify_token(token) user_id = payload.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload", ) stmt = select(User).where(User.id == UUID(user_id)) result = await session.execute(stmt) user = result.scalar_one_or_none() if user is None or not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive", ) return user # ═══════════════════════════════════════════════════════════════════════════ # Request / Response models # ═══════════════════════════════════════════════════════════════════════════ class RefreshRequest(BaseModel): refresh_token: str class UserUpdate(BaseModel): first_name: Optional[str] = None last_name: Optional[str] = None phone: Optional[str] = None avatar_url: Optional[str] = None # ═══════════════════════════════════════════════════════════════════════════ # POST /auth/register # ═══════════════════════════════════════════════════════════════════════════ @router.post( "/register", response_model=TokenResponse, status_code=status.HTTP_201_CREATED, ) async def register( payload: UserCreate, session: AsyncSession = Depends(get_session), ) -> TokenResponse: """Create a new user, hash the password, return access + refresh tokens.""" # — uniqueness checks — for field, value, label in [ ("email", payload.email, "Email"), ("username", payload.username, "Username"), ]: stmt = select(User).where(getattr(User, field) == value) result = await session.execute(stmt) if result.scalar_one_or_none() is not None: raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail=f"{label} already registered", ) # — create user — user = User( email=payload.email, username=payload.username, hashed_password=_hash_password(payload.password), first_name=payload.first_name, last_name=payload.last_name, phone=payload.phone, role="user", is_active=True, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) session.add(user) await session.commit() await session.refresh(user) # — issue tokens — token_data = {"sub": str(user.id), "role": user.role} access_token = create_access_token(token_data) refresh_token = create_refresh_token(token_data) return TokenResponse( access_token=access_token, refresh_token=refresh_token, token_type="bearer", ) # ═══════════════════════════════════════════════════════════════════════════ # POST /auth/login # ═══════════════════════════════════════════════════════════════════════════ @router.post("/login", response_model=TokenResponse) async def login( payload: UserLogin, session: AsyncSession = Depends(get_session), ) -> TokenResponse: """Verify credentials and return access + refresh tokens.""" stmt = select(User).where(User.email == payload.email) result = await session.execute(stmt) user = result.scalar_one_or_none() if user is None or not _verify_password(payload.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password", headers={"WWW-Authenticate": "Bearer"}, ) if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User account is deactivated", ) token_data = {"sub": str(user.id), "role": user.role} access_token = create_access_token(token_data) refresh_token = create_refresh_token(token_data) return TokenResponse( access_token=access_token, refresh_token=refresh_token, token_type="bearer", ) # ═══════════════════════════════════════════════════════════════════════════ # POST /auth/refresh # ═══════════════════════════════════════════════════════════════════════════ @router.post("/refresh", response_model=TokenResponse) async def refresh( payload: RefreshRequest, session: AsyncSession = Depends(get_session), ) -> TokenResponse: """Rotate a refresh token → new access token (+ new refresh token).""" data = verify_token(payload.refresh_token) if data.get("type") != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token is not a refresh token", ) # Blacklist check await _enforce_not_blacklisted(payload.refresh_token) user_id = data.get("sub") if user_id is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload", ) # Load user stmt = select(User).where(User.id == UUID(user_id)) result = await session.execute(stmt) user = result.scalar_one_or_none() if user is None or not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive", ) # Rotate: issue new pair, blacklist old refresh token token_data = {"sub": str(user.id), "role": user.role} new_access = create_access_token(token_data) new_refresh = create_refresh_token(token_data) # Blacklist the used refresh token until it naturally expires await _blacklist_token(payload.refresh_token, REFRESH_TOKEN_EXPIRE_DAYS * 86400) return TokenResponse( access_token=new_access, refresh_token=new_refresh, token_type="bearer", ) # ═══════════════════════════════════════════════════════════════════════════ # GET /auth/me # ═══════════════════════════════════════════════════════════════════════════ @router.get("/me", response_model=UserResponse) async def read_current_user( current_user: User = Depends(get_current_user_dep), ) -> UserResponse: """Return the authenticated user's profile.""" return UserResponse.model_validate(current_user) # ═══════════════════════════════════════════════════════════════════════════ # PUT /auth/me # ═══════════════════════════════════════════════════════════════════════════ @router.put("/me", response_model=UserResponse) async def update_current_user( payload: UserUpdate, current_user: User = Depends(get_current_user_dep), session: AsyncSession = Depends(get_session), ) -> UserResponse: """Update the authenticated user's profile fields.""" update_data = payload.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(current_user, field, value) current_user.updated_at = datetime.now(timezone.utc) session.add(current_user) await session.commit() await session.refresh(current_user) return UserResponse.model_validate(current_user) # ═══════════════════════════════════════════════════════════════════════════ # POST /auth/logout # ═══════════════════════════════════════════════════════════════════════════ @router.post("/logout", status_code=status.HTTP_204_NO_CONTENT) async def logout( current_user: User = Depends(get_current_user_dep), token: str = Depends(_oauth2_scheme), ) -> None: """Blacklist the current access token (Redis) so it cannot be reused.""" await _blacklist_token(token, ACCESS_TOKEN_EXPIRE_MINUTES * 60)