Files
AIcode/test/module2_predictor/app.py

396 lines
13 KiB
Python
Raw Normal View History

2026-04-02 19:52:38 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
CardioAI - 心血管疾病预测API服务
功能
1. 加载预训练的机器学习模型
2. 提供RESTful API接口
3. 接收原始特征值并返回预测结果
4. 提供Web前端界面
启动方式
conda activate cardioenv
python app.py
flask run
"""
from flask import Flask, request, jsonify, render_template, send_from_directory
import pandas as pd
import numpy as np
import joblib
import logging
from pathlib import Path
import sys
import os
import traceback
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 创建Flask应用
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False # 确保JSON支持中文
# 全局变量存储模型和特征信息
model_data = None
feature_names = None
pipeline = None
def load_model():
"""加载预训练的模型"""
global model_data, feature_names, pipeline
try:
# 模型文件路径
model_dir = Path(__file__).parent / "models"
model_path = model_dir / "cardio_predictor_model.pkl"
if not model_path.exists():
logger.error(f"模型文件不存在: {model_path}")
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 加载模型
logger.info(f"正在加载模型: {model_path}")
model_data = joblib.load(model_path)
# 提取Pipeline和特征信息
pipeline = model_data['pipeline']
feature_names = model_data.get('feature_names', [])
logger.info(f"模型加载成功!版本: {model_data.get('model_version', '未知')}")
logger.info(f"特征数量: {len(feature_names)}")
logger.info(f"特征列表: {feature_names}")
return True
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
logger.error(traceback.format_exc())
return False
def preprocess_input(input_data):
"""
预处理输入数据与训练时相同的处理
参数
input_data: 包含原始特征的字典
返回
pd.DataFrame: 预处理后的特征数据框
"""
try:
# 创建数据框
df = pd.DataFrame([input_data])
# 1. 年龄转换:从天转换为年(四舍五入)
if 'age' in df.columns:
df['age_years'] = (df['age'] / 365.25).round().astype(int)
elif 'age_years' in df.columns:
# 如果已经提供了转换后的年龄,直接使用
df['age_years'] = df['age_years'].astype(int)
else:
raise ValueError("输入数据中必须包含'age''age_years'字段")
# 2. 计算BMI: BMI = weight(kg) / (height(m)^2)
if 'height' in df.columns and 'weight' in df.columns:
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
df['bmi'] = df['bmi'].round(2)
elif 'bmi' in df.columns:
# 如果已经提供了BMI直接使用
df['bmi'] = df['bmi'].astype(float)
else:
raise ValueError("输入数据中必须包含'height''weight'字段或'bmi'字段")
# 3. 确保所有必要特征都存在
required_features = ['age_years', 'bmi', 'ap_hi', 'ap_lo',
'gender', 'cholesterol', 'gluc',
'smoke', 'alco', 'active']
missing_features = [f for f in required_features if f not in df.columns]
if missing_features:
raise ValueError(f"缺少必要特征: {missing_features}")
# 4. 选择模型需要的特征(按训练时的顺序)
processed_df = df[required_features].copy()
logger.debug(f"预处理后的特征数据框:\n{processed_df}")
return processed_df
except Exception as e:
logger.error(f"数据预处理失败: {str(e)}")
raise
def validate_input(input_data):
"""
验证输入数据的有效性
参数
input_data: 输入特征字典
返回
tuple: (是否有效, 错误消息)
"""
try:
# 检查必需字段
required_fields = ['age', 'gender', 'height', 'weight',
'ap_hi', 'ap_lo', 'cholesterol', 'gluc',
'smoke', 'alco', 'active']
missing_fields = [f for f in required_fields if f not in input_data]
if missing_fields:
return False, f"缺少必需字段: {missing_fields}"
# 检查数据类型
for field in required_fields:
value = input_data[field]
if not isinstance(value, (int, float)):
try:
# 尝试转换为数值
input_data[field] = float(value)
except ValueError:
return False, f"字段'{field}'必须为数值类型,当前值: {value}"
# 检查数值范围
validations = [
('age', 0, 365*150), # 年龄0-150岁
('gender', 1, 2), # 性别1或2
('height', 100, 250), # 身高cm100-250
('weight', 20, 300), # 体重kg20-300
('ap_hi', 50, 300), # 收缩压50-300
('ap_lo', 30, 200), # 舒张压30-200
('cholesterol', 1, 3), # 胆固醇1-3
('gluc', 1, 3), # 血糖1-3
('smoke', 0, 1), # 吸烟0或1
('alco', 0, 1), # 饮酒0或1
('active', 0, 1) # 活动0或1
]
for field, min_val, max_val in validations:
value = input_data[field]
if not (min_val <= value <= max_val):
return False, f"字段'{field}'的值{value}超出有效范围[{min_val}, {max_val}]"
# 检查血压合理性
if input_data['ap_lo'] >= input_data['ap_hi']:
return False, "舒张压不能高于或等于收缩压"
return True, "输入数据有效"
except Exception as e:
return False, f"输入数据验证失败: {str(e)}"
@app.route('/')
def index():
"""主页 - 返回前端界面"""
return render_template('index.html')
@app.route('/predict_cardio', methods=['POST'])
def predict_cardio():
"""
心血管疾病预测API接口
请求格式JSON
{
"age": 20228, # 年龄(天)
"gender": 1, # 性别1=女性2=男性)
"height": 156, # 身高cm
"weight": 85, # 体重kg
"ap_hi": 140, # 收缩压mmHg
"ap_lo": 90, # 舒张压mmHg
"cholesterol": 1, # 胆固醇水平1=正常2=高于正常3=极高)
"gluc": 1, # 血糖水平1=正常2=高于正常3=极高)
"smoke": 0, # 吸烟0=否1=是)
"alco": 0, # 饮酒0=否1=是)
"active": 1 # 体育活动0=否1=是)
}
响应格式JSON
{
"success": true,
"prediction": 1,
"probability": 0.85,
"risk_level": "高危",
"message": "预测成功",
"features": {
"age_years": 55,
"bmi": 34.9,
... // 其他处理后的特征
}
}
"""
try:
# 检查模型是否已加载
if pipeline is None:
return jsonify({
"success": False,
"message": "模型未加载,请等待或联系管理员"
}), 503
# 获取JSON数据
if not request.is_json:
return jsonify({
"success": False,
"message": "请求必须是JSON格式"
}), 400
input_data = request.get_json()
logger.info(f"收到预测请求: {input_data}")
# 验证输入数据
is_valid, error_message = validate_input(input_data)
if not is_valid:
return jsonify({
"success": False,
"message": error_message
}), 400
# 预处理输入数据
processed_df = preprocess_input(input_data)
# 进行预测
prediction = pipeline.predict(processed_df)[0]
probability = pipeline.predict_proba(processed_df)[0][1] # 类别1的概率
# 确定风险等级
if probability < 0.3:
risk_level = "低危"
elif probability < 0.6:
risk_level = "中危"
else:
risk_level = "高危"
# 准备响应数据
response_data = {
"success": True,
"prediction": int(prediction),
"probability": float(round(probability, 4)),
"risk_level": risk_level,
"message": "预测成功",
"features": {
"age_years": int(processed_df['age_years'].iloc[0]),
"bmi": float(round(processed_df['bmi'].iloc[0], 2)),
"ap_hi": int(processed_df['ap_hi'].iloc[0]),
"ap_lo": int(processed_df['ap_lo'].iloc[0]),
"gender": int(processed_df['gender'].iloc[0]),
"cholesterol": int(processed_df['cholesterol'].iloc[0]),
"gluc": int(processed_df['gluc'].iloc[0]),
"smoke": int(processed_df['smoke'].iloc[0]),
"alco": int(processed_df['alco'].iloc[0]),
"active": int(processed_df['active'].iloc[0])
}
}
logger.info(f"预测结果: {response_data}")
return jsonify(response_data), 200
except Exception as e:
error_msg = f"预测过程中发生错误: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
"success": False,
"message": error_msg
}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查端点"""
try:
if pipeline is None:
return jsonify({
"status": "unhealthy",
"message": "模型未加载"
}), 503
# 简单的模型测试
test_data = {
"age": 20228,
"gender": 1,
"height": 156,
"weight": 85,
"ap_hi": 140,
"ap_lo": 90,
"cholesterol": 1,
"gluc": 1,
"smoke": 0,
"alco": 0,
"active": 1
}
processed_df = preprocess_input(test_data)
_ = pipeline.predict(processed_df)
return jsonify({
"status": "healthy",
"model_version": model_data.get('model_version', '未知'),
"features": len(feature_names) if feature_names else 0,
"message": "模型服务运行正常"
}), 200
except Exception as e:
return jsonify({
"status": "unhealthy",
"message": f"健康检查失败: {str(e)}"
}), 500
@app.route('/model_info', methods=['GET'])
def model_info():
"""获取模型信息"""
if model_data is None:
return jsonify({
"success": False,
"message": "模型未加载"
}), 503
return jsonify({
"success": True,
"model_version": model_data.get('model_version', '未知'),
"description": model_data.get('description', 'CardioAI心血管疾病预测模型'),
"feature_count": len(feature_names) if feature_names else 0,
"features": feature_names if feature_names else []
}), 200
# 模型加载标志
_model_loaded = False
@app.before_request
def ensure_model_loaded():
"""确保模型已加载(每个请求前检查)"""
global pipeline, model_data, feature_names, _model_loaded
if not _model_loaded:
logger.info("首次请求,正在加载模型...")
success = load_model()
if success:
_model_loaded = True
logger.info("模型加载完成")
else:
logger.error("模型加载失败")
if __name__ == '__main__':
# 加载模型
success = load_model()
if not success:
logger.error("启动失败: 模型加载失败")
sys.exit(1)
# 启动Flask应用
logger.info("启动CardioAI预测API服务...")
logger.info("访问 http://localhost:5000 使用预测界面")
logger.info("API文档:")
logger.info(" GET / - 前端界面")
logger.info(" POST /predict_cardio - 预测接口")
logger.info(" GET /health - 健康检查")
logger.info(" GET /model_info - 模型信息")
app.run(host='0.0.0.0', port=5000, debug=True)