320 lines
13 KiB
Python
320 lines
13 KiB
Python
"""Authentication routes — register, login, refresh, profile, logout."""
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
import redis.asyncio as redis
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.auth.jwt import (
|
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
|
REFRESH_TOKEN_EXPIRE_DAYS,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
verify_token,
|
|
)
|
|
from app.config import settings
|
|
from app.database import get_session
|
|
from app.models.models import User
|
|
from app.schemas.schemas import TokenResponse, UserCreate, UserLogin, UserResponse
|
|
from passlib.context import CryptContext
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Router
|
|
# ---------------------------------------------------------------------------
|
|
router = APIRouter(prefix="/auth", tags=["Authentication"])
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Password hashing
|
|
# ---------------------------------------------------------------------------
|
|
_pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def _hash_password(password: str) -> str:
|
|
return _pwd_context.hash(password)
|
|
|
|
|
|
def _verify_password(plain: str, hashed: str) -> bool:
|
|
return _pwd_context.verify(plain, hashed)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Redis helpers (lazy singleton client)
|
|
# ---------------------------------------------------------------------------
|
|
_redis_client: Optional[redis.Redis] = None
|
|
|
|
|
|
def _get_redis() -> redis.Redis:
|
|
"""Return a singleton Redis client built from *settings.REDIS_URL*."""
|
|
global _redis_client
|
|
if _redis_client is None:
|
|
_redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
|
|
return _redis_client
|
|
|
|
|
|
async def _blacklist_token(token: str, ttl: int) -> None:
|
|
"""Store *token* in the blacklist with *ttl* seconds expiry."""
|
|
r = _get_redis()
|
|
await r.setex(f"token:blacklist:{token}", ttl, "1")
|
|
|
|
|
|
async def _is_token_blacklisted(token: str) -> bool:
|
|
"""Return True if *token* has been blacklisted."""
|
|
r = _get_redis()
|
|
return await r.exists(f"token:blacklist:{token}") > 0
|
|
|
|
|
|
async def _enforce_not_blacklisted(token: str) -> None:
|
|
if await _is_token_blacklisted(token):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token has been revoked",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OAuth2 scheme
|
|
# ---------------------------------------------------------------------------
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
|
|
_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dependency: load current user from JWT (with blacklist check)
|
|
# ---------------------------------------------------------------------------
|
|
async def get_current_user_dep(
|
|
token: str = Depends(_oauth2_scheme),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> User:
|
|
"""Full dependency: verify token string, check blacklist, load user."""
|
|
await _enforce_not_blacklisted(token)
|
|
payload = verify_token(token)
|
|
user_id = payload.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token payload",
|
|
)
|
|
stmt = select(User).where(User.id == UUID(user_id))
|
|
result = await session.execute(stmt)
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found or inactive",
|
|
)
|
|
return user
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# Request / Response models
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
class RefreshRequest(BaseModel):
|
|
refresh_token: str
|
|
|
|
|
|
class UserUpdate(BaseModel):
|
|
first_name: Optional[str] = None
|
|
last_name: Optional[str] = None
|
|
phone: Optional[str] = None
|
|
avatar_url: Optional[str] = None
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# POST /auth/register
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
@router.post(
|
|
"/register",
|
|
response_model=TokenResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
async def register(
|
|
payload: UserCreate,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenResponse:
|
|
"""Create a new user, hash the password, return access + refresh tokens."""
|
|
|
|
# — uniqueness checks —
|
|
for field, value, label in [
|
|
("email", payload.email, "Email"),
|
|
("username", payload.username, "Username"),
|
|
]:
|
|
stmt = select(User).where(getattr(User, field) == value)
|
|
result = await session.execute(stmt)
|
|
if result.scalar_one_or_none() is not None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail=f"{label} already registered",
|
|
)
|
|
|
|
# — create user —
|
|
user = User(
|
|
email=payload.email,
|
|
username=payload.username,
|
|
hashed_password=_hash_password(payload.password),
|
|
first_name=payload.first_name,
|
|
last_name=payload.last_name,
|
|
phone=payload.phone,
|
|
role="user",
|
|
is_active=True,
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
session.add(user)
|
|
await session.commit()
|
|
await session.refresh(user)
|
|
|
|
# — issue tokens —
|
|
token_data = {"sub": str(user.id), "role": user.role}
|
|
access_token = create_access_token(token_data)
|
|
refresh_token = create_refresh_token(token_data)
|
|
|
|
return TokenResponse(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer",
|
|
)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# POST /auth/login
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
@router.post("/login", response_model=TokenResponse)
|
|
async def login(
|
|
payload: UserLogin,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenResponse:
|
|
"""Verify credentials and return access + refresh tokens."""
|
|
|
|
stmt = select(User).where(User.email == payload.email)
|
|
result = await session.execute(stmt)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None or not _verify_password(payload.password, user.hashed_password):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid email or password",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
if not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="User account is deactivated",
|
|
)
|
|
|
|
token_data = {"sub": str(user.id), "role": user.role}
|
|
access_token = create_access_token(token_data)
|
|
refresh_token = create_refresh_token(token_data)
|
|
|
|
return TokenResponse(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer",
|
|
)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# POST /auth/refresh
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
@router.post("/refresh", response_model=TokenResponse)
|
|
async def refresh(
|
|
payload: RefreshRequest,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenResponse:
|
|
"""Rotate a refresh token → new access token (+ new refresh token)."""
|
|
|
|
data = verify_token(payload.refresh_token)
|
|
|
|
if data.get("type") != "refresh":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token is not a refresh token",
|
|
)
|
|
|
|
# Blacklist check
|
|
await _enforce_not_blacklisted(payload.refresh_token)
|
|
|
|
user_id = data.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token payload",
|
|
)
|
|
|
|
# Load user
|
|
stmt = select(User).where(User.id == UUID(user_id))
|
|
result = await session.execute(stmt)
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User not found or inactive",
|
|
)
|
|
|
|
# Rotate: issue new pair, blacklist old refresh token
|
|
token_data = {"sub": str(user.id), "role": user.role}
|
|
new_access = create_access_token(token_data)
|
|
new_refresh = create_refresh_token(token_data)
|
|
|
|
# Blacklist the used refresh token until it naturally expires
|
|
await _blacklist_token(payload.refresh_token, REFRESH_TOKEN_EXPIRE_DAYS * 86400)
|
|
|
|
return TokenResponse(
|
|
access_token=new_access,
|
|
refresh_token=new_refresh,
|
|
token_type="bearer",
|
|
)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# GET /auth/me
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
@router.get("/me", response_model=UserResponse)
|
|
async def read_current_user(
|
|
current_user: User = Depends(get_current_user_dep),
|
|
) -> UserResponse:
|
|
"""Return the authenticated user's profile."""
|
|
return UserResponse.model_validate(current_user)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# PUT /auth/me
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
@router.put("/me", response_model=UserResponse)
|
|
async def update_current_user(
|
|
payload: UserUpdate,
|
|
current_user: User = Depends(get_current_user_dep),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> UserResponse:
|
|
"""Update the authenticated user's profile fields."""
|
|
|
|
update_data = payload.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(current_user, field, value)
|
|
|
|
current_user.updated_at = datetime.now(timezone.utc)
|
|
session.add(current_user)
|
|
await session.commit()
|
|
await session.refresh(current_user)
|
|
|
|
return UserResponse.model_validate(current_user)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
# POST /auth/logout
|
|
# ═══════════════════════════════════════════════════════════════════════════
|
|
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def logout(
|
|
current_user: User = Depends(get_current_user_dep),
|
|
token: str = Depends(_oauth2_scheme),
|
|
) -> None:
|
|
"""Blacklist the current access token (Redis) so it cannot be reused."""
|
|
await _blacklist_token(token, ACCESS_TOKEN_EXPIRE_MINUTES * 60)
|