143 lines
7.0 KiB
Python
143 lines
7.0 KiB
Python
import asyncio
|
||
import json
|
||
|
||
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
||
from langchain_core.prompts import ChatPromptTemplate
|
||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||
from langchain_openai import ChatOpenAI
|
||
from mcp import ClientSession
|
||
from mcp.client.streamable_http import streamablehttp_client
|
||
|
||
from conf import settings
|
||
from create_logger import logger
|
||
|
||
# 初始化LLM
|
||
llm = ChatOpenAI(
|
||
model=settings.model_name,
|
||
base_url=settings.base_url,
|
||
api_key=settings.api_key,
|
||
temperature=0.1
|
||
)
|
||
|
||
|
||
async def test_ticket_mcp():
|
||
try:
|
||
# 启动 MCP server,通过streamable建立连接
|
||
async with streamablehttp_client("http://127.0.0.1:8001/mcp") as (read, write, _):
|
||
# 使用读写通道创建 MCP 会话
|
||
async with ClientSession(read, write) as session:
|
||
try:
|
||
await session.initialize()
|
||
print("会话初始化成功,可以开始调用工具。")
|
||
|
||
# 从 session 自动获取 MCP server 提供的工具列表。
|
||
tools = await load_mcp_tools(session)
|
||
print(f"tools-->{tools}")
|
||
|
||
# 调用远程工具
|
||
# 测试1: 查询机票
|
||
sql_flights = "SELECT * FROM flight_tickets WHERE departure_city = '上海' AND arrival_city = '北京' AND DATE(departure_time) = '2025-10-28' AND cabin_type = '公务舱'"
|
||
result_flights = await session.call_tool("query_tickets", {"sql": sql_flights})
|
||
result_flights_data = json.loads(result_flights) if isinstance(result_flights, str) else result_flights
|
||
print(f"机票查询结果:{result_flights_data}")
|
||
|
||
# 测试2: 查询火车票
|
||
sql_trains = "SELECT * FROM train_tickets WHERE departure_city = '北京' AND arrival_city = '上海' AND DATE(departure_time) = '2025-10-22' AND seat_type = '二等座'"
|
||
result_trains = await session.call_tool("query_tickets", {"sql": sql_trains})
|
||
result_trains_data = json.loads(result_trains) if isinstance(result_trains, str) else result_trains
|
||
print(f"火车票查询结果:{result_trains_data}")
|
||
|
||
# 测试3: 查询演唱会票
|
||
sql_concerts = "SELECT * FROM concert_tickets WHERE city = '北京' AND artist = '刀郎' AND DATE(start_time) = '2025-10-31' AND ticket_type = '看台'"
|
||
result_concerts = await session.call_tool("query_tickets", {"sql": sql_concerts})
|
||
result_concerts_data = json.loads(result_concerts) if isinstance(result_concerts, str) else result_concerts
|
||
print(f"演唱会票查询结果:{result_concerts_data}")
|
||
except Exception as e:
|
||
print(f"票务 MCP 测试出错:{str(e)}")
|
||
except Exception as e:
|
||
print(f"连接或会话初始化时发生错误: {e}")
|
||
print("请确认服务端脚本已启动并运行在 http://127.0.0.1:8001/mcp")
|
||
|
||
|
||
async def test_weather_mcp():
|
||
try:
|
||
# 启动 MCP server,通过streamable建立连接
|
||
async with streamablehttp_client("http://127.0.0.1:8002/mcp") as (read, write, _):
|
||
# 使用读写通道创建 MCP 会话
|
||
async with ClientSession(read, write) as session:
|
||
try:
|
||
await session.initialize()
|
||
print("会话初始化成功,可以开始调用工具。")
|
||
|
||
# 从 session 自动获取 MCP server 提供的工具列表。
|
||
tools = await load_mcp_tools(session)
|
||
print(f"tools-->{tools}")
|
||
|
||
# 测试1: 查询指定日期天气
|
||
sql = "SELECT * FROM weather_data WHERE city = '北京' AND fx_date = '2025-10-28'"
|
||
result = await session.call_tool("query_weather", {"sql": sql})
|
||
result_data = json.loads(result) if isinstance(result, str) else result
|
||
print(f"指定日期天气结果:{result_data}")
|
||
|
||
# 测试2: 查询未来3天天气
|
||
sql_range = "SELECT * FROM weather_data WHERE city = '北京' AND fx_date BETWEEN '2025-10-28' AND '2025-10-30'"
|
||
result_range = await session.call_tool("query_weather", {"sql": sql_range})
|
||
result_range_data = json.loads(result_range) if isinstance(result_range, str) else result_range
|
||
print(f"天气范围查询结果:{result_range_data}")
|
||
except Exception as e:
|
||
print(f"天气 MCP 测试出错:{str(e)}")
|
||
except Exception as e:
|
||
print(f"连接或会话初始化时发生错误: {e}")
|
||
print("请确认服务端脚本已启动并运行在 http://127.0.0.1:8002/mcp")
|
||
|
||
|
||
async def order_tickets(query):
|
||
try:
|
||
# 启动 MCP server,通过streamable建立连接
|
||
async with streamablehttp_client("http://127.0.0.1:8003/mcp") as (read, write, _):
|
||
# 使用读写通道创建 MCP 会话
|
||
async with ClientSession(read, write) as session:
|
||
try:
|
||
await session.initialize()
|
||
|
||
# 从 session 自动获取 MCP server 提供的工具列表。
|
||
tools = await load_mcp_tools(session)
|
||
# print(f"tools-->{tools}")
|
||
|
||
# 创建 agent 的提示模板
|
||
prompt = ChatPromptTemplate.from_messages([
|
||
("system",
|
||
"你是一个票务预定助手,能够调用工具来完成火车票、飞机票或演出票的预定。你需要仔细分析工具需要的参数,然后从用户提供的信息中提取信息。如果用户提供的信息不足以提取到调用工具所有必要参数,则向用户追问,以获取该信息。不能自己编撰参数。"),
|
||
("human", "{input}"),
|
||
("placeholder", "{agent_scratchpad}"),
|
||
])
|
||
|
||
# 构建工具调用代理
|
||
agent = create_tool_calling_agent(llm, tools, prompt)
|
||
|
||
# 创建代理执行器
|
||
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
||
|
||
# 代理调用
|
||
response = await agent_executor.ainvoke({"input": query})
|
||
|
||
return response['output']
|
||
except Exception as e:
|
||
logger.info(f"票务 MCP 测试出错:{str(e)}")
|
||
return f"票务 MCP 查询出错:{str(e)}"
|
||
except Exception as e:
|
||
logger.error(f"连接或会话初始化时发生错误: {e}")
|
||
return "连接或会话初始化时发生错误"
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
asyncio.run(test_ticket_mcp())
|
||
|
||
asyncio.run(test_weather_mcp())
|
||
|
||
while True:
|
||
query = input("请输入查询:")
|
||
if query == "exit":
|
||
break
|
||
print(asyncio.run(order_tickets(query))) |