| | """ |
| | Shared Hugging Face Space runtime for streaming chat inference. |
| | |
| | This module provides: |
| | - one-time global model loading |
| | - async request queue |
| | - worker pool with semaphore-based concurrency limits |
| | - per-request streamer/thread isolation |
| | - SSE streaming responses |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import asyncio |
| | import json |
| | import logging |
| | import os |
| | import time |
| | import uuid |
| | from contextlib import asynccontextmanager |
| | from dataclasses import dataclass, field |
| | from queue import Empty as QueueEmpty |
| | from threading import Event as ThreadEvent |
| | from threading import Thread |
| | from typing import Any, Dict, List, Optional |
| |
|
| | import torch |
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import FileResponse, StreamingResponse |
| | from pydantic import BaseModel |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | StoppingCriteria, |
| | StoppingCriteriaList, |
| | TextIteratorStreamer, |
| | ) |
| |
|
| |
|
| | class Message(BaseModel): |
| | role: str |
| | content: str |
| |
|
| |
|
| | class ChatRequest(BaseModel): |
| | messages: List[Message] |
| | stream: bool = True |
| | max_tokens: int = 8192 |
| | temperature: Optional[float] = None |
| | tools: Optional[List[Dict[str, Any]]] = None |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class RuntimeConfig: |
| | model_name: str |
| | title: str |
| | description: str |
| | version: str = "1.0.0" |
| | max_input_tokens: int = 32768 |
| | max_new_tokens: int = 131072 |
| | top_p: float = 0.95 |
| | top_k: Optional[int] = None |
| | repetition_penalty: float = 1.0 |
| | eos_token_id: Optional[int] = None |
| | default_temperature: float = 0.6 |
| | tokenizer_use_fast: Optional[bool] = None |
| | logger_name: str = "hf_space" |
| |
|
| |
|
| | @dataclass |
| | class GenerationTask: |
| | request_id: str |
| | prompt: str |
| | max_tokens: int |
| | temperature: float |
| | output_queue: asyncio.Queue[Optional[Dict[str, Any]]] |
| | created_at: float = field(default_factory=time.time) |
| | cancel_event: ThreadEvent = field(default_factory=ThreadEvent) |
| | prompt_tokens: int = 0 |
| | generated_tokens: int = 0 |
| | first_token_latency: Optional[float] = None |
| | start_time: Optional[float] = None |
| | end_time: Optional[float] = None |
| |
|
| |
|
| | class CancelAwareStoppingCriteria(StoppingCriteria): |
| | """Stops generation when the request is cancelled/disconnected.""" |
| |
|
| | def __init__(self, cancel_event: ThreadEvent): |
| | self.cancel_event = cancel_event |
| |
|
| | def __call__(self, input_ids, scores, **kwargs) -> bool: |
| | return self.cancel_event.is_set() |
| |
|
| |
|
| | def _is_truthy(value: str) -> bool: |
| | return value.strip().lower() in {"1", "true", "yes", "on"} |
| |
|
| |
|
| | def _format_sse_event(payload: Dict[str, Any]) -> str: |
| | event_type = str(payload.get("type", "token")) |
| | return f"event: {event_type}\ndata: {json.dumps(payload)}\n\n" |
| |
|
| |
|
| | def _read_stream_item(stream_iter) -> tuple[bool, Optional[str]]: |
| | """Read one item from streamer iterator without leaking StopIteration across threads.""" |
| | try: |
| | return False, next(stream_iter) |
| | except StopIteration: |
| | return True, None |
| |
|
| |
|
| | def _detect_concurrency(device: str) -> int: |
| | |
| | override = os.getenv("HF_MAX_WORKERS", "").strip() |
| | if override: |
| | try: |
| | parsed = int(override) |
| | if parsed > 0: |
| | return parsed |
| | except ValueError: |
| | pass |
| |
|
| | if device == "cuda" and torch.cuda.is_available(): |
| | total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) |
| | if total_vram_gb >= 20: |
| | return 5 |
| | if total_vram_gb >= 10: |
| | return 4 |
| | return 3 |
| |
|
| | cpu_count = os.cpu_count() or 1 |
| | |
| | return max(1, min(4, max(1, cpu_count // 6))) |
| |
|
| |
|
| | def create_hf_space_app(config: RuntimeConfig) -> FastAPI: |
| | logger = logging.getLogger(config.logger_name) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| | debug_token_logs = _is_truthy(os.getenv("HF_DEBUG_TOKEN_LOGS", "0")) |
| | queue_max_size = int(os.getenv("HF_QUEUE_MAX_SIZE", "512")) |
| | streamer_timeout = float(os.getenv("HF_STREAMER_TIMEOUT_SECONDS", "8")) |
| | join_timeout = float(os.getenv("HF_GENERATION_JOIN_TIMEOUT_SECONDS", "180")) |
| | max_input_tokens = int(os.getenv("HF_MAX_INPUT_TOKENS", str(config.max_input_tokens))) |
| | max_new_tokens_limit = int(os.getenv("HF_MAX_NEW_TOKENS", str(config.max_new_tokens))) |
| |
|
| | base_dir = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| | model = None |
| | tokenizer = None |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | max_workers = _detect_concurrency(device) |
| |
|
| | request_queue: asyncio.Queue[Optional[GenerationTask]] = asyncio.Queue(maxsize=queue_max_size) |
| | worker_tasks: List[asyncio.Task] = [] |
| | worker_semaphore = asyncio.Semaphore(max_workers) |
| |
|
| | active_workers = 0 |
| | active_workers_lock = asyncio.Lock() |
| |
|
| | async def set_active_workers(delta: int) -> int: |
| | nonlocal active_workers |
| | async with active_workers_lock: |
| | active_workers += delta |
| | if active_workers < 0: |
| | active_workers = 0 |
| | return active_workers |
| |
|
| | def format_messages_proper(messages: List[Message], tools: Optional[List[Dict[str, Any]]] = None) -> str: |
| | message_dicts = [{"role": msg.role, "content": msg.content} for msg in messages] |
| | if tools: |
| | return tokenizer.apply_chat_template( |
| | message_dicts, |
| | tools=tools, |
| | add_generation_prompt=True, |
| | tokenize=False, |
| | ) |
| | return tokenizer.apply_chat_template( |
| | message_dicts, |
| | add_generation_prompt=True, |
| | tokenize=False, |
| | ) |
| |
|
| | async def run_generation(task: GenerationTask, worker_id: int) -> None: |
| | request_start = time.time() |
| | task.start_time = request_start |
| | await set_active_workers(+1) |
| |
|
| | try: |
| | logger.info( |
| | "[%s] worker=%d start queue_size=%d active_workers=%d", |
| | task.request_id, |
| | worker_id, |
| | request_queue.qsize(), |
| | active_workers, |
| | ) |
| |
|
| | inputs = tokenizer( |
| | task.prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=max_input_tokens, |
| | add_special_tokens=False, |
| | ) |
| |
|
| | task.prompt_tokens = int(inputs.input_ids.shape[1]) |
| |
|
| | if device == "cuda": |
| | inputs = inputs.to("cuda") |
| |
|
| | streamer = TextIteratorStreamer( |
| | tokenizer, |
| | skip_prompt=True, |
| | skip_special_tokens=True, |
| | timeout=streamer_timeout, |
| | ) |
| |
|
| | stopping_criteria = StoppingCriteriaList( |
| | [CancelAwareStoppingCriteria(task.cancel_event)] |
| | ) |
| |
|
| | generation_kwargs: Dict[str, Any] = dict( |
| | **inputs, |
| | streamer=streamer, |
| | max_new_tokens=min(task.max_tokens, max_new_tokens_limit), |
| | temperature=task.temperature, |
| | top_p=config.top_p, |
| | repetition_penalty=config.repetition_penalty, |
| | do_sample=task.temperature > 0, |
| | eos_token_id=config.eos_token_id if config.eos_token_id is not None else tokenizer.eos_token_id, |
| | pad_token_id=tokenizer.eos_token_id, |
| | stopping_criteria=stopping_criteria, |
| | ) |
| | if config.top_k is not None: |
| | generation_kwargs["top_k"] = config.top_k |
| |
|
| | generation_error: Dict[str, Exception] = {} |
| | generation_done = ThreadEvent() |
| |
|
| | def generate_target() -> None: |
| | try: |
| | with torch.inference_mode(): |
| | model.generate(**generation_kwargs) |
| | except Exception as exc: |
| | generation_error["error"] = exc |
| | logger.error("[%s] generation thread error: %s", task.request_id, exc, exc_info=True) |
| | finally: |
| | generation_done.set() |
| | try: |
| | streamer.end() |
| | except Exception: |
| | |
| | pass |
| |
|
| | generation_thread = Thread( |
| | target=generate_target, |
| | name=f"gen-{task.request_id[:8]}", |
| | daemon=True, |
| | ) |
| | generation_thread.start() |
| |
|
| | stream_iter = iter(streamer) |
| | while True: |
| | if task.cancel_event.is_set(): |
| | logger.info("[%s] cancellation requested", task.request_id) |
| | break |
| |
|
| | try: |
| | stream_finished, new_text = await asyncio.to_thread(_read_stream_item, stream_iter) |
| | if stream_finished: |
| | break |
| | except QueueEmpty: |
| | if generation_done.is_set(): |
| | break |
| | continue |
| | except Exception as exc: |
| | if generation_done.is_set(): |
| | break |
| | logger.error("[%s] streamer read error: %s", task.request_id, exc, exc_info=True) |
| | generation_error["error"] = exc |
| | break |
| |
|
| | if not new_text: |
| | continue |
| |
|
| | task.generated_tokens += 1 |
| | if task.first_token_latency is None: |
| | task.first_token_latency = time.time() - request_start |
| | logger.info( |
| | "[%s] first_token=%.2fs worker=%d", |
| | task.request_id, |
| | task.first_token_latency, |
| | worker_id, |
| | ) |
| |
|
| | if debug_token_logs: |
| | logger.info("[%s] token#%d: %r", task.request_id, task.generated_tokens, new_text) |
| |
|
| | await task.output_queue.put({"type": "token", "content": new_text}) |
| | await asyncio.sleep(0) |
| |
|
| | |
| | try: |
| | await asyncio.wait_for(asyncio.to_thread(generation_thread.join), timeout=join_timeout) |
| | except asyncio.TimeoutError: |
| | logger.error( |
| | "[%s] generation thread still alive after %.1fs join timeout", |
| | task.request_id, |
| | join_timeout, |
| | ) |
| |
|
| | if task.cancel_event.is_set(): |
| | await task.output_queue.put({"type": "error", "content": "Generation interrupted. You can continue."}) |
| | elif "error" in generation_error: |
| | await task.output_queue.put({"type": "error", "content": str(generation_error["error"])}) |
| | else: |
| | await task.output_queue.put({"type": "done", "content": ""}) |
| |
|
| | except Exception as exc: |
| | logger.error("[%s] worker failure: %s", task.request_id, exc, exc_info=True) |
| | await task.output_queue.put({"type": "error", "content": str(exc)}) |
| | finally: |
| | task.end_time = time.time() |
| | duration = max(1e-6, task.end_time - request_start) |
| | tps = task.generated_tokens / duration |
| | logger.info( |
| | "[%s] worker=%d end tokens=%d duration=%.2fs tok_s=%.2f active_workers=%d queue_size=%d", |
| | task.request_id, |
| | worker_id, |
| | task.generated_tokens, |
| | duration, |
| | tps, |
| | active_workers, |
| | request_queue.qsize(), |
| | ) |
| |
|
| | await task.output_queue.put(None) |
| | await set_active_workers(-1) |
| |
|
| | async def worker_loop(worker_id: int) -> None: |
| | logger.info("Worker-%d started", worker_id) |
| | while True: |
| | task = await request_queue.get() |
| | if task is None: |
| | request_queue.task_done() |
| | logger.info("Worker-%d received shutdown signal", worker_id) |
| | break |
| |
|
| | try: |
| | if task.cancel_event.is_set(): |
| | await task.output_queue.put({"type": "error", "content": "Request cancelled before execution."}) |
| | await task.output_queue.put(None) |
| | continue |
| |
|
| | async with worker_semaphore: |
| | await run_generation(task, worker_id) |
| | finally: |
| | request_queue.task_done() |
| |
|
| | logger.info("Worker-%d stopped", worker_id) |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | nonlocal model, tokenizer, worker_tasks, max_workers, device |
| |
|
| | logger.info("Loading model %s on %s", config.model_name, device) |
| | tokenizer_kwargs: Dict[str, Any] = {"trust_remote_code": True} |
| | if config.tokenizer_use_fast is not None: |
| | tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast |
| | tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs) |
| | model_load_kwargs: Dict[str, Any] = { |
| | "trust_remote_code": True, |
| | "device_map": "auto" if device == "cuda" else None, |
| | } |
| | if device == "cuda": |
| | model_load_kwargs["dtype"] = "auto" |
| | else: |
| | model_load_kwargs["torch_dtype"] = torch.float32 |
| |
|
| | try: |
| | model = AutoModelForCausalLM.from_pretrained( |
| | config.model_name, |
| | **model_load_kwargs, |
| | ) |
| | except TypeError: |
| | |
| | if "dtype" in model_load_kwargs: |
| | model_load_kwargs["torch_dtype"] = model_load_kwargs.pop("dtype") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | config.model_name, |
| | **model_load_kwargs, |
| | ) |
| |
|
| | if device != "cuda": |
| | model = model.to("cpu") |
| |
|
| | logger.info( |
| | "Model loaded: %s | device=%s | max_workers=%d | queue_max_size=%d", |
| | config.model_name, |
| | device, |
| | max_workers, |
| | queue_max_size, |
| | ) |
| | logger.info( |
| | "Runtime config: max_input_tokens=%d max_new_tokens_limit=%d top_p=%.3f top_k=%s rep_penalty=%.3f", |
| | max_input_tokens, |
| | max_new_tokens_limit, |
| | config.top_p, |
| | str(config.top_k), |
| | config.repetition_penalty, |
| | ) |
| |
|
| | worker_tasks = [ |
| | asyncio.create_task(worker_loop(i + 1), name=f"generation-worker-{i + 1}") |
| | for i in range(max_workers) |
| | ] |
| |
|
| | try: |
| | yield |
| | finally: |
| | logger.info("Shutting down workers...") |
| | for _ in worker_tasks: |
| | await request_queue.put(None) |
| | await asyncio.gather(*worker_tasks, return_exceptions=True) |
| |
|
| | logger.info("Releasing model resources...") |
| | del model |
| | del tokenizer |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | app = FastAPI( |
| | title=config.title, |
| | description=config.description, |
| | version=config.version, |
| | lifespan=lifespan, |
| | ) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | @app.get("/") |
| | async def root(): |
| | return { |
| | "name": config.title, |
| | "version": config.version, |
| | "model": config.model_name, |
| | "status": "running", |
| | "device": device, |
| | "max_workers": max_workers, |
| | } |
| |
|
| | @app.get("/index", response_class=FileResponse) |
| | async def serve_chat(): |
| | return FileResponse(os.path.join(base_dir, "index.html")) |
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return { |
| | "status": "healthy", |
| | "model_loaded": model is not None and tokenizer is not None, |
| | "device": device, |
| | "active_workers": active_workers, |
| | "queue_size": request_queue.qsize(), |
| | "max_workers": max_workers, |
| | } |
| |
|
| | @app.post("/chat") |
| | async def chat(request: ChatRequest): |
| | if model is None or tokenizer is None: |
| | raise HTTPException(status_code=503, detail="Model not loaded yet") |
| |
|
| | prompt = format_messages_proper(request.messages, request.tools) |
| | task = GenerationTask( |
| | request_id=uuid.uuid4().hex, |
| | prompt=prompt, |
| | max_tokens=request.max_tokens, |
| | temperature=request.temperature if request.temperature is not None else config.default_temperature, |
| | output_queue=asyncio.Queue(maxsize=2048), |
| | ) |
| |
|
| | logger.info( |
| | "[%s] queued request prompt_len=%d queue_size=%d", |
| | task.request_id, |
| | len(prompt), |
| | request_queue.qsize(), |
| | ) |
| | await request_queue.put(task) |
| |
|
| | if request.stream: |
| | async def stream_events(): |
| | try: |
| | while True: |
| | event = await task.output_queue.get() |
| | if event is None: |
| | break |
| | yield _format_sse_event(event) |
| | except asyncio.CancelledError: |
| | task.cancel_event.set() |
| | raise |
| | finally: |
| | task.cancel_event.set() |
| |
|
| | return StreamingResponse( |
| | stream_events(), |
| | media_type="text/event-stream", |
| | headers={ |
| | "Cache-Control": "no-cache, no-store, must-revalidate", |
| | "Pragma": "no-cache", |
| | "Expires": "0", |
| | "Connection": "keep-alive", |
| | "X-Accel-Buffering": "no", |
| | "Transfer-Encoding": "chunked", |
| | }, |
| | ) |
| |
|
| | chunks: List[str] = [] |
| | error_message: Optional[str] = None |
| | while True: |
| | event = await task.output_queue.get() |
| | if event is None: |
| | break |
| | event_type = event.get("type") |
| | if event_type == "token": |
| | chunks.append(str(event.get("content", ""))) |
| | elif event_type == "error": |
| | error_message = str(event.get("content", "Generation failed")) |
| |
|
| | if error_message: |
| | raise HTTPException(status_code=500, detail=error_message) |
| |
|
| | response_text = "".join(chunks).strip() |
| | return { |
| | "content": response_text, |
| | "usage": { |
| | "prompt_tokens": task.prompt_tokens, |
| | "completion_tokens": task.generated_tokens, |
| | }, |
| | } |
| |
|
| | return app |
| |
|