"""Extract and pack PDF text page-by-page."""

from __future__ import annotations

import io
from typing import Any

import pdfplumber

from app.parsers.text_normalize import normalize_pdf_text

# Must stay in sync with frontend train-mode offset calculation.
PAGE_TEXT_SEPARATOR = "\n\n"
MAX_PAGES = 30
MAX_RAW_TEXT_CHARS = 20000
MAX_PAGE_TEXT_CHARS = 8000


def extract_text_pages(file_bytes: bytes, *, max_pages: int = MAX_PAGES) -> list[dict[str, Any]]:
    pages: list[dict[str, Any]] = []
    with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
        for index, page in enumerate(pdf.pages[:max_pages]):
            page_text = normalize_pdf_text(page.extract_text() or "")
            if len(page_text) > MAX_PAGE_TEXT_CHARS:
                page_text = page_text[:MAX_PAGE_TEXT_CHARS]
            pages.append({"page": index + 1, "text": page_text})
    return pages


def pages_to_raw_text(pages: list[dict[str, Any]]) -> str:
    return PAGE_TEXT_SEPARATOR.join(p["text"] for p in pages if p.get("text"))


def pack_raw_text_fields(file_bytes: bytes, *, allow_ocr: bool = True) -> dict[str, Any]:
    """Return raw_text + raw_text_pages for storage in extraction raw_data."""
    from app.utils.pdf_ocr import is_likely_scanned, ocr_text_pages

    pages = extract_text_pages(file_bytes)
    text_source = "native"
    ocr_engine = None

    if allow_ocr and is_likely_scanned(pages):
        ocr_pages, ocr_engine = ocr_text_pages(file_bytes)
        native_chars = sum(len(p.get("text") or "") for p in pages)
        ocr_chars = sum(len(p.get("text") or "") for p in ocr_pages)
        if ocr_pages and ocr_chars > native_chars:
            pages = ocr_pages
            text_source = "ocr"
            from app.utils.pdf_ocr import ocr_engine_label

            ocr_engine = ocr_engine_label(ocr_engine)

    raw_text = pages_to_raw_text(pages)
    if len(raw_text) > MAX_RAW_TEXT_CHARS:
        raw_text = raw_text[:MAX_RAW_TEXT_CHARS]
    return {
        "raw_text": raw_text,
        "raw_text_pages": pages,
        "page_count": len(pages),
        "text_source": text_source,
        "ocr_engine": ocr_engine,
        "is_scanned": text_source == "ocr" or is_likely_scanned(pages),
    }


def raw_text_pages_from_joined(raw_text: str) -> list[dict[str, Any]]:
    """Fallback when only combined raw_text exists (legacy extractions)."""
    if not raw_text:
        return []
    parts = raw_text.split(PAGE_TEXT_SEPARATOR)
    if len(parts) <= 1:
        return [{"page": 1, "text": raw_text}]
    return [{"page": index + 1, "text": part} for index, part in enumerate(parts) if part]
