from datetime import date, timedelta

from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session, joinedload

from app.models.customer import Customer
from app.models.enums import PaymentStatus, PolicyStatus
from app.models.insurance_company import InsuranceCompany
from app.models.policy import Policy
from app.models.vehicle import Vehicle
from app.utils.policy_status import compute_policy_status


class PolicyRepository:
    def __init__(self, db: Session):
        self.db = db

    def get_by_id(self, policy_id: int, agency_id: int) -> Policy | None:
        return (
            self.db.query(Policy)
            .options(
                joinedload(Policy.customer),
                joinedload(Policy.vehicle),
                joinedload(Policy.insurance_company),
            )
            .filter(Policy.id == policy_id, Policy.agency_id == agency_id, Policy.deleted_at.is_(None))
            .first()
        )

    def find_by_policy_number(
        self,
        agency_id: int,
        policy_number: str,
        *,
        exclude_policy_id: int | None = None,
    ) -> Policy | None:
        normalized = policy_number.strip()
        if not normalized:
            return None
        query = (
            self.db.query(Policy)
            .options(joinedload(Policy.customer), joinedload(Policy.vehicle))
            .filter(
                Policy.agency_id == agency_id,
                Policy.deleted_at.is_(None),
                Policy.status.notin_([PolicyStatus.RENEWED, PolicyStatus.CANCELLED]),
                func.lower(func.trim(Policy.policy_number)) == normalized.lower(),
            )
        )
        if exclude_policy_id is not None:
            query = query.filter(Policy.id != exclude_policy_id)
        return query.first()

    def find_by_vehicle_registration(
        self,
        agency_id: int,
        registration_number: str,
        *,
        exclude_policy_id: int | None = None,
    ) -> Policy | None:
        normalized = registration_number.strip()
        if not normalized:
            return None
        query = (
            self.db.query(Policy)
            .join(Vehicle)
            .options(joinedload(Policy.customer), joinedload(Policy.vehicle))
            .filter(
                Policy.agency_id == agency_id,
                Policy.deleted_at.is_(None),
                Policy.status.notin_([PolicyStatus.RENEWED, PolicyStatus.CANCELLED]),
                func.lower(func.trim(Vehicle.registration_number)) == normalized.lower(),
            )
        )
        if exclude_policy_id is not None:
            query = query.filter(Policy.id != exclude_policy_id)
        return query.first()

    def find_active_matches(
        self,
        agency_id: int,
        *,
        policy_numbers: list[str],
        vehicle_registrations: list[str],
    ) -> tuple[dict[str, Policy], dict[str, Policy]]:
        by_number: dict[str, Policy] = {}
        by_vehicle: dict[str, Policy] = {}

        normalized_numbers = {value.strip().lower() for value in policy_numbers if value and value.strip()}
        if normalized_numbers:
            rows = (
                self.db.query(Policy)
                .options(joinedload(Policy.customer), joinedload(Policy.vehicle))
                .filter(
                    Policy.agency_id == agency_id,
                    Policy.deleted_at.is_(None),
                    Policy.status.notin_([PolicyStatus.RENEWED, PolicyStatus.CANCELLED]),
                    func.lower(func.trim(Policy.policy_number)).in_(normalized_numbers),
                )
                .all()
            )
            for row in rows:
                key = (row.policy_number or "").strip().lower()
                if key and key not in by_number:
                    by_number[key] = row

        normalized_regs = {value.strip().lower() for value in vehicle_registrations if value and value.strip()}
        if normalized_regs:
            rows = (
                self.db.query(Policy)
                .join(Vehicle)
                .options(joinedload(Policy.customer), joinedload(Policy.vehicle))
                .filter(
                    Policy.agency_id == agency_id,
                    Policy.deleted_at.is_(None),
                    Policy.status.notin_([PolicyStatus.RENEWED, PolicyStatus.CANCELLED]),
                    func.lower(func.trim(Vehicle.registration_number)).in_(normalized_regs),
                )
                .all()
            )
            for row in rows:
                reg = (row.vehicle.registration_number if row.vehicle else "") or ""
                key = reg.strip().lower()
                if key and key not in by_vehicle:
                    by_vehicle[key] = row

        return by_number, by_vehicle

    def list_paginated(
        self,
        agency_id: int,
        page: int,
        page_size: int,
        search: str | None = None,
        status: str | None = None,
        company_id: int | None = None,
        payment_status: str | None = None,
        expiry_preset: str | None = None,
        expiry_from: date | None = None,
        expiry_to: date | None = None,
        vehicle_type: str | None = None,
        sort_by: str = "policy_end_date",
        sort_order: str = "asc",
    ) -> tuple[list[Policy], int]:
        query = (
            self.db.query(Policy)
            .join(Customer)
            .outerjoin(Vehicle)
            .outerjoin(InsuranceCompany)
            .filter(Policy.agency_id == agency_id, Policy.deleted_at.is_(None))
        )

        if search:
            pattern = f"%{search}%"
            query = query.filter(
                or_(
                    Customer.name.ilike(pattern),
                    Customer.mobile.ilike(pattern),
                    Policy.policy_number.ilike(pattern),
                    Vehicle.registration_number.ilike(pattern),
                    Vehicle.engine_number.ilike(pattern),
                    Vehicle.chassis_number.ilike(pattern),
                    InsuranceCompany.name.ilike(pattern),
                )
            )

        if status:
            query = query.filter(Policy.status == PolicyStatus(status))

        if company_id:
            query = query.filter(Policy.insurance_company_id == company_id)

        if payment_status:
            query = query.filter(Policy.payment_status == PaymentStatus(payment_status))

        today = date.today()
        if expiry_preset == "today":
            query = query.filter(Policy.policy_end_date == today)
        elif expiry_preset == "tomorrow":
            query = query.filter(Policy.policy_end_date == today + timedelta(days=1))
        elif expiry_preset == "next_3_days":
            query = query.filter(Policy.policy_end_date.between(today, today + timedelta(days=3)))
        elif expiry_preset == "next_7_days":
            query = query.filter(Policy.policy_end_date.between(today, today + timedelta(days=7)))
        elif expiry_preset == "next_15_days":
            query = query.filter(Policy.policy_end_date.between(today, today + timedelta(days=15)))
        elif expiry_preset == "next_30_days":
            query = query.filter(Policy.policy_end_date.between(today, today + timedelta(days=30)))

        if expiry_from:
            query = query.filter(Policy.policy_end_date >= expiry_from)
        if expiry_to:
            query = query.filter(Policy.policy_end_date <= expiry_to)

        if vehicle_type:
            query = query.filter(Vehicle.vehicle_type == vehicle_type)

        if expiry_preset or expiry_from or expiry_to:
            query = query.filter(Policy.status.notin_([PolicyStatus.RENEWED, PolicyStatus.CANCELLED]))

        sort_map = {
            "expiry_date": Policy.policy_end_date,
            "upload_date": Policy.created_at,
            "customer_name": Customer.name,
        }
        sort_col = sort_map.get(sort_by, Policy.policy_end_date)
        query = query.order_by(sort_col.desc() if sort_order == "desc" else sort_col.asc())

        total = query.count()
        items = (
            query.options(joinedload(Policy.customer), joinedload(Policy.vehicle), joinedload(Policy.insurance_company))
            .offset((page - 1) * page_size)
            .limit(page_size)
            .all()
        )
        return items, total

    def create(self, agency_id: int, data: dict) -> Policy:
        policy = Policy(agency_id=agency_id, **data)
        self.db.add(policy)
        self.db.flush()
        return policy

    def update(self, policy: Policy, data: dict) -> Policy:
        for key, value in data.items():
            setattr(policy, key, value)
        if "policy_end_date" in data or "status" in data:
            policy.status = compute_policy_status(policy.policy_end_date, policy.status)
        self.db.flush()
        return policy

    def soft_delete(self, policy: Policy) -> None:
        from datetime import datetime, timezone

        policy.deleted_at = datetime.now(timezone.utc)
        self.db.flush()

    def count_expiring_between(self, agency_id: int, start: date, end: date) -> int:
        return (
            self.db.query(func.count(Policy.id))
            .filter(
                Policy.agency_id == agency_id,
                Policy.deleted_at.is_(None),
                Policy.status.notin_([PolicyStatus.RENEWED, PolicyStatus.CANCELLED, PolicyStatus.DRAFT]),
                Policy.policy_end_date.between(start, end),
            )
            .scalar()
            or 0
        )

    def count_active(self, agency_id: int) -> int:
        return (
            self.db.query(func.count(Policy.id))
            .filter(
                Policy.agency_id == agency_id,
                Policy.deleted_at.is_(None),
                Policy.status.in_([PolicyStatus.ACTIVE, PolicyStatus.EXPIRING_SOON]),
            )
            .scalar()
            or 0
        )

    def sum_pending_payments(self, agency_id: int) -> float:
        result = (
            self.db.query(func.coalesce(func.sum(Policy.pending_amount), 0))
            .filter(
                Policy.agency_id == agency_id,
                Policy.deleted_at.is_(None),
                Policy.pending_amount > 0,
            )
            .scalar()
        )
        return float(result or 0)

    def count_pending_payment_policies(self, agency_id: int) -> int:
        return (
            self.db.query(func.count(Policy.id))
            .filter(
                Policy.agency_id == agency_id,
                Policy.deleted_at.is_(None),
                Policy.pending_amount > 0,
            )
            .scalar()
            or 0
        )

    def company_wise_counts(self, agency_id: int) -> list[tuple]:
        return (
            self.db.query(InsuranceCompany.id, InsuranceCompany.name, func.count(Policy.id))
            .join(Policy, Policy.insurance_company_id == InsuranceCompany.id)
            .filter(Policy.agency_id == agency_id, Policy.deleted_at.is_(None))
            .group_by(InsuranceCompany.id, InsuranceCompany.name)
            .all()
        )

    def recent_uploads(self, agency_id: int, limit: int = 10) -> list[Policy]:
        return (
            self.db.query(Policy)
            .options(joinedload(Policy.customer))
            .filter(Policy.agency_id == agency_id, Policy.deleted_at.is_(None))
            .order_by(Policy.created_at.desc())
            .limit(limit)
            .all()
        )

    def list_by_customer(self, agency_id: int, customer_id: int) -> list[Policy]:
        return (
            self.db.query(Policy)
            .options(
                joinedload(Policy.customer),
                joinedload(Policy.vehicle),
                joinedload(Policy.insurance_company),
            )
            .filter(
                Policy.agency_id == agency_id,
                Policy.customer_id == customer_id,
                Policy.deleted_at.is_(None),
            )
            .order_by(Policy.policy_end_date.desc())
            .all()
        )
