Files

336 lines
11 KiB
Python
Raw Permalink Normal View History

#!/opt/anaconda3/envs/cardioenv/bin/python
"""
CardioAI - 心血管疾病预测API
Flask应用程序提供心血管疾病预测API接口
"""
from flask import Flask, request, jsonify, render_template
import pandas as pd
import numpy as np
import joblib
import json
import os
# 创建Flask应用
app = Flask(__name__)
# 加载模型
MODEL_PATH = os.path.join(os.path.dirname(__file__), "cardio_predictor_model.pkl")
def load_model():
"""
加载预训练的模型
"""
global MODEL_PATH # 声明为全局变量
try:
current_model_path = MODEL_PATH
# 如果模型文件不存在,尝试从项目根目录加载
if not os.path.exists(current_model_path):
alt_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
if os.path.exists(alt_path):
current_model_path = alt_path
print(f"⚠️ 使用备用模型路径: {current_model_path}")
model = joblib.load(current_model_path)
print(f"✅ 模型加载成功: {current_model_path}")
return model
except Exception as e:
print(f"❌ 模型加载失败: {e}")
# 尝试从绝对路径加载
try:
root_model = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
if os.path.exists(root_model):
model = joblib.load(root_model)
print(f"✅ 从项目根目录加载模型成功: {root_model}")
return model
except Exception as e2:
print(f"❌ 备用路径也失败: {e2}")
return None
# 全局模型变量
model = load_model()
def preprocess_input(data):
"""
预处理输入数据与训练时保持一致
"""
try:
# 首先确保数值类型正确
# 转换数值字段为适当的类型
for field in ['age', 'height', 'weight', 'ap_hi', 'ap_lo']:
if field in data:
data[field] = float(data[field])
# 转换分类字段为整数(除了已经映射的字段)
for field in ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']:
if field in data:
data[field] = int(data[field])
# 创建DataFrame
df = pd.DataFrame([data])
# 特征工程
# 将age(天)转换为年,四舍五入
if 'age' in df.columns:
df['age_years'] = (df['age'] / 365.25).round().astype(int)
# 计算BMI: weight / (height/100)^2
if 'height' in df.columns and 'weight' in df.columns:
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
# 类别转换
# cholesterol转换
cholesterol_map = {
1: 'normal',
2: 'above_normal',
3: 'well_above_normal'
}
if 'cholesterol' in df.columns:
df['cholesterol_cat'] = df['cholesterol'].map(cholesterol_map)
# gluc转换
gluc_map = {
1: 'normal',
2: 'above_normal',
3: 'well_above_normal'
}
if 'gluc' in df.columns:
df['gluc_cat'] = df['gluc'].map(gluc_map)
# BMI分类
def categorize_bmi(bmi):
if bmi < 18.5:
return 'underweight'
elif 18.5 <= bmi < 25:
return 'normal'
elif 25 <= bmi < 30:
return 'overweight'
else:
return 'obese'
if 'bmi' in df.columns:
df['bmi_category'] = df['bmi'].apply(categorize_bmi)
return df
except Exception as e:
print(f"数据预处理失败: {e}")
return None
def validate_input(data):
"""
验证输入数据的完整性和有效性
"""
required_fields = [
'age', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
'cholesterol', 'gluc', 'smoke', 'alco', 'active'
]
# 检查必需字段
missing_fields = [field for field in required_fields if field not in data]
if missing_fields:
return False, f"缺少必需字段: {', '.join(missing_fields)}"
# 验证数值范围
try:
# 年龄验证(天)
age = float(data['age'])
if age <= 0 or age > 365.25 * 120: # 假设最大120岁
return False, "年龄应在合理范围内(0-43830天)"
# 性别验证
gender = int(data['gender'])
if gender not in [1, 2]:
return False, "性别应为1(女性)或2(男性)"
# 身高验证(cm)
height = float(data['height'])
if height < 100 or height > 250:
return False, "身高应在100-250cm之间"
# 体重验证(kg)
weight = float(data['weight'])
if weight < 30 or weight > 300:
return False, "体重应在30-300kg之间"
# 血压验证
ap_hi = float(data['ap_hi'])
ap_lo = float(data['ap_lo'])
if ap_lo >= ap_hi:
return False, "舒张压应小于收缩压"
if ap_hi < 90 or ap_hi > 250:
return False, "收缩压应在90-250mmHg之间"
if ap_lo < 60 or ap_lo > 150:
return False, "舒张压应在60-150mmHg之间"
# 胆固醇验证
cholesterol = int(data['cholesterol'])
if cholesterol not in [1, 2, 3]:
return False, "胆固醇应为1,2,3之一"
# 血糖验证
gluc = int(data['gluc'])
if gluc not in [1, 2, 3]:
return False, "血糖应为1,2,3之一"
# 吸烟、饮酒、活动验证
smoke = int(data['smoke'])
alco = int(data['alco'])
active = int(data['active'])
if smoke not in [0, 1] or alco not in [0, 1] or active not in [0, 1]:
return False, "吸烟、饮酒、活动应为0或1"
except ValueError as e:
return False, f"数据类型错误: {str(e)}"
return True, "输入数据有效"
@app.route('/')
def index():
"""
主页提供HTML界面
"""
return render_template('index.html')
@app.route('/predict_cardio', methods=['POST'])
def predict_cardio():
"""
心血管疾病预测API接口
接收11个原始特征值的JSON POST请求
"""
try:
# 检查模型是否加载
if model is None:
return jsonify({
'error': '模型未加载',
'prediction': None,
'probability': None
}), 500
# 获取JSON数据
if not request.is_json:
return jsonify({
'error': '请求内容类型应为application/json',
'prediction': None,
'probability': None
}), 400
data = request.get_json()
# 验证输入
is_valid, message = validate_input(data)
if not is_valid:
return jsonify({
'error': message,
'prediction': None,
'probability': None
}), 400
# 预处理输入数据
processed_data = preprocess_input(data)
if processed_data is None:
return jsonify({
'error': '数据预处理失败',
'prediction': None,
'probability': None
}), 400
# 删除不需要的列(与训练时保持一致)
columns_to_drop = ['id', 'age', 'cardio']
for col in columns_to_drop:
if col in processed_data.columns:
processed_data = processed_data.drop(col, axis=1)
# 进行预测
try:
prediction = model.predict(processed_data)[0]
probability = model.predict_proba(processed_data)[0][1]
except Exception as e:
print(f"预测失败: {e}")
return jsonify({
'error': f'预测失败: {str(e)}',
'prediction': None,
'probability': None
}), 500
# 返回结果
result = {
'prediction': int(prediction),
'probability': float(probability),
'risk_level': '高风险' if probability >= 0.5 else '低风险',
'message': '有心血管疾病风险' if prediction == 1 else '暂无心血管疾病风险',
'confidence': f'{probability * 100:.1f}%'
}
print(f"预测结果: {result}")
return jsonify(result)
except Exception as e:
print(f"API处理异常: {e}")
return jsonify({
'error': f'服务器内部错误: {str(e)}',
'prediction': None,
'probability': None
}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""
健康检查端点
"""
health_status = {
'status': 'healthy' if model is not None else 'unhealthy',
'model_loaded': model is not None,
'api_version': '1.0.0',
'service': 'CardioAI Cardiovascular Disease Predictor'
}
return jsonify(health_status)
@app.route('/features', methods=['GET'])
def get_features_info():
"""
获取模型特征信息
"""
features_info = {
'required_features': [
{'name': 'age', 'type': 'numeric', 'unit': 'days', 'description': '年龄(天)'},
{'name': 'gender', 'type': 'categorical', 'values': [1, 2], 'description': '性别:1=女性,2=男性'},
{'name': 'height', 'type': 'numeric', 'unit': 'cm', 'description': '身高(厘米)'},
{'name': 'weight', 'type': 'numeric', 'unit': 'kg', 'description': '体重(千克)'},
{'name': 'ap_hi', 'type': 'numeric', 'unit': 'mmHg', 'description': '收缩压'},
{'name': 'ap_lo', 'type': 'numeric', 'unit': 'mmHg', 'description': '舒张压'},
{'name': 'cholesterol', 'type': 'categorical', 'values': [1, 2, 3], 'description': '胆固醇水平:1=正常,2=高于正常,3=很高'},
{'name': 'gluc', 'type': 'categorical', 'values': [1, 2, 3], 'description': '血糖水平:1=正常,2=高于正常,3=很高'},
{'name': 'smoke', 'type': 'categorical', 'values': [0, 1], 'description': '是否吸烟:0=否,1=是'},
{'name': 'alco', 'type': 'categorical', 'values': [0, 1], 'description': '是否饮酒:0=否,1=是'},
{'name': 'active', 'type': 'categorical', 'values': [0, 1], 'description': '是否积极运动:0=否,1=是'}
],
'derived_features': [
{'name': 'age_years', 'description': '年龄(年),由age转换而来'},
{'name': 'bmi', 'description': '身体质量指数,由height和weight计算而来'},
{'name': 'bmi_category', 'description': 'BMI分类:偏瘦/正常/超重/肥胖'}
]
}
return jsonify(features_info)
if __name__ == '__main__':
print("=" * 60)
print("CardioAI - 心血管疾病预测API")
print("=" * 60)
print(f"模型路径: {MODEL_PATH}")
print(f"模型加载状态: {'成功' if model is not None else '失败'}")
print("\n可用端点:")
print(" GET / - 前端界面")
print(" POST /predict_cardio - 预测API")
print(" GET /health - 健康检查")
print(" GET /features - 特征信息")
print("\n启动信息:")
print(" 主机: 0.0.0.0")
print(" 端口: 5000")
print(" 调试模式: 开启")
print("=" * 60)
# 启动Flask应用
app.run(host='0.0.0.0', port=9011, debug=True)