166 lines
6.0 KiB
Python
166 lines
6.0 KiB
Python
import pinecone
|
|
|
|
from langchain.vectorstores import Pinecone as PineconeStore
|
|
from langchain.memory import ConversationBufferMemory
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.memory.chat_message_histories import RedisChatMessageHistory
|
|
# from langchain.vectorstores.redis import Redis
|
|
from langchain.agents import initialize_agent, Tool
|
|
from langchain.agents import AgentType
|
|
|
|
from langchain.tools.render import render_text_description
|
|
from langchain.agents.output_parsers import ReActSingleInputOutputParser
|
|
from langchain.agents.format_scratchpad import format_log_to_str
|
|
from langchain import hub
|
|
from langchain.agents import AgentExecutor
|
|
|
|
from src.config import REDIS_HOST, PINECONE_API_ENV, PINECONE_API_KEY, PINECONE_INDEX_NAME
|
|
from src.chat.document_loader import document_loader
|
|
from src.chat.model_manager import ModelManager
|
|
from src.bot.service import bot_service
|
|
|
|
class BaseMessage:
|
|
def __init__(self, question, result):
|
|
self.question = question
|
|
self.result = result
|
|
|
|
class ChatManager:
|
|
def __init__(self):
|
|
self.chat_history = None
|
|
self.model_manager = ModelManager()
|
|
self.llm = None
|
|
self.redis_message_history = None
|
|
|
|
async def initialize(self, data):
|
|
try:
|
|
return await self.chat_response(data)
|
|
except Exception as e:
|
|
return f"--------------------There's an error occurred in initializing chat manager: {e}"
|
|
|
|
async def get_llm(self, data):
|
|
try:
|
|
bot_details = await bot_service.get_bot_settings(data["bot_id"])
|
|
bot_details = bot_details[0].advance_settings
|
|
llm_ecosystem = bot_details.llm_model if hasattr(bot_details, 'llm_model') else "openai"
|
|
except Exception as e:
|
|
return f"--------------------There's an error occurred in get llm method: {e}"
|
|
|
|
self.llm = self.model_manager.llm_selector(llm_ecosystem,data['stream_handler'])
|
|
return llm_ecosystem
|
|
|
|
def nlp_process(self, data):
|
|
embeddings = self.model_manager.embedding_selector(data["embedding"])
|
|
try:
|
|
self.redis_message_history = RedisChatMessageHistory(url=REDIS_HOST, ttl=600, session_id=data["chat_session"])
|
|
print('--------------------redis message_history: ', self.redis_message_history.messages)
|
|
except Exception as e:
|
|
print('error redischathistory: ', e)
|
|
return {"message": f"An error occurred : {e}"}
|
|
try:
|
|
memory = ConversationBufferMemory(memory_key="memory",chat_memory=self.redis_message_history, return_messages=True)
|
|
self.chat_history = memory
|
|
print('--------------------memory: ', memory)
|
|
except Exception as e:
|
|
print('error conversation buffer memory: ', e)
|
|
return {"message": f"An error occurred : {e}"}
|
|
try:
|
|
docs = document_loader(data["bot_id"])
|
|
print('--------------------docs: ', docs)
|
|
except Exception as e:
|
|
print('error document loader: ', e)
|
|
return {"message": f"An error occurred : {e}"}
|
|
|
|
|
|
|
|
try:
|
|
query = data['question']
|
|
query_result = embeddings.embed_query(query)
|
|
doc_result = embeddings.embed_documents([t.page_content for t in docs])
|
|
print('--------------------Store embeddings to vector store: Pinecone')
|
|
pinecone.init(
|
|
api_key=PINECONE_API_KEY, # find at app.pinecone.io
|
|
environment=PINECONE_API_ENV # next to api key in console
|
|
)
|
|
index = pinecone.Index(PINECONE_INDEX_NAME)
|
|
# print(index.describe_index_stats())
|
|
# index.delete(deleteAll='true', namespace='')
|
|
except Exception as e:
|
|
print(f"--------------------Error on Pinecone process: {e}")
|
|
|
|
docsearch = PineconeStore.from_texts([t.page_content for t in docs], embeddings, index_name=PINECONE_INDEX_NAME)
|
|
retriever = self.model_manager.retriever_selector(data["retriever"], docsearch)
|
|
qa_tool = RetrievalQA.from_chain_type(
|
|
llm=self.llm, retriever=retriever, memory=memory,
|
|
verbose=True
|
|
)
|
|
return qa_tool
|
|
|
|
def custom_agent(self, tools):
|
|
max_iterations = 3
|
|
handle_parsing_errors = True
|
|
early_stopping_method="generate"
|
|
verbose=True
|
|
agent = AgentType.REACT_DOCSTORE
|
|
try:
|
|
return initialize_agent(
|
|
tools,
|
|
self.llm,
|
|
agent,
|
|
max_iterations,
|
|
handle_parsing_errors,
|
|
early_stopping_method,
|
|
verbose,
|
|
memory=self.chat_history,
|
|
)
|
|
except Exception as e:
|
|
return f"Error occured in custom agent function: {e}"
|
|
|
|
async def chat_response(self,data):
|
|
print(f"--------------------chat history: {self.chat_history}")
|
|
llm_ecosystem = await self.get_llm(data)
|
|
qa_tool = self.nlp_process({
|
|
"question": data["question"],
|
|
"chat_history": self.chat_history,
|
|
"embedding":llm_ecosystem,
|
|
"retriever":llm_ecosystem,
|
|
"bot_id":data["bot_id"],
|
|
"chat_session": data["chat_session"]
|
|
})
|
|
|
|
tools = [
|
|
Tool(
|
|
name='Document Store',
|
|
func=qa_tool.run,
|
|
description="Use it to lookup information from the document store. You must return a final answer, not an action. If you can't return a final answer, return you don't know the answer.",
|
|
),
|
|
]
|
|
|
|
prompt = hub.pull("hwchase17/react-chat")
|
|
prompt = prompt.partial(
|
|
tools=render_text_description(tools),
|
|
tool_names=", ".join([t.name for t in tools]),
|
|
)
|
|
|
|
llm_with_stop = self.llm.bind(stop=["\nObservation"])
|
|
try:
|
|
agent = {
|
|
"input": lambda x: x["input"],
|
|
"agent_scratchpad": lambda x: format_log_to_str(x['intermediate_steps']),
|
|
"chat_history": lambda x: x["memory"]
|
|
} | prompt | llm_with_stop | ReActSingleInputOutputParser()
|
|
|
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, memory=self.chat_history, max_iterations=3, early_stopping_method= 'generate')
|
|
result = agent_executor.invoke({"input": data["question"]})['output']
|
|
|
|
self.redis_message_history.add_user_message(data["question"])
|
|
|
|
except Exception as e:
|
|
return f"There's an error occured in chat response function: {e}"
|
|
|
|
|
|
print(f"Result: {result}")
|
|
self.redis_message_history.add_ai_message(result)
|
|
|
|
return result
|
|
|
|
chat_manager = ChatManager() |