336 lines
11 KiB
Python
336 lines
11 KiB
Python
|
|
#!/opt/anaconda3/envs/cardioenv/bin/python
|
|||
|
|
"""
|
|||
|
|
CardioAI - 心血管疾病预测API
|
|||
|
|
Flask应用程序,提供心血管疾病预测API接口
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from flask import Flask, request, jsonify, render_template
|
|||
|
|
import pandas as pd
|
|||
|
|
import numpy as np
|
|||
|
|
import joblib
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# 创建Flask应用
|
|||
|
|
app = Flask(__name__)
|
|||
|
|
|
|||
|
|
# 加载模型
|
|||
|
|
MODEL_PATH = os.path.join(os.path.dirname(__file__), "cardio_predictor_model.pkl")
|
|||
|
|
|
|||
|
|
def load_model():
|
|||
|
|
"""
|
|||
|
|
加载预训练的模型
|
|||
|
|
"""
|
|||
|
|
global MODEL_PATH # 声明为全局变量
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
current_model_path = MODEL_PATH
|
|||
|
|
|
|||
|
|
# 如果模型文件不存在,尝试从项目根目录加载
|
|||
|
|
if not os.path.exists(current_model_path):
|
|||
|
|
alt_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
|
|||
|
|
if os.path.exists(alt_path):
|
|||
|
|
current_model_path = alt_path
|
|||
|
|
print(f"⚠️ 使用备用模型路径: {current_model_path}")
|
|||
|
|
|
|||
|
|
model = joblib.load(current_model_path)
|
|||
|
|
print(f"✅ 模型加载成功: {current_model_path}")
|
|||
|
|
return model
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 模型加载失败: {e}")
|
|||
|
|
# 尝试从绝对路径加载
|
|||
|
|
try:
|
|||
|
|
root_model = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
|
|||
|
|
if os.path.exists(root_model):
|
|||
|
|
model = joblib.load(root_model)
|
|||
|
|
print(f"✅ 从项目根目录加载模型成功: {root_model}")
|
|||
|
|
return model
|
|||
|
|
except Exception as e2:
|
|||
|
|
print(f"❌ 备用路径也失败: {e2}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 全局模型变量
|
|||
|
|
model = load_model()
|
|||
|
|
|
|||
|
|
def preprocess_input(data):
|
|||
|
|
"""
|
|||
|
|
预处理输入数据,与训练时保持一致
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 首先确保数值类型正确
|
|||
|
|
# 转换数值字段为适当的类型
|
|||
|
|
for field in ['age', 'height', 'weight', 'ap_hi', 'ap_lo']:
|
|||
|
|
if field in data:
|
|||
|
|
data[field] = float(data[field])
|
|||
|
|
|
|||
|
|
# 转换分类字段为整数(除了已经映射的字段)
|
|||
|
|
for field in ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']:
|
|||
|
|
if field in data:
|
|||
|
|
data[field] = int(data[field])
|
|||
|
|
|
|||
|
|
# 创建DataFrame
|
|||
|
|
df = pd.DataFrame([data])
|
|||
|
|
|
|||
|
|
# 特征工程
|
|||
|
|
# 将age(天)转换为年,四舍五入
|
|||
|
|
if 'age' in df.columns:
|
|||
|
|
df['age_years'] = (df['age'] / 365.25).round().astype(int)
|
|||
|
|
|
|||
|
|
# 计算BMI: weight / (height/100)^2
|
|||
|
|
if 'height' in df.columns and 'weight' in df.columns:
|
|||
|
|
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
|
|||
|
|
|
|||
|
|
# 类别转换
|
|||
|
|
# cholesterol转换
|
|||
|
|
cholesterol_map = {
|
|||
|
|
1: 'normal',
|
|||
|
|
2: 'above_normal',
|
|||
|
|
3: 'well_above_normal'
|
|||
|
|
}
|
|||
|
|
if 'cholesterol' in df.columns:
|
|||
|
|
df['cholesterol_cat'] = df['cholesterol'].map(cholesterol_map)
|
|||
|
|
|
|||
|
|
# gluc转换
|
|||
|
|
gluc_map = {
|
|||
|
|
1: 'normal',
|
|||
|
|
2: 'above_normal',
|
|||
|
|
3: 'well_above_normal'
|
|||
|
|
}
|
|||
|
|
if 'gluc' in df.columns:
|
|||
|
|
df['gluc_cat'] = df['gluc'].map(gluc_map)
|
|||
|
|
|
|||
|
|
# BMI分类
|
|||
|
|
def categorize_bmi(bmi):
|
|||
|
|
if bmi < 18.5:
|
|||
|
|
return 'underweight'
|
|||
|
|
elif 18.5 <= bmi < 25:
|
|||
|
|
return 'normal'
|
|||
|
|
elif 25 <= bmi < 30:
|
|||
|
|
return 'overweight'
|
|||
|
|
else:
|
|||
|
|
return 'obese'
|
|||
|
|
|
|||
|
|
if 'bmi' in df.columns:
|
|||
|
|
df['bmi_category'] = df['bmi'].apply(categorize_bmi)
|
|||
|
|
|
|||
|
|
return df
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"数据预处理失败: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def validate_input(data):
|
|||
|
|
"""
|
|||
|
|
验证输入数据的完整性和有效性
|
|||
|
|
"""
|
|||
|
|
required_fields = [
|
|||
|
|
'age', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
|||
|
|
'cholesterol', 'gluc', 'smoke', 'alco', 'active'
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 检查必需字段
|
|||
|
|
missing_fields = [field for field in required_fields if field not in data]
|
|||
|
|
if missing_fields:
|
|||
|
|
return False, f"缺少必需字段: {', '.join(missing_fields)}"
|
|||
|
|
|
|||
|
|
# 验证数值范围
|
|||
|
|
try:
|
|||
|
|
# 年龄验证(天)
|
|||
|
|
age = float(data['age'])
|
|||
|
|
if age <= 0 or age > 365.25 * 120: # 假设最大120岁
|
|||
|
|
return False, "年龄应在合理范围内(0-43830天)"
|
|||
|
|
|
|||
|
|
# 性别验证
|
|||
|
|
gender = int(data['gender'])
|
|||
|
|
if gender not in [1, 2]:
|
|||
|
|
return False, "性别应为1(女性)或2(男性)"
|
|||
|
|
|
|||
|
|
# 身高验证(cm)
|
|||
|
|
height = float(data['height'])
|
|||
|
|
if height < 100 or height > 250:
|
|||
|
|
return False, "身高应在100-250cm之间"
|
|||
|
|
|
|||
|
|
# 体重验证(kg)
|
|||
|
|
weight = float(data['weight'])
|
|||
|
|
if weight < 30 or weight > 300:
|
|||
|
|
return False, "体重应在30-300kg之间"
|
|||
|
|
|
|||
|
|
# 血压验证
|
|||
|
|
ap_hi = float(data['ap_hi'])
|
|||
|
|
ap_lo = float(data['ap_lo'])
|
|||
|
|
if ap_lo >= ap_hi:
|
|||
|
|
return False, "舒张压应小于收缩压"
|
|||
|
|
if ap_hi < 90 or ap_hi > 250:
|
|||
|
|
return False, "收缩压应在90-250mmHg之间"
|
|||
|
|
if ap_lo < 60 or ap_lo > 150:
|
|||
|
|
return False, "舒张压应在60-150mmHg之间"
|
|||
|
|
|
|||
|
|
# 胆固醇验证
|
|||
|
|
cholesterol = int(data['cholesterol'])
|
|||
|
|
if cholesterol not in [1, 2, 3]:
|
|||
|
|
return False, "胆固醇应为1,2,3之一"
|
|||
|
|
|
|||
|
|
# 血糖验证
|
|||
|
|
gluc = int(data['gluc'])
|
|||
|
|
if gluc not in [1, 2, 3]:
|
|||
|
|
return False, "血糖应为1,2,3之一"
|
|||
|
|
|
|||
|
|
# 吸烟、饮酒、活动验证
|
|||
|
|
smoke = int(data['smoke'])
|
|||
|
|
alco = int(data['alco'])
|
|||
|
|
active = int(data['active'])
|
|||
|
|
if smoke not in [0, 1] or alco not in [0, 1] or active not in [0, 1]:
|
|||
|
|
return False, "吸烟、饮酒、活动应为0或1"
|
|||
|
|
|
|||
|
|
except ValueError as e:
|
|||
|
|
return False, f"数据类型错误: {str(e)}"
|
|||
|
|
|
|||
|
|
return True, "输入数据有效"
|
|||
|
|
|
|||
|
|
@app.route('/')
|
|||
|
|
def index():
|
|||
|
|
"""
|
|||
|
|
主页,提供HTML界面
|
|||
|
|
"""
|
|||
|
|
return render_template('index.html')
|
|||
|
|
|
|||
|
|
@app.route('/predict_cardio', methods=['POST'])
|
|||
|
|
def predict_cardio():
|
|||
|
|
"""
|
|||
|
|
心血管疾病预测API接口
|
|||
|
|
接收11个原始特征值的JSON POST请求
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 检查模型是否加载
|
|||
|
|
if model is None:
|
|||
|
|
return jsonify({
|
|||
|
|
'error': '模型未加载',
|
|||
|
|
'prediction': None,
|
|||
|
|
'probability': None
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
# 获取JSON数据
|
|||
|
|
if not request.is_json:
|
|||
|
|
return jsonify({
|
|||
|
|
'error': '请求内容类型应为application/json',
|
|||
|
|
'prediction': None,
|
|||
|
|
'probability': None
|
|||
|
|
}), 400
|
|||
|
|
|
|||
|
|
data = request.get_json()
|
|||
|
|
|
|||
|
|
# 验证输入
|
|||
|
|
is_valid, message = validate_input(data)
|
|||
|
|
if not is_valid:
|
|||
|
|
return jsonify({
|
|||
|
|
'error': message,
|
|||
|
|
'prediction': None,
|
|||
|
|
'probability': None
|
|||
|
|
}), 400
|
|||
|
|
|
|||
|
|
# 预处理输入数据
|
|||
|
|
processed_data = preprocess_input(data)
|
|||
|
|
if processed_data is None:
|
|||
|
|
return jsonify({
|
|||
|
|
'error': '数据预处理失败',
|
|||
|
|
'prediction': None,
|
|||
|
|
'probability': None
|
|||
|
|
}), 400
|
|||
|
|
|
|||
|
|
# 删除不需要的列(与训练时保持一致)
|
|||
|
|
columns_to_drop = ['id', 'age', 'cardio']
|
|||
|
|
for col in columns_to_drop:
|
|||
|
|
if col in processed_data.columns:
|
|||
|
|
processed_data = processed_data.drop(col, axis=1)
|
|||
|
|
|
|||
|
|
# 进行预测
|
|||
|
|
try:
|
|||
|
|
prediction = model.predict(processed_data)[0]
|
|||
|
|
probability = model.predict_proba(processed_data)[0][1]
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"预测失败: {e}")
|
|||
|
|
return jsonify({
|
|||
|
|
'error': f'预测失败: {str(e)}',
|
|||
|
|
'prediction': None,
|
|||
|
|
'probability': None
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
# 返回结果
|
|||
|
|
result = {
|
|||
|
|
'prediction': int(prediction),
|
|||
|
|
'probability': float(probability),
|
|||
|
|
'risk_level': '高风险' if probability >= 0.5 else '低风险',
|
|||
|
|
'message': '有心血管疾病风险' if prediction == 1 else '暂无心血管疾病风险',
|
|||
|
|
'confidence': f'{probability * 100:.1f}%'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print(f"预测结果: {result}")
|
|||
|
|
return jsonify(result)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"API处理异常: {e}")
|
|||
|
|
return jsonify({
|
|||
|
|
'error': f'服务器内部错误: {str(e)}',
|
|||
|
|
'prediction': None,
|
|||
|
|
'probability': None
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
@app.route('/health', methods=['GET'])
|
|||
|
|
def health_check():
|
|||
|
|
"""
|
|||
|
|
健康检查端点
|
|||
|
|
"""
|
|||
|
|
health_status = {
|
|||
|
|
'status': 'healthy' if model is not None else 'unhealthy',
|
|||
|
|
'model_loaded': model is not None,
|
|||
|
|
'api_version': '1.0.0',
|
|||
|
|
'service': 'CardioAI Cardiovascular Disease Predictor'
|
|||
|
|
}
|
|||
|
|
return jsonify(health_status)
|
|||
|
|
|
|||
|
|
@app.route('/features', methods=['GET'])
|
|||
|
|
def get_features_info():
|
|||
|
|
"""
|
|||
|
|
获取模型特征信息
|
|||
|
|
"""
|
|||
|
|
features_info = {
|
|||
|
|
'required_features': [
|
|||
|
|
{'name': 'age', 'type': 'numeric', 'unit': 'days', 'description': '年龄(天)'},
|
|||
|
|
{'name': 'gender', 'type': 'categorical', 'values': [1, 2], 'description': '性别:1=女性,2=男性'},
|
|||
|
|
{'name': 'height', 'type': 'numeric', 'unit': 'cm', 'description': '身高(厘米)'},
|
|||
|
|
{'name': 'weight', 'type': 'numeric', 'unit': 'kg', 'description': '体重(千克)'},
|
|||
|
|
{'name': 'ap_hi', 'type': 'numeric', 'unit': 'mmHg', 'description': '收缩压'},
|
|||
|
|
{'name': 'ap_lo', 'type': 'numeric', 'unit': 'mmHg', 'description': '舒张压'},
|
|||
|
|
{'name': 'cholesterol', 'type': 'categorical', 'values': [1, 2, 3], 'description': '胆固醇水平:1=正常,2=高于正常,3=很高'},
|
|||
|
|
{'name': 'gluc', 'type': 'categorical', 'values': [1, 2, 3], 'description': '血糖水平:1=正常,2=高于正常,3=很高'},
|
|||
|
|
{'name': 'smoke', 'type': 'categorical', 'values': [0, 1], 'description': '是否吸烟:0=否,1=是'},
|
|||
|
|
{'name': 'alco', 'type': 'categorical', 'values': [0, 1], 'description': '是否饮酒:0=否,1=是'},
|
|||
|
|
{'name': 'active', 'type': 'categorical', 'values': [0, 1], 'description': '是否积极运动:0=否,1=是'}
|
|||
|
|
],
|
|||
|
|
'derived_features': [
|
|||
|
|
{'name': 'age_years', 'description': '年龄(年),由age转换而来'},
|
|||
|
|
{'name': 'bmi', 'description': '身体质量指数,由height和weight计算而来'},
|
|||
|
|
{'name': 'bmi_category', 'description': 'BMI分类:偏瘦/正常/超重/肥胖'}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
return jsonify(features_info)
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("CardioAI - 心血管疾病预测API")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print(f"模型路径: {MODEL_PATH}")
|
|||
|
|
print(f"模型加载状态: {'成功' if model is not None else '失败'}")
|
|||
|
|
print("\n可用端点:")
|
|||
|
|
print(" GET / - 前端界面")
|
|||
|
|
print(" POST /predict_cardio - 预测API")
|
|||
|
|
print(" GET /health - 健康检查")
|
|||
|
|
print(" GET /features - 特征信息")
|
|||
|
|
print("\n启动信息:")
|
|||
|
|
print(" 主机: 0.0.0.0")
|
|||
|
|
print(" 端口: 5000")
|
|||
|
|
print(" 调试模式: 开启")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
# 启动Flask应用
|
|||
|
|
app.run(host='0.0.0.0', port=9011, debug=True)
|