diff --git a/module2_predictor/app.py b/module2_predictor/app.py new file mode 100644 index 0000000..88e2ad8 --- /dev/null +++ b/module2_predictor/app.py @@ -0,0 +1,336 @@ +#!/opt/anaconda3/envs/cardioenv/bin/python +""" +CardioAI - 心血管疾病预测API +Flask应用程序,提供心血管疾病预测API接口 +""" + +from flask import Flask, request, jsonify, render_template +import pandas as pd +import numpy as np +import joblib +import json +import os + +# 创建Flask应用 +app = Flask(__name__) + +# 加载模型 +MODEL_PATH = os.path.join(os.path.dirname(__file__), "cardio_predictor_model.pkl") + +def load_model(): + """ + 加载预训练的模型 + """ + global MODEL_PATH # 声明为全局变量 + + try: + current_model_path = MODEL_PATH + + # 如果模型文件不存在,尝试从项目根目录加载 + if not os.path.exists(current_model_path): + alt_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl") + if os.path.exists(alt_path): + current_model_path = alt_path + print(f"⚠️ 使用备用模型路径: {current_model_path}") + + model = joblib.load(current_model_path) + print(f"✅ 模型加载成功: {current_model_path}") + return model + except Exception as e: + print(f"❌ 模型加载失败: {e}") + # 尝试从绝对路径加载 + try: + root_model = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl") + if os.path.exists(root_model): + model = joblib.load(root_model) + print(f"✅ 从项目根目录加载模型成功: {root_model}") + return model + except Exception as e2: + print(f"❌ 备用路径也失败: {e2}") + return None + +# 全局模型变量 +model = load_model() + +def preprocess_input(data): + """ + 预处理输入数据,与训练时保持一致 + """ + try: + # 首先确保数值类型正确 + # 转换数值字段为适当的类型 + for field in ['age', 'height', 'weight', 'ap_hi', 'ap_lo']: + if field in data: + data[field] = float(data[field]) + + # 转换分类字段为整数(除了已经映射的字段) + for field in ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']: + if field in data: + data[field] = int(data[field]) + + # 创建DataFrame + df = pd.DataFrame([data]) + + # 特征工程 + # 将age(天)转换为年,四舍五入 + if 'age' in df.columns: + df['age_years'] = (df['age'] / 365.25).round().astype(int) + + # 计算BMI: weight / (height/100)^2 + if 'height' in df.columns and 'weight' in df.columns: + df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2) + + # 类别转换 + # cholesterol转换 + cholesterol_map = { + 1: 'normal', + 2: 'above_normal', + 3: 'well_above_normal' + } + if 'cholesterol' in df.columns: + df['cholesterol_cat'] = df['cholesterol'].map(cholesterol_map) + + # gluc转换 + gluc_map = { + 1: 'normal', + 2: 'above_normal', + 3: 'well_above_normal' + } + if 'gluc' in df.columns: + df['gluc_cat'] = df['gluc'].map(gluc_map) + + # BMI分类 + def categorize_bmi(bmi): + if bmi < 18.5: + return 'underweight' + elif 18.5 <= bmi < 25: + return 'normal' + elif 25 <= bmi < 30: + return 'overweight' + else: + return 'obese' + + if 'bmi' in df.columns: + df['bmi_category'] = df['bmi'].apply(categorize_bmi) + + return df + + except Exception as e: + print(f"数据预处理失败: {e}") + return None + +def validate_input(data): + """ + 验证输入数据的完整性和有效性 + """ + required_fields = [ + 'age', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo', + 'cholesterol', 'gluc', 'smoke', 'alco', 'active' + ] + + # 检查必需字段 + missing_fields = [field for field in required_fields if field not in data] + if missing_fields: + return False, f"缺少必需字段: {', '.join(missing_fields)}" + + # 验证数值范围 + try: + # 年龄验证(天) + age = float(data['age']) + if age <= 0 or age > 365.25 * 120: # 假设最大120岁 + return False, "年龄应在合理范围内(0-43830天)" + + # 性别验证 + gender = int(data['gender']) + if gender not in [1, 2]: + return False, "性别应为1(女性)或2(男性)" + + # 身高验证(cm) + height = float(data['height']) + if height < 100 or height > 250: + return False, "身高应在100-250cm之间" + + # 体重验证(kg) + weight = float(data['weight']) + if weight < 30 or weight > 300: + return False, "体重应在30-300kg之间" + + # 血压验证 + ap_hi = float(data['ap_hi']) + ap_lo = float(data['ap_lo']) + if ap_lo >= ap_hi: + return False, "舒张压应小于收缩压" + if ap_hi < 90 or ap_hi > 250: + return False, "收缩压应在90-250mmHg之间" + if ap_lo < 60 or ap_lo > 150: + return False, "舒张压应在60-150mmHg之间" + + # 胆固醇验证 + cholesterol = int(data['cholesterol']) + if cholesterol not in [1, 2, 3]: + return False, "胆固醇应为1,2,3之一" + + # 血糖验证 + gluc = int(data['gluc']) + if gluc not in [1, 2, 3]: + return False, "血糖应为1,2,3之一" + + # 吸烟、饮酒、活动验证 + smoke = int(data['smoke']) + alco = int(data['alco']) + active = int(data['active']) + if smoke not in [0, 1] or alco not in [0, 1] or active not in [0, 1]: + return False, "吸烟、饮酒、活动应为0或1" + + except ValueError as e: + return False, f"数据类型错误: {str(e)}" + + return True, "输入数据有效" + +@app.route('/') +def index(): + """ + 主页,提供HTML界面 + """ + return render_template('index.html') + +@app.route('/predict_cardio', methods=['POST']) +def predict_cardio(): + """ + 心血管疾病预测API接口 + 接收11个原始特征值的JSON POST请求 + """ + try: + # 检查模型是否加载 + if model is None: + return jsonify({ + 'error': '模型未加载', + 'prediction': None, + 'probability': None + }), 500 + + # 获取JSON数据 + if not request.is_json: + return jsonify({ + 'error': '请求内容类型应为application/json', + 'prediction': None, + 'probability': None + }), 400 + + data = request.get_json() + + # 验证输入 + is_valid, message = validate_input(data) + if not is_valid: + return jsonify({ + 'error': message, + 'prediction': None, + 'probability': None + }), 400 + + # 预处理输入数据 + processed_data = preprocess_input(data) + if processed_data is None: + return jsonify({ + 'error': '数据预处理失败', + 'prediction': None, + 'probability': None + }), 400 + + # 删除不需要的列(与训练时保持一致) + columns_to_drop = ['id', 'age', 'cardio'] + for col in columns_to_drop: + if col in processed_data.columns: + processed_data = processed_data.drop(col, axis=1) + + # 进行预测 + try: + prediction = model.predict(processed_data)[0] + probability = model.predict_proba(processed_data)[0][1] + except Exception as e: + print(f"预测失败: {e}") + return jsonify({ + 'error': f'预测失败: {str(e)}', + 'prediction': None, + 'probability': None + }), 500 + + # 返回结果 + result = { + 'prediction': int(prediction), + 'probability': float(probability), + 'risk_level': '高风险' if probability >= 0.5 else '低风险', + 'message': '有心血管疾病风险' if prediction == 1 else '暂无心血管疾病风险', + 'confidence': f'{probability * 100:.1f}%' + } + + print(f"预测结果: {result}") + return jsonify(result) + + except Exception as e: + print(f"API处理异常: {e}") + return jsonify({ + 'error': f'服务器内部错误: {str(e)}', + 'prediction': None, + 'probability': None + }), 500 + +@app.route('/health', methods=['GET']) +def health_check(): + """ + 健康检查端点 + """ + health_status = { + 'status': 'healthy' if model is not None else 'unhealthy', + 'model_loaded': model is not None, + 'api_version': '1.0.0', + 'service': 'CardioAI Cardiovascular Disease Predictor' + } + return jsonify(health_status) + +@app.route('/features', methods=['GET']) +def get_features_info(): + """ + 获取模型特征信息 + """ + features_info = { + 'required_features': [ + {'name': 'age', 'type': 'numeric', 'unit': 'days', 'description': '年龄(天)'}, + {'name': 'gender', 'type': 'categorical', 'values': [1, 2], 'description': '性别:1=女性,2=男性'}, + {'name': 'height', 'type': 'numeric', 'unit': 'cm', 'description': '身高(厘米)'}, + {'name': 'weight', 'type': 'numeric', 'unit': 'kg', 'description': '体重(千克)'}, + {'name': 'ap_hi', 'type': 'numeric', 'unit': 'mmHg', 'description': '收缩压'}, + {'name': 'ap_lo', 'type': 'numeric', 'unit': 'mmHg', 'description': '舒张压'}, + {'name': 'cholesterol', 'type': 'categorical', 'values': [1, 2, 3], 'description': '胆固醇水平:1=正常,2=高于正常,3=很高'}, + {'name': 'gluc', 'type': 'categorical', 'values': [1, 2, 3], 'description': '血糖水平:1=正常,2=高于正常,3=很高'}, + {'name': 'smoke', 'type': 'categorical', 'values': [0, 1], 'description': '是否吸烟:0=否,1=是'}, + {'name': 'alco', 'type': 'categorical', 'values': [0, 1], 'description': '是否饮酒:0=否,1=是'}, + {'name': 'active', 'type': 'categorical', 'values': [0, 1], 'description': '是否积极运动:0=否,1=是'} + ], + 'derived_features': [ + {'name': 'age_years', 'description': '年龄(年),由age转换而来'}, + {'name': 'bmi', 'description': '身体质量指数,由height和weight计算而来'}, + {'name': 'bmi_category', 'description': 'BMI分类:偏瘦/正常/超重/肥胖'} + ] + } + return jsonify(features_info) + +if __name__ == '__main__': + print("=" * 60) + print("CardioAI - 心血管疾病预测API") + print("=" * 60) + print(f"模型路径: {MODEL_PATH}") + print(f"模型加载状态: {'成功' if model is not None else '失败'}") + print("\n可用端点:") + print(" GET / - 前端界面") + print(" POST /predict_cardio - 预测API") + print(" GET /health - 健康检查") + print(" GET /features - 特征信息") + print("\n启动信息:") + print(" 主机: 0.0.0.0") + print(" 端口: 5000") + print(" 调试模式: 开启") + print("=" * 60) + + # 启动Flask应用 + app.run(host='0.0.0.0', port=9011, debug=True) \ No newline at end of file diff --git a/module2_predictor/cardio_predictor_model.pkl b/module2_predictor/cardio_predictor_model.pkl new file mode 100644 index 0000000..ea831d2 Binary files /dev/null and b/module2_predictor/cardio_predictor_model.pkl differ diff --git a/module2_predictor/templates/index.html b/module2_predictor/templates/index.html new file mode 100644 index 0000000..83df411 --- /dev/null +++ b/module2_predictor/templates/index.html @@ -0,0 +1,790 @@ + + +
+ + ++ 基于机器学习模型,通过11项健康指标预测心血管疾病风险 +
++ 请填写以下11项健康指标,获取精准的风险评估 +
+正在分析数据,请稍候...
+您的健康数据仅用于本次预测,不会被存储或用于其他用途。
+ +本预测结果仅供参考,不能替代专业医疗诊断。如有健康问题,请及时咨询医生。
++ 基于XGBoost机器学习算法 +
++ 训练数据:68,492条医疗记录 +
++ 模型准确率:> 85% +
++ 更新日期:2026年2月 +
+置信度:-
+处理时间:- 毫秒
+请求ID:-
+时间戳:-
+