feat: 项目完成
This commit is contained in:
158
a2a_server/order_server.py
Normal file
158
a2a_server/order_server.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from python_a2a import AgentCard, AgentSkill, run_server, TaskStatus, TaskState, A2AServer, A2AClient, Message, \
|
||||
TextContent, MessageRole, Task
|
||||
|
||||
from create_logger import logger
|
||||
from conf import settings
|
||||
|
||||
|
||||
# 初始化LLM
|
||||
llm = ChatOpenAI(
|
||||
model=settings.model_name,
|
||||
base_url=settings.base_url,
|
||||
api_key=settings.api_key,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
|
||||
# 定义订票函数
|
||||
async def order_tickets(query):
|
||||
try:
|
||||
# 启动 MCP server,通过streamable建立连接
|
||||
async with streamablehttp_client("http://127.0.0.1:8003/mcp") as (read, write, _):
|
||||
# 使用读写通道创建 MCP 会话
|
||||
async with ClientSession(read, write) as session:
|
||||
try:
|
||||
await session.initialize()
|
||||
|
||||
# 从 session 自动获取 MCP server 提供的工具列表。
|
||||
tools = await load_mcp_tools(session)
|
||||
# print(f"tools-->{tools}")
|
||||
|
||||
# 创建 agent 的提示模板
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system",
|
||||
"你是一个票务预定助手,能够调用工具来完成火车票、飞机票或演出票的预定。你需要仔细分析工具需要的参数,然后从用户提供的信息中提取信息。如果用户提供的信息不足以提取到调用工具所有必要参数,则向用户追问,以获取该信息。不能自己编撰参数。"),
|
||||
("human", "{input}"),
|
||||
("placeholder", "{agent_scratchpad}"),
|
||||
])
|
||||
|
||||
# 构建工具调用代理
|
||||
agent = create_tool_calling_agent(llm, tools, prompt)
|
||||
|
||||
# 创建代理执行器
|
||||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
# 代理调用
|
||||
response = await agent_executor.ainvoke({"input": query})
|
||||
|
||||
return {"status": "success", "message": f"{response['output']}"}
|
||||
except Exception as e:
|
||||
logger.error(f"票务 MCP 测试出错:{str(e)}")
|
||||
return {"status": "error", "message": f"票务 MCP 查询出错:{str(e)}"}
|
||||
except Exception as e:
|
||||
logger.error(f"连接或会话初始化时发生错误: {e}")
|
||||
return {"status": "error", "message": "连接或会话初始化时发生错误"}
|
||||
|
||||
|
||||
# Agent 卡片定义
|
||||
agent_card = AgentCard(
|
||||
name="TicketOrderAssistant",
|
||||
description="通过MCP提供票务预定服务的助手",
|
||||
url="http://localhost:5007",
|
||||
version="1.0.4",
|
||||
capabilities={"streaming": True, "memory": True},
|
||||
skills=[
|
||||
AgentSkill(
|
||||
name="execute ticket order",
|
||||
description="根据客户端提供的输入执行票务预定,返回执行结果",
|
||||
examples=["北京 到 上海 2025-11-15 火车票 二等座 1张",
|
||||
"上海 到 北京 2025-12-11 飞机票 公务舱 2张"]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 票务预定服务器类
|
||||
class TicketOrderServer(A2AServer):
|
||||
def __init__(self):
|
||||
super().__init__(agent_card=agent_card)
|
||||
self.llm = llm
|
||||
self.ticket_client = A2AClient("http://localhost:5006")
|
||||
|
||||
# 处理任务:提取输入,查询余票,调用MCP,结果输出
|
||||
def handle_task(self, task):
|
||||
# 1 提取输入
|
||||
content = (task.message or {}).get("content", {}) # 从消息中获取内容
|
||||
# 提取conversation,即客户端发起的任务中的query语句
|
||||
conversation = content.get("text", "") if isinstance(content, dict) else ""
|
||||
logger.info(f"对话历史及用户问题: {conversation}")
|
||||
|
||||
try:
|
||||
# 2 调用票务查询agent查询余票
|
||||
message_ticket = Message(content=TextContent(text=conversation), role=MessageRole.USER)
|
||||
task_ticket = Task(id="task-" + str(uuid.uuid4()), message=message_ticket.to_dict())
|
||||
|
||||
# 发送任务并获取最终结果
|
||||
ticket_result_task = asyncio.run(self.ticket_client.send_task_async(task_ticket))
|
||||
logger.info(f"原始响应: {ticket_result_task}")
|
||||
|
||||
# 处理结果:未查到余票信息时,则返回提示信息
|
||||
if ticket_result_task.status.state != 'completed':
|
||||
required_message = ticket_result_task.status.message['content']['text']
|
||||
logger.info(f'余票未查到:{required_message}')
|
||||
task.status = TaskStatus(state=TaskState.INPUT_REQUIRED,
|
||||
message={"role": "agent", "content": {"text": required_message}})
|
||||
return task
|
||||
# 处理结果:查到余票信息时,进行订票
|
||||
ticket_result = ticket_result_task.artifacts[0]["parts"][0]["text"]
|
||||
logger.info(f"余票信息: {ticket_result}")
|
||||
|
||||
# 3 调用MCP订票
|
||||
order_result = asyncio.run(order_tickets(conversation + '\n余票信息:' + ticket_result))
|
||||
logger.info(f"MCP 返回: {order_result}")
|
||||
|
||||
# 4 结果输出
|
||||
data = order_result.get("message", '')
|
||||
logger.info(f"订票结果: {data}")
|
||||
# 检查响应状态
|
||||
if order_result.get("status") == "success":
|
||||
result = '余票信息:' + ticket_result + '\n订票结果:' + data
|
||||
# 设置任务产物为文本部分,并设置任务状态为完成
|
||||
task.artifacts = [{"parts": [{"type": "text", "text": result}]}]
|
||||
task.status = TaskStatus(state=TaskState.COMPLETED)
|
||||
else:
|
||||
# 设置任务状态为失败,添加错误信息
|
||||
task.status = TaskStatus(state=TaskState.FAILED,
|
||||
message={"role": "agent", "content": {"text": data}})
|
||||
return task
|
||||
except Exception as e: # 捕获异常
|
||||
logger.error(f"查询失败: {str(e)}")
|
||||
|
||||
# 设置任务状态为失败,添加错误信息
|
||||
task.status = TaskStatus(state=TaskState.FAILED,
|
||||
message={"role": "agent", "content": {"text": f"查询失败: {str(e)} 请重试或提供更多细节。"}})
|
||||
return task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建并运行服务器
|
||||
# 实例化票务查询服务器
|
||||
ticket_server = TicketOrderServer()
|
||||
# 打印服务器信息
|
||||
print("\n=== 服务器信息 ===")
|
||||
print(f"名称: {ticket_server.agent_card.name}")
|
||||
print(f"描述: {ticket_server.agent_card.description}")
|
||||
print("\n技能:")
|
||||
for skill in ticket_server.agent_card.skills:
|
||||
print(f"- {skill.name}: {skill.description}")
|
||||
# 运行服务器
|
||||
run_server(ticket_server, host="127.0.0.1", port=5007)
|
||||
283
a2a_server/ticket_server.py
Normal file
283
a2a_server/ticket_server.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from python_a2a import A2AServer, run_server, AgentCard, AgentSkill, TaskStatus, TaskState
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
from create_logger import logger
|
||||
|
||||
from conf import settings
|
||||
|
||||
|
||||
# 初始化LLM
|
||||
llm = ChatOpenAI(
|
||||
model=settings.model_name,
|
||||
base_url=settings.base_url,
|
||||
api_key=settings.api_key,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
|
||||
# 数据表 schema
|
||||
table_schema_string = """ # 定义票务表SQL schema字符串,用于Prompt上下文
|
||||
CREATE TABLE train_tickets (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY COMMENT '主键,自增,唯一标识每条记录',
|
||||
departure_city VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '出发城市(如“北京”)',
|
||||
arrival_city VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '到达城市(如“上海”)',
|
||||
departure_time DATETIME NOT NULL COMMENT '出发时间(如“2025-08-12 07:00:00”)',
|
||||
arrival_time DATETIME NOT NULL COMMENT '到达时间(如“2025-08-12 11:30:00”)',
|
||||
train_number VARCHAR(20) NOT NULL COMMENT '火车车次(如“G1001”)',
|
||||
seat_type VARCHAR(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '座位类型(如“二等座”)',
|
||||
total_seats INT NOT NULL COMMENT '总座位数(如 1000)',
|
||||
remaining_seats INT NOT NULL COMMENT '剩余座位数(如 50)',
|
||||
price DECIMAL(10, 2) NOT NULL COMMENT '票价(如 553.50)',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间,自动记录插入时间',
|
||||
UNIQUE KEY unique_train (departure_time, train_number)
|
||||
) COMMENT='火车票信息表';
|
||||
|
||||
-- 机票表
|
||||
CREATE TABLE flight_tickets (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY COMMENT '主键,自增,唯一标识每条记录',
|
||||
departure_city VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '出发城市(如“北京”)',
|
||||
arrival_city VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '到达城市(如“上海”)',
|
||||
departure_time DATETIME NOT NULL COMMENT '出发时间(如“2025-08-12 08:00:00”)',
|
||||
arrival_time DATETIME NOT NULL COMMENT '到达时间(如“2025-08-12 10:30:00”)',
|
||||
flight_number VARCHAR(20) NOT NULL COMMENT '航班号(如“CA1234”)',
|
||||
cabin_type VARCHAR(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '舱位类型(如“经济舱”)',
|
||||
total_seats INT NOT NULL COMMENT '总座位数(如 200)',
|
||||
remaining_seats INT NOT NULL COMMENT '剩余座位数(如 10)',
|
||||
price DECIMAL(10, 2) NOT NULL COMMENT '票价(如 1200.00)',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间,自动记录插入时间',
|
||||
UNIQUE KEY unique_flight (departure_time, flight_number)
|
||||
) COMMENT='航班机票信息表';
|
||||
|
||||
-- 演唱会票表
|
||||
CREATE TABLE concert_tickets (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY COMMENT '主键,自增,唯一标识每条记录',
|
||||
artist VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '艺人名称(如“周杰伦”)',
|
||||
city VARCHAR(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '举办城市(如“上海”)',
|
||||
venue VARCHAR(100) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '场馆(如“上海体育场”)',
|
||||
start_time DATETIME NOT NULL COMMENT '开始时间(如“2025-08-12 19:00:00”)',
|
||||
end_time DATETIME NOT NULL COMMENT '结束时间(如“2025-08-12 22:00:00”)',
|
||||
ticket_type VARCHAR(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci NOT NULL COMMENT '票类型(如“VIP”)',
|
||||
total_seats INT NOT NULL COMMENT '总座位数(如 5000)',
|
||||
remaining_seats INT NOT NULL COMMENT '剩余座位数(如 100)',
|
||||
price DECIMAL(10, 2) NOT NULL COMMENT '票价(如 880.00)',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间,自动记录插入时间',
|
||||
UNIQUE KEY unique_concert (start_time, artist, ticket_type)
|
||||
) COMMENT='演唱会门票信息表';
|
||||
"""
|
||||
|
||||
# 生成SQL的提示词
|
||||
sql_prompt = ChatPromptTemplate.from_template("""
|
||||
系统提示:你是一个专业的票务SQL生成器,需要从对话历史(含用户的问题)中提取用户的意图以及关键信息,然后基于train_tickets、flight_tickets、concert_tickets表生成SELECT语句。
|
||||
根据对话历史:
|
||||
1. 提取用户的意图,意图有3种(train: 火车/高铁, flight: 机票, concert: 演唱会),输出:{{"type": "train/flight/concert"}};如果无法识别意图,或者意图不在这3种内,则模仿最后1个示例回复即可。
|
||||
2. 根据用户的意图,生成对应表的 SELECT 语句,仅查询指定字段:
|
||||
- train_tickets: id, departure_city, arrival_city, departure_time, arrival_time, train_number, seat_type, price, remaining_seats
|
||||
- flight_tickets: id, departure_city, arrival_city, departure_time, arrival_time, flight_number, cabin_type, price, remaining_seats
|
||||
- concert_tickets: id, artist, city, venue, start_time, end_time, ticket_type, price, remaining_seats
|
||||
3. 如果用户在查询票务信息时,缺少必要信息,则输出:{{"status": "input_required", "message": "请提供票务类型(如火车票、机票、演唱会)和必要信息(如城市、日期)。"}} ,如示例所示;如果对话历史中信息齐全,则输出纯SQL即可。
|
||||
其中,每种意图必要的信息有:
|
||||
- flight/train: 【departure_city (出发城市), arrival_city (到达城市), date (日期)】 或 【train_number/flight_number (车次)】
|
||||
- concert: city (城市), artist (艺人), date (日期)。
|
||||
4. 按要求输出两行数据或一行数据即可,不需要输出其他内容。
|
||||
|
||||
|
||||
示例:
|
||||
- 对话: user: 火车票 北京 上海 2025-07-31 硬卧
|
||||
输出:
|
||||
{{"type": "train"}}
|
||||
SELECT id, departure_city, arrival_city, departure_time, arrival_time, train_number, seat_type, price, remaining_seats FROM train_tickets WHERE departure_city = '北京' AND arrival_city = '上海' AND DATE(departure_time) = '2025-07-31' AND seat_type = '硬卧'
|
||||
|
||||
- 对话: user: 机票 上海 广州 2025-09-11 头等舱
|
||||
输出:
|
||||
{{"type": "flight"}}
|
||||
SELECT id, departure_city, arrival_city, departure_time, arrival_time, flight_number, cabin_type, price, remaining_seats FROM flight_tickets WHERE departure_city = '上海' AND arrival_city = '广州' AND DATE(departure_time) = '2025-09-11' AND cabin_type = '头等舱'
|
||||
|
||||
- 对话: user: 演唱会 北京 刀郎 2025-08-23 看台
|
||||
输出:
|
||||
{{"type": "concert"}}
|
||||
SELECT id, artist, city, venue, start_time, end_time, ticket_type, price, remaining_seats FROM concert_tickets WHERE city = '北京' AND artist = '刀郎' AND DATE(start_time) = '2025-08-23' AND ticket_type = '看台'
|
||||
|
||||
- 对话: user: 火车票
|
||||
输出:
|
||||
{{"status": "input_required", "message": "请提供票务类型(如火车票、机票、演唱会)和必要信息(如城市、日期)。"}}
|
||||
|
||||
- 对话: user: 你好
|
||||
输出:
|
||||
{{"status": "input_required", "message": "请提供票务类型(如火车票、机票、演唱会)和必要信息(如城市、日期)。"}}
|
||||
|
||||
表结构:{table_schema_string}
|
||||
对话历史: {conversation}
|
||||
当前日期: {current_date} (Asia/Shanghai)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# 定义查询函数
|
||||
async def get_ticket_info(sql):
|
||||
try:
|
||||
# 启动 MCP server,通过streamable建立连接
|
||||
async with streamablehttp_client("http://127.0.0.1:8001/mcp") as (read, write, _):
|
||||
# 使用读写通道创建 MCP 会话
|
||||
async with ClientSession(read, write) as session:
|
||||
try:
|
||||
await session.initialize()
|
||||
# 工具调用
|
||||
result = await session.call_tool("query_tickets", {"sql": sql})
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
logger.info(f"票务查询结果:{result_data}")
|
||||
return result_data.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"票务 MCP 测试出错:{str(e)}")
|
||||
return {"status": "error", "message": f"票务 MCP 查询出错:{str(e)}"}
|
||||
except Exception as e:
|
||||
logger.error(f"连接或会话初始化时发生错误: {e}")
|
||||
return {"status": "error", "message": "连接或会话初始化时发生错误"}
|
||||
|
||||
|
||||
# Agent 卡片定义
|
||||
agent_card = AgentCard(
|
||||
name="TicketQueryAssistant",
|
||||
description="基于 LangChain 提供票务查询服务的助手",
|
||||
url="http://localhost:5006",
|
||||
version="1.0.4",
|
||||
capabilities={"streaming": True, "memory": True},
|
||||
skills=[
|
||||
AgentSkill(
|
||||
name="execute ticket query",
|
||||
description="根据客户端提供的输入执行票务查询,返回数据库结果,支持自然语言输入",
|
||||
examples=["火车票 北京 上海 2025-07-31 硬卧", "机票 北京 上海 2025-07-31 经济舱",
|
||||
"演唱会 北京 刀郎 2025-08-23 看台"]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 票务查询服务器类
|
||||
class TicketQueryServer(A2AServer):
|
||||
def __init__(self):
|
||||
super().__init__(agent_card=agent_card)
|
||||
self.llm = llm
|
||||
self.sql_prompt = sql_prompt
|
||||
self.schema = table_schema_string
|
||||
|
||||
# 定义生成SQL查询方法,输入对话历史,返回SQL或追问JSON
|
||||
def generate_sql_query(self, conversation: str) -> dict:
|
||||
try:
|
||||
# 组装链
|
||||
chain = self.sql_prompt | self.llm
|
||||
# 调用链
|
||||
current_date = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d') # 获取当前日期,格式化为字符串
|
||||
output = chain.invoke({"conversation": conversation, "current_date": current_date, "table_schema_string": self.schema}).content.strip()
|
||||
logger.info(f"原始 LLM 输出: {output}")
|
||||
|
||||
# 处理结果,返回字典
|
||||
lines = output.split('\n')
|
||||
type_line = lines[0].strip()
|
||||
if type_line.startswith('```json'): # 检查是否以```json开头
|
||||
type_line = lines[1].strip() # 取下一行为类型行
|
||||
sql_lines = lines[3:-1] if lines[-1].strip() == '```' else lines[3:] # 提取SQL行,跳过代码块标记
|
||||
else:
|
||||
sql_lines = lines[1:] if len(lines) > 1 else [] # 取剩余行为SQL行
|
||||
|
||||
# 提取 type 和 SQL
|
||||
if type_line.startswith('{"type":'): # 如果以{"type":开头
|
||||
query_type = json.loads(type_line)["type"] # 解析并提取类型
|
||||
sql_query = ' '.join([line.strip() for line in sql_lines if line.strip() and not line.startswith('```')]) # 连接SQL行,过滤空行和代码块
|
||||
logger.info(f"分类类型: {query_type}, 生成的 SQL: {sql_query}")
|
||||
return {"status": "sql", "type": query_type, "sql": sql_query} # 返回SQL状态字典,包括类型
|
||||
elif type_line.startswith('{"status": "input_required"'): # 检查是否为追问JSON
|
||||
return json.loads(type_line)
|
||||
else: # 无效格式
|
||||
logger.error(f"无效的 LLM 输出格式: {output}")
|
||||
return {"status": "input_required", "message": "无法解析查询类型或SQL,请提供更明确的信息。"} # 返回默认追问
|
||||
except Exception as e:
|
||||
logger.error(f"SQL 生成失败: {str(e)}")
|
||||
return {"status": "input_required", "message": "查询无效,请提供查询票务的相关信息。"} # 返回追问JSON
|
||||
|
||||
# 处理任务:提取输入,生成SQL,调用MCP,格式化结果
|
||||
def handle_task(self, task):
|
||||
# 1 提取输入
|
||||
content = (task.message or {}).get("content", {}) # 从消息中获取内容
|
||||
# 提取conversation,即客户端发起的任务中的query语句
|
||||
conversation = content.get("text", "") if isinstance(content, dict) else ""
|
||||
logger.info(f"对话历史及用户问题: {conversation}")
|
||||
|
||||
try:
|
||||
# 2 基于用户问题生成SQL查询
|
||||
gen_result = self.generate_sql_query(conversation)
|
||||
# 检查是否需要追问,如果是则添加追问消息后返回任务
|
||||
if gen_result["status"] == "input_required":
|
||||
task.status = TaskStatus(state=TaskState.INPUT_REQUIRED,
|
||||
message={"role": "agent", "content": {"text": gen_result["message"]}})
|
||||
return task
|
||||
|
||||
# 否则则提取SQL查询,并进行MCP调用
|
||||
sql_query = gen_result["sql"]
|
||||
query_type = gen_result["type"]
|
||||
logger.info(f"执行 SQL 查询: {sql_query} (类型: {query_type})")
|
||||
|
||||
# 3 调用MCP
|
||||
ticket_result = asyncio.run(get_ticket_info(sql_query))
|
||||
|
||||
# 4 格式化结果
|
||||
response = json.loads(ticket_result) if isinstance(ticket_result, str) else ticket_result
|
||||
logger.info(f"MCP 返回: {response}")
|
||||
# 检查响应状态
|
||||
if response.get("status") == "success":
|
||||
data = response.get("data", []) # 提取数据列表
|
||||
response_text = "" # 初始化响应文本
|
||||
for d in data: # 遍历每个数据项
|
||||
if query_type == "train": # 火车票类型
|
||||
response_text += f"{d['departure_city']} 到 {d['arrival_city']} {d['departure_time']}: 车次 {d['train_number']},{d['seat_type']},票价 {d['price']}元,剩余 {d['remaining_seats']} 张\n" # 格式化火车票文本
|
||||
elif query_type == "flight": # 机票类型
|
||||
response_text += f"{d['departure_city']} 到 {d['arrival_city']} {d['departure_time']}: 航班 {d['flight_number']},{d['cabin_type']},票价 {d['price']}元,剩余 {d['remaining_seats']} 张\n" # 格式化机票文本
|
||||
elif query_type == "concert": # 演唱会类型
|
||||
response_text += f"{d['city']} {d['start_time']}: {d['artist']} 演唱会,{d['ticket_type']},场地 {d['venue']},票价 {d['price']}元,剩余 {d['remaining_seats']} 张\n" # 格式化演唱会文本
|
||||
if not response_text: # 检查文本是否为空
|
||||
response_text = "无结果。如果需要其他日期,请补充。"
|
||||
|
||||
# 设置任务产物为文本部分,并设置任务状态为完成
|
||||
task.artifacts = [{"parts": [{"type": "text", "text": response_text}]}]
|
||||
task.status = TaskStatus(state=TaskState.COMPLETED)
|
||||
elif response.get("status") == "no_data":
|
||||
response_text = response.get("message", "请输出查询票务的详细信息。")
|
||||
|
||||
# 设置任务状态为输入所需,添加追问消息
|
||||
task.status = TaskStatus(state=TaskState.INPUT_REQUIRED,
|
||||
message={"role": "agent", "content": {"text": response_text}})
|
||||
else:
|
||||
response_text = response.get("message", "查询失败,请重试或提供更多细节。")
|
||||
|
||||
# 设置任务状态为失败,添加错误信息
|
||||
task.status = TaskStatus(state=TaskState.FAILED,
|
||||
message={"role": "agent", "content": {"text": response_text}})
|
||||
return task
|
||||
except Exception as e: # 捕获异常
|
||||
logger.error(f"查询失败: {str(e)}")
|
||||
|
||||
# 设置任务状态为失败,添加错误信息
|
||||
task.status = TaskStatus(state=TaskState.FAILED,
|
||||
message={"role": "agent", "content": {"text": f"查询失败: {str(e)} 请重试或提供更多细节。"}})
|
||||
return task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建并运行服务器
|
||||
# 实例化票务查询服务器
|
||||
ticket_server = TicketQueryServer()
|
||||
# 打印服务器信息
|
||||
print("\n=== 服务器信息 ===")
|
||||
print(f"名称: {ticket_server.agent_card.name}")
|
||||
print(f"描述: {ticket_server.agent_card.description}")
|
||||
print("\n技能:")
|
||||
for skill in ticket_server.agent_card.skills:
|
||||
print(f"- {skill.name}: {skill.description}")
|
||||
# 运行服务器
|
||||
run_server(ticket_server, host="127.0.0.1", port=5006)
|
||||
227
a2a_server/weather_server.py
Normal file
227
a2a_server/weather_server.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import json
|
||||
import asyncio
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from python_a2a import A2AServer, run_server, AgentCard, AgentSkill, TaskStatus, TaskState
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
from create_logger import logger
|
||||
from conf import settings
|
||||
|
||||
|
||||
# 初始化LLM
|
||||
llm = ChatOpenAI(
|
||||
model=settings.model_name,
|
||||
base_url=settings.base_url,
|
||||
api_key=settings.api_key,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
# 数据表 schema
|
||||
table_schema_string = """ # 定义天气数据表的SQL schema字符串,用于Prompt上下文
|
||||
CREATE TABLE IF NOT EXISTS weather_data (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
city VARCHAR(50) NOT NULL COMMENT '城市名称',
|
||||
fx_date DATE NOT NULL COMMENT '预报日期',
|
||||
sunrise TIME COMMENT '日出时间',
|
||||
sunset TIME COMMENT '日落时间',
|
||||
moonrise TIME COMMENT '月升时间',
|
||||
moonset TIME COMMENT '月落时间',
|
||||
moon_phase VARCHAR(20) COMMENT '月相名称',
|
||||
moon_phase_icon VARCHAR(10) COMMENT '月相图标代码',
|
||||
temp_max INT COMMENT '最高温度',
|
||||
temp_min INT COMMENT '最低温度',
|
||||
icon_day VARCHAR(10) COMMENT '白天天气图标代码',
|
||||
text_day VARCHAR(20) COMMENT '白天天气描述',
|
||||
icon_night VARCHAR(10) COMMENT '夜间天气图标代码',
|
||||
text_night VARCHAR(20) COMMENT '夜间天气描述',
|
||||
wind360_day INT COMMENT '白天风向360角度',
|
||||
wind_dir_day VARCHAR(20) COMMENT '白天风向',
|
||||
wind_scale_day VARCHAR(10) COMMENT '白天风力等级',
|
||||
wind_speed_day INT COMMENT '白天风速 (km/h)',
|
||||
wind360_night INT COMMENT '夜间风向360角度',
|
||||
wind_dir_night VARCHAR(20) COMMENT '夜间风向',
|
||||
wind_scale_night VARCHAR(10) COMMENT '夜间风力等级',
|
||||
wind_speed_night INT COMMENT '夜间风速 (km/h)',
|
||||
precip DECIMAL(5,1) COMMENT '降水量 (mm)',
|
||||
uv_index INT COMMENT '紫外线指数',
|
||||
humidity INT COMMENT '相对湿度 (%)',
|
||||
pressure INT COMMENT '大气压强 (hPa)',
|
||||
vis INT COMMENT '能见度 (km)',
|
||||
cloud INT COMMENT '云量 (%)',
|
||||
update_time DATETIME COMMENT '数据更新时间',
|
||||
UNIQUE KEY unique_city_date (city, fx_date)
|
||||
) ENGINE=INNODB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='天气数据表';
|
||||
"""
|
||||
|
||||
# 生成SQL的提示词
|
||||
sql_prompt = ChatPromptTemplate.from_template(
|
||||
"""
|
||||
系统提示:你是一个专业的天气SQL生成器,需要从对话历史(含用户的问题)中提取关键信息,然后基于weather_data表生成SELECT语句。
|
||||
- 如果用户需要查天气,则至少需要城市和时间信息。如果对话历史中缺乏必要的信息,可以向其追问,输出格式为json格式,如示例所示;
|
||||
- 如果对话历史中信息齐全,则输出纯SQL即可。
|
||||
- 如果用户问与天气无关的问题,则模仿最后2个示例回复即可。
|
||||
|
||||
|
||||
示例:
|
||||
- 对话: user: 北京 2025-07-30
|
||||
输出: SELECT city, fx_date, temp_max, temp_min, text_day, text_night, humidity, wind_dir_day, precip FROM weather_data WHERE city = '北京' AND fx_date = '2025-07-30'
|
||||
- 对话: user: 上海未来3天的天气
|
||||
输出: SELECT city, fx_date, temp_max, temp_min, text_day, text_night, humidity, wind_dir_day, precip FROM weather_data WHERE city = '上海' AND fx_date BETWEEN '2025-07-30' AND '2025-08-01' ORDER BY fx_date
|
||||
- 对话: user: 北京的天气
|
||||
输出: {{"status": "input_required", "message": "请提供具体的需要查询的日期,例如 '2025-07-30'。"}}
|
||||
- 对话: user: 今天\nassistant: 请提供城市。\nuser: 北京
|
||||
输出: SELECT city, fx_date, temp_max, temp_min, text_day, text_night, humidity, wind_dir_day, precip FROM weather_data WHERE city = '北京' AND fx_date = '2025-07-30'
|
||||
- 对话: user: 北京明天的天气\nassistant: 多云。\nuser: 后天呢
|
||||
输出: SELECT city, fx_date, temp_max, temp_min, text_day, text_night, humidity, wind_dir_day, precip FROM weather_data WHERE city = '北京' AND fx_date = '2025-08-01'
|
||||
- 对话: user: 你好
|
||||
输出: {{"status": "input_required", "message": "请提供城市和日期,例如 '北京 2025-07-30'。"}}
|
||||
- 对话: user: 今天有什么好吃的
|
||||
输出: {{"status": "input_required", "message": "请提供天气相关查询,包括城市和日期。"}}
|
||||
|
||||
weather_data表结构:{table_schema_string}
|
||||
对话历史: {conversation}
|
||||
当前日期: {current_date} (Asia/Shanghai)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# 定义查询函数
|
||||
async def get_weather(sql):
|
||||
try:
|
||||
# 启动 MCP server,通过streamable建立连接
|
||||
async with streamablehttp_client("http://127.0.0.1:8002/mcp") as (read, write, _):
|
||||
# 使用读写通道创建 MCP 会话
|
||||
async with ClientSession(read, write) as session:
|
||||
try:
|
||||
await session.initialize()
|
||||
# 工具调用
|
||||
result = await session.call_tool("query_weather", {"sql": sql})
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
logger.info(f"天气查询结果:{result_data}")
|
||||
return result_data.content[0].text
|
||||
except Exception as e:
|
||||
logger.error(f"天气 MCP 测试出错:{str(e)}")
|
||||
return {"status": "error", "message": f"天气 MCP 查询出错:{str(e)}"}
|
||||
except Exception as e:
|
||||
logger.error(f"连接或会话初始化时发生错误: {e}")
|
||||
return {"status": "error", "message": "连接或会话初始化时发生错误"}
|
||||
|
||||
|
||||
# Agent卡片定义
|
||||
agent_card = AgentCard(
|
||||
name="WeatherQueryAssistant",
|
||||
description="基于LangChain提供天气查询服务的助手",
|
||||
url="http://localhost:5005",
|
||||
version="1.0.0",
|
||||
capabilities={"streaming": True, "memory": True}, # 设置能力:支持流式和内存
|
||||
skills=[ # 定义技能列表
|
||||
AgentSkill(
|
||||
name="execute weather query",
|
||||
description="执行天气查询,返回天气数据库结果,支持自然语言输入",
|
||||
examples=["北京 2025-07-30 天气", "上海未来5天", "今天天气如何"]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 天气查询服务器类
|
||||
class WeatherQueryServer(A2AServer):
|
||||
def __init__(self):
|
||||
super().__init__(agent_card=agent_card)
|
||||
self.llm = llm
|
||||
self.sql_prompt = sql_prompt
|
||||
self.schema = table_schema_string
|
||||
|
||||
# 定义生成SQL查询方法,输入对话历史,返回SQL或追问JSON
|
||||
def generate_sql_query(self, conversation: str) -> dict:
|
||||
try:
|
||||
# 组装链
|
||||
chain = self.sql_prompt | self.llm
|
||||
# 调用链
|
||||
current_date = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d') # 获取当前日期,格式化为字符串
|
||||
output = chain.invoke({"conversation": conversation, "current_date": current_date, "table_schema_string": self.schema}).content.strip()
|
||||
logger.info(f"原始 LLM 输出: {output}")
|
||||
# 处理结果,返回字典
|
||||
if output.startswith('{'): # 检查输出是否以JSON开头
|
||||
return json.loads(output)
|
||||
return {"status": "sql", "sql": output}
|
||||
except Exception as e:
|
||||
logger.error(f"SQL生成失败: {str(e)}")
|
||||
return {"status": "input_required", "message": "查询无效,请提供城市和日期。"} # 返回追问JSON
|
||||
|
||||
# 处理任务:提取输入,生成SQL,调用MCP,格式化结果
|
||||
def handle_task(self, task):
|
||||
# 1 提取输入
|
||||
content = (task.message or {}).get("content", {}) # 从消息中获取内容
|
||||
# 提取conversation,即客户端发起的任务中的query语句
|
||||
conversation = content.get("text", "") if isinstance(content, dict) else ""
|
||||
logger.info(f"对话历史及用户问题: {conversation}")
|
||||
|
||||
try:
|
||||
# 2 基于用户问题生成SQL查询
|
||||
gen_result = self.generate_sql_query(conversation)
|
||||
# 检查是否需要追问,如果是则添加追问消息后返回任务
|
||||
if gen_result["status"] == "input_required":
|
||||
# 追问逻辑,这里是指在无法正常生成sql时,设置任务状态为输入所需,添加追问消息
|
||||
task.status = TaskStatus(state=TaskState.INPUT_REQUIRED,
|
||||
message={"role": "agent", "content": {"text": gen_result["message"]}})
|
||||
return task
|
||||
|
||||
# 否则则提取SQL查询,并进行MCP调用
|
||||
sql_query = gen_result["sql"] #
|
||||
logger.info(f"生成的SQL查询: {sql_query}")
|
||||
|
||||
# 3 调用MCP
|
||||
weather_result = asyncio.run(get_weather(sql_query))
|
||||
|
||||
# 4 格式化结果
|
||||
response = json.loads(weather_result) if isinstance(weather_result, str) else weather_result
|
||||
logger.info(f"MCP 返回: {response}")
|
||||
# 检查响应状态
|
||||
if response.get("status") == "success":
|
||||
data = response.get("data", []) # 提取数据列表
|
||||
response_text = "\n".join([f"{d['city']} {d['fx_date']}: {d['text_day']}(夜间 {d['text_night']}),温度 {d['temp_min']}-{d['temp_max']}°C,湿度 {d['humidity']}%,风向 {d['wind_dir_day']},降水 {d['precip']}mm" for d in data]) # 格式化每个数据项为友好文本,连接成多行
|
||||
|
||||
# 设置任务产物为文本部分,并设置任务状态为完成
|
||||
task.artifacts = [{"parts": [{"type": "text", "text": response_text}]}]
|
||||
task.status = TaskStatus(state=TaskState.COMPLETED)
|
||||
elif response.get("status") == "no_data":
|
||||
response_text = response.get("message", "请重新输入查询的城市和日期。")
|
||||
|
||||
# 设置任务状态为输入所需,添加追问消息
|
||||
task.status = TaskStatus(state=TaskState.INPUT_REQUIRED,
|
||||
message={"role": "agent", "content": {"text": response_text}})
|
||||
else:
|
||||
response_text = response.get("message", "查询失败,请重试或提供更多细节。")
|
||||
|
||||
# 设置任务状态为失败,添加错误信息
|
||||
task.status = TaskStatus(state=TaskState.FAILED,
|
||||
message={"role": "agent", "content": {"text": response_text}})
|
||||
|
||||
return task
|
||||
except Exception as e: # 捕获异常
|
||||
logger.error(f"查询失败: {str(e)}")
|
||||
|
||||
# 设置任务状态为失败,添加错误信息
|
||||
task.status = TaskStatus(state=TaskState.FAILED,
|
||||
message={"role": "agent",
|
||||
"content": {"text": f"查询失败: {str(e)} 请重试或提供更多细节。"}})
|
||||
return task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建并运行服务器
|
||||
# 实例化天气查询服务器
|
||||
weather_server = WeatherQueryServer()
|
||||
# 打印服务器信息
|
||||
print("\n=== 服务器信息 ===")
|
||||
print(f"名称: {weather_server.agent_card.name}")
|
||||
print(f"描述: {weather_server.agent_card.description}")
|
||||
print("\n技能:")
|
||||
for skill in weather_server.agent_card.skills:
|
||||
print(f"- {skill.name}: {skill.description}")
|
||||
# 运行服务器
|
||||
run_server(weather_server, host="127.0.0.1", port=5005)
|
||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
236
app/mian.py
Normal file
236
app/mian.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
import re
|
||||
from python_a2a import AgentNetwork, TextContent, Message, MessageRole, Task
|
||||
from langchain_openai import ChatOpenAI
|
||||
from create_logger import logger
|
||||
from app.prompts import SmartVoyagePrompts
|
||||
from conf import settings
|
||||
|
||||
|
||||
# 初始化全局变量,用于模拟会话状态 这些变量替换了Streamlit的session_state
|
||||
messages = [] # 存储对话历史消息列表,每个元素为字典{"role": "user/assistant", "content": "消息内容"}
|
||||
agent_network = None # 代理网络实例
|
||||
llm = None # 大语言模型实例
|
||||
agent_urls = {} # 存储代理的URL信息字典
|
||||
conversation_history = "" # 存储整个对话历史字符串,用于意图识别
|
||||
|
||||
|
||||
# 初始化代理网络和相关组件 此部分在脚本启动时执行一次,模拟Streamlit的初始化
|
||||
def initialize_system():
|
||||
"""
|
||||
初始化系统组件,包括代理网络、路由器、LLM和会话状态
|
||||
核心逻辑:构建AgentNetwork,添加代理,创建路由器和LLM
|
||||
"""
|
||||
global agent_network, llm, agent_urls, conversation_history
|
||||
# 存储代理URL信息,便于查看
|
||||
agent_urls = {
|
||||
"WeatherQueryAssistant": "http://localhost:5005", # 天气代理URL
|
||||
"TicketQueryAssistant": "http://localhost:5006", # 票务代理URL
|
||||
"TicketOrderAssistant": "http://localhost:5007" # 票务预定URL
|
||||
}
|
||||
# 创建代理网络
|
||||
network = AgentNetwork(name="旅行助手网络")
|
||||
network.add("WeatherQueryAssistant", "http://localhost:5005")
|
||||
network.add("TicketQueryAssistant", "http://localhost:5006")
|
||||
network.add("TicketOrderAssistant", "http://localhost:5007")
|
||||
agent_network = network
|
||||
|
||||
# 加载配置并创建LLM
|
||||
llm = ChatOpenAI(
|
||||
model=settings.model_name,
|
||||
api_key=settings.api_key,
|
||||
base_url=settings.base_url,
|
||||
temperature=0.1
|
||||
)
|
||||
|
||||
# 初始化对话历史为空字符串
|
||||
conversation_history = ""
|
||||
|
||||
|
||||
# 意图识别agent
|
||||
def intent_agent(user_input):
|
||||
global conversation_history, llm
|
||||
|
||||
# 创建意图识别链:提示模板 + LLM
|
||||
chain = SmartVoyagePrompts.intent_prompt() | llm
|
||||
|
||||
# 调用LLM进行意图识别
|
||||
current_date = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d') # 获取当前日期(Asia/Shanghai时区)
|
||||
intent_response = chain.invoke(
|
||||
{"conversation_history": '\n'.join(conversation_history.split("\n")[-6:]), "query": user_input,
|
||||
"current_date": current_date}).content.strip()
|
||||
logger.info(f"意图识别原始响应: {intent_response}")
|
||||
|
||||
# 清理响应:移除可能的Markdown代码块标记
|
||||
intent_response = re.sub(r'^```json\s*|\s*```$', '', intent_response).strip()
|
||||
logger.info(f"清理后响应: {intent_response}")
|
||||
intent_output = json.loads(intent_response)
|
||||
# 提取意图、改写问题和追问消息
|
||||
intents = intent_output.get("intents", [])
|
||||
user_queries = intent_output.get("user_queries", {})
|
||||
follow_up_message = intent_output.get("follow_up_message", "")
|
||||
logger.info(f"intents: {intents}||user_queries: {user_queries}||follow_up_message: {follow_up_message} ")
|
||||
|
||||
return intents, user_queries, follow_up_message
|
||||
|
||||
|
||||
# 处理用户输入的核心函数
|
||||
# 此函数模拟Streamlit的输入处理逻辑,包括意图识别、路由和响应生成
|
||||
def process_user_input(prompt):
|
||||
"""
|
||||
处理用户输入:识别意图、调用代理、生成响应
|
||||
核心逻辑:使用LLM进行意图识别,根据意图路由到相应代理或直接生成内容
|
||||
"""
|
||||
global messages, conversation_history, llm
|
||||
# 添加用户消息到历史
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
conversation_history += f"\nUser: {prompt}"
|
||||
|
||||
print("正在分析您的意图...")
|
||||
try:
|
||||
# 意图识别过程
|
||||
intents, user_queries, follow_up_message = intent_agent(prompt)
|
||||
|
||||
# 根据意图输出生成响应
|
||||
if "out_of_scope" in intents:
|
||||
# 如果意图超出范围,返回大模型直接回复
|
||||
response = follow_up_message
|
||||
conversation_history += f"\nAssistant: {response}"
|
||||
elif follow_up_message != "":
|
||||
# 如果有追问消息,则直接返回
|
||||
response = follow_up_message
|
||||
conversation_history += f"\nAssistant: {response}" # 更新历史
|
||||
else: # 处理有效意图
|
||||
responses = [] # 存储每个意图的响应列表
|
||||
routed_agents = [] # 记录路由到的代理列表
|
||||
for intent in intents:
|
||||
logger.info(f"处理意图:{intent}")
|
||||
# 根据意图确定代理名称
|
||||
if intent == "weather":
|
||||
agent_name = "WeatherQueryAssistant"
|
||||
elif intent in ["flight", "train", "concert"]:
|
||||
agent_name = "TicketQueryAssistant"
|
||||
elif intent == "order":
|
||||
agent_name = "TicketOrderAssistant"
|
||||
else:
|
||||
agent_name = None
|
||||
|
||||
# 不同意图处理方式
|
||||
if intent == "attraction":
|
||||
# 对于景点推荐,直接使用LLM生成
|
||||
chain = SmartVoyagePrompts.attraction_prompt() | llm
|
||||
rec_response = chain.invoke({"query": prompt}).content.strip()
|
||||
responses.append(rec_response)
|
||||
elif agent_name:
|
||||
# 对于代理意图,则调用代理
|
||||
# 1)获取问题
|
||||
query_str = user_queries.get(intent, {})
|
||||
logger.info(f"{agent_name} 查询:{query_str}")
|
||||
# 2)获取代理实例
|
||||
agent = agent_network.get_agent(agent_name)
|
||||
# 3)构建历史对话信息+新查询,然后调用代理
|
||||
chat_history = '\n'.join(conversation_history.split("\n")[-7:-1]) + f'\nUser: {query_str}'
|
||||
message = Message(content=TextContent(text=chat_history), role=MessageRole.USER)
|
||||
task = Task(id="task-" + str(uuid.uuid4()), message=message.to_dict())
|
||||
raw_response = asyncio.run(agent.send_task_async(task))
|
||||
logger.info(f"{agent_name} 原始响应: {raw_response}") # 记录原始响应日志
|
||||
# 4)处理结果
|
||||
if raw_response.status.state == 'completed': # 正常结果
|
||||
agent_result = raw_response.artifacts[0]['parts'][0]['text']
|
||||
else: # 异常结果
|
||||
agent_result = raw_response.status.message['content']['text']
|
||||
|
||||
# 根据代理类型总结响应
|
||||
if agent_name == "WeatherQueryAssistant":
|
||||
chain = SmartVoyagePrompts.summarize_weather_prompt() | llm
|
||||
final_response = chain.invoke({"query": query_str, "raw_response": agent_result}).content.strip()
|
||||
elif agent_name == "TicketQueryAssistant":
|
||||
chain = SmartVoyagePrompts.summarize_ticket_prompt() | llm
|
||||
final_response = chain.invoke({"query": query_str, "raw_response": agent_result}).content.strip()
|
||||
else :
|
||||
final_response = agent_result
|
||||
|
||||
# 5)添加到历史
|
||||
responses.append(final_response) # 添加到响应列表
|
||||
routed_agents.append(agent_name) # 记录路由代理
|
||||
else:
|
||||
# 不支持的意图
|
||||
responses.append("暂不支持此意图。")
|
||||
|
||||
# 组合所有响应
|
||||
response = "\n\n".join(responses)
|
||||
if routed_agents:
|
||||
logger.info(f"路由到代理:{routed_agents}")
|
||||
conversation_history += f"\nAssistant: {response}" # 更新历史
|
||||
|
||||
# 输出助手响应(模拟Streamlit的显示)
|
||||
print(f"\n助手回复:\n{response}\n") # 打印响应
|
||||
# 添加到消息历史
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
except json.JSONDecodeError as json_err:
|
||||
# 处理JSON解析错误
|
||||
logger.error(f"意图识别JSON解析失败")
|
||||
error_message = f"意图识别JSON解析失败:{str(json_err)}。请重试。"
|
||||
print(f"\n助手回复:\n{error_message}\n") # 打印错误
|
||||
messages.append({"role": "assistant", "content": error_message})
|
||||
except Exception as e:
|
||||
# 处理其他异常
|
||||
logger.error(f"处理异常: {str(e)}")
|
||||
error_message = f"处理失败:{str(e)}。请重试。"
|
||||
print(f"\n助手回复:\n{error_message}\n") # 打印错误
|
||||
messages.append({"role": "assistant", "content": error_message})
|
||||
|
||||
|
||||
# 显示代理卡片信息
|
||||
# 此函数模拟Streamlit的右侧Agent Card,打印代理详情
|
||||
def display_agent_cards():
|
||||
"""
|
||||
显示所有代理的卡片信息,包括技能、描述、地址和状态
|
||||
核心逻辑:遍历代理网络,获取并打印卡片内容
|
||||
"""
|
||||
print("\n🛠️ Agent Cards:")
|
||||
for agent_name in agent_network.agents.keys():
|
||||
# 获取代理卡片
|
||||
agent_card = agent_network.get_agent_card(agent_name)
|
||||
agent_url = agent_urls.get(agent_name, "未知地址")
|
||||
print(f"\n--- Agent: {agent_name} ---")
|
||||
print(f"技能: {agent_card.skills}")
|
||||
print(f"描述: {agent_card.description}")
|
||||
print(f"地址: {agent_url}")
|
||||
print(f"状态: 在线") # 固定状态为在线
|
||||
|
||||
# 主函数:脚本入口
|
||||
# 初始化系统并进入交互循环
|
||||
if __name__ == "__main__":
|
||||
# 初始化系统
|
||||
initialize_system()
|
||||
print("🤖 基于A2A的SmartVoyage旅行智能助手")
|
||||
print("欢迎体验智能对话!输入问题,按回车提交;输入'quit'退出;输入'cards'查看代理卡片。")
|
||||
|
||||
# 显示初始代理卡片
|
||||
display_agent_cards()
|
||||
|
||||
# 交互循环:模拟Streamlit的连续输入
|
||||
while True:
|
||||
# 获取用户输入
|
||||
prompt = input("\n请输入您的问题: ").strip()
|
||||
if prompt.lower() == 'quit':
|
||||
print("感谢使用SmartVoyage!再见!")
|
||||
break
|
||||
elif prompt.lower() == 'cards': # 查看卡片条件
|
||||
display_agent_cards() # 重新显示卡片
|
||||
continue
|
||||
elif not prompt: # 空输入跳过
|
||||
continue
|
||||
else:
|
||||
# 处理输入
|
||||
process_user_input(prompt) # 调用核心处理函数
|
||||
|
||||
# 脚本结束时打印页脚信息
|
||||
print("\n---")
|
||||
print("Powered by 黑马程序员 | 基于Agent2Agent的旅行助手系统 v2.0")
|
||||
80
app/prompts.py
Normal file
80
app/prompts.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
|
||||
class SmartVoyagePrompts:
|
||||
|
||||
# 定义意图识别提示模板
|
||||
@staticmethod
|
||||
def intent_prompt():
|
||||
return ChatPromptTemplate.from_template(
|
||||
"""
|
||||
系统提示:您是一个专业的旅行意图识别专家,基于用户查询和对话历史,识别其意图,用于调用专门的agent server来执行;为方便后续的agent server处理,可以基于对话历史对用户查询进行改写,使问题更明确。严格遵守规则:
|
||||
- 支持意图:['weather' (天气查询), 'flight' (机票查询), 'train' (高铁/火车票查询), 'concert' (演唱会票查询), 'order' (票务预定), 'attraction' (景点推荐)] 或其组合(如 ['weather', 'flight'])。如果意图超出范围,返回意图 'out_of_scope'。
|
||||
- 注意票务预定和票务查询要区分开,涉及到订票时则为order,只是查询则为flight、train或concert。
|
||||
- 如果意图为 'out_of_scope'时,此时不需要再进行查询改写,你可以直接根据用户问题进行回复,将回复答案写到follow_up_message中即可。
|
||||
- 在进行用户查询改写时,不要回答其问题,也不要修改其原意,只需要将对话历史中跟该查询相关的上下文信息取出来,然后整合到一起,使用户查询更明确即可,要仔细分析上下文信息,不要进行过度整合。如果用户查询跟对话历史无关,则输出原始查询。
|
||||
- 如果用户的意图很不明确或者有歧义,可以向其进行追问,将追问问题填充到follow_up_message中。
|
||||
- 输出严格为JSON:{{"intents": ["intent1", "intent2"], "user_queries": {{"intent1": "user_query1", "intent2": "user_query2"}}, "follow_up_message": "追问消息"}}。不要添加额外文本!
|
||||
|
||||
输出示例:
|
||||
{{"intents": ["weather"], "user_queries": {{"weather": "今天北京天气如何"}}, "follow_up_message": ""}}
|
||||
{{"intents": ["weather"], "user_queries": {{}}, "follow_up_message": "你问的是今天北京天气状况吗"}}
|
||||
{{"intents": ["weather", "flight"], "user_queries": {{"weather": "今天北京天气如何", "flight": "查询一下10月28日,从北京飞往杭州的机票"}}, "follow_up_message": ""}}
|
||||
{{"intents": ["out_of_scope"], "user_queries": {{}}, "follow_up_message": "你好,我是智能旅行助手,欢迎您向我提问"}}
|
||||
|
||||
当前日期:{current_date} (Asia/Shanghai)。
|
||||
对话历史:{conversation_history}
|
||||
用户查询:{query}
|
||||
""")
|
||||
|
||||
# 定义天气结果总结提示模板,用于LLM总结天气查询的原始响应
|
||||
@staticmethod
|
||||
def summarize_weather_prompt():
|
||||
return ChatPromptTemplate.from_template(
|
||||
"""
|
||||
系统提示:您是一位专业的天气预报员,以生动、准确的风格总结天气信息。基于查询和结果:
|
||||
- 核心描述点:城市、日期、温度范围、天气描述、湿度、风向、降水等。
|
||||
- 如果结果为空或者意思为需要补充数据,则委婉提示“未找到数据,请确认城市/日期”
|
||||
- 语气:专业预报,如“根据最新数据,北京2025-07-31的天气预报为...”。
|
||||
- 保持中文,100-150字。
|
||||
- 如果查询无关,返回“请提供天气相关查询。”
|
||||
|
||||
查询:{query}
|
||||
结果:{raw_response}
|
||||
""")
|
||||
|
||||
# 定义票务结果总结提示模板,用于LLM总结票务查询的原始响应
|
||||
@staticmethod
|
||||
def summarize_ticket_prompt():
|
||||
return ChatPromptTemplate.from_template(
|
||||
"""
|
||||
系统提示:您是一位专业的旅行顾问,以热情、精确的风格总结票务信息。基于查询和结果:
|
||||
- 核心描述点:出发/到达、时间、类型、价格、剩余座位等。
|
||||
- 如果结果为空或者意思为需要补充数据,则委婉提示“未找到数据,请确认或修改条件”
|
||||
- 语气:顾问式,如“为您推荐北京到上海的机票选项...”。
|
||||
- 保持中文,100-150字。
|
||||
- 如果查询无关,返回“请提供票务相关查询。”
|
||||
|
||||
|
||||
查询:{query}
|
||||
结果:{raw_response}
|
||||
""")
|
||||
|
||||
# 定义景点推荐提示模板,用于LLM直接生成景点推荐内容
|
||||
@staticmethod
|
||||
def attraction_prompt():
|
||||
return ChatPromptTemplate.from_template(
|
||||
"""
|
||||
系统提示:您是一位旅行专家,基于用户查询生成景点推荐。规则:
|
||||
- 推荐3-5个景点,包含描述、理由、注意事项。
|
||||
- 基于槽位:城市、偏好。
|
||||
- 语气:热情推荐,如“推荐您在北京探索故宫...”。
|
||||
- 备注:内容生成,仅供参考。
|
||||
- 保持中文,150-250字。
|
||||
|
||||
查询:{query}
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(SmartVoyagePrompts.intent_prompt())
|
||||
250
app_streamlit/main.py
Normal file
250
app_streamlit/main.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
import asyncio
|
||||
import uuid
|
||||
import streamlit as st
|
||||
from python_a2a import AgentNetwork, Message, TextContent, MessageRole, Task
|
||||
from langchain_openai import ChatOpenAI
|
||||
import json
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
import re # 用于清理响应
|
||||
from create_logger import logger
|
||||
from app.prompts import SmartVoyagePrompts
|
||||
from conf import settings
|
||||
|
||||
# 启动命令:streamlit run main.py
|
||||
|
||||
# 设置页面配置
|
||||
st.set_page_config(page_title="基于A2A的SmartVoyage旅行助手系统", layout="wide", page_icon="🤖")
|
||||
|
||||
# 自定义 CSS 打造高端大气科技感,优化对比度
|
||||
st.markdown("""
|
||||
<style>
|
||||
/* 聊天消息框样式 */
|
||||
.stChatMessage {
|
||||
background-color: #2c3e50 !important;
|
||||
border-radius: 12px !important;
|
||||
padding: 15px !important;
|
||||
margin-bottom: 15px !important;
|
||||
box-shadow: 0 3px 6px rgba(0,0,0,0.2) !important;
|
||||
}
|
||||
|
||||
/* 用户消息框稍亮 */
|
||||
.stChatMessage.user {
|
||||
background-color: #34495e !important;
|
||||
}
|
||||
|
||||
/* ✅ 核心:强制所有文字变为白色(包括 markdown 内部) */
|
||||
.stChatMessage .stMarkdown,
|
||||
.stChatMessage .stMarkdown p,
|
||||
.stChatMessage .stMarkdown span,
|
||||
.stChatMessage .stMarkdown div,
|
||||
.stChatMessage .stMarkdown strong,
|
||||
.stChatMessage .stMarkdown em,
|
||||
.stChatMessage .stMarkdown code {
|
||||
color: #ffffff !important;
|
||||
}
|
||||
|
||||
/* 如果你想让 emoji 图标更亮一点 */
|
||||
.stChatMessage [data-testid="stChatMessageAvatarIcon"] {
|
||||
filter: brightness(1.2);
|
||||
}
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
# 初始化会话状态
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "agent_network" not in st.session_state:
|
||||
# 存储代理URL信息,便于查看
|
||||
st.session_state.agent_urls = {
|
||||
"WeatherQueryAssistant": "http://localhost:5005",
|
||||
"TicketQueryAssistant": "http://localhost:5006",
|
||||
"TicketOrderAssistant": "http://localhost:5007"
|
||||
}
|
||||
# 初始化网络
|
||||
network = AgentNetwork(name="Travel Assistant Network")
|
||||
network.add("WeatherQueryAssistant", "http://localhost:5005")
|
||||
network.add("TicketQueryAssistant", "http://localhost:5006")
|
||||
network.add("TicketOrderAssistant", "http://localhost:5007")
|
||||
st.session_state.agent_network = network
|
||||
# 加载配置并创建LLM
|
||||
st.session_state.llm = ChatOpenAI(
|
||||
model=settings.model_name,
|
||||
api_key=settings.api_key,
|
||||
base_url=settings.base_url,
|
||||
temperature=0.1
|
||||
)
|
||||
# 存储对话历史用于意图识别
|
||||
st.session_state.conversation_history = ""
|
||||
|
||||
# 意图识别agent
|
||||
def intent_agent(user_input):
|
||||
# 创建意图识别链:提示模板 + LLM
|
||||
chain = SmartVoyagePrompts.intent_prompt() | st.session_state.llm
|
||||
|
||||
# 调用LLM进行意图识别
|
||||
current_date = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d') # 获取当前日期(Asia/Shanghai时区)
|
||||
intent_response = chain.invoke(
|
||||
{"conversation_history": '\n'.join(st.session_state.conversation_history.split("\n")[-6:]), "query": user_input,
|
||||
"current_date": current_date}).content.strip()
|
||||
logger.info(f"意图识别原始响应: {intent_response}")
|
||||
|
||||
# 清理响应:移除可能的Markdown代码块标记
|
||||
intent_response = re.sub(r'^```json\s*|\s*```$', '', intent_response).strip()
|
||||
logger.info(f"清理后响应: {intent_response}")
|
||||
intent_output = json.loads(intent_response)
|
||||
# 提取意图、改写问题和追问消息
|
||||
intents = intent_output.get("intents", [])
|
||||
user_queries = intent_output.get("user_queries", {})
|
||||
follow_up_message = intent_output.get("follow_up_message", "")
|
||||
logger.info(f"intents: {intents}||user_queries: {user_queries}||follow_up_message: {follow_up_message} ")
|
||||
|
||||
return intents, user_queries, follow_up_message
|
||||
|
||||
|
||||
# 主界面布局
|
||||
st.title("🤖 基于A2A的SmartVoyage旅行智能助手")
|
||||
st.markdown("欢迎体验智能对话!输入问题,系统将精准识别意图并提供服务。")
|
||||
|
||||
# 两栏布局:左侧对话,右侧 Agent Card
|
||||
col1, col2 = st.columns([2, 1])
|
||||
|
||||
# 左侧对话区域
|
||||
with col1:
|
||||
st.subheader("💬 对话")
|
||||
# 对话历史
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
# 输入框
|
||||
if prompt := st.chat_input("请输入您的问题..."):
|
||||
# 显示用户消息
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
st.session_state.conversation_history += f"\nUser: {prompt}"
|
||||
|
||||
# 获取 LLM 和当前日期
|
||||
llm = st.session_state.llm
|
||||
current_date = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y-%m-%d')
|
||||
|
||||
# 意图识别
|
||||
with st.spinner("正在分析您的意图..."):
|
||||
try:
|
||||
# 意图识别过程
|
||||
intents, user_queries, follow_up_message = intent_agent(prompt)
|
||||
|
||||
# 根据意图输出生成响应
|
||||
if "out_of_scope" in intents:
|
||||
# 如果意图超出范围,返回大模型直接回复
|
||||
response = follow_up_message
|
||||
st.session_state.conversation_history += f"\nAssistant: {response}"
|
||||
elif follow_up_message != "":
|
||||
# 如果有追问消息,则直接返回
|
||||
response = follow_up_message
|
||||
st.session_state.conversation_history += f"\nAssistant: {response}" # 更新历史
|
||||
else: # 处理有效意图
|
||||
responses = [] # 存储每个意图的响应列表
|
||||
routed_agents = [] # 记录路由到的代理列表
|
||||
for intent in intents:
|
||||
logger.info(f"处理意图:{intent}")
|
||||
# 根据意图确定代理名称
|
||||
if intent == "weather":
|
||||
agent_name = "WeatherQueryAssistant"
|
||||
elif intent in ["flight", "train", "concert"]:
|
||||
agent_name = "TicketQueryAssistant"
|
||||
elif intent == "order":
|
||||
agent_name = "TicketOrderAssistant"
|
||||
else:
|
||||
agent_name = None
|
||||
|
||||
# 不同意图处理方式
|
||||
if intent == "attraction":
|
||||
# 对于景点推荐,直接使用LLM生成
|
||||
chain = SmartVoyagePrompts.attraction_prompt() | llm
|
||||
rec_response = chain.invoke({"query": prompt}).content.strip()
|
||||
responses.append(rec_response)
|
||||
elif agent_name:
|
||||
# 对于代理意图,则调用代理
|
||||
# 1)获取问题
|
||||
query_str = user_queries.get(intent, {})
|
||||
logger.info(f"{agent_name} 查询:{query_str}")
|
||||
# 2)获取代理实例
|
||||
agent = st.session_state.agent_network.get_agent(agent_name)
|
||||
# 3)构建历史对话信息+新查询,然后调用代理
|
||||
chat_history = '\n'.join(st.session_state.conversation_history.split("\n")[-7:-1]) + f'\nUser: {query_str}'
|
||||
message = Message(content=TextContent(text=chat_history), role=MessageRole.USER)
|
||||
task = Task(id="task-" + str(uuid.uuid4()), message=message.to_dict())
|
||||
raw_response = asyncio.run(agent.send_task_async(task))
|
||||
logger.info(f"{agent_name} 原始响应: {raw_response}") # 记录原始响应日志
|
||||
# 4)处理结果
|
||||
if raw_response.status.state == 'completed': # 正常结果
|
||||
agent_result = raw_response.artifacts[0]['parts'][0]['text']
|
||||
else: # 异常结果
|
||||
agent_result = raw_response.status.message['content']['text']
|
||||
|
||||
# 根据代理类型总结响应
|
||||
if agent_name == "WeatherQueryAssistant":
|
||||
chain = SmartVoyagePrompts.summarize_weather_prompt() | llm
|
||||
final_response = chain.invoke(
|
||||
{"query": query_str, "raw_response": agent_result}).content.strip()
|
||||
elif agent_name == "TicketQueryAssistant":
|
||||
chain = SmartVoyagePrompts.summarize_ticket_prompt() | llm
|
||||
final_response = chain.invoke(
|
||||
{"query": query_str, "raw_response": agent_result}).content.strip()
|
||||
else:
|
||||
final_response = agent_result
|
||||
|
||||
# 5)添加到历史
|
||||
responses.append(final_response) # 添加到响应列表
|
||||
routed_agents.append(agent_name) # 记录路由代理
|
||||
else:
|
||||
# 不支持的意图
|
||||
responses.append("暂不支持此意图。")
|
||||
|
||||
response = "\n\n".join(responses)
|
||||
if routed_agents:
|
||||
logger.info(f"路由到代理:{routed_agents}")
|
||||
st.session_state.conversation_history += f"\nAssistant: {response}"
|
||||
|
||||
# 显示助手消息
|
||||
with st.chat_message("assistant"):
|
||||
st.markdown(response)
|
||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"意图识别JSON解析失败")
|
||||
error_message = f"意图识别JSON解析失败:{str(json_err)}。请重试。"
|
||||
with st.chat_message("assistant"):
|
||||
st.markdown(error_message)
|
||||
st.session_state.messages.append({"role": "assistant", "content": error_message})
|
||||
except Exception as e:
|
||||
logger.error(f"处理异常: {str(e)}")
|
||||
error_message = f"处理失败:{str(e)}。请重试。"
|
||||
with st.chat_message("assistant"):
|
||||
st.markdown(error_message)
|
||||
st.session_state.messages.append({"role": "assistant", "content": error_message})
|
||||
|
||||
# 右侧 Agent Card 区域
|
||||
with col2:
|
||||
st.subheader("🛠️ AgentCard")
|
||||
for agent_name in st.session_state.agent_network.agents.keys():
|
||||
agent_card = st.session_state.agent_network.get_agent_card(agent_name)
|
||||
agent_url = st.session_state.agent_urls.get(agent_name, "未知地址")
|
||||
with st.expander(f"Agent: {agent_name}", expanded=False):
|
||||
st.markdown(f"<div class='card-title'>技能</div>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='card-content'>{agent_card.skills}</div>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='card-title'>描述</div>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='card-content'>{agent_card.description}</div>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='card-title'>地址</div>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='card-content'>{agent_url}</div>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='card-title'>状态</div>", unsafe_allow_html=True)
|
||||
st.markdown(f"<div class='card-content'>在线</div>", unsafe_allow_html=True)
|
||||
|
||||
# 页脚
|
||||
st.markdown("---")
|
||||
st.markdown('<div class="footer">Powered by 黑马程序员 | 基于Agent2Agent的旅行助手系统 v2.0</div>', unsafe_allow_html=True)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from conf import settings
|
||||
from create_logger import logger
|
||||
|
||||
# 创建FastMCP实例
|
||||
|
||||
Reference in New Issue
Block a user