187 lines
4.9 KiB
Python
187 lines
4.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
CardioAI - Module 2: Flask API 部署
|
|
心血管疾病预测服务后端
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import joblib
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from flask import Flask, render_template, request, jsonify
|
|
from dotenv import load_dotenv
|
|
|
|
# ============================================
|
|
# 配置与初始化
|
|
# ============================================
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
CODE_ROOT = Path(r"E:\project_ai\claude_project1\aicodes")
|
|
MODEL_PATH = CODE_ROOT / "module2_predictor" / "cardio_predictor_model.pkl"
|
|
TEMPLATE_DIR = CODE_ROOT / "module2_predictor" / "templates"
|
|
|
|
# 创建 Flask 应用
|
|
app = Flask(__name__,
|
|
template_folder=str(TEMPLATE_DIR),
|
|
static_folder=str(TEMPLATE_DIR.parent / 'static'))
|
|
|
|
# 加载模型
|
|
print("正在加载模型...")
|
|
try:
|
|
model_pipeline = joblib.load(MODEL_PATH)
|
|
print(f"模型加载成功: {MODEL_PATH}")
|
|
except Exception as e:
|
|
print(f"模型加载失败: {e}")
|
|
model_pipeline = None
|
|
|
|
|
|
# ============================================
|
|
# 辅助函数
|
|
# ============================================
|
|
def preprocess_input(data):
|
|
"""
|
|
预处理输入数据
|
|
|
|
Args:
|
|
data (dict): 包含11个特征的字典
|
|
|
|
Returns:
|
|
pd.DataFrame: 格式化后的输入数据
|
|
"""
|
|
import pandas as pd
|
|
|
|
# 构建输入DataFrame
|
|
input_data = {
|
|
'age_years': int(data['age_years']),
|
|
'gender': int(data['gender']),
|
|
'height': float(data['height']),
|
|
'weight': float(data['weight']),
|
|
'ap_hi': int(data['ap_hi']),
|
|
'ap_lo': int(data['ap_lo']),
|
|
'cholesterol': int(data['cholesterol']),
|
|
'gluc': int(data['gluc']),
|
|
'smoke': int(data['smoke']),
|
|
'alco': int(data['alco']),
|
|
'active': int(data['active'])
|
|
}
|
|
|
|
# 计算BMI
|
|
input_data['bmi'] = input_data['weight'] / ((input_data['height'] / 100) ** 2)
|
|
|
|
df = pd.DataFrame([input_data])
|
|
|
|
# 按照模型训练时的特征顺序排列
|
|
feature_order = ['age_years', 'height', 'weight', 'bmi',
|
|
'ap_hi', 'ap_lo', 'gender', 'cholesterol',
|
|
'gluc', 'smoke', 'alco', 'active']
|
|
|
|
return df[feature_order]
|
|
|
|
|
|
# ============================================
|
|
# 路由定义
|
|
# ============================================
|
|
@app.route('/')
|
|
def index():
|
|
"""主页 - 渲染预测表单"""
|
|
return render_template('index.html')
|
|
|
|
|
|
@app.route('/predict_cardio', methods=['POST'])
|
|
def predict_cardio():
|
|
"""
|
|
心血管疾病预测API接口
|
|
|
|
接收11个原始特征值的JSON POST请求
|
|
返回预测结果和概率
|
|
"""
|
|
if model_pipeline is None:
|
|
return jsonify({
|
|
'error': '模型未加载,请先运行训练脚本'
|
|
}), 500
|
|
|
|
try:
|
|
# 获取JSON数据
|
|
data = request.get_json()
|
|
|
|
# 验证必填字段
|
|
required_fields = [
|
|
'age_years', 'gender', 'height', 'weight',
|
|
'ap_hi', 'ap_lo', 'cholesterol', 'gluc',
|
|
'smoke', 'alco', 'active'
|
|
]
|
|
|
|
missing_fields = [field for field in required_fields if field not in data]
|
|
if missing_fields:
|
|
return jsonify({
|
|
'error': f'缺少必填字段: {", ".join(missing_fields)}'
|
|
}), 400
|
|
|
|
# 预处理输入数据
|
|
input_df = preprocess_input(data)
|
|
|
|
# 模型预测
|
|
prediction = model_pipeline.predict(input_df)[0]
|
|
probability = model_pipeline.predict_proba(input_df)[0, 1]
|
|
|
|
# 返回结果
|
|
return jsonify({
|
|
'prediction': int(prediction),
|
|
'probability': float(probability),
|
|
'risk_level': get_risk_level(probability),
|
|
'message': '预测成功'
|
|
})
|
|
|
|
except Exception as e:
|
|
return jsonify({
|
|
'error': f'预测失败: {str(e)}'
|
|
}), 500
|
|
|
|
|
|
def get_risk_level(probability):
|
|
"""
|
|
根据概率返回风险等级描述
|
|
|
|
Args:
|
|
probability (float): 疾病概率
|
|
|
|
Returns:
|
|
str: 风险等级描述
|
|
"""
|
|
if probability < 0.3:
|
|
return "低风险"
|
|
elif probability < 0.5:
|
|
return "中等风险"
|
|
elif probability < 0.7:
|
|
return "高风险"
|
|
else:
|
|
return "极高风险"
|
|
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
def health():
|
|
"""健康检查接口"""
|
|
return jsonify({
|
|
'status': 'healthy',
|
|
'model_loaded': model_pipeline is not None
|
|
})
|
|
|
|
|
|
# ============================================
|
|
# 主程序入口
|
|
# ============================================
|
|
if __name__ == '__main__':
|
|
# 从环境变量读取配置,使用默认值
|
|
host = os.getenv('FLASK_HOST', '127.0.0.1')
|
|
port = int(os.getenv('FLASK_PORT', 5000))
|
|
|
|
print("\n" + "=" * 50)
|
|
print("CardioAI - Module 2: 心血管疾病预测服务")
|
|
print("=" * 50)
|
|
print(f"服务地址: http://{host}:{port}")
|
|
print("=" * 50 + "\n")
|
|
|
|
app.run(host=host, port=port, debug=True)
|