lfmreact / server_runtime.py
extraplus's picture
Upload server_runtime.py
0f37f6b verified
"""
Shared Hugging Face Space runtime for streaming chat inference.
This module provides:
- one-time global model loading
- async request queue
- worker pool with semaphore-based concurrency limits
- per-request streamer/thread isolation
- SSE streaming responses
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from queue import Empty as QueueEmpty
from threading import Event as ThreadEvent
from threading import Thread
from typing import Any, Dict, List, Optional
import torch
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
stream: bool = True
max_tokens: int = 8192
temperature: Optional[float] = None
tools: Optional[List[Dict[str, Any]]] = None
@dataclass(frozen=True)
class RuntimeConfig:
model_name: str
title: str
description: str
version: str = "1.0.0"
max_input_tokens: int = 32768
max_new_tokens: int = 131072
top_p: float = 0.95
top_k: Optional[int] = None
repetition_penalty: float = 1.0
eos_token_id: Optional[int] = None
default_temperature: float = 0.6
tokenizer_use_fast: Optional[bool] = None
logger_name: str = "hf_space"
@dataclass
class GenerationTask:
request_id: str
prompt: str
max_tokens: int
temperature: float
output_queue: asyncio.Queue[Optional[Dict[str, Any]]]
created_at: float = field(default_factory=time.time)
cancel_event: ThreadEvent = field(default_factory=ThreadEvent)
prompt_tokens: int = 0
generated_tokens: int = 0
first_token_latency: Optional[float] = None
start_time: Optional[float] = None
end_time: Optional[float] = None
class CancelAwareStoppingCriteria(StoppingCriteria):
"""Stops generation when the request is cancelled/disconnected."""
def __init__(self, cancel_event: ThreadEvent):
self.cancel_event = cancel_event
def __call__(self, input_ids, scores, **kwargs) -> bool:
return self.cancel_event.is_set()
def _is_truthy(value: str) -> bool:
return value.strip().lower() in {"1", "true", "yes", "on"}
def _format_sse_event(payload: Dict[str, Any]) -> str:
event_type = str(payload.get("type", "token"))
return f"event: {event_type}\ndata: {json.dumps(payload)}\n\n"
def _read_stream_item(stream_iter) -> tuple[bool, Optional[str]]:
"""Read one item from streamer iterator without leaking StopIteration across threads."""
try:
return False, next(stream_iter)
except StopIteration:
return True, None
def _detect_concurrency(device: str) -> int:
# Allow environment override if needed for debugging/tuning.
override = os.getenv("HF_MAX_WORKERS", "").strip()
if override:
try:
parsed = int(override)
if parsed > 0:
return parsed
except ValueError:
pass
if device == "cuda" and torch.cuda.is_available():
total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
if total_vram_gb >= 20:
return 5
if total_vram_gb >= 10:
return 4
return 3
cpu_count = os.cpu_count() or 1
# Conservative CPU default for large models; still within 1..4 range.
return max(1, min(4, max(1, cpu_count // 6)))
def create_hf_space_app(config: RuntimeConfig) -> FastAPI:
logger = logging.getLogger(config.logger_name)
logging.basicConfig(level=logging.INFO)
debug_token_logs = _is_truthy(os.getenv("HF_DEBUG_TOKEN_LOGS", "0"))
queue_max_size = int(os.getenv("HF_QUEUE_MAX_SIZE", "512"))
streamer_timeout = float(os.getenv("HF_STREAMER_TIMEOUT_SECONDS", "8"))
join_timeout = float(os.getenv("HF_GENERATION_JOIN_TIMEOUT_SECONDS", "180"))
max_input_tokens = int(os.getenv("HF_MAX_INPUT_TOKENS", str(config.max_input_tokens)))
max_new_tokens_limit = int(os.getenv("HF_MAX_NEW_TOKENS", str(config.max_new_tokens)))
base_dir = os.path.dirname(os.path.abspath(__file__))
model = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
max_workers = _detect_concurrency(device)
request_queue: asyncio.Queue[Optional[GenerationTask]] = asyncio.Queue(maxsize=queue_max_size)
worker_tasks: List[asyncio.Task] = []
worker_semaphore = asyncio.Semaphore(max_workers)
active_workers = 0
active_workers_lock = asyncio.Lock()
async def set_active_workers(delta: int) -> int:
nonlocal active_workers
async with active_workers_lock:
active_workers += delta
if active_workers < 0:
active_workers = 0
return active_workers
def format_messages_proper(messages: List[Message], tools: Optional[List[Dict[str, Any]]] = None) -> str:
message_dicts = [{"role": msg.role, "content": msg.content} for msg in messages]
if tools:
return tokenizer.apply_chat_template(
message_dicts,
tools=tools,
add_generation_prompt=True,
tokenize=False,
)
return tokenizer.apply_chat_template(
message_dicts,
add_generation_prompt=True,
tokenize=False,
)
async def run_generation(task: GenerationTask, worker_id: int) -> None:
request_start = time.time()
task.start_time = request_start
await set_active_workers(+1)
try:
logger.info(
"[%s] worker=%d start queue_size=%d active_workers=%d",
task.request_id,
worker_id,
request_queue.qsize(),
active_workers,
)
inputs = tokenizer(
task.prompt,
return_tensors="pt",
truncation=True,
max_length=max_input_tokens,
add_special_tokens=False,
)
task.prompt_tokens = int(inputs.input_ids.shape[1])
if device == "cuda":
inputs = inputs.to("cuda")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=streamer_timeout,
)
stopping_criteria = StoppingCriteriaList(
[CancelAwareStoppingCriteria(task.cancel_event)]
)
generation_kwargs: Dict[str, Any] = dict(
**inputs,
streamer=streamer,
max_new_tokens=min(task.max_tokens, max_new_tokens_limit),
temperature=task.temperature,
top_p=config.top_p,
repetition_penalty=config.repetition_penalty,
do_sample=task.temperature > 0,
eos_token_id=config.eos_token_id if config.eos_token_id is not None else tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria,
)
if config.top_k is not None:
generation_kwargs["top_k"] = config.top_k
generation_error: Dict[str, Exception] = {}
generation_done = ThreadEvent()
def generate_target() -> None:
try:
with torch.inference_mode():
model.generate(**generation_kwargs)
except Exception as exc: # pragma: no cover - defensive logging
generation_error["error"] = exc
logger.error("[%s] generation thread error: %s", task.request_id, exc, exc_info=True)
finally:
generation_done.set()
try:
streamer.end()
except Exception:
# Best-effort close of streamer queue.
pass
generation_thread = Thread(
target=generate_target,
name=f"gen-{task.request_id[:8]}",
daemon=True,
)
generation_thread.start()
stream_iter = iter(streamer)
while True:
if task.cancel_event.is_set():
logger.info("[%s] cancellation requested", task.request_id)
break
try:
stream_finished, new_text = await asyncio.to_thread(_read_stream_item, stream_iter)
if stream_finished:
break
except QueueEmpty:
if generation_done.is_set():
break
continue
except Exception as exc: # pragma: no cover - defensive logging
if generation_done.is_set():
break
logger.error("[%s] streamer read error: %s", task.request_id, exc, exc_info=True)
generation_error["error"] = exc
break
if not new_text:
continue
task.generated_tokens += 1
if task.first_token_latency is None:
task.first_token_latency = time.time() - request_start
logger.info(
"[%s] first_token=%.2fs worker=%d",
task.request_id,
task.first_token_latency,
worker_id,
)
if debug_token_logs:
logger.info("[%s] token#%d: %r", task.request_id, task.generated_tokens, new_text)
await task.output_queue.put({"type": "token", "content": new_text})
await asyncio.sleep(0)
# Ensure generation thread is not left running in background.
try:
await asyncio.wait_for(asyncio.to_thread(generation_thread.join), timeout=join_timeout)
except asyncio.TimeoutError:
logger.error(
"[%s] generation thread still alive after %.1fs join timeout",
task.request_id,
join_timeout,
)
if task.cancel_event.is_set():
await task.output_queue.put({"type": "error", "content": "Generation interrupted. You can continue."})
elif "error" in generation_error:
await task.output_queue.put({"type": "error", "content": str(generation_error["error"])})
else:
await task.output_queue.put({"type": "done", "content": ""})
except Exception as exc:
logger.error("[%s] worker failure: %s", task.request_id, exc, exc_info=True)
await task.output_queue.put({"type": "error", "content": str(exc)})
finally:
task.end_time = time.time()
duration = max(1e-6, task.end_time - request_start)
tps = task.generated_tokens / duration
logger.info(
"[%s] worker=%d end tokens=%d duration=%.2fs tok_s=%.2f active_workers=%d queue_size=%d",
task.request_id,
worker_id,
task.generated_tokens,
duration,
tps,
active_workers,
request_queue.qsize(),
)
await task.output_queue.put(None)
await set_active_workers(-1)
async def worker_loop(worker_id: int) -> None:
logger.info("Worker-%d started", worker_id)
while True:
task = await request_queue.get()
if task is None:
request_queue.task_done()
logger.info("Worker-%d received shutdown signal", worker_id)
break
try:
if task.cancel_event.is_set():
await task.output_queue.put({"type": "error", "content": "Request cancelled before execution."})
await task.output_queue.put(None)
continue
async with worker_semaphore:
await run_generation(task, worker_id)
finally:
request_queue.task_done()
logger.info("Worker-%d stopped", worker_id)
@asynccontextmanager
async def lifespan(app: FastAPI):
nonlocal model, tokenizer, worker_tasks, max_workers, device
logger.info("Loading model %s on %s", config.model_name, device)
tokenizer_kwargs: Dict[str, Any] = {"trust_remote_code": True}
if config.tokenizer_use_fast is not None:
tokenizer_kwargs["use_fast"] = config.tokenizer_use_fast
tokenizer = AutoTokenizer.from_pretrained(config.model_name, **tokenizer_kwargs)
model_load_kwargs: Dict[str, Any] = {
"trust_remote_code": True,
"device_map": "auto" if device == "cuda" else None,
}
if device == "cuda":
model_load_kwargs["dtype"] = "auto"
else:
model_load_kwargs["torch_dtype"] = torch.float32
try:
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
**model_load_kwargs,
)
except TypeError:
# Backward compatibility for older transformers that do not accept `dtype`.
if "dtype" in model_load_kwargs:
model_load_kwargs["torch_dtype"] = model_load_kwargs.pop("dtype")
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
**model_load_kwargs,
)
if device != "cuda":
model = model.to("cpu")
logger.info(
"Model loaded: %s | device=%s | max_workers=%d | queue_max_size=%d",
config.model_name,
device,
max_workers,
queue_max_size,
)
logger.info(
"Runtime config: max_input_tokens=%d max_new_tokens_limit=%d top_p=%.3f top_k=%s rep_penalty=%.3f",
max_input_tokens,
max_new_tokens_limit,
config.top_p,
str(config.top_k),
config.repetition_penalty,
)
worker_tasks = [
asyncio.create_task(worker_loop(i + 1), name=f"generation-worker-{i + 1}")
for i in range(max_workers)
]
try:
yield
finally:
logger.info("Shutting down workers...")
for _ in worker_tasks:
await request_queue.put(None)
await asyncio.gather(*worker_tasks, return_exceptions=True)
logger.info("Releasing model resources...")
del model
del tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
app = FastAPI(
title=config.title,
description=config.description,
version=config.version,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"name": config.title,
"version": config.version,
"model": config.model_name,
"status": "running",
"device": device,
"max_workers": max_workers,
}
@app.get("/index", response_class=FileResponse)
async def serve_chat():
return FileResponse(os.path.join(base_dir, "index.html"))
@app.get("/health")
async def health():
return {
"status": "healthy",
"model_loaded": model is not None and tokenizer is not None,
"device": device,
"active_workers": active_workers,
"queue_size": request_queue.qsize(),
"max_workers": max_workers,
}
@app.post("/chat")
async def chat(request: ChatRequest):
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
prompt = format_messages_proper(request.messages, request.tools)
task = GenerationTask(
request_id=uuid.uuid4().hex,
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature if request.temperature is not None else config.default_temperature,
output_queue=asyncio.Queue(maxsize=2048),
)
logger.info(
"[%s] queued request prompt_len=%d queue_size=%d",
task.request_id,
len(prompt),
request_queue.qsize(),
)
await request_queue.put(task)
if request.stream:
async def stream_events():
try:
while True:
event = await task.output_queue.get()
if event is None:
break
yield _format_sse_event(event)
except asyncio.CancelledError:
task.cancel_event.set()
raise
finally:
task.cancel_event.set()
return StreamingResponse(
stream_events(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Transfer-Encoding": "chunked",
},
)
chunks: List[str] = []
error_message: Optional[str] = None
while True:
event = await task.output_queue.get()
if event is None:
break
event_type = event.get("type")
if event_type == "token":
chunks.append(str(event.get("content", "")))
elif event_type == "error":
error_message = str(event.get("content", "Generation failed"))
if error_message:
raise HTTPException(status_code=500, detail=error_message)
response_text = "".join(chunks).strip()
return {
"content": response_text,
"usage": {
"prompt_tokens": task.prompt_tokens,
"completion_tokens": task.generated_tokens,
},
}
return app