| | import logging |
| | import re |
| | import time |
| | from typing import List, Dict, Any, Optional |
| | from langgraph.graph import StateGraph, END |
| | from langgraph.checkpoint.memory import MemorySaver |
| |
|
| | from pydantic import BaseModel, Field |
| |
|
| | from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
| | from langchain_core.output_parsers import StrOutputParser, JsonOutputParser |
| |
|
| | from .config import settings |
| | from .schemas import PlannerState, KeyIssue, GraphConfig |
| | from .prompts import get_initial_planner_prompt, KEY_ISSUE_STRUCTURING_PROMPT |
| | from .llm_interface import get_llm, invoke_llm |
| | from .graph_operations import ( |
| | generate_cypher_auto, generate_cypher_guided, |
| | retrieve_documents, evaluate_documents |
| | ) |
| | from .processing import process_documents |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| |
|
| | def start_planning(state: PlannerState) -> Dict[str, Any]: |
| | """Generates the initial plan based on the user query.""" |
| | logger.info("Node: start_planning") |
| | user_query = state['user_query'] |
| | if not user_query: |
| | return {"error": "User query is empty."} |
| |
|
| | initial_prompt = get_initial_planner_prompt(settings.plan_method, user_query) |
| | llm = get_llm(settings.main_llm_model) |
| | chain = initial_prompt | llm | StrOutputParser() |
| |
|
| | try: |
| | plan_text = invoke_llm(chain,{}) |
| | logger.debug(f"Raw plan text: {plan_text}") |
| |
|
| | |
| | plan_match = re.search(r"Plan:(.*?)<END_OF_PLAN>", plan_text, re.DOTALL | re.IGNORECASE) |
| | if plan_match: |
| | plan_steps = [step.strip() for step in re.split(r"\n\s*\d+\.\s*", plan_match.group(1)) if step.strip()] |
| | logger.info(f"Extracted plan: {plan_steps}") |
| | return { |
| | "plan": plan_steps, |
| | "current_plan_step_index": 0, |
| | "messages": [AIMessage(content=plan_text)], |
| | "step_outputs": {} |
| | } |
| | else: |
| | logger.error("Could not parse plan from LLM response.") |
| | return {"error": "Failed to parse plan from LLM response.", "messages": [AIMessage(content=plan_text)]} |
| | except Exception as e: |
| | logger.error(f"Error during plan generation: {e}", exc_info=True) |
| | return {"error": f"LLM error during plan generation: {e}"} |
| |
|
| |
|
| | def execute_plan_step(state: PlannerState) -> Dict[str, Any]: |
| | """Executes the current step of the plan (retrieval, processing).""" |
| | current_index = state['current_plan_step_index'] |
| | plan = state['plan'] |
| | user_query = state['user_query'] |
| |
|
| | if current_index >= len(plan): |
| | logger.warning("Plan step index out of bounds, attempting to finalize.") |
| | |
| | return {"error": "Plan execution finished unexpectedly."} |
| |
|
| | step_description = plan[current_index] |
| | logger.info(f"Node: execute_plan_step - Step {current_index + 1}/{len(plan)}: {step_description}") |
| |
|
| | |
| | |
| | |
| | query_for_retrieval = f"Regarding the query '{user_query}', focus on: {step_description}" |
| | logger.info(f"Query for retrieval: {query_for_retrieval}") |
| |
|
| | |
| | cypher_query = "" |
| | if settings.cypher_gen_method == 'auto': |
| | cypher_query = generate_cypher_auto(query_for_retrieval) |
| | elif settings.cypher_gen_method == 'guided': |
| | cypher_query = generate_cypher_guided(query_for_retrieval, current_index) |
| | |
| |
|
| | |
| | retrieved_docs = retrieve_documents(cypher_query) |
| |
|
| | |
| | evaluated_docs = evaluate_documents(retrieved_docs, query_for_retrieval) |
| |
|
| | |
| | |
| | processed_docs_content = process_documents(evaluated_docs, settings.process_steps) |
| |
|
| | |
| | |
| | step_output = "\n\n".join(processed_docs_content) if processed_docs_content else "No relevant information found for this step." |
| | current_step_outputs = state.get('step_outputs', {}) |
| | current_step_outputs[current_index] = step_output |
| |
|
| | logger.info(f"Finished executing plan step {current_index + 1}. Stored output.") |
| |
|
| | return { |
| | "current_plan_step_index": current_index + 1, |
| | "messages": [SystemMessage(content=f"Completed plan step {current_index + 1}. Context gathered:\n{step_output[:500]}...")], |
| | "step_outputs": current_step_outputs |
| | } |
| |
|
| | class KeyIssue(BaseModel): |
| | |
| | id: int |
| | description: str |
| |
|
| | class KeyIssueList(BaseModel): |
| | key_issues: List[KeyIssue] = Field(description="List of key issues") |
| |
|
| | class KeyIssueInvoke(BaseModel): |
| | id: int |
| | title: str |
| | description: str |
| | challenges: List[str] |
| | potential_impact: Optional[str] = None |
| |
|
| | def generate_structured_issues(state: PlannerState) -> Dict[str, Any]: |
| | """Generates the final structured Key Issues based on all gathered context.""" |
| | logger.info("Node: generate_structured_issues") |
| |
|
| | user_query = state['user_query'] |
| | step_outputs = state.get('step_outputs', {}) |
| |
|
| | |
| | full_context = f"Original User Query: {user_query}\n\n" |
| | full_context += "Context gathered during planning:\n" |
| | for i, output in sorted(step_outputs.items()): |
| | full_context += f"--- Context from Step {i+1} ---\n{output}\n\n" |
| |
|
| | if not step_outputs: |
| | full_context += "No context was gathered during the planning steps.\n" |
| |
|
| | logger.info(f"Generating key issues using combined context (length: {len(full_context)} chars).") |
| | |
| |
|
| | |
| | issue_llm = get_llm(settings.main_llm_model) |
| | |
| | output_parser = JsonOutputParser(pydantic_object=KeyIssueList) |
| |
|
| | |
| | prompt = KEY_ISSUE_STRUCTURING_PROMPT.partial( |
| | schema=output_parser.get_format_instructions(), |
| | ) |
| |
|
| | chain = prompt | issue_llm | output_parser |
| |
|
| | try: |
| | structured_issues_obj = invoke_llm(chain, { |
| | "user_query": user_query, |
| | "context": full_context |
| | }) |
| | print(f"structured_issues_obj => type : {type(structured_issues_obj)}, value : {structured_issues_obj}") |
| | |
| | |
| | if isinstance(structured_issues_obj, dict) and 'key_issues' in structured_issues_obj: |
| | issues_data = structured_issues_obj['key_issues'] |
| | else: |
| | issues_data = structured_issues_obj |
| | |
| | |
| | key_issues_list = [KeyIssueInvoke(**issue_dict) for issue_dict in issues_data] |
| | |
| | |
| | for i, issue in enumerate(key_issues_list): |
| | issue.id = i + 1 |
| | |
| | logger.info(f"Successfully generated {len(key_issues_list)} structured key issues.") |
| | final_message = f"Generated {len(key_issues_list)} Key Issues based on the query '{user_query}'." |
| | return { |
| | "key_issues": key_issues_list, |
| | "messages": [AIMessage(content=final_message)], |
| | "error": None |
| | } |
| | except Exception as e: |
| | logger.error(f"Failed to generate or parse structured key issues: {e}", exc_info=True) |
| | |
| | raw_output = "Could not retrieve raw output." |
| | try: |
| | raw_chain = prompt | issue_llm | StrOutputParser() |
| | raw_output = invoke_llm(raw_chain, {"user_query": user_query, "context": full_context}) |
| | logger.debug(f"Raw output from failed JSON parsing:\n{raw_output}") |
| | except Exception as raw_e: |
| | logger.error(f"Could not even get raw output: {raw_e}") |
| | |
| | return {"error": f"Failed to generate structured key issues: {e}. Raw output hint: {raw_output[:500]}..."} |
| |
|
| |
|
| | |
| |
|
| | def should_continue_planning(state: PlannerState) -> str: |
| | """Determines if there are more plan steps to execute.""" |
| | logger.debug("Edge: should_continue_planning") |
| | if state.get("error"): |
| | logger.error(f"Error state detected: {state['error']}. Ending execution.") |
| | return "error_state" |
| |
|
| | current_index = state['current_plan_step_index'] |
| | plan_length = len(state.get('plan', [])) |
| |
|
| | if current_index < plan_length: |
| | logger.debug(f"Continuing plan execution. Next step index: {current_index}") |
| | return "continue_execution" |
| | else: |
| | logger.debug("Plan finished. Proceeding to final generation.") |
| | return "finalize" |
| |
|
| |
|
| | |
| | def build_graph(): |
| | """Builds the LangGraph workflow.""" |
| | workflow = StateGraph(PlannerState) |
| |
|
| | |
| | workflow.add_node("start_planning", start_planning) |
| | workflow.add_node("execute_plan_step", execute_plan_step) |
| | workflow.add_node("generate_issues", generate_structured_issues) |
| | |
| | workflow.add_node("error_node", lambda state: {"messages": [SystemMessage(content=f"Execution failed: {state.get('error', 'Unknown error')}") ]}) |
| |
|
| |
|
| | |
| | workflow.set_entry_point("start_planning") |
| | workflow.add_edge("start_planning", "execute_plan_step") |
| |
|
| | workflow.add_conditional_edges( |
| | "execute_plan_step", |
| | should_continue_planning, |
| | { |
| | "continue_execution": "execute_plan_step", |
| | "finalize": "generate_issues", |
| | "error_state": "error_node" |
| | } |
| | ) |
| |
|
| | workflow.add_edge("generate_issues", END) |
| | workflow.add_edge("error_node", END) |
| |
|
| | |
| | |
| | |
| | app_graph = workflow.compile() |
| | return app_graph |