feat: mcp
This commit is contained in:
45
services/sql_service.py
Normal file
45
services/sql_service.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import mysql.connector
|
||||
import json
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from create_logger import logger
|
||||
from utils.format import DateEncoder, default_encoder
|
||||
from conf import settings
|
||||
|
||||
|
||||
class SqlService:
|
||||
def __init__(self, service_name="数据库查询"):
|
||||
# 连接数据库
|
||||
self.conn = mysql.connector.connect(
|
||||
host=settings.mysql_host,
|
||||
port=settings.mysql_port,
|
||||
user=settings.mysql_user,
|
||||
password=settings.mysql_password,
|
||||
database=settings.mysql_database
|
||||
)
|
||||
self.service_name = service_name
|
||||
|
||||
# 定义执行SQL查询方法,输入SQL字符串,返回JSON字符串
|
||||
def execute_query(self, sql: str, no_data_message="未找数据") -> str:
|
||||
try:
|
||||
cursor = self.conn.cursor(dictionary=True)
|
||||
cursor.execute(sql)
|
||||
results = cursor.fetchall()
|
||||
cursor.close()
|
||||
# 格式化结果
|
||||
for result in results: # 遍历每个结果字典
|
||||
for key, value in result.items():
|
||||
if isinstance(value, (date, datetime, timedelta, Decimal)): # 检查值是否为特殊类型
|
||||
result[key] = default_encoder(value) # 使用自定义编码器格式化该值
|
||||
# 序列化为JSON,如果有结果返回success,否则no_data;使用DateEncoder,非ASCII不转义
|
||||
return json.dumps({"status": "success", "data": results} if results else {"status": "no_data", "message": no_data_message}, cls=DateEncoder, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.service_name}错误: {str(e)}")
|
||||
# 返回错误JSON响应
|
||||
return json.dumps({"status": "error", "message": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
service = SqlService()
|
||||
sql = "SELECT * FROM flight_tickets WHERE departure_city = '上海' AND arrival_city = '北京' AND DATE(departure_time) = '2025-10-28' AND cabin_type = '公务舱'"
|
||||
print(service.execute_query(sql))
|
||||
Reference in New Issue
Block a user