from datetime import datetime, timedelta, timezone

from fastapi import Request
from sqlalchemy.orm import Session

from app.core.config import get_settings
from app.core.exceptions import UnauthorizedError, ValidationError
from app.core.security import create_access_token, create_refresh_token, decode_token, get_password_hash, verify_password
from app.models.user import User
from app.repositories.user_repository import UserRepository
from app.utils.rate_limit import auth_rate_limiter


class AuthService:
    def __init__(self, db: Session):
        self.db = db
        self.users = UserRepository(db)
        self.settings = get_settings()

    def login(self, email: str, password: str) -> dict:
        user = self.users.get_by_email(email)
        if user:
            self._ensure_not_locked(user)

        if not user or not verify_password(password, user.hashed_password):
            if user:
                self._record_failed_login(user)
            raise UnauthorizedError("Invalid email or password")

        if not user.is_active:
            raise UnauthorizedError("Account is inactive")

        self._reset_lockout(user)
        extra = {"role": user.role.value, "agency_id": user.agency_id}
        return {
            "access_token": create_access_token(str(user.id), extra),
            "refresh_token": create_refresh_token(str(user.id)),
            "token_type": "bearer",
        }

    def refresh(self, refresh_token: str) -> dict:
        try:
            payload = decode_token(refresh_token)
            if payload.get("type") != "refresh":
                raise UnauthorizedError("Invalid refresh token")
            user_id = int(payload["sub"])
        except Exception as exc:
            raise UnauthorizedError("Invalid refresh token") from exc

        user = self.users.get_by_id(user_id)
        if not user or not user.is_active:
            raise UnauthorizedError("User not found or inactive")
        self._ensure_not_locked(user)

        extra = {"role": user.role.value, "agency_id": user.agency_id}
        return {
            "access_token": create_access_token(str(user.id), extra),
            "refresh_token": create_refresh_token(str(user.id)),
            "token_type": "bearer",
        }

    def change_password(self, user_id: int, current_password: str, new_password: str) -> None:
        user = self.users.get_by_id(user_id)
        if not user or not verify_password(current_password, user.hashed_password):
            raise ValidationError("Current password is incorrect")
        user.hashed_password = get_password_hash(new_password)
        self.db.commit()

    def _ensure_not_locked(self, user: User) -> None:
        now = datetime.now(timezone.utc)
        locked_until = user.locked_until
        if locked_until is not None:
            if locked_until.tzinfo is None:
                locked_until = locked_until.replace(tzinfo=timezone.utc)
            if locked_until > now:
                raise UnauthorizedError("Account temporarily locked due to failed login attempts. Try again later.")
            user.locked_until = None
            user.failed_login_attempts = 0
            self.db.commit()

    def _record_failed_login(self, user: User) -> None:
        user.failed_login_attempts += 1
        if user.failed_login_attempts >= self.settings.auth_lockout_max_attempts:
            user.locked_until = datetime.now(timezone.utc) + timedelta(minutes=self.settings.auth_lockout_minutes)
            user.failed_login_attempts = 0
        self.db.commit()

    def _reset_lockout(self, user: User) -> None:
        if user.failed_login_attempts or user.locked_until:
            user.failed_login_attempts = 0
            user.locked_until = None
            self.db.commit()


def check_auth_rate_limit(request: Request) -> None:
    settings = get_settings()
    ip = request.client.host if request.client else "unknown"
    forwarded = request.headers.get("x-forwarded-for")
    if forwarded:
        ip = forwarded.split(",")[0].strip()
    key = f"{ip}:{request.url.path}"
    auth_rate_limiter.check(key, settings.auth_rate_limit_requests, settings.auth_rate_limit_window_seconds)
