File size: 5,188 Bytes
bd27421 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """PDF page classifier — public factory with HuggingFace auto-download.
Standalone usage (files downloaded from HF repo):
from classifiers import load_classifier
clf = load_classifier(".") # local directory with model files
result = clf.predict("page.png")
HuggingFace usage:
from classifiers import load_classifier
clf = load_classifier("Wikit/pdf-pages-classifier")
result = clf.predict("page.png")
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
# INT8 preferred over FP32 for both backends — matches classifier lookup order
_HF_ONNX_INT8_FILES = ["model_int8.onnx", "config.json"]
_HF_ONNX_FP32_FILES = ["model.onnx", "config.json"]
_HF_OV_INT8_FILES = ["openvino_model_int8.xml", "openvino_model_int8.bin", "config.json"]
_HF_OV_FP32_FILES = ["openvino_model.xml", "openvino_model.bin", "config.json"]
def _is_hf_repo_id(path: str) -> bool:
"""Return True if path looks like 'owner/repo' rather than a local path."""
if os.path.exists(path):
return False
# HF repo IDs have exactly one '/' and no OS path separators or leading dots
normalized = path.replace("\\", "/")
if normalized.startswith((".", "/", "~")):
return False
parts = normalized.split("/")
return len(parts) == 2 and all(p.strip() for p in parts)
def _download_from_hf(repo_id: str, filenames: list[str], cache_dir: str | None) -> Path:
"""Download specific files from a HF repo and return the local snapshot directory."""
try:
from huggingface_hub import hf_hub_download
except ImportError as e:
raise ImportError(
"huggingface_hub is required to load from a HuggingFace repo.\n"
"Install with: pip install huggingface-hub"
) from e
last: Path | None = None
for filename in filenames:
last = Path(hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir))
assert last is not None
return last.parent
def _download_with_int8_fallback(
repo_id: str,
int8_files: list[str],
fp32_files: list[str],
cache_dir: str | None,
) -> Path:
"""Download files from HF, preferring INT8 over FP32 when available."""
try:
from huggingface_hub import EntryNotFoundError
except ImportError as e:
raise ImportError(
"huggingface_hub is required to load from a HuggingFace repo.\n"
"Install with: pip install huggingface-hub"
) from e
try:
return _download_from_hf(repo_id, int8_files, cache_dir)
except EntryNotFoundError:
return _download_from_hf(repo_id, fp32_files, cache_dir)
def load_classifier(
repo_or_dir: str = "Wikit/pdf-pages-classifier",
backend: str = "auto",
device: str = "CPU",
cache_dir: str | None = None,
) -> Any:
"""Load a PDF page classifier with automatic backend selection.
Args:
repo_or_dir: HuggingFace repo ID (e.g. ``"Wikit/pdf-pages-classifier"``)
or local directory containing ``config.json`` and model files.
backend: ``"auto"`` tries OpenVINO first, falls back to ONNX.
Pass ``"openvino"`` or ``"onnx"`` to force a specific backend.
device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
Ignored for ONNX.
cache_dir: Custom cache directory for HuggingFace downloads.
Returns:
A classifier instance exposing ``predict(images)``.
Example::
clf = load_classifier("Wikit/pdf-pages-classifier")
result = clf.predict("page.png")
print(result["needs_image_embedding"], result["predicted_classes"])
"""
if backend not in ("auto", "onnx", "openvino"):
raise ValueError(f"Unknown backend {backend!r}. Choose 'auto', 'onnx', or 'openvino'.")
is_hf = _is_hf_repo_id(repo_or_dir)
if backend in ("auto", "openvino"):
try:
return _load_openvino(repo_or_dir, device=device, cache_dir=cache_dir, is_hf=is_hf)
except (ImportError, FileNotFoundError):
if backend == "openvino":
raise
return _load_onnx(repo_or_dir, cache_dir=cache_dir, is_hf=is_hf)
def _load_onnx(repo_or_dir: str, cache_dir: str | None, is_hf: bool) -> Any:
try:
from .classifier_onnx import PDFPageClassifierONNX
except ImportError:
from classifier_onnx import PDFPageClassifierONNX # type: ignore[no-redef]
model_dir = (
_download_with_int8_fallback(repo_or_dir, _HF_ONNX_INT8_FILES, _HF_ONNX_FP32_FILES, cache_dir)
if is_hf else Path(repo_or_dir)
)
return PDFPageClassifierONNX.from_pretrained(str(model_dir))
def _load_openvino(repo_or_dir: str, device: str, cache_dir: str | None, is_hf: bool) -> Any:
try:
from .classifier_ov import PDFPageClassifierOV
except ImportError:
from classifier_ov import PDFPageClassifierOV # type: ignore[no-redef]
model_dir = (
_download_with_int8_fallback(repo_or_dir, _HF_OV_INT8_FILES, _HF_OV_FP32_FILES, cache_dir)
if is_hf else Path(repo_or_dir)
)
return PDFPageClassifierOV.from_pretrained(str(model_dir), device=device)
|