396 lines
13 KiB
Python
396 lines
13 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
CardioAI - 心血管疾病预测API服务
|
||
|
||
功能:
|
||
1. 加载预训练的机器学习模型
|
||
2. 提供RESTful API接口
|
||
3. 接收原始特征值并返回预测结果
|
||
4. 提供Web前端界面
|
||
|
||
启动方式:
|
||
conda activate cardioenv
|
||
python app.py
|
||
或
|
||
flask run
|
||
"""
|
||
|
||
from flask import Flask, request, jsonify, render_template, send_from_directory
|
||
import pandas as pd
|
||
import numpy as np
|
||
import joblib
|
||
import logging
|
||
from pathlib import Path
|
||
import sys
|
||
import os
|
||
import traceback
|
||
|
||
# 添加项目根目录到Python路径
|
||
project_root = Path(__file__).parent.parent
|
||
sys.path.append(str(project_root))
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建Flask应用
|
||
app = Flask(__name__)
|
||
app.config['JSON_AS_ASCII'] = False # 确保JSON支持中文
|
||
|
||
# 全局变量存储模型和特征信息
|
||
model_data = None
|
||
feature_names = None
|
||
pipeline = None
|
||
|
||
def load_model():
|
||
"""加载预训练的模型"""
|
||
global model_data, feature_names, pipeline
|
||
|
||
try:
|
||
# 模型文件路径
|
||
model_dir = Path(__file__).parent / "models"
|
||
model_path = model_dir / "cardio_predictor_model.pkl"
|
||
|
||
if not model_path.exists():
|
||
logger.error(f"模型文件不存在: {model_path}")
|
||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||
|
||
# 加载模型
|
||
logger.info(f"正在加载模型: {model_path}")
|
||
model_data = joblib.load(model_path)
|
||
|
||
# 提取Pipeline和特征信息
|
||
pipeline = model_data['pipeline']
|
||
feature_names = model_data.get('feature_names', [])
|
||
|
||
logger.info(f"模型加载成功!版本: {model_data.get('model_version', '未知')}")
|
||
logger.info(f"特征数量: {len(feature_names)}")
|
||
logger.info(f"特征列表: {feature_names}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
def preprocess_input(input_data):
|
||
"""
|
||
预处理输入数据(与训练时相同的处理)
|
||
|
||
参数:
|
||
input_data: 包含原始特征的字典
|
||
|
||
返回:
|
||
pd.DataFrame: 预处理后的特征数据框
|
||
"""
|
||
try:
|
||
# 创建数据框
|
||
df = pd.DataFrame([input_data])
|
||
|
||
# 1. 年龄转换:从天转换为年(四舍五入)
|
||
if 'age' in df.columns:
|
||
df['age_years'] = (df['age'] / 365.25).round().astype(int)
|
||
elif 'age_years' in df.columns:
|
||
# 如果已经提供了转换后的年龄,直接使用
|
||
df['age_years'] = df['age_years'].astype(int)
|
||
else:
|
||
raise ValueError("输入数据中必须包含'age'或'age_years'字段")
|
||
|
||
# 2. 计算BMI: BMI = weight(kg) / (height(m)^2)
|
||
if 'height' in df.columns and 'weight' in df.columns:
|
||
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
|
||
df['bmi'] = df['bmi'].round(2)
|
||
elif 'bmi' in df.columns:
|
||
# 如果已经提供了BMI,直接使用
|
||
df['bmi'] = df['bmi'].astype(float)
|
||
else:
|
||
raise ValueError("输入数据中必须包含'height'和'weight'字段或'bmi'字段")
|
||
|
||
# 3. 确保所有必要特征都存在
|
||
required_features = ['age_years', 'bmi', 'ap_hi', 'ap_lo',
|
||
'gender', 'cholesterol', 'gluc',
|
||
'smoke', 'alco', 'active']
|
||
|
||
missing_features = [f for f in required_features if f not in df.columns]
|
||
if missing_features:
|
||
raise ValueError(f"缺少必要特征: {missing_features}")
|
||
|
||
# 4. 选择模型需要的特征(按训练时的顺序)
|
||
processed_df = df[required_features].copy()
|
||
|
||
logger.debug(f"预处理后的特征数据框:\n{processed_df}")
|
||
return processed_df
|
||
|
||
except Exception as e:
|
||
logger.error(f"数据预处理失败: {str(e)}")
|
||
raise
|
||
|
||
def validate_input(input_data):
|
||
"""
|
||
验证输入数据的有效性
|
||
|
||
参数:
|
||
input_data: 输入特征字典
|
||
|
||
返回:
|
||
tuple: (是否有效, 错误消息)
|
||
"""
|
||
try:
|
||
# 检查必需字段
|
||
required_fields = ['age', 'gender', 'height', 'weight',
|
||
'ap_hi', 'ap_lo', 'cholesterol', 'gluc',
|
||
'smoke', 'alco', 'active']
|
||
|
||
missing_fields = [f for f in required_fields if f not in input_data]
|
||
if missing_fields:
|
||
return False, f"缺少必需字段: {missing_fields}"
|
||
|
||
# 检查数据类型
|
||
for field in required_fields:
|
||
value = input_data[field]
|
||
if not isinstance(value, (int, float)):
|
||
try:
|
||
# 尝试转换为数值
|
||
input_data[field] = float(value)
|
||
except ValueError:
|
||
return False, f"字段'{field}'必须为数值类型,当前值: {value}"
|
||
|
||
# 检查数值范围
|
||
validations = [
|
||
('age', 0, 365*150), # 年龄(天):0-150岁
|
||
('gender', 1, 2), # 性别:1或2
|
||
('height', 100, 250), # 身高(cm):100-250
|
||
('weight', 20, 300), # 体重(kg):20-300
|
||
('ap_hi', 50, 300), # 收缩压:50-300
|
||
('ap_lo', 30, 200), # 舒张压:30-200
|
||
('cholesterol', 1, 3), # 胆固醇:1-3
|
||
('gluc', 1, 3), # 血糖:1-3
|
||
('smoke', 0, 1), # 吸烟:0或1
|
||
('alco', 0, 1), # 饮酒:0或1
|
||
('active', 0, 1) # 活动:0或1
|
||
]
|
||
|
||
for field, min_val, max_val in validations:
|
||
value = input_data[field]
|
||
if not (min_val <= value <= max_val):
|
||
return False, f"字段'{field}'的值{value}超出有效范围[{min_val}, {max_val}]"
|
||
|
||
# 检查血压合理性
|
||
if input_data['ap_lo'] >= input_data['ap_hi']:
|
||
return False, "舒张压不能高于或等于收缩压"
|
||
|
||
return True, "输入数据有效"
|
||
|
||
except Exception as e:
|
||
return False, f"输入数据验证失败: {str(e)}"
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""主页 - 返回前端界面"""
|
||
return render_template('index.html')
|
||
|
||
@app.route('/predict_cardio', methods=['POST'])
|
||
def predict_cardio():
|
||
"""
|
||
心血管疾病预测API接口
|
||
|
||
请求格式(JSON):
|
||
{
|
||
"age": 20228, # 年龄(天)
|
||
"gender": 1, # 性别(1=女性,2=男性)
|
||
"height": 156, # 身高(cm)
|
||
"weight": 85, # 体重(kg)
|
||
"ap_hi": 140, # 收缩压(mmHg)
|
||
"ap_lo": 90, # 舒张压(mmHg)
|
||
"cholesterol": 1, # 胆固醇水平(1=正常,2=高于正常,3=极高)
|
||
"gluc": 1, # 血糖水平(1=正常,2=高于正常,3=极高)
|
||
"smoke": 0, # 吸烟(0=否,1=是)
|
||
"alco": 0, # 饮酒(0=否,1=是)
|
||
"active": 1 # 体育活动(0=否,1=是)
|
||
}
|
||
|
||
响应格式(JSON):
|
||
{
|
||
"success": true,
|
||
"prediction": 1,
|
||
"probability": 0.85,
|
||
"risk_level": "高危",
|
||
"message": "预测成功",
|
||
"features": {
|
||
"age_years": 55,
|
||
"bmi": 34.9,
|
||
... // 其他处理后的特征
|
||
}
|
||
}
|
||
"""
|
||
try:
|
||
# 检查模型是否已加载
|
||
if pipeline is None:
|
||
return jsonify({
|
||
"success": False,
|
||
"message": "模型未加载,请等待或联系管理员"
|
||
}), 503
|
||
|
||
# 获取JSON数据
|
||
if not request.is_json:
|
||
return jsonify({
|
||
"success": False,
|
||
"message": "请求必须是JSON格式"
|
||
}), 400
|
||
|
||
input_data = request.get_json()
|
||
logger.info(f"收到预测请求: {input_data}")
|
||
|
||
# 验证输入数据
|
||
is_valid, error_message = validate_input(input_data)
|
||
if not is_valid:
|
||
return jsonify({
|
||
"success": False,
|
||
"message": error_message
|
||
}), 400
|
||
|
||
# 预处理输入数据
|
||
processed_df = preprocess_input(input_data)
|
||
|
||
# 进行预测
|
||
prediction = pipeline.predict(processed_df)[0]
|
||
probability = pipeline.predict_proba(processed_df)[0][1] # 类别1的概率
|
||
|
||
# 确定风险等级
|
||
if probability < 0.3:
|
||
risk_level = "低危"
|
||
elif probability < 0.6:
|
||
risk_level = "中危"
|
||
else:
|
||
risk_level = "高危"
|
||
|
||
# 准备响应数据
|
||
response_data = {
|
||
"success": True,
|
||
"prediction": int(prediction),
|
||
"probability": float(round(probability, 4)),
|
||
"risk_level": risk_level,
|
||
"message": "预测成功",
|
||
"features": {
|
||
"age_years": int(processed_df['age_years'].iloc[0]),
|
||
"bmi": float(round(processed_df['bmi'].iloc[0], 2)),
|
||
"ap_hi": int(processed_df['ap_hi'].iloc[0]),
|
||
"ap_lo": int(processed_df['ap_lo'].iloc[0]),
|
||
"gender": int(processed_df['gender'].iloc[0]),
|
||
"cholesterol": int(processed_df['cholesterol'].iloc[0]),
|
||
"gluc": int(processed_df['gluc'].iloc[0]),
|
||
"smoke": int(processed_df['smoke'].iloc[0]),
|
||
"alco": int(processed_df['alco'].iloc[0]),
|
||
"active": int(processed_df['active'].iloc[0])
|
||
}
|
||
}
|
||
|
||
logger.info(f"预测结果: {response_data}")
|
||
return jsonify(response_data), 200
|
||
|
||
except Exception as e:
|
||
error_msg = f"预测过程中发生错误: {str(e)}"
|
||
logger.error(error_msg)
|
||
logger.error(traceback.format_exc())
|
||
return jsonify({
|
||
"success": False,
|
||
"message": error_msg
|
||
}), 500
|
||
|
||
@app.route('/health', methods=['GET'])
|
||
def health_check():
|
||
"""健康检查端点"""
|
||
try:
|
||
if pipeline is None:
|
||
return jsonify({
|
||
"status": "unhealthy",
|
||
"message": "模型未加载"
|
||
}), 503
|
||
|
||
# 简单的模型测试
|
||
test_data = {
|
||
"age": 20228,
|
||
"gender": 1,
|
||
"height": 156,
|
||
"weight": 85,
|
||
"ap_hi": 140,
|
||
"ap_lo": 90,
|
||
"cholesterol": 1,
|
||
"gluc": 1,
|
||
"smoke": 0,
|
||
"alco": 0,
|
||
"active": 1
|
||
}
|
||
|
||
processed_df = preprocess_input(test_data)
|
||
_ = pipeline.predict(processed_df)
|
||
|
||
return jsonify({
|
||
"status": "healthy",
|
||
"model_version": model_data.get('model_version', '未知'),
|
||
"features": len(feature_names) if feature_names else 0,
|
||
"message": "模型服务运行正常"
|
||
}), 200
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
"status": "unhealthy",
|
||
"message": f"健康检查失败: {str(e)}"
|
||
}), 500
|
||
|
||
@app.route('/model_info', methods=['GET'])
|
||
def model_info():
|
||
"""获取模型信息"""
|
||
if model_data is None:
|
||
return jsonify({
|
||
"success": False,
|
||
"message": "模型未加载"
|
||
}), 503
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"model_version": model_data.get('model_version', '未知'),
|
||
"description": model_data.get('description', 'CardioAI心血管疾病预测模型'),
|
||
"feature_count": len(feature_names) if feature_names else 0,
|
||
"features": feature_names if feature_names else []
|
||
}), 200
|
||
|
||
# 模型加载标志
|
||
_model_loaded = False
|
||
|
||
@app.before_request
|
||
def ensure_model_loaded():
|
||
"""确保模型已加载(每个请求前检查)"""
|
||
global pipeline, model_data, feature_names, _model_loaded
|
||
|
||
if not _model_loaded:
|
||
logger.info("首次请求,正在加载模型...")
|
||
success = load_model()
|
||
if success:
|
||
_model_loaded = True
|
||
logger.info("模型加载完成")
|
||
else:
|
||
logger.error("模型加载失败")
|
||
|
||
if __name__ == '__main__':
|
||
# 加载模型
|
||
success = load_model()
|
||
if not success:
|
||
logger.error("启动失败: 模型加载失败")
|
||
sys.exit(1)
|
||
|
||
# 启动Flask应用
|
||
logger.info("启动CardioAI预测API服务...")
|
||
logger.info("访问 http://localhost:5000 使用预测界面")
|
||
logger.info("API文档:")
|
||
logger.info(" GET / - 前端界面")
|
||
logger.info(" POST /predict_cardio - 预测接口")
|
||
logger.info(" GET /health - 健康检查")
|
||
logger.info(" GET /model_info - 模型信息")
|
||
|
||
app.run(host='0.0.0.0', port=5000, debug=True) |