| | from langchain.document_loaders import PyPDFLoader, DirectoryLoader |
| | from langchain import PromptTemplate |
| | from langchain.embeddings import HuggingFaceEmbeddings |
| | from langchain.vectorstores import FAISS |
| | from langchain.llms import CTransformers |
| | from langchain.chains import RetrievalQA |
| | import chainlit as cl |
| |
|
| | DB_FAISS_PATH = 'vectorstore/db_faiss' |
| |
|
| | custom_prompt_template = """Use the following pieces of information to answer the user's questions. |
| | If you don't know the answer, just say that you don't know, but don't try to make up an answer. |
| | |
| | Context: {context} |
| | Question: {question} |
| | |
| | Only return the helpful answer below and nothing else. |
| | Helpful answer: |
| | """ |
| |
|
| | def set_custom_prompt(): |
| | """ |
| | Prompt template for QA retrieval for each vectorstore |
| | """ |
| | prompt = PromptTemplate(template=custom_prompt_template, |
| | input_variables=['context', 'question']) |
| | return prompt |
| |
|
| | |
| | def retrieval_qa_chain(llm, prompt, db): |
| | qa_chain = RetrievalQA.from_chain_type(llm=llm, |
| | chain_type='stuff', |
| | retriever=db.as_retriever(search_kwargs={'k': 2}), |
| | return_source_documents=True, |
| | chain_type_kwargs={'prompt': prompt} |
| | ) |
| | return qa_chain |
| |
|
| | |
| | def load_llm(): |
| | |
| | llm = CTransformers( |
| | model = "llama-2-7b-chat.ggmlv3.q8_0.bin", |
| | model_type="llama", |
| | max_new_tokens = 512, |
| | temperature = 0.5 |
| | ) |
| | return llm |
| |
|
| | |
| | def qa_bot(): |
| | embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", |
| | model_kwargs={'device': 'cpu'}) |
| | db = FAISS.load_local(DB_FAISS_PATH, embeddings) |
| | llm = load_llm() |
| | qa_prompt = set_custom_prompt() |
| | qa = retrieval_qa_chain(llm, qa_prompt, db) |
| |
|
| | return qa |
| |
|
| | |
| | def final_result(query): |
| | qa_result = qa_bot() |
| | response = qa_result({'query': query}) |
| | return response |
| |
|
| | |
| | @cl.on_chat_start |
| | async def start(): |
| | chain = qa_bot() |
| | msg = cl.Message(content="Starting the bot...") |
| | await msg.send() |
| | msg.content = "Hi, Welcome to AstroBot. What is your query?" |
| | await msg.update() |
| |
|
| | cl.user_session.set("chain", chain) |
| |
|
| | @cl.on_message |
| | async def main(message): |
| | chain = cl.user_session.get("chain") |
| | cb = cl.AsyncLangchainCallbackHandler( |
| | stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"] |
| | ) |
| | cb.answer_reached = True |
| | res = await chain.acall(message, callbacks=[cb]) |
| | answer = res["result"] |
| | sources = res["source_documents"] |
| |
|
| | if sources: |
| | answer += f"\nSources:" + str(sources) |
| | else: |
| | answer += "\nNo sources found" |
| |
|
| | await cl.Message(content=answer).send() |