- Add comprehensive README.md with setup and usage instructions - Add .env.example template (sanitized, no real API keys) - Add root-level .gitignore to exclude .env and generated files - Add all project modules (dashboard, predictor) - Add data file and requirements.txt
188 lines
5.2 KiB
Python
188 lines
5.2 KiB
Python
"""
|
|
CardioAI 模块2: Flask API服务
|
|
心血管疾病风险预测 - 后端接口
|
|
"""
|
|
|
|
import joblib
|
|
import numpy as np
|
|
import pandas as pd
|
|
from flask import Flask, request, jsonify, render_template
|
|
from pathlib import Path
|
|
|
|
# ==================== 常量定义 ====================
|
|
CODE_ROOT = Path(r"F:\My_Git_Project\CardioAI")
|
|
MODEL_PATH = CODE_ROOT / "module2_predictor" / "cardio_predictor_model.pkl"
|
|
|
|
# ==================== Flask应用 ====================
|
|
app = Flask(__name__,
|
|
template_folder='templates',
|
|
static_folder='static')
|
|
|
|
# 全局变量存储模型
|
|
model = None
|
|
|
|
|
|
def load_model():
|
|
"""加载模型"""
|
|
global model
|
|
if model is None:
|
|
print("📂 正在加载模型...")
|
|
model = joblib.load(MODEL_PATH)
|
|
print("✅ 模型加载成功!")
|
|
return model
|
|
|
|
|
|
# ==================== 路由定义 ====================
|
|
@app.route('/')
|
|
def index():
|
|
"""渲染前端页面"""
|
|
return render_template('index.html')
|
|
|
|
|
|
@app.route('/predict_cardio', methods=['POST'])
|
|
def predict_cardio():
|
|
"""
|
|
心血管疾病风险预测接口
|
|
接收11个原始特征值的JSON POST请求
|
|
返回预测概率和结果
|
|
"""
|
|
try:
|
|
# 获取JSON数据
|
|
data = request.get_json()
|
|
|
|
if not data:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': '未收到数据'
|
|
}), 400
|
|
|
|
# 定义特征列顺序(与训练时一致)
|
|
feature_names = [
|
|
'age_years', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
|
'cholesterol', 'gluc', 'smoke', 'alco', 'active'
|
|
]
|
|
|
|
# 从请求中提取特征值
|
|
features = []
|
|
missing_fields = []
|
|
|
|
for col in feature_names:
|
|
if col in data:
|
|
features.append(float(data[col]))
|
|
else:
|
|
missing_fields.append(col)
|
|
features.append(0.0) # 默认值
|
|
|
|
# 计算BMI: weight / (height/100)^2
|
|
weight = float(data.get('weight', 0))
|
|
height = float(data.get('height', 0))
|
|
if height > 0:
|
|
bmi = weight / ((height / 100) ** 2)
|
|
features.append(bmi)
|
|
else:
|
|
features.append(0.0)
|
|
|
|
if missing_fields:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': f'缺少必要字段: {", ".join(missing_fields)}'
|
|
}), 400
|
|
|
|
# 定义特征列名(与训练时一致)
|
|
feature_columns = [
|
|
'age_years', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
|
'cholesterol', 'gluc', 'smoke', 'alco', 'active', 'bmi'
|
|
]
|
|
|
|
# 转换为DataFrame格式
|
|
X_input = pd.DataFrame([features], columns=feature_columns)
|
|
|
|
# 加载模型(如果尚未加载)
|
|
predictor = load_model()
|
|
|
|
# 预测
|
|
prediction = int(predictor.predict(X_input)[0])
|
|
prob_risk = float(predictor.predict_proba(X_input)[0][1])
|
|
prob_healthy = float(predictor.predict_proba(X_input)[0][0])
|
|
|
|
# 构建响应
|
|
result = {
|
|
'success': True,
|
|
'prediction': prediction,
|
|
'prediction_label': '有风险' if prediction == 1 else '健康',
|
|
'probability': {
|
|
'健康': round(prob_healthy * 100, 2),
|
|
'有风险': round(prob_risk * 100, 2)
|
|
},
|
|
'risk_level': get_risk_level(prob_risk),
|
|
'recommendation': get_recommendation(prob_risk, data)
|
|
}
|
|
|
|
return jsonify(result)
|
|
|
|
except ValueError as e:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': f'数据格式错误: {str(e)}'
|
|
}), 400
|
|
|
|
except Exception as e:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': f'预测失败: {str(e)}'
|
|
}), 500
|
|
|
|
|
|
def get_risk_level(probability: float) -> str:
|
|
"""根据概率返回风险等级"""
|
|
if probability < 0.3:
|
|
return '🟢 低风险'
|
|
elif probability < 0.5:
|
|
return '🟡 中低风险'
|
|
elif probability < 0.7:
|
|
return '🟠 中高风险'
|
|
else:
|
|
return '🔴 高风险'
|
|
|
|
|
|
def get_recommendation(probability: float, data: dict) -> str:
|
|
"""根据预测结果给出建议"""
|
|
if probability < 0.3:
|
|
return '继续保持健康的生活方式,定期体检。'
|
|
elif probability < 0.5:
|
|
return '建议适当增加运动,注意饮食均衡。'
|
|
elif probability < 0.7:
|
|
return '建议咨询医生,制定健康管理计划。'
|
|
else:
|
|
return '⚠️ 风险较高,请尽快就医检查。'
|
|
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health_check():
|
|
"""健康检查接口"""
|
|
return jsonify({
|
|
'status': 'healthy',
|
|
'service': 'CardioAI Cardiovascular Prediction API',
|
|
'version': '1.0.0'
|
|
})
|
|
|
|
|
|
# ==================== 启动应用 ====================
|
|
if __name__ == '__main__':
|
|
print("\n" + "="*60)
|
|
print("❤️ CardioAI 心血管疾病风险预测 API")
|
|
print("="*60)
|
|
print(f"📂 模型路径: {MODEL_PATH}")
|
|
print(f"🌐 启动地址: http://localhost:5001")
|
|
print("="*60 + "\n")
|
|
|
|
# 预加载模型
|
|
load_model()
|
|
|
|
# 启动Flask应用
|
|
app.run(
|
|
host='0.0.0.0',
|
|
port=5001,
|
|
debug=True
|
|
)
|