| import os |
| import sys |
| from litellm import acompletion |
| from dotenv import load_dotenv |
| from fastapi import FastAPI |
| from fastapi.staticfiles import StaticFiles |
| from jinja2 import Environment, FileSystemLoader, StrictUndefined, TemplateNotFound |
| from schemas import CreateSearchPlanRequest, CreateSearchPlanResponse, ExtractEntitiesRequest, ExtractEntitiesResponse, ExtractedRelationsResponse |
| from utils import build_visjs_graph, fmt_prompt |
| import logging |
|
|
| load_dotenv() |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
|
|
|
|
| LLM_MODEL = os.environ.get('LLM_MODEL', default=None) |
| LLM_TOKEN = os.environ.get('LLM_TOKEN', default=None) |
| LLM_BASE_URL = os.environ.get('LLM_BASE_URL', default=None) |
|
|
| if not LLM_MODEL and not LLM_TOKEN: |
| logging.error("No LLM_TOKEN and LLM_MODEL were provided.") |
| sys.exit(-1) |
|
|
| prompt_env = Environment(loader=FileSystemLoader( |
| "prompts"), undefined=StrictUndefined, enable_async=True) |
|
|
| api = FastAPI() |
|
|
|
|
| @api.post("/extract_entities") |
| async def extract_entities(body: ExtractEntitiesRequest): |
| """Extract entities from the given input text and return them""" |
| |
| entities_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ |
| { |
| "role": "user", |
| "content": await fmt_prompt(prompt_env, "ner/extract_entities", **{ |
| "response_format": ExtractEntitiesResponse.model_json_schema(), |
| "input_text": body.content |
| }) |
| } |
| ], response_format=ExtractEntitiesResponse) |
|
|
| extracted_entities = ExtractEntitiesResponse.model_validate_json( |
| entities_completion.choices[0].message.content) |
|
|
| |
| relations_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ |
| { |
| "role": "user", |
| "content": await fmt_prompt(prompt_env, "ner/extract_relations", **{ |
| "response_format": ExtractedRelationsResponse.model_json_schema(), |
| "input_text": body.content, |
| "entities": extracted_entities.entities |
| }) |
| } |
| ], response_format=ExtractedRelationsResponse, num_retries=5) |
|
|
| relation_model = ExtractedRelationsResponse.model_validate_json( |
| relations_completion.choices[0].message.content) |
|
|
| display_lists = build_visjs_graph( |
| extracted_entities.entities, relation_model.relations) |
|
|
| return display_lists |
|
|
|
|
| @api.post("/create_search_plan") |
| async def create_search_plan(body: CreateSearchPlanRequest): |
| plan_completion = await acompletion(LLM_MODEL, api_key=LLM_TOKEN, base_url=LLM_BASE_URL, messages=[ |
| { |
| "role": "user", |
| "content": await fmt_prompt(prompt_env, "search/create_search_plan", **{ |
| "response_format": CreateSearchPlanResponse.model_json_schema(), |
| "user_query": body.query, |
| }) |
| } |
| ], response_format=CreateSearchPlanResponse) |
|
|
| plan_model = CreateSearchPlanResponse.model_validate_json( |
| plan_completion.choices[0].message.content) |
|
|
| return plan_model |
|
|
|
|
| api.mount("/", StaticFiles(directory="static", html=True), name="static") |
|
|