#!/usr/bin/env python # -*- coding: utf-8 -*- """ CardioAI - 心血管疾病预测API服务 功能: 1. 加载预训练的机器学习模型 2. 提供RESTful API接口 3. 接收原始特征值并返回预测结果 4. 提供Web前端界面 启动方式: conda activate cardioenv python app.py 或 flask run """ from flask import Flask, request, jsonify, render_template, send_from_directory import pandas as pd import numpy as np import joblib import logging from pathlib import Path import sys import os import traceback # 添加项目根目录到Python路径 project_root = Path(__file__).parent.parent sys.path.append(str(project_root)) # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 创建Flask应用 app = Flask(__name__) app.config['JSON_AS_ASCII'] = False # 确保JSON支持中文 # 全局变量存储模型和特征信息 model_data = None feature_names = None pipeline = None def load_model(): """加载预训练的模型""" global model_data, feature_names, pipeline try: # 模型文件路径 model_dir = Path(__file__).parent / "models" model_path = model_dir / "cardio_predictor_model.pkl" if not model_path.exists(): logger.error(f"模型文件不存在: {model_path}") raise FileNotFoundError(f"模型文件不存在: {model_path}") # 加载模型 logger.info(f"正在加载模型: {model_path}") model_data = joblib.load(model_path) # 提取Pipeline和特征信息 pipeline = model_data['pipeline'] feature_names = model_data.get('feature_names', []) logger.info(f"模型加载成功!版本: {model_data.get('model_version', '未知')}") logger.info(f"特征数量: {len(feature_names)}") logger.info(f"特征列表: {feature_names}") return True except Exception as e: logger.error(f"模型加载失败: {str(e)}") logger.error(traceback.format_exc()) return False def preprocess_input(input_data): """ 预处理输入数据(与训练时相同的处理) 参数: input_data: 包含原始特征的字典 返回: pd.DataFrame: 预处理后的特征数据框 """ try: # 创建数据框 df = pd.DataFrame([input_data]) # 1. 年龄转换:从天转换为年(四舍五入) if 'age' in df.columns: df['age_years'] = (df['age'] / 365.25).round().astype(int) elif 'age_years' in df.columns: # 如果已经提供了转换后的年龄,直接使用 df['age_years'] = df['age_years'].astype(int) else: raise ValueError("输入数据中必须包含'age'或'age_years'字段") # 2. 计算BMI: BMI = weight(kg) / (height(m)^2) if 'height' in df.columns and 'weight' in df.columns: df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2) df['bmi'] = df['bmi'].round(2) elif 'bmi' in df.columns: # 如果已经提供了BMI,直接使用 df['bmi'] = df['bmi'].astype(float) else: raise ValueError("输入数据中必须包含'height'和'weight'字段或'bmi'字段") # 3. 确保所有必要特征都存在 required_features = ['age_years', 'bmi', 'ap_hi', 'ap_lo', 'gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active'] missing_features = [f for f in required_features if f not in df.columns] if missing_features: raise ValueError(f"缺少必要特征: {missing_features}") # 4. 选择模型需要的特征(按训练时的顺序) processed_df = df[required_features].copy() logger.debug(f"预处理后的特征数据框:\n{processed_df}") return processed_df except Exception as e: logger.error(f"数据预处理失败: {str(e)}") raise def validate_input(input_data): """ 验证输入数据的有效性 参数: input_data: 输入特征字典 返回: tuple: (是否有效, 错误消息) """ try: # 检查必需字段 required_fields = ['age', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo', 'cholesterol', 'gluc', 'smoke', 'alco', 'active'] missing_fields = [f for f in required_fields if f not in input_data] if missing_fields: return False, f"缺少必需字段: {missing_fields}" # 检查数据类型 for field in required_fields: value = input_data[field] if not isinstance(value, (int, float)): try: # 尝试转换为数值 input_data[field] = float(value) except ValueError: return False, f"字段'{field}'必须为数值类型,当前值: {value}" # 检查数值范围 validations = [ ('age', 0, 365*150), # 年龄(天):0-150岁 ('gender', 1, 2), # 性别:1或2 ('height', 100, 250), # 身高(cm):100-250 ('weight', 20, 300), # 体重(kg):20-300 ('ap_hi', 50, 300), # 收缩压:50-300 ('ap_lo', 30, 200), # 舒张压:30-200 ('cholesterol', 1, 3), # 胆固醇:1-3 ('gluc', 1, 3), # 血糖:1-3 ('smoke', 0, 1), # 吸烟:0或1 ('alco', 0, 1), # 饮酒:0或1 ('active', 0, 1) # 活动:0或1 ] for field, min_val, max_val in validations: value = input_data[field] if not (min_val <= value <= max_val): return False, f"字段'{field}'的值{value}超出有效范围[{min_val}, {max_val}]" # 检查血压合理性 if input_data['ap_lo'] >= input_data['ap_hi']: return False, "舒张压不能高于或等于收缩压" return True, "输入数据有效" except Exception as e: return False, f"输入数据验证失败: {str(e)}" @app.route('/') def index(): """主页 - 返回前端界面""" return render_template('index.html') @app.route('/predict_cardio', methods=['POST']) def predict_cardio(): """ 心血管疾病预测API接口 请求格式(JSON): { "age": 20228, # 年龄(天) "gender": 1, # 性别(1=女性,2=男性) "height": 156, # 身高(cm) "weight": 85, # 体重(kg) "ap_hi": 140, # 收缩压(mmHg) "ap_lo": 90, # 舒张压(mmHg) "cholesterol": 1, # 胆固醇水平(1=正常,2=高于正常,3=极高) "gluc": 1, # 血糖水平(1=正常,2=高于正常,3=极高) "smoke": 0, # 吸烟(0=否,1=是) "alco": 0, # 饮酒(0=否,1=是) "active": 1 # 体育活动(0=否,1=是) } 响应格式(JSON): { "success": true, "prediction": 1, "probability": 0.85, "risk_level": "高危", "message": "预测成功", "features": { "age_years": 55, "bmi": 34.9, ... // 其他处理后的特征 } } """ try: # 检查模型是否已加载 if pipeline is None: return jsonify({ "success": False, "message": "模型未加载,请等待或联系管理员" }), 503 # 获取JSON数据 if not request.is_json: return jsonify({ "success": False, "message": "请求必须是JSON格式" }), 400 input_data = request.get_json() logger.info(f"收到预测请求: {input_data}") # 验证输入数据 is_valid, error_message = validate_input(input_data) if not is_valid: return jsonify({ "success": False, "message": error_message }), 400 # 预处理输入数据 processed_df = preprocess_input(input_data) # 进行预测 prediction = pipeline.predict(processed_df)[0] probability = pipeline.predict_proba(processed_df)[0][1] # 类别1的概率 # 确定风险等级 if probability < 0.3: risk_level = "低危" elif probability < 0.6: risk_level = "中危" else: risk_level = "高危" # 准备响应数据 response_data = { "success": True, "prediction": int(prediction), "probability": float(round(probability, 4)), "risk_level": risk_level, "message": "预测成功", "features": { "age_years": int(processed_df['age_years'].iloc[0]), "bmi": float(round(processed_df['bmi'].iloc[0], 2)), "ap_hi": int(processed_df['ap_hi'].iloc[0]), "ap_lo": int(processed_df['ap_lo'].iloc[0]), "gender": int(processed_df['gender'].iloc[0]), "cholesterol": int(processed_df['cholesterol'].iloc[0]), "gluc": int(processed_df['gluc'].iloc[0]), "smoke": int(processed_df['smoke'].iloc[0]), "alco": int(processed_df['alco'].iloc[0]), "active": int(processed_df['active'].iloc[0]) } } logger.info(f"预测结果: {response_data}") return jsonify(response_data), 200 except Exception as e: error_msg = f"预测过程中发生错误: {str(e)}" logger.error(error_msg) logger.error(traceback.format_exc()) return jsonify({ "success": False, "message": error_msg }), 500 @app.route('/health', methods=['GET']) def health_check(): """健康检查端点""" try: if pipeline is None: return jsonify({ "status": "unhealthy", "message": "模型未加载" }), 503 # 简单的模型测试 test_data = { "age": 20228, "gender": 1, "height": 156, "weight": 85, "ap_hi": 140, "ap_lo": 90, "cholesterol": 1, "gluc": 1, "smoke": 0, "alco": 0, "active": 1 } processed_df = preprocess_input(test_data) _ = pipeline.predict(processed_df) return jsonify({ "status": "healthy", "model_version": model_data.get('model_version', '未知'), "features": len(feature_names) if feature_names else 0, "message": "模型服务运行正常" }), 200 except Exception as e: return jsonify({ "status": "unhealthy", "message": f"健康检查失败: {str(e)}" }), 500 @app.route('/model_info', methods=['GET']) def model_info(): """获取模型信息""" if model_data is None: return jsonify({ "success": False, "message": "模型未加载" }), 503 return jsonify({ "success": True, "model_version": model_data.get('model_version', '未知'), "description": model_data.get('description', 'CardioAI心血管疾病预测模型'), "feature_count": len(feature_names) if feature_names else 0, "features": feature_names if feature_names else [] }), 200 # 模型加载标志 _model_loaded = False @app.before_request def ensure_model_loaded(): """确保模型已加载(每个请求前检查)""" global pipeline, model_data, feature_names, _model_loaded if not _model_loaded: logger.info("首次请求,正在加载模型...") success = load_model() if success: _model_loaded = True logger.info("模型加载完成") else: logger.error("模型加载失败") if __name__ == '__main__': # 加载模型 success = load_model() if not success: logger.error("启动失败: 模型加载失败") sys.exit(1) # 启动Flask应用 logger.info("启动CardioAI预测API服务...") logger.info("访问 http://localhost:5000 使用预测界面") logger.info("API文档:") logger.info(" GET / - 前端界面") logger.info(" POST /predict_cardio - 预测接口") logger.info(" GET /health - 健康检查") logger.info(" GET /model_info - 模型信息") app.run(host='0.0.0.0', port=5000, debug=True)