from app.utils.file_storage import get_file_storage, read_storage_bytes

from sqlalchemy.orm import Session

from app.core.exceptions import NotFoundError, ValidationError
from app.models.extraction_result import ExtractionResult
from app.models.parser_correction import ParserCorrection
from app.parsers.base import ExtractedField
from app.parsers.registry import COMPANY_SIGNATURES, PARSER_REGISTRY, detect_company
from app.schemas.parser_lab import (
    KEY_PARSER_FIELDS,
    PARSER_LAB_FIELD_DEFINITIONS,
    ReExtractRequest,
    SaveParserCorrectionsRequest,
)
from app.services.extraction_correction_service import (
    ExtractionCorrectionService,
    SOURCE_PARSER_LAB,
    build_field_corrections,
)
from app.services.parser_training_service import ParserTrainingService
from app.services.upload_service import UploadService
from app.utils.pdf_text import pack_raw_text_fields, raw_text_pages_from_joined


class ParserLabService:
    def __init__(self, db: Session):
        self.db = db
        self.upload = UploadService(db)

    def extract_pdf(self, file_bytes: bytes, filename: str, user_id: int, agency_id: int) -> dict:
        result = self.upload.extract_pdf(file_bytes, filename, user_id, agency_id, lab_mode=True)
        return result

    def list_extractions(self, agency_id: int, limit: int = 30) -> list[dict]:
        rows = (
            self.db.query(ExtractionResult)
            .filter(ExtractionResult.agency_id == agency_id)
            .order_by(ExtractionResult.created_at.desc())
            .limit(limit)
            .all()
        )
        return [self._extraction_summary(row) for row in rows]

    def get_extraction_detail(self, extraction_id: int, agency_id: int) -> dict:
        extraction = self.upload.get_extraction(extraction_id, agency_id)
        raw = extraction.raw_data or {}
        field_map = {f["field_name"]: f for f in raw.get("fields", [])}
        raw_text, raw_text_pages, page_count = self._resolve_raw_text(extraction, raw)
        field_map = self._apply_learned_patterns_to_field_map(
            agency_id=agency_id,
            raw=raw,
            field_map=field_map,
            raw_text=raw_text,
            raw_text_pages=raw_text_pages,
        )

        latest_correction = (
            self.db.query(ParserCorrection)
            .filter(ParserCorrection.extraction_id == extraction_id)
            .order_by(ParserCorrection.created_at.desc())
            .first()
        )

        corrected_values = {}
        if latest_correction:
            for item in latest_correction.fields or []:
                corrected_values[item["field_name"]] = item.get("corrected_value")

        fields = []
        for section in PARSER_LAB_FIELD_DEFINITIONS:
            for field in section["fields"]:
                key = field["key"]
                extracted = field_map.get(key, {})
                fields.append(
                    {
                        "field_name": key,
                        "label": field["label"],
                        "section": section["title"],
                        "extracted_value": extracted.get("value"),
                        "confidence": extracted.get("confidence"),
                        "source": extracted.get("source"),
                        "corrected_value": corrected_values.get(key, extracted.get("value")),
                        "is_key_field": key in KEY_PARSER_FIELDS,
                    }
                )

        return {
            "id": extraction.id,
            "status": extraction.status.value,
            "original_filename": extraction.original_filename,
            "policy_id": extraction.policy_id,
            "insurance_company_id": extraction.insurance_company_id,
            "company_detection": raw.get("company_detection"),
            "parser_key": raw.get("company_code"),
            "fields": fields,
            "raw_text": raw_text,
            "raw_text_pages": raw_text_pages,
            "page_count": page_count,
            "latest_correction": ExtractionCorrectionService.correction_to_dict(latest_correction)
            if latest_correction
            else None,
            "field_definitions": PARSER_LAB_FIELD_DEFINITIONS,
            "key_fields": KEY_PARSER_FIELDS,
        }

    def re_extract(self, extraction_id: int, payload: ReExtractRequest, agency_id: int) -> dict:
        extraction = self.upload.get_extraction(extraction_id, agency_id)
        if not extraction.storage_path:
            raise NotFoundError("PDF file not found for this extraction")

        file_bytes = read_storage_bytes(extraction.storage_path)
        text_pack = pack_raw_text_fields(file_bytes)
        text = text_pack["raw_text"]

        if payload.company_code:
            company_code = payload.company_code
            company_name = next(
                (name for code, name, _ in COMPANY_SIGNATURES if code == company_code),
                company_code.replace("_", " ").title(),
            )
            company_confidence = 1.0
        else:
            company_code, company_name, company_confidence = detect_company(text, extraction.original_filename)

        parser = PARSER_REGISTRY.get(company_code or "generic", PARSER_REGISTRY["generic"])
        result = parser(text)

        if payload.apply_learned_patterns and company_code:
            result.fields = ParserTrainingService.apply_learned_patterns(
                self.db,
                agency_id,
                company_code,
                text,
                result.fields,
                raw_text_pages=text_pack["raw_text_pages"],
            )

        raw = extraction.raw_data or {}
        extraction.raw_data = {
            **result.to_dict(),
            "company_detection": {
                "code": company_code,
                "name": company_name,
                "confidence": company_confidence,
            },
            "checksum": raw.get("checksum"),
            "file_size": raw.get("file_size"),
            "raw_text": text_pack["raw_text"],
            "raw_text_pages": text_pack["raw_text_pages"],
            "page_count": text_pack["page_count"],
            "lab_mode": raw.get("lab_mode", True),
        }
        self.db.commit()
        self.db.refresh(extraction)
        return self.get_extraction_detail(extraction_id, agency_id)

    def save_corrections(
        self,
        extraction_id: int,
        agency_id: int,
        user_id: int,
        payload: SaveParserCorrectionsRequest,
    ) -> dict:
        extraction = self.upload.get_extraction(extraction_id, agency_id)
        raw = extraction.raw_data or {}
        company_code = payload.company_code or raw.get("company_code")

        stored_fields = []
        for item in payload.fields:
            stored_fields.append(
                {
                    "field_name": item.field_name,
                    "extracted_value": item.extracted_value,
                    "corrected_value": item.corrected_value,
                    "was_corrected": build_field_corrections(
                        {item.field_name: item.extracted_value},
                        {item.field_name: item.corrected_value},
                    )[0]["was_corrected"],
                }
            )

        if not stored_fields:
            raise ValidationError("At least one field correction is required")

        record = ParserCorrection(
            extraction_id=extraction_id,
            agency_id=agency_id,
            corrected_by=user_id,
            company_code=company_code,
            parser_key=company_code,
            original_filename=extraction.original_filename,
            fields=stored_fields,
            notes=payload.notes,
            source=SOURCE_PARSER_LAB,
        )
        self.db.add(record)
        self.db.commit()
        self.db.refresh(record)
        return ExtractionCorrectionService.correction_to_dict(record)

    def correction_stats(self, agency_id: int, company_code: str | None = None) -> dict:
        return ExtractionCorrectionService(self.db).correction_stats(agency_id, company_code=company_code)

    def list_corrections(self, agency_id: int, company_code: str | None = None, limit: int = 50) -> list[dict]:
        query = self.db.query(ParserCorrection).filter(ParserCorrection.agency_id == agency_id)
        if company_code:
            query = query.filter(ParserCorrection.company_code == company_code)
        rows = query.order_by(ParserCorrection.created_at.desc()).limit(limit).all()
        return [ExtractionCorrectionService.correction_to_dict(row) for row in rows]

    def export_corrections(self, agency_id: int, company_code: str | None = None) -> list[dict]:
        corrections = self.list_corrections(agency_id, company_code, limit=500)
        export_rows = []
        for correction in corrections:
            for field in correction["fields"]:
                if not field.get("was_corrected"):
                    continue
                export_rows.append(
                    {
                        "correction_id": correction["id"],
                        "extraction_id": correction["extraction_id"],
                        "company_code": correction["company_code"],
                        "source": correction.get("source"),
                        "original_filename": correction["original_filename"],
                        "field_name": field["field_name"],
                        "extracted_value": field["extracted_value"],
                        "corrected_value": field["corrected_value"],
                        "notes": correction.get("notes"),
                        "created_at": correction["created_at"],
                    }
                )
        return export_rows

    def list_parsers(self) -> list[dict]:
        parsers = []
        for code, name, _ in COMPANY_SIGNATURES:
            parser_fn = PARSER_REGISTRY.get(code)
            parsers.append(
                {
                    "code": code,
                    "name": name,
                    "parser_function": getattr(parser_fn, "__name__", None) if parser_fn else None,
                }
            )
        return parsers

    def _resolve_raw_text(self, extraction: ExtractionResult, raw: dict) -> tuple[str, list[dict], int]:
        if raw.get("raw_text_pages"):
            pages = raw["raw_text_pages"]
            return raw.get("raw_text") or "", pages, raw.get("page_count") or len(pages)
        if raw.get("raw_text"):
            pages = raw_text_pages_from_joined(raw["raw_text"])
            return raw["raw_text"], pages, len(pages)
        return self._read_pdf_text(extraction)

    def _read_pdf_text(self, extraction: ExtractionResult) -> tuple[str, list[dict], int]:
        if not extraction.storage_path:
            return "", [], 0
        path = extraction.storage_path
        if not path or not get_file_storage().exists(path):
            return "", [], 0
        text_pack = pack_raw_text_fields(read_storage_bytes(path))
        raw = extraction.raw_data or {}
        raw["raw_text"] = text_pack["raw_text"]
        raw["raw_text_pages"] = text_pack["raw_text_pages"]
        raw["page_count"] = text_pack["page_count"]
        extraction.raw_data = raw
        self.db.commit()
        return text_pack["raw_text"], text_pack["raw_text_pages"], text_pack["page_count"]

    def _apply_learned_patterns_to_field_map(
        self,
        *,
        agency_id: int,
        raw: dict,
        field_map: dict,
        raw_text: str,
        raw_text_pages: list[dict],
    ) -> dict:
        company_code = raw.get("company_code") or (raw.get("company_detection") or {}).get("code")
        if not raw_text:
            return field_map

        existing_fields = [
            ExtractedField(
                field_name=name,
                value=data.get("value"),
                confidence=data.get("confidence") or 0.0,
                source_page=data.get("source_page"),
                source=data.get("source") or "extracted",
            )
            for name, data in field_map.items()
        ]
        updated_fields = ParserTrainingService.apply_learned_patterns(
            self.db,
            agency_id,
            company_code or "",
            raw_text,
            existing_fields,
            raw_text_pages=raw_text_pages,
        )
        updated = dict(field_map)
        for field in updated_fields:
            if field.source not in ("learned_pattern", "learned_example"):
                continue
            updated[field.field_name] = {
                **updated.get(field.field_name, {}),
                "field_name": field.field_name,
                "value": field.value,
                "confidence": field.confidence,
                "source_page": field.source_page,
                "source": field.source,
            }
        return updated

    @staticmethod
    def _extraction_summary(extraction: ExtractionResult) -> dict:
        raw = extraction.raw_data or {}
        company = raw.get("company_detection") or {}
        field_map = {f["field_name"]: f.get("value") for f in raw.get("fields", [])}
        key_found = sum(1 for key in KEY_PARSER_FIELDS if field_map.get(key))
        return {
            "id": extraction.id,
            "original_filename": extraction.original_filename,
            "status": extraction.status.value,
            "company_code": company.get("code"),
            "company_name": company.get("name"),
            "key_fields_found": key_found,
            "key_fields_total": len(KEY_PARSER_FIELDS),
            "created_at": extraction.created_at.isoformat() if extraction.created_at else None,
            "lab_mode": raw.get("lab_mode", False),
        }

    @staticmethod
    def _correction_to_dict(row: ParserCorrection | None) -> dict | None:
        return ExtractionCorrectionService.correction_to_dict(row)
