File size: 5,188 Bytes
bd27421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""PDF page classifier — public factory with HuggingFace auto-download.

Standalone usage (files downloaded from HF repo):
    from classifiers import load_classifier
    clf = load_classifier(".")          # local directory with model files
    result = clf.predict("page.png")

HuggingFace usage:
    from classifiers import load_classifier
    clf = load_classifier("Wikit/pdf-pages-classifier")
    result = clf.predict("page.png")
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import Any


# INT8 preferred over FP32 for both backends — matches classifier lookup order
_HF_ONNX_INT8_FILES = ["model_int8.onnx", "config.json"]
_HF_ONNX_FP32_FILES = ["model.onnx", "config.json"]

_HF_OV_INT8_FILES = ["openvino_model_int8.xml", "openvino_model_int8.bin", "config.json"]
_HF_OV_FP32_FILES = ["openvino_model.xml", "openvino_model.bin", "config.json"]


def _is_hf_repo_id(path: str) -> bool:
    """Return True if path looks like 'owner/repo' rather than a local path."""
    if os.path.exists(path):
        return False
    # HF repo IDs have exactly one '/' and no OS path separators or leading dots
    normalized = path.replace("\\", "/")
    if normalized.startswith((".", "/", "~")):
        return False
    parts = normalized.split("/")
    return len(parts) == 2 and all(p.strip() for p in parts)


def _download_from_hf(repo_id: str, filenames: list[str], cache_dir: str | None) -> Path:
    """Download specific files from a HF repo and return the local snapshot directory."""
    try:
        from huggingface_hub import hf_hub_download
    except ImportError as e:
        raise ImportError(
            "huggingface_hub is required to load from a HuggingFace repo.\n"
            "Install with: pip install huggingface-hub"
        ) from e

    last: Path | None = None
    for filename in filenames:
        last = Path(hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir))

    assert last is not None
    return last.parent


def _download_with_int8_fallback(
    repo_id: str,
    int8_files: list[str],
    fp32_files: list[str],
    cache_dir: str | None,
) -> Path:
    """Download files from HF, preferring INT8 over FP32 when available."""
    try:
        from huggingface_hub import EntryNotFoundError
    except ImportError as e:
        raise ImportError(
            "huggingface_hub is required to load from a HuggingFace repo.\n"
            "Install with: pip install huggingface-hub"
        ) from e

    try:
        return _download_from_hf(repo_id, int8_files, cache_dir)
    except EntryNotFoundError:
        return _download_from_hf(repo_id, fp32_files, cache_dir)


def load_classifier(
    repo_or_dir: str = "Wikit/pdf-pages-classifier",
    backend: str = "auto",
    device: str = "CPU",
    cache_dir: str | None = None,
) -> Any:
    """Load a PDF page classifier with automatic backend selection.

    Args:
        repo_or_dir: HuggingFace repo ID (e.g. ``"Wikit/pdf-pages-classifier"``)
            or local directory containing ``config.json`` and model files.
        backend: ``"auto"`` tries OpenVINO first, falls back to ONNX.
            Pass ``"openvino"`` or ``"onnx"`` to force a specific backend.
        device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
            Ignored for ONNX.
        cache_dir: Custom cache directory for HuggingFace downloads.

    Returns:
        A classifier instance exposing ``predict(images)``.

    Example::

        clf = load_classifier("Wikit/pdf-pages-classifier")
        result = clf.predict("page.png")
        print(result["needs_image_embedding"], result["predicted_classes"])
    """
    if backend not in ("auto", "onnx", "openvino"):
        raise ValueError(f"Unknown backend {backend!r}. Choose 'auto', 'onnx', or 'openvino'.")

    is_hf = _is_hf_repo_id(repo_or_dir)

    if backend in ("auto", "openvino"):
        try:
            return _load_openvino(repo_or_dir, device=device, cache_dir=cache_dir, is_hf=is_hf)
        except (ImportError, FileNotFoundError):
            if backend == "openvino":
                raise

    return _load_onnx(repo_or_dir, cache_dir=cache_dir, is_hf=is_hf)


def _load_onnx(repo_or_dir: str, cache_dir: str | None, is_hf: bool) -> Any:
    try:
        from .classifier_onnx import PDFPageClassifierONNX
    except ImportError:
        from classifier_onnx import PDFPageClassifierONNX  # type: ignore[no-redef]

    model_dir = (
        _download_with_int8_fallback(repo_or_dir, _HF_ONNX_INT8_FILES, _HF_ONNX_FP32_FILES, cache_dir)
        if is_hf else Path(repo_or_dir)
    )
    return PDFPageClassifierONNX.from_pretrained(str(model_dir))


def _load_openvino(repo_or_dir: str, device: str, cache_dir: str | None, is_hf: bool) -> Any:
    try:
        from .classifier_ov import PDFPageClassifierOV
    except ImportError:
        from classifier_ov import PDFPageClassifierOV  # type: ignore[no-redef]

    model_dir = (
        _download_with_int8_fallback(repo_or_dir, _HF_OV_INT8_FILES, _HF_OV_FP32_FILES, cache_dir)
        if is_hf else Path(repo_or_dir)
    )
    return PDFPageClassifierOV.from_pretrained(str(model_dir), device=device)