Files
111/aicodes/module2_predictor/app.py
2026-01-30 20:40:57 +08:00

187 lines
4.9 KiB
Python

# -*- coding: utf-8 -*-
"""
CardioAI - Module 2: Flask API 部署
心血管疾病预测服务后端
"""
import os
import json
import joblib
import numpy as np
from pathlib import Path
from flask import Flask, render_template, request, jsonify
from dotenv import load_dotenv
# ============================================
# 配置与初始化
# ============================================
# 加载环境变量
load_dotenv()
CODE_ROOT = Path(r"E:\project_ai\claude_project1\aicodes")
MODEL_PATH = CODE_ROOT / "module2_predictor" / "cardio_predictor_model.pkl"
TEMPLATE_DIR = CODE_ROOT / "module2_predictor" / "templates"
# 创建 Flask 应用
app = Flask(__name__,
template_folder=str(TEMPLATE_DIR),
static_folder=str(TEMPLATE_DIR.parent / 'static'))
# 加载模型
print("正在加载模型...")
try:
model_pipeline = joblib.load(MODEL_PATH)
print(f"模型加载成功: {MODEL_PATH}")
except Exception as e:
print(f"模型加载失败: {e}")
model_pipeline = None
# ============================================
# 辅助函数
# ============================================
def preprocess_input(data):
"""
预处理输入数据
Args:
data (dict): 包含11个特征的字典
Returns:
pd.DataFrame: 格式化后的输入数据
"""
import pandas as pd
# 构建输入DataFrame
input_data = {
'age_years': int(data['age_years']),
'gender': int(data['gender']),
'height': float(data['height']),
'weight': float(data['weight']),
'ap_hi': int(data['ap_hi']),
'ap_lo': int(data['ap_lo']),
'cholesterol': int(data['cholesterol']),
'gluc': int(data['gluc']),
'smoke': int(data['smoke']),
'alco': int(data['alco']),
'active': int(data['active'])
}
# 计算BMI
input_data['bmi'] = input_data['weight'] / ((input_data['height'] / 100) ** 2)
df = pd.DataFrame([input_data])
# 按照模型训练时的特征顺序排列
feature_order = ['age_years', 'height', 'weight', 'bmi',
'ap_hi', 'ap_lo', 'gender', 'cholesterol',
'gluc', 'smoke', 'alco', 'active']
return df[feature_order]
# ============================================
# 路由定义
# ============================================
@app.route('/')
def index():
"""主页 - 渲染预测表单"""
return render_template('index.html')
@app.route('/predict_cardio', methods=['POST'])
def predict_cardio():
"""
心血管疾病预测API接口
接收11个原始特征值的JSON POST请求
返回预测结果和概率
"""
if model_pipeline is None:
return jsonify({
'error': '模型未加载,请先运行训练脚本'
}), 500
try:
# 获取JSON数据
data = request.get_json()
# 验证必填字段
required_fields = [
'age_years', 'gender', 'height', 'weight',
'ap_hi', 'ap_lo', 'cholesterol', 'gluc',
'smoke', 'alco', 'active'
]
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return jsonify({
'error': f'缺少必填字段: {", ".join(missing_fields)}'
}), 400
# 预处理输入数据
input_df = preprocess_input(data)
# 模型预测
prediction = model_pipeline.predict(input_df)[0]
probability = model_pipeline.predict_proba(input_df)[0, 1]
# 返回结果
return jsonify({
'prediction': int(prediction),
'probability': float(probability),
'risk_level': get_risk_level(probability),
'message': '预测成功'
})
except Exception as e:
return jsonify({
'error': f'预测失败: {str(e)}'
}), 500
def get_risk_level(probability):
"""
根据概率返回风险等级描述
Args:
probability (float): 疾病概率
Returns:
str: 风险等级描述
"""
if probability < 0.3:
return "低风险"
elif probability < 0.5:
return "中等风险"
elif probability < 0.7:
return "高风险"
else:
return "极高风险"
@app.route('/health', methods=['GET'])
def health():
"""健康检查接口"""
return jsonify({
'status': 'healthy',
'model_loaded': model_pipeline is not None
})
# ============================================
# 主程序入口
# ============================================
if __name__ == '__main__':
# 从环境变量读取配置,使用默认值
host = os.getenv('FLASK_HOST', '127.0.0.1')
port = int(os.getenv('FLASK_PORT', 5000))
print("\n" + "=" * 50)
print("CardioAI - Module 2: 心血管疾病预测服务")
print("=" * 50)
print(f"服务地址: http://{host}:{port}")
print("=" * 50 + "\n")
app.run(host=host, port=port, debug=True)