chat_services.py•3.58 kB
import re
import time
from langchain_core.caches import InMemoryCache
from langchain_core.globals import set_llm_cache
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableSequence
from langchain_core.output_parsers import StrOutputParser
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from app.core.config import GROQ_API_KEY
from app.mcp.client import mcp_client
# Enable in-memory cache for performance
set_llm_cache(InMemoryCache())
# System behavior template
system_message = (
"You are a helpful grocery assistant for an online store in Bangladesh. "
"Use the context below to answer all product-related queries. "
"Be polite, accurate, and never answer beyond the context."
)
# Prompt for LLM
prompt = ChatPromptTemplate.from_template(
"""{system_message}
Context:
{context}
{chat_history}
User: {question}
Assistant:"""
).partial(system_message=system_message)
# Set up the Groq LLM
llm = ChatGroq(api_key=GROQ_API_KEY, model="llama3-70b-8192", temperature=0.3)
# Session-based memory storage
message_histories = {}
chat_sessions = {}
# Returns memory history per session
def get_history(session_id):
if session_id not in message_histories:
message_histories[session_id] = InMemoryChatMessageHistory()
return message_histories[session_id]
# Returns QA chain per session
def get_chain(session_id):
if session_id not in chat_sessions:
chain = RunnableSequence(prompt | llm | StrOutputParser())
wrapped = RunnableWithMessageHistory(
chain,
get_session_history=get_history,
input_messages_key="question",
history_messages_key="chat_history",
)
chat_sessions[session_id] = wrapped
return chat_sessions[session_id]
# Format product metadata into readable context
def format_product_context(products):
lines = []
for p in products:
p.pop("_id", None)
for k, v in p.items():
lines.append(f"{k}: {v}")
lines.append("---")
return "\n".join(lines)
# Main async handler
async def async_get_answer_for_session(session_id, question):
start = time.time()
lower = question.lower()
# Simple keyword-based count
if "how many" in lower and "product" in lower:
count = await mcp_client.count_products()
return f"We currently have {count} products in the shop."
# Try to detect product by ID
match = re.search(r'\bP-\d{3}\b', question)
if match:
product = await mcp_client.get_product_by_id(match.group().upper())
context = format_product_context([product]) if product else "No info found."
else:
# Fallback to search
products = await mcp_client.search_products(query=question)
context = format_product_context(products)
# Get memory + chain
history = get_history(session_id)
chain = get_chain(session_id)
# Format inputs with session memory
inputs = {
"context": context,
"question": question,
"chat_history": "\n".join(
f"{'User' if m.type == 'human' else 'Assistant'}: {m.content}" for m in history.messages
)
}
# Call LLM
response = chain.invoke(inputs, config={"configurable": {"session_id": session_id}})
history.add_user_message(question)
history.add_ai_message(response)
print(f"[DEBUG] Responded in {time.time() - start:.2f}s")
return response