Files
AIcode/test/module2_predictor/app.py
2026-04-02 19:52:38 +08:00

396 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
CardioAI - 心血管疾病预测API服务
功能:
1. 加载预训练的机器学习模型
2. 提供RESTful API接口
3. 接收原始特征值并返回预测结果
4. 提供Web前端界面
启动方式:
conda activate cardioenv
python app.py
flask run
"""
from flask import Flask, request, jsonify, render_template, send_from_directory
import pandas as pd
import numpy as np
import joblib
import logging
from pathlib import Path
import sys
import os
import traceback
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 创建Flask应用
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False # 确保JSON支持中文
# 全局变量存储模型和特征信息
model_data = None
feature_names = None
pipeline = None
def load_model():
"""加载预训练的模型"""
global model_data, feature_names, pipeline
try:
# 模型文件路径
model_dir = Path(__file__).parent / "models"
model_path = model_dir / "cardio_predictor_model.pkl"
if not model_path.exists():
logger.error(f"模型文件不存在: {model_path}")
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 加载模型
logger.info(f"正在加载模型: {model_path}")
model_data = joblib.load(model_path)
# 提取Pipeline和特征信息
pipeline = model_data['pipeline']
feature_names = model_data.get('feature_names', [])
logger.info(f"模型加载成功!版本: {model_data.get('model_version', '未知')}")
logger.info(f"特征数量: {len(feature_names)}")
logger.info(f"特征列表: {feature_names}")
return True
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
logger.error(traceback.format_exc())
return False
def preprocess_input(input_data):
"""
预处理输入数据(与训练时相同的处理)
参数:
input_data: 包含原始特征的字典
返回:
pd.DataFrame: 预处理后的特征数据框
"""
try:
# 创建数据框
df = pd.DataFrame([input_data])
# 1. 年龄转换:从天转换为年(四舍五入)
if 'age' in df.columns:
df['age_years'] = (df['age'] / 365.25).round().astype(int)
elif 'age_years' in df.columns:
# 如果已经提供了转换后的年龄,直接使用
df['age_years'] = df['age_years'].astype(int)
else:
raise ValueError("输入数据中必须包含'age''age_years'字段")
# 2. 计算BMI: BMI = weight(kg) / (height(m)^2)
if 'height' in df.columns and 'weight' in df.columns:
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
df['bmi'] = df['bmi'].round(2)
elif 'bmi' in df.columns:
# 如果已经提供了BMI直接使用
df['bmi'] = df['bmi'].astype(float)
else:
raise ValueError("输入数据中必须包含'height''weight'字段或'bmi'字段")
# 3. 确保所有必要特征都存在
required_features = ['age_years', 'bmi', 'ap_hi', 'ap_lo',
'gender', 'cholesterol', 'gluc',
'smoke', 'alco', 'active']
missing_features = [f for f in required_features if f not in df.columns]
if missing_features:
raise ValueError(f"缺少必要特征: {missing_features}")
# 4. 选择模型需要的特征(按训练时的顺序)
processed_df = df[required_features].copy()
logger.debug(f"预处理后的特征数据框:\n{processed_df}")
return processed_df
except Exception as e:
logger.error(f"数据预处理失败: {str(e)}")
raise
def validate_input(input_data):
"""
验证输入数据的有效性
参数:
input_data: 输入特征字典
返回:
tuple: (是否有效, 错误消息)
"""
try:
# 检查必需字段
required_fields = ['age', 'gender', 'height', 'weight',
'ap_hi', 'ap_lo', 'cholesterol', 'gluc',
'smoke', 'alco', 'active']
missing_fields = [f for f in required_fields if f not in input_data]
if missing_fields:
return False, f"缺少必需字段: {missing_fields}"
# 检查数据类型
for field in required_fields:
value = input_data[field]
if not isinstance(value, (int, float)):
try:
# 尝试转换为数值
input_data[field] = float(value)
except ValueError:
return False, f"字段'{field}'必须为数值类型,当前值: {value}"
# 检查数值范围
validations = [
('age', 0, 365*150), # 年龄0-150岁
('gender', 1, 2), # 性别1或2
('height', 100, 250), # 身高cm100-250
('weight', 20, 300), # 体重kg20-300
('ap_hi', 50, 300), # 收缩压50-300
('ap_lo', 30, 200), # 舒张压30-200
('cholesterol', 1, 3), # 胆固醇1-3
('gluc', 1, 3), # 血糖1-3
('smoke', 0, 1), # 吸烟0或1
('alco', 0, 1), # 饮酒0或1
('active', 0, 1) # 活动0或1
]
for field, min_val, max_val in validations:
value = input_data[field]
if not (min_val <= value <= max_val):
return False, f"字段'{field}'的值{value}超出有效范围[{min_val}, {max_val}]"
# 检查血压合理性
if input_data['ap_lo'] >= input_data['ap_hi']:
return False, "舒张压不能高于或等于收缩压"
return True, "输入数据有效"
except Exception as e:
return False, f"输入数据验证失败: {str(e)}"
@app.route('/')
def index():
"""主页 - 返回前端界面"""
return render_template('index.html')
@app.route('/predict_cardio', methods=['POST'])
def predict_cardio():
"""
心血管疾病预测API接口
请求格式JSON
{
"age": 20228, # 年龄(天)
"gender": 1, # 性别1=女性2=男性)
"height": 156, # 身高cm
"weight": 85, # 体重kg
"ap_hi": 140, # 收缩压mmHg
"ap_lo": 90, # 舒张压mmHg
"cholesterol": 1, # 胆固醇水平1=正常2=高于正常3=极高)
"gluc": 1, # 血糖水平1=正常2=高于正常3=极高)
"smoke": 0, # 吸烟0=否1=是)
"alco": 0, # 饮酒0=否1=是)
"active": 1 # 体育活动0=否1=是)
}
响应格式JSON
{
"success": true,
"prediction": 1,
"probability": 0.85,
"risk_level": "高危",
"message": "预测成功",
"features": {
"age_years": 55,
"bmi": 34.9,
... // 其他处理后的特征
}
}
"""
try:
# 检查模型是否已加载
if pipeline is None:
return jsonify({
"success": False,
"message": "模型未加载,请等待或联系管理员"
}), 503
# 获取JSON数据
if not request.is_json:
return jsonify({
"success": False,
"message": "请求必须是JSON格式"
}), 400
input_data = request.get_json()
logger.info(f"收到预测请求: {input_data}")
# 验证输入数据
is_valid, error_message = validate_input(input_data)
if not is_valid:
return jsonify({
"success": False,
"message": error_message
}), 400
# 预处理输入数据
processed_df = preprocess_input(input_data)
# 进行预测
prediction = pipeline.predict(processed_df)[0]
probability = pipeline.predict_proba(processed_df)[0][1] # 类别1的概率
# 确定风险等级
if probability < 0.3:
risk_level = "低危"
elif probability < 0.6:
risk_level = "中危"
else:
risk_level = "高危"
# 准备响应数据
response_data = {
"success": True,
"prediction": int(prediction),
"probability": float(round(probability, 4)),
"risk_level": risk_level,
"message": "预测成功",
"features": {
"age_years": int(processed_df['age_years'].iloc[0]),
"bmi": float(round(processed_df['bmi'].iloc[0], 2)),
"ap_hi": int(processed_df['ap_hi'].iloc[0]),
"ap_lo": int(processed_df['ap_lo'].iloc[0]),
"gender": int(processed_df['gender'].iloc[0]),
"cholesterol": int(processed_df['cholesterol'].iloc[0]),
"gluc": int(processed_df['gluc'].iloc[0]),
"smoke": int(processed_df['smoke'].iloc[0]),
"alco": int(processed_df['alco'].iloc[0]),
"active": int(processed_df['active'].iloc[0])
}
}
logger.info(f"预测结果: {response_data}")
return jsonify(response_data), 200
except Exception as e:
error_msg = f"预测过程中发生错误: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
"success": False,
"message": error_msg
}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查端点"""
try:
if pipeline is None:
return jsonify({
"status": "unhealthy",
"message": "模型未加载"
}), 503
# 简单的模型测试
test_data = {
"age": 20228,
"gender": 1,
"height": 156,
"weight": 85,
"ap_hi": 140,
"ap_lo": 90,
"cholesterol": 1,
"gluc": 1,
"smoke": 0,
"alco": 0,
"active": 1
}
processed_df = preprocess_input(test_data)
_ = pipeline.predict(processed_df)
return jsonify({
"status": "healthy",
"model_version": model_data.get('model_version', '未知'),
"features": len(feature_names) if feature_names else 0,
"message": "模型服务运行正常"
}), 200
except Exception as e:
return jsonify({
"status": "unhealthy",
"message": f"健康检查失败: {str(e)}"
}), 500
@app.route('/model_info', methods=['GET'])
def model_info():
"""获取模型信息"""
if model_data is None:
return jsonify({
"success": False,
"message": "模型未加载"
}), 503
return jsonify({
"success": True,
"model_version": model_data.get('model_version', '未知'),
"description": model_data.get('description', 'CardioAI心血管疾病预测模型'),
"feature_count": len(feature_names) if feature_names else 0,
"features": feature_names if feature_names else []
}), 200
# 模型加载标志
_model_loaded = False
@app.before_request
def ensure_model_loaded():
"""确保模型已加载(每个请求前检查)"""
global pipeline, model_data, feature_names, _model_loaded
if not _model_loaded:
logger.info("首次请求,正在加载模型...")
success = load_model()
if success:
_model_loaded = True
logger.info("模型加载完成")
else:
logger.error("模型加载失败")
if __name__ == '__main__':
# 加载模型
success = load_model()
if not success:
logger.error("启动失败: 模型加载失败")
sys.exit(1)
# 启动Flask应用
logger.info("启动CardioAI预测API服务...")
logger.info("访问 http://localhost:5000 使用预测界面")
logger.info("API文档:")
logger.info(" GET / - 前端界面")
logger.info(" POST /predict_cardio - 预测接口")
logger.info(" GET /health - 健康检查")
logger.info(" GET /model_info - 模型信息")
app.run(host='0.0.0.0', port=5000, debug=True)