143 lines
7.0 KiB
Python
143 lines
7.0 KiB
Python
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
|||
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|||
|
|
from langchain_mcp_adapters.tools import load_mcp_tools
|
|||
|
|
from langchain_openai import ChatOpenAI
|
|||
|
|
from mcp import ClientSession
|
|||
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|||
|
|
|
|||
|
|
from conf import settings
|
|||
|
|
from create_logger import logger
|
|||
|
|
|
|||
|
|
# 初始化LLM
|
|||
|
|
llm = ChatOpenAI(
|
|||
|
|
model=settings.model_name,
|
|||
|
|
base_url=settings.base_url,
|
|||
|
|
api_key=settings.api_key,
|
|||
|
|
temperature=0.1
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_ticket_mcp():
|
|||
|
|
try:
|
|||
|
|
# 启动 MCP server,通过streamable建立连接
|
|||
|
|
async with streamablehttp_client("http://127.0.0.1:8001/mcp") as (read, write, _):
|
|||
|
|
# 使用读写通道创建 MCP 会话
|
|||
|
|
async with ClientSession(read, write) as session:
|
|||
|
|
try:
|
|||
|
|
await session.initialize()
|
|||
|
|
print("会话初始化成功,可以开始调用工具。")
|
|||
|
|
|
|||
|
|
# 从 session 自动获取 MCP server 提供的工具列表。
|
|||
|
|
tools = await load_mcp_tools(session)
|
|||
|
|
print(f"tools-->{tools}")
|
|||
|
|
|
|||
|
|
# 调用远程工具
|
|||
|
|
# 测试1: 查询机票
|
|||
|
|
sql_flights = "SELECT * FROM flight_tickets WHERE departure_city = '上海' AND arrival_city = '北京' AND DATE(departure_time) = '2025-10-28' AND cabin_type = '公务舱'"
|
|||
|
|
result_flights = await session.call_tool("query_tickets", {"sql": sql_flights})
|
|||
|
|
result_flights_data = json.loads(result_flights) if isinstance(result_flights, str) else result_flights
|
|||
|
|
print(f"机票查询结果:{result_flights_data}")
|
|||
|
|
|
|||
|
|
# 测试2: 查询火车票
|
|||
|
|
sql_trains = "SELECT * FROM train_tickets WHERE departure_city = '北京' AND arrival_city = '上海' AND DATE(departure_time) = '2025-10-22' AND seat_type = '二等座'"
|
|||
|
|
result_trains = await session.call_tool("query_tickets", {"sql": sql_trains})
|
|||
|
|
result_trains_data = json.loads(result_trains) if isinstance(result_trains, str) else result_trains
|
|||
|
|
print(f"火车票查询结果:{result_trains_data}")
|
|||
|
|
|
|||
|
|
# 测试3: 查询演唱会票
|
|||
|
|
sql_concerts = "SELECT * FROM concert_tickets WHERE city = '北京' AND artist = '刀郎' AND DATE(start_time) = '2025-10-31' AND ticket_type = '看台'"
|
|||
|
|
result_concerts = await session.call_tool("query_tickets", {"sql": sql_concerts})
|
|||
|
|
result_concerts_data = json.loads(result_concerts) if isinstance(result_concerts, str) else result_concerts
|
|||
|
|
print(f"演唱会票查询结果:{result_concerts_data}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"票务 MCP 测试出错:{str(e)}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"连接或会话初始化时发生错误: {e}")
|
|||
|
|
print("请确认服务端脚本已启动并运行在 http://127.0.0.1:8001/mcp")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_weather_mcp():
|
|||
|
|
try:
|
|||
|
|
# 启动 MCP server,通过streamable建立连接
|
|||
|
|
async with streamablehttp_client("http://127.0.0.1:8002/mcp") as (read, write, _):
|
|||
|
|
# 使用读写通道创建 MCP 会话
|
|||
|
|
async with ClientSession(read, write) as session:
|
|||
|
|
try:
|
|||
|
|
await session.initialize()
|
|||
|
|
print("会话初始化成功,可以开始调用工具。")
|
|||
|
|
|
|||
|
|
# 从 session 自动获取 MCP server 提供的工具列表。
|
|||
|
|
tools = await load_mcp_tools(session)
|
|||
|
|
print(f"tools-->{tools}")
|
|||
|
|
|
|||
|
|
# 测试1: 查询指定日期天气
|
|||
|
|
sql = "SELECT * FROM weather_data WHERE city = '北京' AND fx_date = '2025-10-28'"
|
|||
|
|
result = await session.call_tool("query_weather", {"sql": sql})
|
|||
|
|
result_data = json.loads(result) if isinstance(result, str) else result
|
|||
|
|
print(f"指定日期天气结果:{result_data}")
|
|||
|
|
|
|||
|
|
# 测试2: 查询未来3天天气
|
|||
|
|
sql_range = "SELECT * FROM weather_data WHERE city = '北京' AND fx_date BETWEEN '2025-10-28' AND '2025-10-30'"
|
|||
|
|
result_range = await session.call_tool("query_weather", {"sql": sql_range})
|
|||
|
|
result_range_data = json.loads(result_range) if isinstance(result_range, str) else result_range
|
|||
|
|
print(f"天气范围查询结果:{result_range_data}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"天气 MCP 测试出错:{str(e)}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"连接或会话初始化时发生错误: {e}")
|
|||
|
|
print("请确认服务端脚本已启动并运行在 http://127.0.0.1:8002/mcp")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def order_tickets(query):
|
|||
|
|
try:
|
|||
|
|
# 启动 MCP server,通过streamable建立连接
|
|||
|
|
async with streamablehttp_client("http://127.0.0.1:8003/mcp") as (read, write, _):
|
|||
|
|
# 使用读写通道创建 MCP 会话
|
|||
|
|
async with ClientSession(read, write) as session:
|
|||
|
|
try:
|
|||
|
|
await session.initialize()
|
|||
|
|
|
|||
|
|
# 从 session 自动获取 MCP server 提供的工具列表。
|
|||
|
|
tools = await load_mcp_tools(session)
|
|||
|
|
# print(f"tools-->{tools}")
|
|||
|
|
|
|||
|
|
# 创建 agent 的提示模板
|
|||
|
|
prompt = ChatPromptTemplate.from_messages([
|
|||
|
|
("system",
|
|||
|
|
"你是一个票务预定助手,能够调用工具来完成火车票、飞机票或演出票的预定。你需要仔细分析工具需要的参数,然后从用户提供的信息中提取信息。如果用户提供的信息不足以提取到调用工具所有必要参数,则向用户追问,以获取该信息。不能自己编撰参数。"),
|
|||
|
|
("human", "{input}"),
|
|||
|
|
("placeholder", "{agent_scratchpad}"),
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# 构建工具调用代理
|
|||
|
|
agent = create_tool_calling_agent(llm, tools, prompt)
|
|||
|
|
|
|||
|
|
# 创建代理执行器
|
|||
|
|
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
|||
|
|
|
|||
|
|
# 代理调用
|
|||
|
|
response = await agent_executor.ainvoke({"input": query})
|
|||
|
|
|
|||
|
|
return response['output']
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.info(f"票务 MCP 测试出错:{str(e)}")
|
|||
|
|
return f"票务 MCP 查询出错:{str(e)}"
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"连接或会话初始化时发生错误: {e}")
|
|||
|
|
return "连接或会话初始化时发生错误"
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
|
|||
|
|
asyncio.run(test_ticket_mcp())
|
|||
|
|
|
|||
|
|
asyncio.run(test_weather_mcp())
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
query = input("请输入查询:")
|
|||
|
|
if query == "exit":
|
|||
|
|
break
|
|||
|
|
print(asyncio.run(order_tickets(query)))
|