Files
smart-city-digital-twin-mar…/smart-app-city/backend/app/routes/auth.py

320 lines
13 KiB
Python

"""Authentication routes — register, login, refresh, profile, logout."""
from datetime import datetime, timezone
from typing import Optional
from uuid import UUID
import redis.asyncio as redis
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.auth.jwt import (
ACCESS_TOKEN_EXPIRE_MINUTES,
REFRESH_TOKEN_EXPIRE_DAYS,
create_access_token,
create_refresh_token,
verify_token,
)
from app.config import settings
from app.database import get_session
from app.models.models import User
from app.schemas.schemas import TokenResponse, UserCreate, UserLogin, UserResponse
from passlib.context import CryptContext
# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------
router = APIRouter(prefix="/auth", tags=["Authentication"])
# ---------------------------------------------------------------------------
# Password hashing
# ---------------------------------------------------------------------------
_pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def _hash_password(password: str) -> str:
return _pwd_context.hash(password)
def _verify_password(plain: str, hashed: str) -> bool:
return _pwd_context.verify(plain, hashed)
# ---------------------------------------------------------------------------
# Redis helpers (lazy singleton client)
# ---------------------------------------------------------------------------
_redis_client: Optional[redis.Redis] = None
def _get_redis() -> redis.Redis:
"""Return a singleton Redis client built from *settings.REDIS_URL*."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
return _redis_client
async def _blacklist_token(token: str, ttl: int) -> None:
"""Store *token* in the blacklist with *ttl* seconds expiry."""
r = _get_redis()
await r.setex(f"token:blacklist:{token}", ttl, "1")
async def _is_token_blacklisted(token: str) -> bool:
"""Return True if *token* has been blacklisted."""
r = _get_redis()
return await r.exists(f"token:blacklist:{token}") > 0
async def _enforce_not_blacklisted(token: str) -> None:
if await _is_token_blacklisted(token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been revoked",
headers={"WWW-Authenticate": "Bearer"},
)
# ---------------------------------------------------------------------------
# OAuth2 scheme
# ---------------------------------------------------------------------------
from fastapi.security import OAuth2PasswordBearer
_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
# ---------------------------------------------------------------------------
# Dependency: load current user from JWT (with blacklist check)
# ---------------------------------------------------------------------------
async def get_current_user_dep(
token: str = Depends(_oauth2_scheme),
session: AsyncSession = Depends(get_session),
) -> User:
"""Full dependency: verify token string, check blacklist, load user."""
await _enforce_not_blacklisted(token)
payload = verify_token(token)
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
)
stmt = select(User).where(User.id == UUID(user_id))
result = await session.execute(stmt)
user = result.scalar_one_or_none()
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
)
return user
# ═══════════════════════════════════════════════════════════════════════════
# Request / Response models
# ═══════════════════════════════════════════════════════════════════════════
class RefreshRequest(BaseModel):
refresh_token: str
class UserUpdate(BaseModel):
first_name: Optional[str] = None
last_name: Optional[str] = None
phone: Optional[str] = None
avatar_url: Optional[str] = None
# ═══════════════════════════════════════════════════════════════════════════
# POST /auth/register
# ═══════════════════════════════════════════════════════════════════════════
@router.post(
"/register",
response_model=TokenResponse,
status_code=status.HTTP_201_CREATED,
)
async def register(
payload: UserCreate,
session: AsyncSession = Depends(get_session),
) -> TokenResponse:
"""Create a new user, hash the password, return access + refresh tokens."""
# — uniqueness checks —
for field, value, label in [
("email", payload.email, "Email"),
("username", payload.username, "Username"),
]:
stmt = select(User).where(getattr(User, field) == value)
result = await session.execute(stmt)
if result.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"{label} already registered",
)
# — create user —
user = User(
email=payload.email,
username=payload.username,
hashed_password=_hash_password(payload.password),
first_name=payload.first_name,
last_name=payload.last_name,
phone=payload.phone,
role="user",
is_active=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
session.add(user)
await session.commit()
await session.refresh(user)
# — issue tokens —
token_data = {"sub": str(user.id), "role": user.role}
access_token = create_access_token(token_data)
refresh_token = create_refresh_token(token_data)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
)
# ═══════════════════════════════════════════════════════════════════════════
# POST /auth/login
# ═══════════════════════════════════════════════════════════════════════════
@router.post("/login", response_model=TokenResponse)
async def login(
payload: UserLogin,
session: AsyncSession = Depends(get_session),
) -> TokenResponse:
"""Verify credentials and return access + refresh tokens."""
stmt = select(User).where(User.email == payload.email)
result = await session.execute(stmt)
user = result.scalar_one_or_none()
if user is None or not _verify_password(payload.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is deactivated",
)
token_data = {"sub": str(user.id), "role": user.role}
access_token = create_access_token(token_data)
refresh_token = create_refresh_token(token_data)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
)
# ═══════════════════════════════════════════════════════════════════════════
# POST /auth/refresh
# ═══════════════════════════════════════════════════════════════════════════
@router.post("/refresh", response_model=TokenResponse)
async def refresh(
payload: RefreshRequest,
session: AsyncSession = Depends(get_session),
) -> TokenResponse:
"""Rotate a refresh token → new access token (+ new refresh token)."""
data = verify_token(payload.refresh_token)
if data.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token is not a refresh token",
)
# Blacklist check
await _enforce_not_blacklisted(payload.refresh_token)
user_id = data.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
)
# Load user
stmt = select(User).where(User.id == UUID(user_id))
result = await session.execute(stmt)
user = result.scalar_one_or_none()
if user is None or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found or inactive",
)
# Rotate: issue new pair, blacklist old refresh token
token_data = {"sub": str(user.id), "role": user.role}
new_access = create_access_token(token_data)
new_refresh = create_refresh_token(token_data)
# Blacklist the used refresh token until it naturally expires
await _blacklist_token(payload.refresh_token, REFRESH_TOKEN_EXPIRE_DAYS * 86400)
return TokenResponse(
access_token=new_access,
refresh_token=new_refresh,
token_type="bearer",
)
# ═══════════════════════════════════════════════════════════════════════════
# GET /auth/me
# ═══════════════════════════════════════════════════════════════════════════
@router.get("/me", response_model=UserResponse)
async def read_current_user(
current_user: User = Depends(get_current_user_dep),
) -> UserResponse:
"""Return the authenticated user's profile."""
return UserResponse.model_validate(current_user)
# ═══════════════════════════════════════════════════════════════════════════
# PUT /auth/me
# ═══════════════════════════════════════════════════════════════════════════
@router.put("/me", response_model=UserResponse)
async def update_current_user(
payload: UserUpdate,
current_user: User = Depends(get_current_user_dep),
session: AsyncSession = Depends(get_session),
) -> UserResponse:
"""Update the authenticated user's profile fields."""
update_data = payload.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(current_user, field, value)
current_user.updated_at = datetime.now(timezone.utc)
session.add(current_user)
await session.commit()
await session.refresh(current_user)
return UserResponse.model_validate(current_user)
# ═══════════════════════════════════════════════════════════════════════════
# POST /auth/logout
# ═══════════════════════════════════════════════════════════════════════════
@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
async def logout(
current_user: User = Depends(get_current_user_dep),
token: str = Depends(_oauth2_scheme),
) -> None:
"""Blacklist the current access token (Redis) so it cannot be reused."""
await _blacklist_token(token, ACCESS_TOKEN_EXPIRE_MINUTES * 60)