"""Parser training: save examples, suggest/test/apply patterns."""

import logging

from sqlalchemy import func
from sqlalchemy.orm import Session

from app.core.exceptions import NotFoundError, ValidationError
from app.models.parser_pattern import ParserPattern
from app.models.parser_training_example import ParserTrainingExample
from app.parsers.base import ExtractedField
from app.schemas.parser_lab import PARSER_LAB_FIELD_DEFINITIONS
from app.services.upload_service import UploadService
from app.utils.pdf_bbox import get_pdf_page_meta, resolve_bbox_selection
from app.utils.pattern_suggestion import (
    compute_selection_context,
    extract_mobile_from_match,
    extract_with_regex,
    generic_field_patterns,
    suggest_regex,
)
from app.utils.pattern_tester import evaluate_regression, run_pattern_on_samples, test_pattern_on_text
from app.utils.pattern_promote import build_promote_preview
from app.utils.pdf_text import pack_raw_text_fields

VALID_FIELD_NAMES = {f["key"] for section in PARSER_LAB_FIELD_DEFINITIONS for f in section["fields"]}

logger = logging.getLogger(__name__)

# Parser placeholder strings that should be replaced when training exists for a field.
_SUSPICIOUS_EXTRACTED_VALUES = frozenset(
    {
        "type of body",
        "year of mfg",
        "make/model",
        "engine no",
        "chassis no",
        "registration no",
        "obsolete vehicle",
    }
)


class ParserTrainingService:
    def __init__(self, db: Session):
        self.db = db
        self.upload = UploadService(db)

    def save_training_example(
        self,
        extraction_id: int,
        agency_id: int,
        user_id: int,
        *,
        field_name: str,
        selection_start: int | None = None,
        selection_end: int | None = None,
        company_code: str | None = None,
        page: int = 1,
        notes: str | None = None,
        bbox: dict | None = None,
    ) -> dict:
        if field_name not in VALID_FIELD_NAMES:
            raise ValidationError(f"Unknown field: {field_name}")

        extraction = self.upload.get_extraction(extraction_id, agency_id)
        raw = extraction.raw_data or {}
        raw_text = raw.get("raw_text") or self._load_raw_text(extraction)
        if not raw_text:
            raise ValidationError("No raw text available for this extraction")

        raw_text_pages = raw.get("raw_text_pages") or []
        bbox_payload = None

        if bbox:
            file_bytes = self._read_pdf_bytes(extraction)
            resolved = resolve_bbox_selection(
                file_bytes,
                raw_text,
                raw_text_pages,
                page=page,
                bbox=bbox,
            )
            selection_start = resolved["selection_start"]
            selection_end = resolved["selection_end"]
            page = resolved["page"]
            bbox_payload = resolved["bbox"]
        elif selection_start is None or selection_end is None:
            raise ValidationError("Provide selection offsets or a bounding box")

        company = company_code or raw.get("company_code") or raw.get("company_detection", {}).get("code")
        if not company:
            raise ValidationError("Company code is required (select parser/company first)")

        try:
            context = compute_selection_context(raw_text, selection_start, selection_end)
        except ValueError as exc:
            raise ValidationError(str(exc)) from exc

        record = ParserTrainingExample(
            extraction_id=extraction_id,
            agency_id=agency_id,
            created_by=user_id,
            company_code=company,
            field_name=field_name,
            page=page,
            value=context["value"],
            anchor_before=context["anchor_before"],
            anchor_after=context["anchor_after"],
            snippet=context["snippet"],
            selection_start=context["selection_start"],
            selection_end=context["selection_end"],
            bbox=bbox_payload,
            notes=notes,
        )
        self.db.add(record)
        self.db.commit()
        self.db.refresh(record)

        suggested = suggest_regex(
            field_name,
            context["value"],
            context["anchor_before"],
            context["anchor_after"],
        )
        return {
            **self._example_to_dict(record),
            "suggested_regex": suggested,
        }

    def resolve_bbox(self, extraction_id: int, agency_id: int, *, page: int, bbox: dict) -> dict:
        extraction = self.upload.get_extraction(extraction_id, agency_id)
        raw = extraction.raw_data or {}
        raw_text = raw.get("raw_text") or self._load_raw_text(extraction)
        if not raw_text:
            raise ValidationError("No raw text available for this extraction")

        raw_text_pages = raw.get("raw_text_pages") or []
        file_bytes = self._read_pdf_bytes(extraction)
        try:
            return resolve_bbox_selection(
                file_bytes,
                raw_text,
                raw_text_pages,
                page=page,
                bbox=bbox,
            )
        except ValueError as exc:
            raise ValidationError(str(exc)) from exc

    def get_pdf_page_meta(self, extraction_id: int, agency_id: int) -> list[dict]:
        extraction = self.upload.get_extraction(extraction_id, agency_id)
        file_bytes = self._read_pdf_bytes(extraction)
        return get_pdf_page_meta(file_bytes)

    def list_training_examples(
        self,
        agency_id: int,
        *,
        extraction_id: int | None = None,
        company_code: str | None = None,
        limit: int = 100,
    ) -> list[dict]:
        query = self.db.query(ParserTrainingExample).filter(ParserTrainingExample.agency_id == agency_id)
        if extraction_id:
            query = query.filter(ParserTrainingExample.extraction_id == extraction_id)
        if company_code:
            query = query.filter(ParserTrainingExample.company_code == company_code)
        rows = query.order_by(ParserTrainingExample.created_at.desc()).limit(limit).all()
        return [self._example_to_dict(row) for row in rows]

    def suggest_pattern(
        self,
        *,
        field_name: str | None = None,
        value: str | None = None,
        anchor_before: str | None = None,
        anchor_after: str | None = None,
        training_example_id: int | None = None,
        agency_id: int | None = None,
    ) -> dict:
        if training_example_id:
            example = self._get_example(training_example_id, agency_id)
            field_name = example.field_name
            value = example.value
            anchor_before = example.anchor_before
            anchor_after = example.anchor_after

        if not field_name or field_name not in VALID_FIELD_NAMES:
            raise ValidationError(f"Unknown field: {field_name}")
        if not (value or "").strip():
            raise ValidationError("Value is required to suggest a pattern")

        regex = suggest_regex(field_name, value, anchor_before, anchor_after)
        return {
            "field_name": field_name,
            "value": value,
            "anchor_before": anchor_before,
            "anchor_after": anchor_after,
            "suggested_regex": regex,
        }

    def test_pattern(
        self,
        *,
        regex: str,
        company_code: str,
        field_name: str,
        extraction_id: int | None = None,
        agency_id: int | None = None,
    ) -> dict:
        if not regex.strip():
            raise ValidationError("Regex is required")
        if field_name not in VALID_FIELD_NAMES:
            raise ValidationError(f"Unknown field: {field_name}")

        current = None
        if extraction_id:
            if agency_id is None:
                raise ValidationError("Agency context required for extraction test")
            extraction = self.upload.get_extraction(extraction_id, agency_id)
            raw_text = (extraction.raw_data or {}).get("raw_text") or self._load_raw_text(extraction)
            current = test_pattern_on_text(regex, raw_text)

        sample_results = run_pattern_on_samples(company_code, field_name, regex)
        return {
            "field_name": field_name,
            "company_code": company_code,
            "regex": regex,
            "current_pdf": current,
            **sample_results,
        }

    def promote_pattern_preview(self, pattern_id: int, agency_id: int) -> dict:
        record = (
            self.db.query(ParserPattern)
            .filter(ParserPattern.id == pattern_id, ParserPattern.agency_id == agency_id)
            .first()
        )
        if not record:
            raise NotFoundError("Pattern not found")
        return build_promote_preview(record)

    def save_pattern(
        self,
        agency_id: int,
        user_id: int,
        *,
        company_code: str,
        field_name: str,
        regex: str,
        page_hint: int | None = None,
        priority: int = 100,
        source_training_example_id: int | None = None,
        notes: str | None = None,
        require_sample_pass: bool = True,
    ) -> dict:
        if field_name not in VALID_FIELD_NAMES:
            raise ValidationError(f"Unknown field: {field_name}")
        if not regex.strip():
            raise ValidationError("Regex is required")

        try:
            import re

            re.compile(regex)
        except re.error as exc:
            raise ValidationError(f"Invalid regex: {exc}") from exc

        if source_training_example_id:
            self._get_example(source_training_example_id, agency_id)

        if require_sample_pass:
            regression = evaluate_regression(company_code, field_name, regex)
            if regression["total_files"] > 0 and not regression["passed"]:
                failed_preview = ", ".join(regression["failed_files"][:5])
                suffix = "…" if regression["failed_count"] > 5 else ""
                raise ValidationError(
                    "Regression gate failed: "
                    f"{regression['failed_count']}/{regression['total_files']} sample PDFs did not match. "
                    f"Failed: {failed_preview}{suffix}. "
                    "Fix the regex, re-test, or save with require_sample_pass=false.",
                    details=[
                        {
                            "field": "regex",
                            "error": "regression_failed",
                            "failed_files": regression["failed_files"],
                            "failed_count": regression["failed_count"],
                            "total_files": regression["total_files"],
                        }
                    ],
                )

        record = ParserPattern(
            agency_id=agency_id,
            created_by=user_id,
            company_code=company_code,
            field_name=field_name,
            regex=regex,
            page_hint=page_hint,
            priority=priority,
            active=True,
            source_training_example_id=source_training_example_id,
            notes=notes,
        )
        self.db.add(record)
        self.db.commit()
        self.db.refresh(record)
        return self._pattern_to_dict(record)

    def list_patterns(
        self,
        agency_id: int,
        *,
        company_code: str | None = None,
        active_only: bool = True,
        limit: int = 100,
    ) -> list[dict]:
        query = self.db.query(ParserPattern).filter(ParserPattern.agency_id == agency_id)
        if company_code:
            query = query.filter(ParserPattern.company_code == company_code)
        if active_only:
            query = query.filter(ParserPattern.active.is_(True))
        rows = query.order_by(ParserPattern.company_code, ParserPattern.field_name, ParserPattern.priority).limit(limit).all()
        return [self._pattern_to_dict(row) for row in rows]

    def deactivate_pattern(self, pattern_id: int, agency_id: int) -> dict:
        record = (
            self.db.query(ParserPattern)
            .filter(ParserPattern.id == pattern_id, ParserPattern.agency_id == agency_id)
            .first()
        )
        if not record:
            raise NotFoundError("Pattern not found")
        record.active = False
        self.db.commit()
        self.db.refresh(record)
        return self._pattern_to_dict(record)

    def export_training(self, agency_id: int, company_code: str | None = None) -> dict:
        examples = self.list_training_examples(agency_id, company_code=company_code, limit=500)
        patterns = self.list_patterns(agency_id, company_code=company_code, active_only=False, limit=500)
        return {"training_examples": examples, "patterns": patterns}

    @staticmethod
    def _resolve_training_company_code(db: Session, agency_id: int, detected_code: str | None) -> str | None:
        """Pick company code used to load training data (patterns + examples)."""
        if detected_code:
            has_examples = (
                db.query(ParserTrainingExample.id)
                .filter(
                    ParserTrainingExample.agency_id == agency_id,
                    ParserTrainingExample.company_code == detected_code,
                )
                .first()
            )
            has_patterns = (
                db.query(ParserPattern.id)
                .filter(
                    ParserPattern.agency_id == agency_id,
                    ParserPattern.company_code == detected_code,
                    ParserPattern.active.is_(True),
                )
                .first()
            )
            if has_examples or has_patterns:
                return detected_code

        row = (
            db.query(ParserTrainingExample.company_code, func.count().label("cnt"))
            .filter(ParserTrainingExample.agency_id == agency_id)
            .group_by(ParserTrainingExample.company_code)
            .order_by(func.count().desc())
            .first()
        )
        if row:
            return row[0]
        return detected_code

    @staticmethod
    def _examples_by_field(
        db: Session, agency_id: int, company_code: str
    ) -> dict[str, list[ParserTrainingExample]]:
        rows = (
            db.query(ParserTrainingExample)
            .filter(
                ParserTrainingExample.agency_id == agency_id,
                ParserTrainingExample.company_code == company_code,
            )
            .order_by(ParserTrainingExample.created_at.desc(), ParserTrainingExample.id.desc())
            .all()
        )
        by_field: dict[str, list[ParserTrainingExample]] = {}
        for row in rows:
            by_field.setdefault(row.field_name, []).append(row)
        return by_field

    @staticmethod
    def _extract_field_with_regex(
        *,
        regex: str,
        field_name: str,
        text: str,
        raw_text_pages: list[dict] | None,
        page_hint: int | None,
        confidence: float,
        source: str,
    ) -> ExtractedField | None:
        import re

        search_text = text
        source_page = None
        if page_hint and raw_text_pages:
            page_data = next((p for p in raw_text_pages if p.get("page") == page_hint), None)
            if page_data:
                search_text = page_data.get("text") or ""
                source_page = page_hint

        def _match_value(source_text: str) -> tuple[str | None, bool]:
            try:
                match = re.search(regex, source_text, re.IGNORECASE | re.DOTALL)
            except re.error:
                return None, False
            if not match:
                return None, False
            if field_name == "mobile_number":
                value = extract_mobile_from_match(match)
                return value, bool(value)
            if match.lastindex and match.lastindex >= 1:
                return (match.group(1) or "").strip(), True
            return (match.group(0) or "").strip(), True

        value, matched = _match_value(search_text)
        if (not matched or not value) and search_text is not text:
            value, matched = _match_value(text)
            if matched:
                source_page = None

        if not matched or not value:
            return None

        return ExtractedField(
            field_name=field_name,
            value=value,
            confidence=confidence,
            source_page=source_page,
            source=source,
        )

    @staticmethod
    def apply_learned_patterns(
        db: Session,
        agency_id: int,
        company_code: str,
        text: str,
        existing_fields: list[ExtractedField],
        raw_text_pages: list[dict] | None = None,
    ) -> list[ExtractedField]:
        training_company = ParserTrainingService._resolve_training_company_code(db, agency_id, company_code)
        if not training_company:
            return existing_fields

        patterns = (
            db.query(ParserPattern)
            .filter(
                ParserPattern.agency_id == agency_id,
                ParserPattern.company_code == training_company,
                ParserPattern.active.is_(True),
            )
            .order_by(ParserPattern.priority.asc(), ParserPattern.id.desc())
            .all()
        )

        merged: dict[str, ExtractedField] = {f.field_name: f for f in existing_fields}
        matched_from_pattern: set[str] = set()

        for pattern in patterns:
            try:
                extracted = ParserTrainingService._extract_field_with_regex(
                    regex=pattern.regex,
                    field_name=pattern.field_name,
                    text=text,
                    raw_text_pages=raw_text_pages,
                    page_hint=pattern.page_hint,
                    confidence=0.96 if pattern.source_training_example_id is not None else 0.88,
                    source="learned_pattern",
                )
                if not extracted:
                    continue

                trained = pattern.source_training_example_id is not None
                existing = merged.get(pattern.field_name)
                if (
                    not trained
                    and existing
                    and (existing.value or "").strip()
                    and (existing.confidence or 0) >= extracted.confidence
                ):
                    continue

                merged[pattern.field_name] = extracted
                matched_from_pattern.add(pattern.field_name)
            except Exception:
                logger.exception(
                    "Learned pattern application failed",
                    extra={
                        "pattern_id": pattern.id,
                        "field_name": pattern.field_name,
                        "company_code": training_company,
                        "agency_id": agency_id,
                    },
                )

        examples_by_field = ParserTrainingService._examples_by_field(db, agency_id, training_company)
        matched_from_example: set[str] = set()
        examples_applied = 0
        for field_name, field_examples in examples_by_field.items():
            if field_name in matched_from_pattern:
                continue
            for example in field_examples:
                try:
                    regex = suggest_regex(
                        example.field_name,
                        example.value,
                        example.anchor_before,
                        example.anchor_after,
                    )
                    extracted = ParserTrainingService._extract_field_with_regex(
                        regex=regex,
                        field_name=field_name,
                        text=text,
                        raw_text_pages=raw_text_pages,
                        page_hint=example.page,
                        confidence=0.94,
                        source="learned_example",
                    )
                    if not extracted:
                        continue
                    merged[field_name] = extracted
                    matched_from_example.add(field_name)
                    examples_applied += 1
                    break
                except Exception:
                    logger.exception(
                        "Training example application failed",
                        extra={
                            "example_id": example.id,
                            "field_name": field_name,
                            "company_code": training_company,
                            "agency_id": agency_id,
                        },
                    )

        generic_applied = 0
        for field_name in examples_by_field:
            if field_name in matched_from_pattern or field_name in matched_from_example:
                continue
            if not ParserTrainingService._should_apply_training_fallback(merged.get(field_name)):
                continue
            for regex in generic_field_patterns(field_name):
                try:
                    extracted = ParserTrainingService._extract_field_with_regex(
                        regex=regex,
                        field_name=field_name,
                        text=text,
                        raw_text_pages=raw_text_pages,
                        page_hint=None,
                        confidence=0.9,
                        source="learned_generic",
                    )
                    if not extracted:
                        continue
                    merged[field_name] = extracted
                    generic_applied += 1
                    break
                except Exception:
                    logger.exception(
                        "Generic training fallback failed",
                        extra={
                            "field_name": field_name,
                            "company_code": training_company,
                            "agency_id": agency_id,
                        },
                    )

        if patterns or examples_applied or generic_applied:
            logger.debug(
                "Applied training: company=%s patterns=%d examples=%d generic=%d",
                training_company,
                len(matched_from_pattern),
                examples_applied,
                generic_applied,
                extra={"company_code": training_company, "agency_id": agency_id},
            )

        return list(merged.values())

    @staticmethod
    def _should_apply_training_fallback(existing: ExtractedField | None) -> bool:
        if not existing or not (existing.value or "").strip():
            return True
        if (existing.source or "") in {"learned_pattern", "learned_example", "learned_generic"}:
            return False
        normalized = str(existing.value).strip().lower()
        if normalized in _SUSPICIOUS_EXTRACTED_VALUES:
            return True
        if (existing.confidence or 0) >= 0.9:
            return False
        if (existing.confidence or 0) >= 0.85 and len(normalized) >= 3:
            return False
        return True

    def _load_raw_text(self, extraction) -> str:
        from app.utils.file_storage import get_file_storage, read_storage_bytes

        if not extraction.storage_path:
            return ""
        if not get_file_storage().exists(extraction.storage_path):
            return ""
        text_pack = pack_raw_text_fields(read_storage_bytes(extraction.storage_path))
        raw = extraction.raw_data or {}
        raw["raw_text"] = text_pack["raw_text"]
        raw["raw_text_pages"] = text_pack["raw_text_pages"]
        raw["page_count"] = text_pack["page_count"]
        extraction.raw_data = raw
        self.db.commit()
        return text_pack["raw_text"]

    def _read_pdf_bytes(self, extraction) -> bytes:
        from app.utils.file_storage import get_file_storage, read_storage_bytes

        if not extraction.storage_path:
            raise ValidationError("PDF file not found for this extraction")
        if not get_file_storage().exists(extraction.storage_path):
            raise ValidationError("PDF file not found on disk")
        return read_storage_bytes(extraction.storage_path)

    def _get_example(self, example_id: int, agency_id: int | None) -> ParserTrainingExample:
        query = self.db.query(ParserTrainingExample).filter(ParserTrainingExample.id == example_id)
        if agency_id is not None:
            query = query.filter(ParserTrainingExample.agency_id == agency_id)
        row = query.first()
        if not row:
            raise NotFoundError("Training example not found")
        return row

    @staticmethod
    def _example_to_dict(row: ParserTrainingExample) -> dict:
        return {
            "id": row.id,
            "extraction_id": row.extraction_id,
            "company_code": row.company_code,
            "field_name": row.field_name,
            "page": row.page,
            "value": row.value,
            "anchor_before": row.anchor_before,
            "anchor_after": row.anchor_after,
            "snippet": row.snippet,
            "selection_start": row.selection_start,
            "selection_end": row.selection_end,
            "bbox": row.bbox,
            "notes": row.notes,
            "created_at": row.created_at.isoformat() if row.created_at else None,
        }

    @staticmethod
    def _pattern_to_dict(row: ParserPattern) -> dict:
        return {
            "id": row.id,
            "company_code": row.company_code,
            "field_name": row.field_name,
            "regex": row.regex,
            "page_hint": row.page_hint,
            "priority": row.priority,
            "active": row.active,
            "source_training_example_id": row.source_training_example_id,
            "notes": row.notes,
            "created_at": row.created_at.isoformat() if row.created_at else None,
        }
