Files
SmartVoyage/app_streamlit/main.py
2026-03-20 22:56:24 +08:00

251 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import asyncio
import uuid
import streamlit as st
from python_a2a import AgentNetwork, Message, TextContent, MessageRole, Task
from langchain_openai import ChatOpenAI
import json
from datetime import datetime
import pytz
import re # 用于清理响应
from create_logger import logger
from app.prompts import SmartVoyagePrompts
from conf import settings
# 启动命令streamlit run main.py
# 设置页面配置
st.set_page_config(page_title="基于A2A的SmartVoyage旅行助手系统", layout="wide", page_icon="🤖")
# 自定义 CSS 打造高端大气科技感,优化对比度
st.markdown("""
<style>
/* 聊天消息框样式 */
.stChatMessage {
background-color: #2c3e50 !important;
border-radius: 12px !important;
padding: 15px !important;
margin-bottom: 15px !important;
box-shadow: 0 3px 6px rgba(0,0,0,0.2) !important;
}
/* 用户消息框稍亮 */
.stChatMessage.user {
background-color: #34495e !important;
}
/* ✅ 核心:强制所有文字变为白色(包括 markdown 内部) */
.stChatMessage .stMarkdown,
.stChatMessage .stMarkdown p,
.stChatMessage .stMarkdown span,
.stChatMessage .stMarkdown div,
.stChatMessage .stMarkdown strong,
.stChatMessage .stMarkdown em,
.stChatMessage .stMarkdown code {
color: #ffffff !important;
}
/* 如果你想让 emoji 图标更亮一点 */
.stChatMessage [data-testid="stChatMessageAvatarIcon"] {
filter: brightness(1.2);
}
</style>
""", unsafe_allow_html=True)
# 初始化会话状态
if "messages" not in st.session_state:
st.session_state.messages = []
if "agent_network" not in st.session_state:
# 存储代理URL信息便于查看
st.session_state.agent_urls = {
"WeatherQueryAssistant": "http://localhost:5005",
"TicketQueryAssistant": "http://localhost:5006",
"TicketOrderAssistant": "http://localhost:5007"
}
# 初始化网络
network = AgentNetwork(name="Travel Assistant Network")
network.add("WeatherQueryAssistant", "http://localhost:5005")
network.add("TicketQueryAssistant", "http://localhost:5006")
network.add("TicketOrderAssistant", "http://localhost:5007")
st.session_state.agent_network = network
# 加载配置并创建LLM
st.session_state.llm = ChatOpenAI(
model=settings.model_name,
api_key=settings.api_key,
base_url=settings.base_url,
temperature=0.1
)
# 存储对话历史用于意图识别
st.session_state.conversation_history = ""
# 意图识别agent
def intent_agent(user_input):
# 创建意图识别链:提示模板 + LLM
chain = SmartVoyagePrompts.intent_prompt() | st.session_state.llm
# 调用LLM进行意图识别
current_date = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d') # 获取当前日期Asia/Shanghai时区
intent_response = chain.invoke(
{"conversation_history": '\n'.join(st.session_state.conversation_history.split("\n")[-6:]), "query": user_input,
"current_date": current_date}).content.strip()
logger.info(f"意图识别原始响应: {intent_response}")
# 清理响应移除可能的Markdown代码块标记
intent_response = re.sub(r'^```json\s*|\s*```$', '', intent_response).strip()
logger.info(f"清理后响应: {intent_response}")
intent_output = json.loads(intent_response)
# 提取意图、改写问题和追问消息
intents = intent_output.get("intents", [])
user_queries = intent_output.get("user_queries", {})
follow_up_message = intent_output.get("follow_up_message", "")
logger.info(f"intents: {intents}||user_queries: {user_queries}||follow_up_message: {follow_up_message} ")
return intents, user_queries, follow_up_message
# 主界面布局
st.title("🤖 基于A2A的SmartVoyage旅行智能助手")
st.markdown("欢迎体验智能对话!输入问题,系统将精准识别意图并提供服务。")
# 两栏布局:左侧对话,右侧 Agent Card
col1, col2 = st.columns([2, 1])
# 左侧对话区域
with col1:
st.subheader("💬 对话")
# 对话历史
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 输入框
if prompt := st.chat_input("请输入您的问题..."):
# 显示用户消息
with st.chat_message("user"):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
st.session_state.conversation_history += f"\nUser: {prompt}"
# 获取 LLM 和当前日期
llm = st.session_state.llm
current_date = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d')
# 意图识别
with st.spinner("正在分析您的意图..."):
try:
# 意图识别过程
intents, user_queries, follow_up_message = intent_agent(prompt)
# 根据意图输出生成响应
if "out_of_scope" in intents:
# 如果意图超出范围,返回大模型直接回复
response = follow_up_message
st.session_state.conversation_history += f"\nAssistant: {response}"
elif follow_up_message != "":
# 如果有追问消息,则直接返回
response = follow_up_message
st.session_state.conversation_history += f"\nAssistant: {response}" # 更新历史
else: # 处理有效意图
responses = [] # 存储每个意图的响应列表
routed_agents = [] # 记录路由到的代理列表
for intent in intents:
logger.info(f"处理意图:{intent}")
# 根据意图确定代理名称
if intent == "weather":
agent_name = "WeatherQueryAssistant"
elif intent in ["flight", "train", "concert"]:
agent_name = "TicketQueryAssistant"
elif intent == "order":
agent_name = "TicketOrderAssistant"
else:
agent_name = None
# 不同意图处理方式
if intent == "attraction":
# 对于景点推荐直接使用LLM生成
chain = SmartVoyagePrompts.attraction_prompt() | llm
rec_response = chain.invoke({"query": prompt}).content.strip()
responses.append(rec_response)
elif agent_name:
# 对于代理意图,则调用代理
# 1获取问题
query_str = user_queries.get(intent, {})
logger.info(f"{agent_name} 查询:{query_str}")
# 2获取代理实例
agent = st.session_state.agent_network.get_agent(agent_name)
# 3构建历史对话信息+新查询,然后调用代理
chat_history = '\n'.join(st.session_state.conversation_history.split("\n")[-7:-1]) + f'\nUser: {query_str}'
message = Message(content=TextContent(text=chat_history), role=MessageRole.USER)
task = Task(id="task-" + str(uuid.uuid4()), message=message.to_dict())
raw_response = asyncio.run(agent.send_task_async(task))
logger.info(f"{agent_name} 原始响应: {raw_response}") # 记录原始响应日志
# 4处理结果
if raw_response.status.state == 'completed': # 正常结果
agent_result = raw_response.artifacts[0]['parts'][0]['text']
else: # 异常结果
agent_result = raw_response.status.message['content']['text']
# 根据代理类型总结响应
if agent_name == "WeatherQueryAssistant":
chain = SmartVoyagePrompts.summarize_weather_prompt() | llm
final_response = chain.invoke(
{"query": query_str, "raw_response": agent_result}).content.strip()
elif agent_name == "TicketQueryAssistant":
chain = SmartVoyagePrompts.summarize_ticket_prompt() | llm
final_response = chain.invoke(
{"query": query_str, "raw_response": agent_result}).content.strip()
else:
final_response = agent_result
# 5添加到历史
responses.append(final_response) # 添加到响应列表
routed_agents.append(agent_name) # 记录路由代理
else:
# 不支持的意图
responses.append("暂不支持此意图。")
response = "\n\n".join(responses)
if routed_agents:
logger.info(f"路由到代理:{routed_agents}")
st.session_state.conversation_history += f"\nAssistant: {response}"
# 显示助手消息
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
except json.JSONDecodeError as json_err:
logger.error(f"意图识别JSON解析失败")
error_message = f"意图识别JSON解析失败{str(json_err)}。请重试。"
with st.chat_message("assistant"):
st.markdown(error_message)
st.session_state.messages.append({"role": "assistant", "content": error_message})
except Exception as e:
logger.error(f"处理异常: {str(e)}")
error_message = f"处理失败:{str(e)}。请重试。"
with st.chat_message("assistant"):
st.markdown(error_message)
st.session_state.messages.append({"role": "assistant", "content": error_message})
# 右侧 Agent Card 区域
with col2:
st.subheader("🛠️ AgentCard")
for agent_name in st.session_state.agent_network.agents.keys():
agent_card = st.session_state.agent_network.get_agent_card(agent_name)
agent_url = st.session_state.agent_urls.get(agent_name, "未知地址")
with st.expander(f"Agent: {agent_name}", expanded=False):
st.markdown(f"<div class='card-title'>技能</div>", unsafe_allow_html=True)
st.markdown(f"<div class='card-content'>{agent_card.skills}</div>", unsafe_allow_html=True)
st.markdown(f"<div class='card-title'>描述</div>", unsafe_allow_html=True)
st.markdown(f"<div class='card-content'>{agent_card.description}</div>", unsafe_allow_html=True)
st.markdown(f"<div class='card-title'>地址</div>", unsafe_allow_html=True)
st.markdown(f"<div class='card-content'>{agent_url}</div>", unsafe_allow_html=True)
st.markdown(f"<div class='card-title'>状态</div>", unsafe_allow_html=True)
st.markdown(f"<div class='card-content'>在线</div>", unsafe_allow_html=True)
# 页脚
st.markdown("---")
st.markdown('<div class="footer">Powered by 黑马程序员 | 基于Agent2Agent的旅行助手系统 v2.0</div>', unsafe_allow_html=True)