Files
itcast_01/module2_predictor/app.py
Zane Xu 88021e4a4b 添加Module2心血管疾病预测模型和Flask API
Part A: 模型训练与保存
- train_and_save.py: 一次性脚本,训练XGBoost模型并保存完整Pipeline
- cardio_predictor_model.pkl: 包含预处理器和分类器的完整Pipeline

Part B: Flask API部署
- app.py: 提供/predict_cardio接口,接收11个特征值并返回预测结果
- 包含输入验证、数据处理和模型加载功能

Part C: 前端交互界面
- templates/index.html: 响应式HTML表单,集成JavaScript Fetch API
- 提供示例数据填充和实时预测结果显示

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-26 22:53:47 +08:00

336 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/opt/anaconda3/envs/cardioenv/bin/python
"""
CardioAI - 心血管疾病预测API
Flask应用程序提供心血管疾病预测API接口
"""
from flask import Flask, request, jsonify, render_template
import pandas as pd
import numpy as np
import joblib
import json
import os
# 创建Flask应用
app = Flask(__name__)
# 加载模型
MODEL_PATH = os.path.join(os.path.dirname(__file__), "cardio_predictor_model.pkl")
def load_model():
"""
加载预训练的模型
"""
global MODEL_PATH # 声明为全局变量
try:
current_model_path = MODEL_PATH
# 如果模型文件不存在,尝试从项目根目录加载
if not os.path.exists(current_model_path):
alt_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
if os.path.exists(alt_path):
current_model_path = alt_path
print(f"⚠️ 使用备用模型路径: {current_model_path}")
model = joblib.load(current_model_path)
print(f"✅ 模型加载成功: {current_model_path}")
return model
except Exception as e:
print(f"❌ 模型加载失败: {e}")
# 尝试从绝对路径加载
try:
root_model = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
if os.path.exists(root_model):
model = joblib.load(root_model)
print(f"✅ 从项目根目录加载模型成功: {root_model}")
return model
except Exception as e2:
print(f"❌ 备用路径也失败: {e2}")
return None
# 全局模型变量
model = load_model()
def preprocess_input(data):
"""
预处理输入数据,与训练时保持一致
"""
try:
# 首先确保数值类型正确
# 转换数值字段为适当的类型
for field in ['age', 'height', 'weight', 'ap_hi', 'ap_lo']:
if field in data:
data[field] = float(data[field])
# 转换分类字段为整数(除了已经映射的字段)
for field in ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']:
if field in data:
data[field] = int(data[field])
# 创建DataFrame
df = pd.DataFrame([data])
# 特征工程
# 将age(天)转换为年,四舍五入
if 'age' in df.columns:
df['age_years'] = (df['age'] / 365.25).round().astype(int)
# 计算BMI: weight / (height/100)^2
if 'height' in df.columns and 'weight' in df.columns:
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
# 类别转换
# cholesterol转换
cholesterol_map = {
1: 'normal',
2: 'above_normal',
3: 'well_above_normal'
}
if 'cholesterol' in df.columns:
df['cholesterol_cat'] = df['cholesterol'].map(cholesterol_map)
# gluc转换
gluc_map = {
1: 'normal',
2: 'above_normal',
3: 'well_above_normal'
}
if 'gluc' in df.columns:
df['gluc_cat'] = df['gluc'].map(gluc_map)
# BMI分类
def categorize_bmi(bmi):
if bmi < 18.5:
return 'underweight'
elif 18.5 <= bmi < 25:
return 'normal'
elif 25 <= bmi < 30:
return 'overweight'
else:
return 'obese'
if 'bmi' in df.columns:
df['bmi_category'] = df['bmi'].apply(categorize_bmi)
return df
except Exception as e:
print(f"数据预处理失败: {e}")
return None
def validate_input(data):
"""
验证输入数据的完整性和有效性
"""
required_fields = [
'age', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
'cholesterol', 'gluc', 'smoke', 'alco', 'active'
]
# 检查必需字段
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return False, f"缺少必需字段: {', '.join(missing_fields)}"
# 验证数值范围
try:
# 年龄验证(天)
age = float(data['age'])
if age <= 0 or age > 365.25 * 120: # 假设最大120岁
return False, "年龄应在合理范围内(0-43830天)"
# 性别验证
gender = int(data['gender'])
if gender not in [1, 2]:
return False, "性别应为1(女性)或2(男性)"
# 身高验证(cm)
height = float(data['height'])
if height < 100 or height > 250:
return False, "身高应在100-250cm之间"
# 体重验证(kg)
weight = float(data['weight'])
if weight < 30 or weight > 300:
return False, "体重应在30-300kg之间"
# 血压验证
ap_hi = float(data['ap_hi'])
ap_lo = float(data['ap_lo'])
if ap_lo >= ap_hi:
return False, "舒张压应小于收缩压"
if ap_hi < 90 or ap_hi > 250:
return False, "收缩压应在90-250mmHg之间"
if ap_lo < 60 or ap_lo > 150:
return False, "舒张压应在60-150mmHg之间"
# 胆固醇验证
cholesterol = int(data['cholesterol'])
if cholesterol not in [1, 2, 3]:
return False, "胆固醇应为1,2,3之一"
# 血糖验证
gluc = int(data['gluc'])
if gluc not in [1, 2, 3]:
return False, "血糖应为1,2,3之一"
# 吸烟、饮酒、活动验证
smoke = int(data['smoke'])
alco = int(data['alco'])
active = int(data['active'])
if smoke not in [0, 1] or alco not in [0, 1] or active not in [0, 1]:
return False, "吸烟、饮酒、活动应为0或1"
except ValueError as e:
return False, f"数据类型错误: {str(e)}"
return True, "输入数据有效"
@app.route('/')
def index():
"""
主页提供HTML界面
"""
return render_template('index.html')
@app.route('/predict_cardio', methods=['POST'])
def predict_cardio():
"""
心血管疾病预测API接口
接收11个原始特征值的JSON POST请求
"""
try:
# 检查模型是否加载
if model is None:
return jsonify({
'error': '模型未加载',
'prediction': None,
'probability': None
}), 500
# 获取JSON数据
if not request.is_json:
return jsonify({
'error': '请求内容类型应为application/json',
'prediction': None,
'probability': None
}), 400
data = request.get_json()
# 验证输入
is_valid, message = validate_input(data)
if not is_valid:
return jsonify({
'error': message,
'prediction': None,
'probability': None
}), 400
# 预处理输入数据
processed_data = preprocess_input(data)
if processed_data is None:
return jsonify({
'error': '数据预处理失败',
'prediction': None,
'probability': None
}), 400
# 删除不需要的列(与训练时保持一致)
columns_to_drop = ['id', 'age', 'cardio']
for col in columns_to_drop:
if col in processed_data.columns:
processed_data = processed_data.drop(col, axis=1)
# 进行预测
try:
prediction = model.predict(processed_data)[0]
probability = model.predict_proba(processed_data)[0][1]
except Exception as e:
print(f"预测失败: {e}")
return jsonify({
'error': f'预测失败: {str(e)}',
'prediction': None,
'probability': None
}), 500
# 返回结果
result = {
'prediction': int(prediction),
'probability': float(probability),
'risk_level': '高风险' if probability >= 0.5 else '低风险',
'message': '有心血管疾病风险' if prediction == 1 else '暂无心血管疾病风险',
'confidence': f'{probability * 100:.1f}%'
}
print(f"预测结果: {result}")
return jsonify(result)
except Exception as e:
print(f"API处理异常: {e}")
return jsonify({
'error': f'服务器内部错误: {str(e)}',
'prediction': None,
'probability': None
}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""
健康检查端点
"""
health_status = {
'status': 'healthy' if model is not None else 'unhealthy',
'model_loaded': model is not None,
'api_version': '1.0.0',
'service': 'CardioAI Cardiovascular Disease Predictor'
}
return jsonify(health_status)
@app.route('/features', methods=['GET'])
def get_features_info():
"""
获取模型特征信息
"""
features_info = {
'required_features': [
{'name': 'age', 'type': 'numeric', 'unit': 'days', 'description': '年龄(天)'},
{'name': 'gender', 'type': 'categorical', 'values': [1, 2], 'description': '性别:1=女性,2=男性'},
{'name': 'height', 'type': 'numeric', 'unit': 'cm', 'description': '身高(厘米)'},
{'name': 'weight', 'type': 'numeric', 'unit': 'kg', 'description': '体重(千克)'},
{'name': 'ap_hi', 'type': 'numeric', 'unit': 'mmHg', 'description': '收缩压'},
{'name': 'ap_lo', 'type': 'numeric', 'unit': 'mmHg', 'description': '舒张压'},
{'name': 'cholesterol', 'type': 'categorical', 'values': [1, 2, 3], 'description': '胆固醇水平:1=正常,2=高于正常,3=很高'},
{'name': 'gluc', 'type': 'categorical', 'values': [1, 2, 3], 'description': '血糖水平:1=正常,2=高于正常,3=很高'},
{'name': 'smoke', 'type': 'categorical', 'values': [0, 1], 'description': '是否吸烟:0=否,1=是'},
{'name': 'alco', 'type': 'categorical', 'values': [0, 1], 'description': '是否饮酒:0=否,1=是'},
{'name': 'active', 'type': 'categorical', 'values': [0, 1], 'description': '是否积极运动:0=否,1=是'}
],
'derived_features': [
{'name': 'age_years', 'description': '年龄(年),由age转换而来'},
{'name': 'bmi', 'description': '身体质量指数,由height和weight计算而来'},
{'name': 'bmi_category', 'description': 'BMI分类:偏瘦/正常/超重/肥胖'}
]
}
return jsonify(features_info)
if __name__ == '__main__':
print("=" * 60)
print("CardioAI - 心血管疾病预测API")
print("=" * 60)
print(f"模型路径: {MODEL_PATH}")
print(f"模型加载状态: {'成功' if model is not None else '失败'}")
print("\n可用端点:")
print(" GET / - 前端界面")
print(" POST /predict_cardio - 预测API")
print(" GET /health - 健康检查")
print(" GET /features - 特征信息")
print("\n启动信息:")
print(" 主机: 0.0.0.0")
print(" 端口: 5000")
print(" 调试模式: 开启")
print("=" * 60)
# 启动Flask应用
app.run(host='0.0.0.0', port=9011, debug=True)