| | import re |
| | import logging |
| | import time |
| | from typing import List, Dict, Any, Optional, Tuple |
| | from random import sample, shuffle |
| |
|
| | from langchain_core.output_parsers import StrOutputParser, JsonOutputParser |
| | from langchain_core.runnables import Runnable, RunnablePassthrough |
| | from langchain_core.pydantic_v1 import Field, BaseModel as V1BaseModel |
| |
|
| | from .config import settings |
| | from .graph_client import neo4j_client |
| | from .llm_interface import get_llm, invoke_llm |
| | from .prompts import ( |
| | CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT, |
| | BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT |
| | ) |
| | from .schemas import KeyIssue |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | def extract_cypher(text: str) -> str: |
| | """Extracts the first Cypher code block or returns the text itself.""" |
| | pattern = r"```(?:cypher)?\s*(.*?)\s*```" |
| | match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) |
| | return match.group(1).strip() if match else text.strip() |
| |
|
| | def format_doc_for_llm(doc: Dict[str, Any]) -> str: |
| | """Formats a document dictionary into a string for LLM context.""" |
| | return "\n".join(f"**{key}**: {value}" for key, value in doc.items() if value) |
| |
|
| |
|
| | |
| | def generate_cypher_auto(question: str) -> str: |
| | """Generates Cypher using the 'auto' method.""" |
| | logger.info("Generating Cypher using 'auto' method.") |
| | |
| | |
| | schema_info = "Schema not available." |
| |
|
| | cypher_llm = get_llm(settings.main_llm_model) |
| | chain = ( |
| | {"question": RunnablePassthrough(), "schema": lambda x: schema_info} |
| | | CYPHER_GENERATION_PROMPT |
| | | cypher_llm |
| | | StrOutputParser() |
| | | extract_cypher |
| | ) |
| | return invoke_llm(chain,question) |
| |
|
| | def generate_cypher_guided(question: str, plan_step: int) -> str: |
| | """Generates Cypher using the 'guided' method based on concepts.""" |
| | logger.info(f"Generating Cypher using 'guided' method for plan step {plan_step}.") |
| | try: |
| | concepts = neo4j_client.get_concepts() |
| | if not concepts: |
| | logger.warning("No concepts found in Neo4j for guided cypher generation.") |
| | return "" |
| |
|
| | concept_llm = get_llm(settings.main_llm_model) |
| | concept_chain = ( |
| | CONCEPT_SELECTION_PROMPT |
| | | concept_llm |
| | | StrOutputParser() |
| | ) |
| | selected_concept = invoke_llm(concept_chain,{ |
| | "question": question, |
| | "concepts": "\n".join(concepts) |
| | }).strip() |
| |
|
| | logger.info(f"Concept selected by LLM: {selected_concept}") |
| |
|
| | |
| | if selected_concept not in concepts: |
| | logger.warning(f"LLM selected concept '{selected_concept}' not in the known list. Attempting fallback or ignoring.") |
| | |
| | |
| | |
| | found_match = None |
| | for c in concepts: |
| | if selected_concept.lower() in c.lower(): |
| | found_match = c |
| | logger.info(f"Found potential match: '{found_match}'") |
| | break |
| | if not found_match: |
| | logger.error(f"Could not validate selected concept: {selected_concept}") |
| | return "" |
| | selected_concept = found_match |
| |
|
| |
|
| | |
| | |
| | if plan_step <= 1: |
| | target = "(ts:TechnicalSpecification)" |
| | fields = "ts.title, ts.scope, ts.description" |
| | elif plan_step == 2: |
| | target = "(rp:ResearchPaper)" |
| | fields = "rp.title, rp.abstract" |
| | else: |
| | target = "(n)" |
| | fields = "n.title, n.description" |
| |
|
| | |
| | |
| | cypher = f"MATCH (c:Concept {{name: $conceptName}})-[:RELATED_TO]-{target} RETURN {fields}" |
| | |
| | |
| | |
| | escaped_concept = selected_concept.replace("'", "\\'") |
| | cypher = f"MATCH (c:Concept {{name: '{escaped_concept}'}})-[:RELATED_TO]-{target} RETURN {fields}" |
| |
|
| | logger.info(f"Generated guided Cypher: {cypher}") |
| | return cypher |
| |
|
| | except Exception as e: |
| | logger.error(f"Error during guided cypher generation: {e}", exc_info=True) |
| | time.sleep(60) |
| | return "" |
| |
|
| |
|
| | |
| | def retrieve_documents(cypher_query: str) -> List[Dict[str, Any]]: |
| | """Retrieves documents from Neo4j using a Cypher query.""" |
| | if not cypher_query: |
| | logger.warning("Received empty Cypher query, skipping retrieval.") |
| | return [] |
| | logger.info(f"Retrieving documents with Cypher: {cypher_query} limit 10") |
| | try: |
| | |
| | raw_results = neo4j_client.query(cypher_query + " limit 10") |
| | |
| | processed_results = [] |
| | seen = set() |
| | for doc in raw_results: |
| | |
| | doc_items = frozenset(doc.items()) |
| | if doc_items not in seen: |
| | processed_results.append(doc) |
| | seen.add(doc_items) |
| | logger.info(f"Retrieved {len(processed_results)} unique documents.") |
| | return processed_results |
| | except (ConnectionError, ValueError, RuntimeError) as e: |
| | |
| | logger.error(f"Document retrieval failed: {e}") |
| | return [] |
| |
|
| |
|
| | |
| | |
| | class GradeDocumentsBinary(V1BaseModel): |
| | """Binary score for relevance check.""" |
| | binary_score: str = Field(description="Relevant? 'yes' or 'no'") |
| |
|
| | class GradeDocumentsScore(V1BaseModel): |
| | """Score for relevance check.""" |
| | rationale: str = Field(description="Rationale for the score.") |
| | score: float = Field(description="Relevance score (0.0 to 1.0)") |
| |
|
| | def evaluate_documents( |
| | docs: List[Dict[str, Any]], |
| | query: str |
| | ) -> List[Dict[str, Any]]: |
| | """Evaluates document relevance to a query using configured method.""" |
| | if not docs: |
| | return [] |
| |
|
| | logger.info(f"Evaluating {len(docs)} documents for relevance to query: '{query}' using method: {settings.eval_method}") |
| | eval_llm = get_llm(settings.eval_llm_model) |
| | valid_docs_with_scores: List[Tuple[Dict[str, Any], float]] = [] |
| |
|
| | |
| | |
| | |
| |
|
| | if settings.eval_method == "binary": |
| | binary_grader = BINARY_GRADER_PROMPT | eval_llm | StrOutputParser() |
| | for doc in docs: |
| | formatted_doc = format_doc_for_llm(doc) |
| | if not formatted_doc.strip(): continue |
| | try: |
| | result = invoke_llm(binary_grader,{"question": query, "document": formatted_doc}) |
| | logger.debug(f"Binary grader result for doc '{doc.get('title', 'N/A')}': {result}") |
| | if result and 'yes' in result.lower(): |
| | valid_docs_with_scores.append((doc, 1.0)) |
| | except Exception as e: |
| | logger.warning(f"Binary grading failed for a document: {e}", exc_info=True) |
| |
|
| | elif settings.eval_method == "score": |
| | |
| | score_grader = SCORE_GRADER_PROMPT | eval_llm | JsonOutputParser(pydantic_object=GradeDocumentsScore) |
| | for doc in docs: |
| | formatted_doc = format_doc_for_llm(doc) |
| | if not formatted_doc.strip(): continue |
| | try: |
| | result: GradeDocumentsScore = invoke_llm(score_grader,{"query": query, "document": formatted_doc}) |
| | logger.debug(f"Score grader result for doc '{doc.get('title', 'N/A')}': Score={result.score}, Rationale={result.rationale}") |
| | if result.score >= settings.eval_threshold: |
| | valid_docs_with_scores.append((doc, result.score)) |
| | except Exception as e: |
| | logger.warning(f"Score grading failed for a document: {e}", exc_info=True) |
| | |
| |
|
| | |
| | if settings.eval_method == 'score': |
| | valid_docs_with_scores.sort(key=lambda item: item[1], reverse=True) |
| |
|
| | |
| | final_docs = [doc for doc, score in valid_docs_with_scores[:settings.max_docs]] |
| | logger.info(f"Found {len(final_docs)} relevant documents after evaluation and filtering.") |
| |
|
| | return final_docs |