"""Parser extraction quality metrics for agency admins."""

from __future__ import annotations

from sqlalchemy.orm import Session

from app.models.enums import ExtractionStatus
from app.models.extraction_result import ExtractionResult
from app.models.insurance_company import InsuranceCompany
from app.models.parser_correction import ParserCorrection
from app.services.extraction_correction_service import (
    ExtractionCorrectionService,
    TRACKED_FIELDS,
    summarize_field_corrections,
)


def _empty_status() -> dict[str, int]:
    return {ExtractionStatus.SUCCESS.value: 0, ExtractionStatus.PARTIAL.value: 0, ExtractionStatus.FAILED.value: 0}


def _status_breakdown(counts: dict[str, int]) -> dict[str, int]:
    return {
        "success": counts.get(ExtractionStatus.SUCCESS.value, 0),
        "partial": counts.get(ExtractionStatus.PARTIAL.value, 0),
        "failed": counts.get(ExtractionStatus.FAILED.value, 0),
    }


def _success_rate(counts: dict[str, int]) -> float:
    total = sum(counts.values())
    if not total:
        return 0.0
    good = counts.get(ExtractionStatus.SUCCESS.value, 0) + counts.get(ExtractionStatus.PARTIAL.value, 0)
    return round(good / total, 4)


def _accuracy_rate(change_rate: float) -> float:
    return round(max(0.0, 1.0 - change_rate), 4)


class ParserPerformanceService:
    def __init__(self, db: Session):
        self.db = db

    def _company_lookup(self) -> dict[int, InsuranceCompany]:
        rows = self.db.query(InsuranceCompany).all()
        return {row.id: row for row in rows}

    @staticmethod
    def _company_code_for_extraction(extraction: ExtractionResult, companies: dict[int, InsuranceCompany]) -> str:
        raw = extraction.raw_data or {}
        code = raw.get("company_code") or raw.get("company_detection", {}).get("code")
        if code:
            return str(code)
        if extraction.insurance_company_id and extraction.insurance_company_id in companies:
            return companies[extraction.insurance_company_id].code
        return "unknown"

    @staticmethod
    def _company_name_for_code(code: str, companies: dict[int, InsuranceCompany]) -> str | None:
        if code == "unknown":
            return None
        for company in companies.values():
            if company.code == code:
                return company.name
        return code.replace("_", " ").title()

    def get_performance(self, agency_id: int) -> dict:
        companies = self._company_lookup()
        extractions = (
            self.db.query(ExtractionResult)
            .filter(ExtractionResult.agency_id == agency_id)
            .order_by(ExtractionResult.created_at.desc())
            .all()
        )

        overall_status = _empty_status()
        confirmed = 0
        company_extractions: dict[str, dict] = {}

        extraction_status_by_id: dict[int, str] = {}
        for row in extractions:
            status_key = row.status.value
            overall_status[status_key] = overall_status.get(status_key, 0) + 1
            extraction_status_by_id[row.id] = status_key
            if row.policy_id:
                confirmed += 1

            code = self._company_code_for_extraction(row, companies)
            bucket = company_extractions.setdefault(
                code,
                {"status": _empty_status(), "company_name": self._company_name_for_code(code, companies)},
            )
            bucket["status"][status_key] = bucket["status"].get(status_key, 0) + 1

        correction_stats = ExtractionCorrectionService(self.db).correction_stats(agency_id)
        field_accuracy = _accuracy_rate(correction_stats["overall_change_rate"])
        extraction_success = _success_rate(overall_status)
        overall_score = round((extraction_success + field_accuracy) / 2, 4) if extractions else 0.0

        correction_rows = (
            self.db.query(ParserCorrection)
            .filter(ParserCorrection.agency_id == agency_id)
            .order_by(ParserCorrection.created_at.desc())
            .all()
        )
        company_field_changes: dict[str, dict[str, int]] = {}
        for row in correction_rows:
            code = row.company_code or "unknown"
            for field in row.fields or []:
                if not field.get("was_corrected"):
                    continue
                name = field.get("field_name")
                if not name:
                    continue
                company_field_changes.setdefault(code, {})
                company_field_changes[code][name] = company_field_changes[code].get(name, 0) + 1

        by_field = []
        for item in correction_stats["by_field"]:
            if item["comparisons"] == 0:
                continue
            by_field.append(
                {
                    "field_name": item["field_name"],
                    "comparisons": item["comparisons"],
                    "changes": item["changes"],
                    "accuracy_rate": _accuracy_rate(item["change_rate"]),
                }
            )
        by_field.sort(key=lambda row: (row["accuracy_rate"], -row["comparisons"]))

        company_correction_map = {row["company_code"]: row for row in correction_stats["by_company"]}
        by_company = []
        all_company_codes = sorted(set(company_extractions) | set(company_correction_map))
        for code in all_company_codes:
            ext_bucket = company_extractions.get(code, {"status": _empty_status(), "company_name": None})
            corr = company_correction_map.get(code, {"reviews": 0, "change_rate": 0.0})
            status = ext_bucket["status"]
            ext_total = sum(status.values())
            top_fields = [
                name
                for name, _count in sorted(
                    company_field_changes.get(code, {}).items(),
                    key=lambda item: -item[1],
                )[:5]
            ]
            by_company.append(
                {
                    "company_code": code,
                    "company_name": ext_bucket.get("company_name") or self._company_name_for_code(code, companies),
                    "extractions": ext_total,
                    "status_breakdown": _status_breakdown(status),
                    "extraction_success_rate": _success_rate(status),
                    "reviews": corr.get("reviews", 0),
                    "field_accuracy_rate": _accuracy_rate(corr.get("change_rate", 0.0)),
                    "top_corrected_fields": top_fields,
                }
            )
        by_company.sort(key=lambda row: (-row["extractions"], -row["reviews"]))

        recent_rows = correction_rows[:20]
        recent_reviews = []
        for row in recent_rows:
            summary = summarize_field_corrections(row.fields or [])
            compared = summary["fields_compared"] or 1
            changed = summary["fields_changed"]
            recent_reviews.append(
                {
                    "id": row.id,
                    "extraction_id": row.extraction_id,
                    "original_filename": row.original_filename,
                    "company_code": row.company_code,
                    "extraction_status": extraction_status_by_id.get(row.extraction_id),
                    "fields_changed": changed,
                    "fields_compared": summary["fields_compared"],
                    "accuracy_rate": round((compared - changed) / compared, 4),
                    "source": row.source or "parser_lab",
                    "created_at": row.created_at.isoformat() if row.created_at else None,
                }
            )

        return {
            "summary": {
                "total_extractions": len(extractions),
                "confirmed_policies": confirmed,
                "pending_review": len(extractions) - confirmed,
                "total_reviews": correction_stats["total_reviews"],
                "extraction_success_rate": extraction_success,
                "field_accuracy_rate": field_accuracy,
                "overall_parser_score": overall_score,
                "status_breakdown": _status_breakdown(overall_status),
            },
            "by_company": by_company,
            "by_field": by_field,
            "recent_reviews": recent_reviews,
        }

    def get_platform_performance(self) -> dict:
        from app.models.agency import Agency

        agencies = (
            self.db.query(Agency)
            .filter(Agency.name != "Platform Operations")
            .order_by(Agency.name.asc())
            .all()
        )
        if not agencies:
            return {
                "summary": {
                    "total_extractions": 0,
                    "confirmed_policies": 0,
                    "pending_review": 0,
                    "total_reviews": 0,
                    "extraction_success_rate": 0.0,
                    "field_accuracy_rate": 0.0,
                    "overall_parser_score": 0.0,
                    "status_breakdown": {"success": 0, "partial": 0, "failed": 0},
                },
                "by_company": [],
                "by_field": [],
                "recent_reviews": [],
                "scope": "platform",
            }

        combined_status = _empty_status()
        total_extractions = 0
        confirmed = 0
        total_reviews = 0
        total_comparisons = 0
        total_changes = 0
        by_company: list[dict] = []
        field_totals: dict[str, dict[str, int]] = {}
        recent_reviews: list[dict] = []

        for agency in agencies:
            perf = self.get_performance(agency.id)
            summary = perf["summary"]
            total_extractions += summary["total_extractions"]
            confirmed += summary["confirmed_policies"]
            total_reviews += summary["total_reviews"]
            for key, value in summary["status_breakdown"].items():
                combined_status[key] = combined_status.get(key, 0) + value
            for row in perf["by_company"]:
                by_company.append({**row, "agency_id": agency.id, "agency_name": agency.name})
            for row in perf["by_field"]:
                bucket = field_totals.setdefault(row["field_name"], {"comparisons": 0, "changes": 0})
                bucket["comparisons"] += row["comparisons"]
                bucket["changes"] += row["changes"]
                total_comparisons += row["comparisons"]
                total_changes += row["changes"]
            for row in perf["recent_reviews"]:
                recent_reviews.append({**row, "agency_name": agency.name})

        extraction_success = _success_rate(combined_status)
        field_accuracy = _accuracy_rate(total_changes / total_comparisons if total_comparisons else 0.0)
        by_field = []
        for name in TRACKED_FIELDS:
            stats = field_totals.get(name, {"comparisons": 0, "changes": 0})
            comparisons = stats["comparisons"]
            changes = stats["changes"]
            if comparisons == 0:
                continue
            by_field.append(
                {
                    "field_name": name,
                    "comparisons": comparisons,
                    "changes": changes,
                    "accuracy_rate": _accuracy_rate(changes / comparisons),
                }
            )
        by_field.sort(key=lambda row: (row["accuracy_rate"], -row["comparisons"]))
        recent_reviews.sort(key=lambda row: row.get("created_at") or "", reverse=True)

        return {
            "summary": {
                "total_extractions": total_extractions,
                "confirmed_policies": confirmed,
                "pending_review": total_extractions - confirmed,
                "total_reviews": total_reviews,
                "extraction_success_rate": extraction_success,
                "field_accuracy_rate": field_accuracy,
                "overall_parser_score": round((extraction_success + field_accuracy) / 2, 4),
                "status_breakdown": _status_breakdown(combined_status),
            },
            "by_company": sorted(by_company, key=lambda row: (-row.get("extractions", 0), row.get("company_code", ""))),
            "by_field": by_field,
            "recent_reviews": recent_reviews[:30],
            "scope": "platform",
        }

    def get_parser_success_report(self, agency_id: int | None = None) -> dict:
        from datetime import datetime, timezone

        from app.models.agency import Agency

        companies = self._company_lookup()
        if agency_id is not None:
            agency_ids = [agency_id]
            perf = self.get_performance(agency_id)
            scope = "agency"
            agency = self.db.query(Agency).filter(Agency.id == agency_id).first()
            scope_name = agency.name if agency else f"Agency {agency_id}"
        else:
            agencies = (
                self.db.query(Agency)
                .filter(Agency.name != "Platform Operations")
                .order_by(Agency.name.asc())
                .all()
            )
            agency_ids = [row.id for row in agencies]
            perf = self.get_platform_performance()
            scope = "platform"
            scope_name = "All agencies"

        summary = perf["summary"]
        total_ext = summary["total_extractions"] or 0
        funnel = {
            "extractions": total_ext,
            "confirmed_policies": summary["confirmed_policies"],
            "pending_review": summary["pending_review"],
            "conversion_rate": round(summary["confirmed_policies"] / total_ext, 4) if total_ext else 0.0,
        }

        status = summary["status_breakdown"]
        status_total = status["success"] + status["partial"] + status["failed"]
        status_matrix = [
            {
                "status": key,
                "count": status[key],
                "share_rate": round(status[key] / status_total, 4) if status_total else 0.0,
            }
            for key in ("success", "partial", "failed")
        ]

        correction_stats = ExtractionCorrectionService(self.db).correction_stats(agency_id)
        by_field = []
        for name in TRACKED_FIELDS:
            match = next((row for row in perf["by_field"] if row["field_name"] == name), None)
            corr = next((row for row in correction_stats["by_field"] if row["field_name"] == name), None)
            comparisons = match["comparisons"] if match else corr["comparisons"] if corr else 0
            changes = match["changes"] if match else corr["changes"] if corr else 0
            by_field.append(
                {
                    "field_name": name,
                    "comparisons": comparisons,
                    "changes": changes,
                    "accuracy_rate": _accuracy_rate(changes / comparisons) if comparisons else None,
                    "change_rate": round(changes / comparisons, 4) if comparisons else None,
                }
            )

        by_source = [
            {
                "source": source,
                **stats,
                "accuracy_rate": _accuracy_rate(stats["change_rate"]),
            }
            for source, stats in sorted(correction_stats.get("by_source", {}).items())
        ]

        by_agency = []
        for aid in agency_ids:
            agency = self.db.query(Agency).filter(Agency.id == aid).first()
            if not agency:
                continue
            agency_perf = self.get_performance(aid)
            agency_summary = agency_perf["summary"]
            ext_count = agency_summary["total_extractions"]
            by_agency.append(
                {
                    "agency_id": aid,
                    "agency_name": agency.name,
                    **agency_summary,
                    "conversion_rate": round(agency_summary["confirmed_policies"] / ext_count, 4) if ext_count else 0.0,
                }
            )

        by_insurer = self._aggregate_by_insurer(agency_ids, companies, agency_id)

        agency_insurer_matrix = []
        for row in perf.get("by_company", []):
            agency_insurer_matrix.append(
                {
                    "agency_id": row.get("agency_id"),
                    "agency_name": row.get("agency_name"),
                    "company_code": row.get("company_code"),
                    "company_name": row.get("company_name"),
                    "extractions": row.get("extractions", 0),
                    "extraction_success_rate": row.get("extraction_success_rate", 0.0),
                    "reviews": row.get("reviews", 0),
                    "field_accuracy_rate": row.get("field_accuracy_rate", 0.0),
                    "status_breakdown": row.get("status_breakdown", {}),
                }
            )

        return {
            "generated_at": datetime.now(timezone.utc).isoformat(),
            "scope": scope,
            "scope_name": scope_name,
            "agency_id": agency_id,
            "summary": summary,
            "funnel": funnel,
            "status_matrix": status_matrix,
            "by_agency": by_agency,
            "by_insurer": by_insurer,
            "by_field": by_field,
            "by_source": by_source,
            "agency_insurer_matrix": agency_insurer_matrix,
            "recent_reviews": perf.get("recent_reviews", []),
        }

    def _aggregate_by_insurer(
        self,
        agency_ids: list[int],
        companies: dict[int, InsuranceCompany],
        filter_agency_id: int | None,
    ) -> list[dict]:
        if not agency_ids:
            return []

        query = self.db.query(ExtractionResult).filter(ExtractionResult.agency_id.in_(agency_ids))
        extractions = query.all()
        buckets: dict[str, dict] = {}
        for ext in extractions:
            code = self._company_code_for_extraction(ext, companies)
            bucket = buckets.setdefault(
                code,
                {
                    "company_name": self._company_name_for_code(code, companies),
                    "status": _empty_status(),
                    "confirmed": 0,
                },
            )
            bucket["status"][ext.status.value] = bucket["status"].get(ext.status.value, 0) + 1
            if ext.policy_id:
                bucket["confirmed"] += 1

        correction_stats = ExtractionCorrectionService(self.db).correction_stats(filter_agency_id)
        corr_map = {row["company_code"]: row for row in correction_stats.get("by_company", [])}

        rows: list[dict] = []
        for code, data in buckets.items():
            status = data["status"]
            ext_total = sum(status.values())
            corr = corr_map.get(code, {"reviews": 0, "change_rate": 0.0})
            rows.append(
                {
                    "company_code": code,
                    "company_name": data["company_name"],
                    "extractions": ext_total,
                    "status_breakdown": _status_breakdown(status),
                    "extraction_success_rate": _success_rate(status),
                    "confirmed_policies": data["confirmed"],
                    "conversion_rate": round(data["confirmed"] / ext_total, 4) if ext_total else 0.0,
                    "reviews": corr.get("reviews", 0),
                    "field_accuracy_rate": _accuracy_rate(corr.get("change_rate", 0.0)),
                }
            )
        rows.sort(key=lambda row: (-row["extractions"], row["company_code"]))
        return rows
