Part A: 模型训练与保存 - train_and_save.py: 一次性脚本,训练XGBoost模型并保存完整Pipeline - cardio_predictor_model.pkl: 包含预处理器和分类器的完整Pipeline Part B: Flask API部署 - app.py: 提供/predict_cardio接口,接收11个特征值并返回预测结果 - 包含输入验证、数据处理和模型加载功能 Part C: 前端交互界面 - templates/index.html: 响应式HTML表单,集成JavaScript Fetch API - 提供示例数据填充和实时预测结果显示 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
336 lines
11 KiB
Python
336 lines
11 KiB
Python
#!/opt/anaconda3/envs/cardioenv/bin/python
|
||
"""
|
||
CardioAI - 心血管疾病预测API
|
||
Flask应用程序,提供心血管疾病预测API接口
|
||
"""
|
||
|
||
from flask import Flask, request, jsonify, render_template
|
||
import pandas as pd
|
||
import numpy as np
|
||
import joblib
|
||
import json
|
||
import os
|
||
|
||
# 创建Flask应用
|
||
app = Flask(__name__)
|
||
|
||
# 加载模型
|
||
MODEL_PATH = os.path.join(os.path.dirname(__file__), "cardio_predictor_model.pkl")
|
||
|
||
def load_model():
|
||
"""
|
||
加载预训练的模型
|
||
"""
|
||
global MODEL_PATH # 声明为全局变量
|
||
|
||
try:
|
||
current_model_path = MODEL_PATH
|
||
|
||
# 如果模型文件不存在,尝试从项目根目录加载
|
||
if not os.path.exists(current_model_path):
|
||
alt_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
|
||
if os.path.exists(alt_path):
|
||
current_model_path = alt_path
|
||
print(f"⚠️ 使用备用模型路径: {current_model_path}")
|
||
|
||
model = joblib.load(current_model_path)
|
||
print(f"✅ 模型加载成功: {current_model_path}")
|
||
return model
|
||
except Exception as e:
|
||
print(f"❌ 模型加载失败: {e}")
|
||
# 尝试从绝对路径加载
|
||
try:
|
||
root_model = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
|
||
if os.path.exists(root_model):
|
||
model = joblib.load(root_model)
|
||
print(f"✅ 从项目根目录加载模型成功: {root_model}")
|
||
return model
|
||
except Exception as e2:
|
||
print(f"❌ 备用路径也失败: {e2}")
|
||
return None
|
||
|
||
# 全局模型变量
|
||
model = load_model()
|
||
|
||
def preprocess_input(data):
|
||
"""
|
||
预处理输入数据,与训练时保持一致
|
||
"""
|
||
try:
|
||
# 首先确保数值类型正确
|
||
# 转换数值字段为适当的类型
|
||
for field in ['age', 'height', 'weight', 'ap_hi', 'ap_lo']:
|
||
if field in data:
|
||
data[field] = float(data[field])
|
||
|
||
# 转换分类字段为整数(除了已经映射的字段)
|
||
for field in ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']:
|
||
if field in data:
|
||
data[field] = int(data[field])
|
||
|
||
# 创建DataFrame
|
||
df = pd.DataFrame([data])
|
||
|
||
# 特征工程
|
||
# 将age(天)转换为年,四舍五入
|
||
if 'age' in df.columns:
|
||
df['age_years'] = (df['age'] / 365.25).round().astype(int)
|
||
|
||
# 计算BMI: weight / (height/100)^2
|
||
if 'height' in df.columns and 'weight' in df.columns:
|
||
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
|
||
|
||
# 类别转换
|
||
# cholesterol转换
|
||
cholesterol_map = {
|
||
1: 'normal',
|
||
2: 'above_normal',
|
||
3: 'well_above_normal'
|
||
}
|
||
if 'cholesterol' in df.columns:
|
||
df['cholesterol_cat'] = df['cholesterol'].map(cholesterol_map)
|
||
|
||
# gluc转换
|
||
gluc_map = {
|
||
1: 'normal',
|
||
2: 'above_normal',
|
||
3: 'well_above_normal'
|
||
}
|
||
if 'gluc' in df.columns:
|
||
df['gluc_cat'] = df['gluc'].map(gluc_map)
|
||
|
||
# BMI分类
|
||
def categorize_bmi(bmi):
|
||
if bmi < 18.5:
|
||
return 'underweight'
|
||
elif 18.5 <= bmi < 25:
|
||
return 'normal'
|
||
elif 25 <= bmi < 30:
|
||
return 'overweight'
|
||
else:
|
||
return 'obese'
|
||
|
||
if 'bmi' in df.columns:
|
||
df['bmi_category'] = df['bmi'].apply(categorize_bmi)
|
||
|
||
return df
|
||
|
||
except Exception as e:
|
||
print(f"数据预处理失败: {e}")
|
||
return None
|
||
|
||
def validate_input(data):
|
||
"""
|
||
验证输入数据的完整性和有效性
|
||
"""
|
||
required_fields = [
|
||
'age', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
||
'cholesterol', 'gluc', 'smoke', 'alco', 'active'
|
||
]
|
||
|
||
# 检查必需字段
|
||
missing_fields = [field for field in required_fields if field not in data]
|
||
if missing_fields:
|
||
return False, f"缺少必需字段: {', '.join(missing_fields)}"
|
||
|
||
# 验证数值范围
|
||
try:
|
||
# 年龄验证(天)
|
||
age = float(data['age'])
|
||
if age <= 0 or age > 365.25 * 120: # 假设最大120岁
|
||
return False, "年龄应在合理范围内(0-43830天)"
|
||
|
||
# 性别验证
|
||
gender = int(data['gender'])
|
||
if gender not in [1, 2]:
|
||
return False, "性别应为1(女性)或2(男性)"
|
||
|
||
# 身高验证(cm)
|
||
height = float(data['height'])
|
||
if height < 100 or height > 250:
|
||
return False, "身高应在100-250cm之间"
|
||
|
||
# 体重验证(kg)
|
||
weight = float(data['weight'])
|
||
if weight < 30 or weight > 300:
|
||
return False, "体重应在30-300kg之间"
|
||
|
||
# 血压验证
|
||
ap_hi = float(data['ap_hi'])
|
||
ap_lo = float(data['ap_lo'])
|
||
if ap_lo >= ap_hi:
|
||
return False, "舒张压应小于收缩压"
|
||
if ap_hi < 90 or ap_hi > 250:
|
||
return False, "收缩压应在90-250mmHg之间"
|
||
if ap_lo < 60 or ap_lo > 150:
|
||
return False, "舒张压应在60-150mmHg之间"
|
||
|
||
# 胆固醇验证
|
||
cholesterol = int(data['cholesterol'])
|
||
if cholesterol not in [1, 2, 3]:
|
||
return False, "胆固醇应为1,2,3之一"
|
||
|
||
# 血糖验证
|
||
gluc = int(data['gluc'])
|
||
if gluc not in [1, 2, 3]:
|
||
return False, "血糖应为1,2,3之一"
|
||
|
||
# 吸烟、饮酒、活动验证
|
||
smoke = int(data['smoke'])
|
||
alco = int(data['alco'])
|
||
active = int(data['active'])
|
||
if smoke not in [0, 1] or alco not in [0, 1] or active not in [0, 1]:
|
||
return False, "吸烟、饮酒、活动应为0或1"
|
||
|
||
except ValueError as e:
|
||
return False, f"数据类型错误: {str(e)}"
|
||
|
||
return True, "输入数据有效"
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""
|
||
主页,提供HTML界面
|
||
"""
|
||
return render_template('index.html')
|
||
|
||
@app.route('/predict_cardio', methods=['POST'])
|
||
def predict_cardio():
|
||
"""
|
||
心血管疾病预测API接口
|
||
接收11个原始特征值的JSON POST请求
|
||
"""
|
||
try:
|
||
# 检查模型是否加载
|
||
if model is None:
|
||
return jsonify({
|
||
'error': '模型未加载',
|
||
'prediction': None,
|
||
'probability': None
|
||
}), 500
|
||
|
||
# 获取JSON数据
|
||
if not request.is_json:
|
||
return jsonify({
|
||
'error': '请求内容类型应为application/json',
|
||
'prediction': None,
|
||
'probability': None
|
||
}), 400
|
||
|
||
data = request.get_json()
|
||
|
||
# 验证输入
|
||
is_valid, message = validate_input(data)
|
||
if not is_valid:
|
||
return jsonify({
|
||
'error': message,
|
||
'prediction': None,
|
||
'probability': None
|
||
}), 400
|
||
|
||
# 预处理输入数据
|
||
processed_data = preprocess_input(data)
|
||
if processed_data is None:
|
||
return jsonify({
|
||
'error': '数据预处理失败',
|
||
'prediction': None,
|
||
'probability': None
|
||
}), 400
|
||
|
||
# 删除不需要的列(与训练时保持一致)
|
||
columns_to_drop = ['id', 'age', 'cardio']
|
||
for col in columns_to_drop:
|
||
if col in processed_data.columns:
|
||
processed_data = processed_data.drop(col, axis=1)
|
||
|
||
# 进行预测
|
||
try:
|
||
prediction = model.predict(processed_data)[0]
|
||
probability = model.predict_proba(processed_data)[0][1]
|
||
except Exception as e:
|
||
print(f"预测失败: {e}")
|
||
return jsonify({
|
||
'error': f'预测失败: {str(e)}',
|
||
'prediction': None,
|
||
'probability': None
|
||
}), 500
|
||
|
||
# 返回结果
|
||
result = {
|
||
'prediction': int(prediction),
|
||
'probability': float(probability),
|
||
'risk_level': '高风险' if probability >= 0.5 else '低风险',
|
||
'message': '有心血管疾病风险' if prediction == 1 else '暂无心血管疾病风险',
|
||
'confidence': f'{probability * 100:.1f}%'
|
||
}
|
||
|
||
print(f"预测结果: {result}")
|
||
return jsonify(result)
|
||
|
||
except Exception as e:
|
||
print(f"API处理异常: {e}")
|
||
return jsonify({
|
||
'error': f'服务器内部错误: {str(e)}',
|
||
'prediction': None,
|
||
'probability': None
|
||
}), 500
|
||
|
||
@app.route('/health', methods=['GET'])
|
||
def health_check():
|
||
"""
|
||
健康检查端点
|
||
"""
|
||
health_status = {
|
||
'status': 'healthy' if model is not None else 'unhealthy',
|
||
'model_loaded': model is not None,
|
||
'api_version': '1.0.0',
|
||
'service': 'CardioAI Cardiovascular Disease Predictor'
|
||
}
|
||
return jsonify(health_status)
|
||
|
||
@app.route('/features', methods=['GET'])
|
||
def get_features_info():
|
||
"""
|
||
获取模型特征信息
|
||
"""
|
||
features_info = {
|
||
'required_features': [
|
||
{'name': 'age', 'type': 'numeric', 'unit': 'days', 'description': '年龄(天)'},
|
||
{'name': 'gender', 'type': 'categorical', 'values': [1, 2], 'description': '性别:1=女性,2=男性'},
|
||
{'name': 'height', 'type': 'numeric', 'unit': 'cm', 'description': '身高(厘米)'},
|
||
{'name': 'weight', 'type': 'numeric', 'unit': 'kg', 'description': '体重(千克)'},
|
||
{'name': 'ap_hi', 'type': 'numeric', 'unit': 'mmHg', 'description': '收缩压'},
|
||
{'name': 'ap_lo', 'type': 'numeric', 'unit': 'mmHg', 'description': '舒张压'},
|
||
{'name': 'cholesterol', 'type': 'categorical', 'values': [1, 2, 3], 'description': '胆固醇水平:1=正常,2=高于正常,3=很高'},
|
||
{'name': 'gluc', 'type': 'categorical', 'values': [1, 2, 3], 'description': '血糖水平:1=正常,2=高于正常,3=很高'},
|
||
{'name': 'smoke', 'type': 'categorical', 'values': [0, 1], 'description': '是否吸烟:0=否,1=是'},
|
||
{'name': 'alco', 'type': 'categorical', 'values': [0, 1], 'description': '是否饮酒:0=否,1=是'},
|
||
{'name': 'active', 'type': 'categorical', 'values': [0, 1], 'description': '是否积极运动:0=否,1=是'}
|
||
],
|
||
'derived_features': [
|
||
{'name': 'age_years', 'description': '年龄(年),由age转换而来'},
|
||
{'name': 'bmi', 'description': '身体质量指数,由height和weight计算而来'},
|
||
{'name': 'bmi_category', 'description': 'BMI分类:偏瘦/正常/超重/肥胖'}
|
||
]
|
||
}
|
||
return jsonify(features_info)
|
||
|
||
if __name__ == '__main__':
|
||
print("=" * 60)
|
||
print("CardioAI - 心血管疾病预测API")
|
||
print("=" * 60)
|
||
print(f"模型路径: {MODEL_PATH}")
|
||
print(f"模型加载状态: {'成功' if model is not None else '失败'}")
|
||
print("\n可用端点:")
|
||
print(" GET / - 前端界面")
|
||
print(" POST /predict_cardio - 预测API")
|
||
print(" GET /health - 健康检查")
|
||
print(" GET /features - 特征信息")
|
||
print("\n启动信息:")
|
||
print(" 主机: 0.0.0.0")
|
||
print(" 端口: 5000")
|
||
print(" 调试模式: 开启")
|
||
print("=" * 60)
|
||
|
||
# 启动Flask应用
|
||
app.run(host='0.0.0.0', port=9011, debug=True) |