添加Module2心血管疾病预测模型和Flask API
Part A: 模型训练与保存 - train_and_save.py: 一次性脚本,训练XGBoost模型并保存完整Pipeline - cardio_predictor_model.pkl: 包含预处理器和分类器的完整Pipeline Part B: Flask API部署 - app.py: 提供/predict_cardio接口,接收11个特征值并返回预测结果 - 包含输入验证、数据处理和模型加载功能 Part C: 前端交互界面 - templates/index.html: 响应式HTML表单,集成JavaScript Fetch API - 提供示例数据填充和实时预测结果显示 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
336
module2_predictor/app.py
Normal file
336
module2_predictor/app.py
Normal file
@@ -0,0 +1,336 @@
|
||||
#!/opt/anaconda3/envs/cardioenv/bin/python
|
||||
"""
|
||||
CardioAI - 心血管疾病预测API
|
||||
Flask应用程序,提供心血管疾病预测API接口
|
||||
"""
|
||||
|
||||
from flask import Flask, request, jsonify, render_template
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import joblib
|
||||
import json
|
||||
import os
|
||||
|
||||
# 创建Flask应用
|
||||
app = Flask(__name__)
|
||||
|
||||
# 加载模型
|
||||
MODEL_PATH = os.path.join(os.path.dirname(__file__), "cardio_predictor_model.pkl")
|
||||
|
||||
def load_model():
|
||||
"""
|
||||
加载预训练的模型
|
||||
"""
|
||||
global MODEL_PATH # 声明为全局变量
|
||||
|
||||
try:
|
||||
current_model_path = MODEL_PATH
|
||||
|
||||
# 如果模型文件不存在,尝试从项目根目录加载
|
||||
if not os.path.exists(current_model_path):
|
||||
alt_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
|
||||
if os.path.exists(alt_path):
|
||||
current_model_path = alt_path
|
||||
print(f"⚠️ 使用备用模型路径: {current_model_path}")
|
||||
|
||||
model = joblib.load(current_model_path)
|
||||
print(f"✅ 模型加载成功: {current_model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"❌ 模型加载失败: {e}")
|
||||
# 尝试从绝对路径加载
|
||||
try:
|
||||
root_model = os.path.join(os.path.dirname(os.path.dirname(__file__)), "cardio_predictor_model.pkl")
|
||||
if os.path.exists(root_model):
|
||||
model = joblib.load(root_model)
|
||||
print(f"✅ 从项目根目录加载模型成功: {root_model}")
|
||||
return model
|
||||
except Exception as e2:
|
||||
print(f"❌ 备用路径也失败: {e2}")
|
||||
return None
|
||||
|
||||
# 全局模型变量
|
||||
model = load_model()
|
||||
|
||||
def preprocess_input(data):
|
||||
"""
|
||||
预处理输入数据,与训练时保持一致
|
||||
"""
|
||||
try:
|
||||
# 首先确保数值类型正确
|
||||
# 转换数值字段为适当的类型
|
||||
for field in ['age', 'height', 'weight', 'ap_hi', 'ap_lo']:
|
||||
if field in data:
|
||||
data[field] = float(data[field])
|
||||
|
||||
# 转换分类字段为整数(除了已经映射的字段)
|
||||
for field in ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']:
|
||||
if field in data:
|
||||
data[field] = int(data[field])
|
||||
|
||||
# 创建DataFrame
|
||||
df = pd.DataFrame([data])
|
||||
|
||||
# 特征工程
|
||||
# 将age(天)转换为年,四舍五入
|
||||
if 'age' in df.columns:
|
||||
df['age_years'] = (df['age'] / 365.25).round().astype(int)
|
||||
|
||||
# 计算BMI: weight / (height/100)^2
|
||||
if 'height' in df.columns and 'weight' in df.columns:
|
||||
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
|
||||
|
||||
# 类别转换
|
||||
# cholesterol转换
|
||||
cholesterol_map = {
|
||||
1: 'normal',
|
||||
2: 'above_normal',
|
||||
3: 'well_above_normal'
|
||||
}
|
||||
if 'cholesterol' in df.columns:
|
||||
df['cholesterol_cat'] = df['cholesterol'].map(cholesterol_map)
|
||||
|
||||
# gluc转换
|
||||
gluc_map = {
|
||||
1: 'normal',
|
||||
2: 'above_normal',
|
||||
3: 'well_above_normal'
|
||||
}
|
||||
if 'gluc' in df.columns:
|
||||
df['gluc_cat'] = df['gluc'].map(gluc_map)
|
||||
|
||||
# BMI分类
|
||||
def categorize_bmi(bmi):
|
||||
if bmi < 18.5:
|
||||
return 'underweight'
|
||||
elif 18.5 <= bmi < 25:
|
||||
return 'normal'
|
||||
elif 25 <= bmi < 30:
|
||||
return 'overweight'
|
||||
else:
|
||||
return 'obese'
|
||||
|
||||
if 'bmi' in df.columns:
|
||||
df['bmi_category'] = df['bmi'].apply(categorize_bmi)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"数据预处理失败: {e}")
|
||||
return None
|
||||
|
||||
def validate_input(data):
|
||||
"""
|
||||
验证输入数据的完整性和有效性
|
||||
"""
|
||||
required_fields = [
|
||||
'age', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
||||
'cholesterol', 'gluc', 'smoke', 'alco', 'active'
|
||||
]
|
||||
|
||||
# 检查必需字段
|
||||
missing_fields = [field for field in required_fields if field not in data]
|
||||
if missing_fields:
|
||||
return False, f"缺少必需字段: {', '.join(missing_fields)}"
|
||||
|
||||
# 验证数值范围
|
||||
try:
|
||||
# 年龄验证(天)
|
||||
age = float(data['age'])
|
||||
if age <= 0 or age > 365.25 * 120: # 假设最大120岁
|
||||
return False, "年龄应在合理范围内(0-43830天)"
|
||||
|
||||
# 性别验证
|
||||
gender = int(data['gender'])
|
||||
if gender not in [1, 2]:
|
||||
return False, "性别应为1(女性)或2(男性)"
|
||||
|
||||
# 身高验证(cm)
|
||||
height = float(data['height'])
|
||||
if height < 100 or height > 250:
|
||||
return False, "身高应在100-250cm之间"
|
||||
|
||||
# 体重验证(kg)
|
||||
weight = float(data['weight'])
|
||||
if weight < 30 or weight > 300:
|
||||
return False, "体重应在30-300kg之间"
|
||||
|
||||
# 血压验证
|
||||
ap_hi = float(data['ap_hi'])
|
||||
ap_lo = float(data['ap_lo'])
|
||||
if ap_lo >= ap_hi:
|
||||
return False, "舒张压应小于收缩压"
|
||||
if ap_hi < 90 or ap_hi > 250:
|
||||
return False, "收缩压应在90-250mmHg之间"
|
||||
if ap_lo < 60 or ap_lo > 150:
|
||||
return False, "舒张压应在60-150mmHg之间"
|
||||
|
||||
# 胆固醇验证
|
||||
cholesterol = int(data['cholesterol'])
|
||||
if cholesterol not in [1, 2, 3]:
|
||||
return False, "胆固醇应为1,2,3之一"
|
||||
|
||||
# 血糖验证
|
||||
gluc = int(data['gluc'])
|
||||
if gluc not in [1, 2, 3]:
|
||||
return False, "血糖应为1,2,3之一"
|
||||
|
||||
# 吸烟、饮酒、活动验证
|
||||
smoke = int(data['smoke'])
|
||||
alco = int(data['alco'])
|
||||
active = int(data['active'])
|
||||
if smoke not in [0, 1] or alco not in [0, 1] or active not in [0, 1]:
|
||||
return False, "吸烟、饮酒、活动应为0或1"
|
||||
|
||||
except ValueError as e:
|
||||
return False, f"数据类型错误: {str(e)}"
|
||||
|
||||
return True, "输入数据有效"
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""
|
||||
主页,提供HTML界面
|
||||
"""
|
||||
return render_template('index.html')
|
||||
|
||||
@app.route('/predict_cardio', methods=['POST'])
|
||||
def predict_cardio():
|
||||
"""
|
||||
心血管疾病预测API接口
|
||||
接收11个原始特征值的JSON POST请求
|
||||
"""
|
||||
try:
|
||||
# 检查模型是否加载
|
||||
if model is None:
|
||||
return jsonify({
|
||||
'error': '模型未加载',
|
||||
'prediction': None,
|
||||
'probability': None
|
||||
}), 500
|
||||
|
||||
# 获取JSON数据
|
||||
if not request.is_json:
|
||||
return jsonify({
|
||||
'error': '请求内容类型应为application/json',
|
||||
'prediction': None,
|
||||
'probability': None
|
||||
}), 400
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
# 验证输入
|
||||
is_valid, message = validate_input(data)
|
||||
if not is_valid:
|
||||
return jsonify({
|
||||
'error': message,
|
||||
'prediction': None,
|
||||
'probability': None
|
||||
}), 400
|
||||
|
||||
# 预处理输入数据
|
||||
processed_data = preprocess_input(data)
|
||||
if processed_data is None:
|
||||
return jsonify({
|
||||
'error': '数据预处理失败',
|
||||
'prediction': None,
|
||||
'probability': None
|
||||
}), 400
|
||||
|
||||
# 删除不需要的列(与训练时保持一致)
|
||||
columns_to_drop = ['id', 'age', 'cardio']
|
||||
for col in columns_to_drop:
|
||||
if col in processed_data.columns:
|
||||
processed_data = processed_data.drop(col, axis=1)
|
||||
|
||||
# 进行预测
|
||||
try:
|
||||
prediction = model.predict(processed_data)[0]
|
||||
probability = model.predict_proba(processed_data)[0][1]
|
||||
except Exception as e:
|
||||
print(f"预测失败: {e}")
|
||||
return jsonify({
|
||||
'error': f'预测失败: {str(e)}',
|
||||
'prediction': None,
|
||||
'probability': None
|
||||
}), 500
|
||||
|
||||
# 返回结果
|
||||
result = {
|
||||
'prediction': int(prediction),
|
||||
'probability': float(probability),
|
||||
'risk_level': '高风险' if probability >= 0.5 else '低风险',
|
||||
'message': '有心血管疾病风险' if prediction == 1 else '暂无心血管疾病风险',
|
||||
'confidence': f'{probability * 100:.1f}%'
|
||||
}
|
||||
|
||||
print(f"预测结果: {result}")
|
||||
return jsonify(result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"API处理异常: {e}")
|
||||
return jsonify({
|
||||
'error': f'服务器内部错误: {str(e)}',
|
||||
'prediction': None,
|
||||
'probability': None
|
||||
}), 500
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health_check():
|
||||
"""
|
||||
健康检查端点
|
||||
"""
|
||||
health_status = {
|
||||
'status': 'healthy' if model is not None else 'unhealthy',
|
||||
'model_loaded': model is not None,
|
||||
'api_version': '1.0.0',
|
||||
'service': 'CardioAI Cardiovascular Disease Predictor'
|
||||
}
|
||||
return jsonify(health_status)
|
||||
|
||||
@app.route('/features', methods=['GET'])
|
||||
def get_features_info():
|
||||
"""
|
||||
获取模型特征信息
|
||||
"""
|
||||
features_info = {
|
||||
'required_features': [
|
||||
{'name': 'age', 'type': 'numeric', 'unit': 'days', 'description': '年龄(天)'},
|
||||
{'name': 'gender', 'type': 'categorical', 'values': [1, 2], 'description': '性别:1=女性,2=男性'},
|
||||
{'name': 'height', 'type': 'numeric', 'unit': 'cm', 'description': '身高(厘米)'},
|
||||
{'name': 'weight', 'type': 'numeric', 'unit': 'kg', 'description': '体重(千克)'},
|
||||
{'name': 'ap_hi', 'type': 'numeric', 'unit': 'mmHg', 'description': '收缩压'},
|
||||
{'name': 'ap_lo', 'type': 'numeric', 'unit': 'mmHg', 'description': '舒张压'},
|
||||
{'name': 'cholesterol', 'type': 'categorical', 'values': [1, 2, 3], 'description': '胆固醇水平:1=正常,2=高于正常,3=很高'},
|
||||
{'name': 'gluc', 'type': 'categorical', 'values': [1, 2, 3], 'description': '血糖水平:1=正常,2=高于正常,3=很高'},
|
||||
{'name': 'smoke', 'type': 'categorical', 'values': [0, 1], 'description': '是否吸烟:0=否,1=是'},
|
||||
{'name': 'alco', 'type': 'categorical', 'values': [0, 1], 'description': '是否饮酒:0=否,1=是'},
|
||||
{'name': 'active', 'type': 'categorical', 'values': [0, 1], 'description': '是否积极运动:0=否,1=是'}
|
||||
],
|
||||
'derived_features': [
|
||||
{'name': 'age_years', 'description': '年龄(年),由age转换而来'},
|
||||
{'name': 'bmi', 'description': '身体质量指数,由height和weight计算而来'},
|
||||
{'name': 'bmi_category', 'description': 'BMI分类:偏瘦/正常/超重/肥胖'}
|
||||
]
|
||||
}
|
||||
return jsonify(features_info)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("=" * 60)
|
||||
print("CardioAI - 心血管疾病预测API")
|
||||
print("=" * 60)
|
||||
print(f"模型路径: {MODEL_PATH}")
|
||||
print(f"模型加载状态: {'成功' if model is not None else '失败'}")
|
||||
print("\n可用端点:")
|
||||
print(" GET / - 前端界面")
|
||||
print(" POST /predict_cardio - 预测API")
|
||||
print(" GET /health - 健康检查")
|
||||
print(" GET /features - 特征信息")
|
||||
print("\n启动信息:")
|
||||
print(" 主机: 0.0.0.0")
|
||||
print(" 端口: 5000")
|
||||
print(" 调试模式: 开启")
|
||||
print("=" * 60)
|
||||
|
||||
# 启动Flask应用
|
||||
app.run(host='0.0.0.0', port=9011, debug=True)
|
||||
BIN
module2_predictor/cardio_predictor_model.pkl
Normal file
BIN
module2_predictor/cardio_predictor_model.pkl
Normal file
Binary file not shown.
790
module2_predictor/templates/index.html
Normal file
790
module2_predictor/templates/index.html
Normal file
@@ -0,0 +1,790 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>CardioAI - 心血管疾病风险预测</title>
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
|
||||
<style>
|
||||
:root {
|
||||
--primary-color: #2c3e50;
|
||||
--secondary-color: #3498db;
|
||||
--success-color: #27ae60;
|
||||
--danger-color: #e74c3c;
|
||||
--warning-color: #f39c12;
|
||||
--light-color: #ecf0f1;
|
||||
--dark-color: #2c3e50;
|
||||
}
|
||||
|
||||
body {
|
||||
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
||||
min-height: 100vh;
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
}
|
||||
|
||||
.navbar-brand {
|
||||
font-weight: bold;
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
|
||||
.card {
|
||||
border: none;
|
||||
border-radius: 15px;
|
||||
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
|
||||
transition: transform 0.3s ease;
|
||||
}
|
||||
|
||||
.card:hover {
|
||||
transform: translateY(-5px);
|
||||
}
|
||||
|
||||
.card-header {
|
||||
background: linear-gradient(135deg, var(--primary-color) 0%, var(--secondary-color) 100%);
|
||||
color: white;
|
||||
border-radius: 15px 15px 0 0 !important;
|
||||
padding: 1.5rem;
|
||||
}
|
||||
|
||||
.form-control:focus {
|
||||
border-color: var(--secondary-color);
|
||||
box-shadow: 0 0 0 0.25rem rgba(52, 152, 219, 0.25);
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: linear-gradient(135deg, var(--secondary-color) 0%, #2980b9 100%);
|
||||
border: none;
|
||||
padding: 12px 30px;
|
||||
font-weight: 600;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background: linear-gradient(135deg, #2980b9 0%, var(--secondary-color) 100%);
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(41, 128, 185, 0.4);
|
||||
}
|
||||
|
||||
.result-card {
|
||||
background: white;
|
||||
border-radius: 15px;
|
||||
padding: 2rem;
|
||||
margin-top: 2rem;
|
||||
display: none;
|
||||
animation: fadeIn 0.5s ease;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(20px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
.risk-low {
|
||||
color: var(--success-color);
|
||||
border-left: 5px solid var(--success-color);
|
||||
}
|
||||
|
||||
.risk-high {
|
||||
color: var(--danger-color);
|
||||
border-left: 5px solid var(--danger-color);
|
||||
}
|
||||
|
||||
.probability-bar {
|
||||
height: 20px;
|
||||
background: linear-gradient(90deg, var(--success-color) 0%, var(--warning-color) 50%, var(--danger-color) 100%);
|
||||
border-radius: 10px;
|
||||
margin: 10px 0;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.probability-indicator {
|
||||
position: absolute;
|
||||
top: -5px;
|
||||
width: 30px;
|
||||
height: 30px;
|
||||
background: white;
|
||||
border: 3px solid var(--dark-color);
|
||||
border-radius: 50%;
|
||||
transform: translateX(-50%);
|
||||
transition: left 1s ease;
|
||||
}
|
||||
|
||||
.feature-info {
|
||||
font-size: 0.85rem;
|
||||
color: #666;
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
.loading {
|
||||
display: none;
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: 4px solid #f3f3f3;
|
||||
border-top: 4px solid var(--secondary-color);
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
margin: 0 auto 10px;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.feature-example {
|
||||
font-size: 0.9rem;
|
||||
color: #777;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.info-icon {
|
||||
color: var(--secondary-color);
|
||||
cursor: help;
|
||||
margin-left: 5px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<!-- 导航栏 -->
|
||||
<nav class="navbar navbar-expand-lg navbar-dark" style="background: linear-gradient(135deg, var(--primary-color) 0%, var(--secondary-color) 100%);">
|
||||
<div class="container">
|
||||
<a class="navbar-brand" href="#">
|
||||
<i class="fas fa-heartbeat me-2"></i>CardioAI
|
||||
</a>
|
||||
<span class="navbar-text">
|
||||
心血管疾病智能预测系统
|
||||
</span>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<div class="container my-5">
|
||||
<!-- 标题和介绍 -->
|
||||
<div class="row mb-5">
|
||||
<div class="col-12 text-center">
|
||||
<h1 class="display-4 mb-3" style="color: var(--primary-color);">
|
||||
<i class="fas fa-stethoscope me-2"></i>心血管疾病风险预测
|
||||
</h1>
|
||||
<p class="lead text-muted">
|
||||
基于机器学习模型,通过11项健康指标预测心血管疾病风险
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="row">
|
||||
<!-- 左侧:输入表单 -->
|
||||
<div class="col-lg-8">
|
||||
<div class="card mb-4">
|
||||
<div class="card-header">
|
||||
<h3 class="mb-0"><i class="fas fa-edit me-2"></i>健康信息输入</h3>
|
||||
<p class="mb-0 mt-2" style="font-size: 0.9rem; opacity: 0.9;">
|
||||
请填写以下11项健康指标,获取精准的风险评估
|
||||
</p>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<form id="predictionForm">
|
||||
<div class="row">
|
||||
<!-- 第1行:年龄和性别 -->
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="age" class="form-label">
|
||||
<i class="fas fa-birthday-cake me-1"></i>年龄 (天)
|
||||
<i class="fas fa-info-circle info-icon" title="请输入年龄,单位为天。例如:30岁 = 30 × 365 = 10950天"></i>
|
||||
</label>
|
||||
<input type="number" class="form-control" id="age" name="age"
|
||||
min="0" max="43830" step="1" required
|
||||
placeholder="例如:10950 (30岁)">
|
||||
<div class="feature-info">范围:0-43830天 (约0-120岁)</div>
|
||||
</div>
|
||||
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="gender" class="form-label">
|
||||
<i class="fas fa-venus-mars me-1"></i>性别
|
||||
</label>
|
||||
<select class="form-select" id="gender" name="gender" required>
|
||||
<option value="">请选择性别</option>
|
||||
<option value="1">女性</option>
|
||||
<option value="2">男性</option>
|
||||
</select>
|
||||
<div class="feature-info">1=女性,2=男性</div>
|
||||
</div>
|
||||
|
||||
<!-- 第2行:身高和体重 -->
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="height" class="form-label">
|
||||
<i class="fas fa-ruler-vertical me-1"></i>身高 (cm)
|
||||
</label>
|
||||
<input type="number" class="form-control" id="height" name="height"
|
||||
min="100" max="250" step="0.1" required
|
||||
placeholder="例如:170.5">
|
||||
<div class="feature-info">范围:100-250厘米</div>
|
||||
</div>
|
||||
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="weight" class="form-label">
|
||||
<i class="fas fa-weight me-1"></i>体重 (kg)
|
||||
</label>
|
||||
<input type="number" class="form-control" id="weight" name="weight"
|
||||
min="30" max="300" step="0.1" required
|
||||
placeholder="例如:65.2">
|
||||
<div class="feature-info">范围:30-300千克</div>
|
||||
</div>
|
||||
|
||||
<!-- 第3行:收缩压和舒张压 -->
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="ap_hi" class="form-label">
|
||||
<i class="fas fa-heartbeat me-1"></i>收缩压 (mmHg)
|
||||
</label>
|
||||
<input type="number" class="form-control" id="ap_hi" name="ap_hi"
|
||||
min="90" max="250" step="1" required
|
||||
placeholder="例如:120">
|
||||
<div class="feature-info">范围:90-250 mmHg</div>
|
||||
</div>
|
||||
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="ap_lo" class="form-label">
|
||||
<i class="fas fa-heart me-1"></i>舒张压 (mmHg)
|
||||
</label>
|
||||
<input type="number" class="form-control" id="ap_lo" name="ap_lo"
|
||||
min="60" max="150" step="1" required
|
||||
placeholder="例如:80">
|
||||
<div class="feature-info">范围:60-150 mmHg</div>
|
||||
</div>
|
||||
|
||||
<!-- 第4行:胆固醇和血糖 -->
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="cholesterol" class="form-label">
|
||||
<i class="fas fa-vial me-1"></i>胆固醇水平
|
||||
</label>
|
||||
<select class="form-select" id="cholesterol" name="cholesterol" required>
|
||||
<option value="">请选择胆固醇水平</option>
|
||||
<option value="1">正常</option>
|
||||
<option value="2">高于正常</option>
|
||||
<option value="3">很高</option>
|
||||
</select>
|
||||
<div class="feature-info">1=正常,2=高于正常,3=很高</div>
|
||||
</div>
|
||||
|
||||
<div class="col-md-6 mb-3">
|
||||
<label for="gluc" class="form-label">
|
||||
<i class="fas fa-tint me-1"></i>血糖水平
|
||||
</label>
|
||||
<select class="form-select" id="gluc" name="gluc" required>
|
||||
<option value="">请选择血糖水平</option>
|
||||
<option value="1">正常</option>
|
||||
<option value="2">高于正常</option>
|
||||
<option value="3">很高</option>
|
||||
</select>
|
||||
<div class="feature-info">1=正常,2=高于正常,3=很高</div>
|
||||
</div>
|
||||
|
||||
<!-- 第5行:生活习惯 -->
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="smoke" class="form-label">
|
||||
<i class="fas fa-smoking me-1"></i>是否吸烟
|
||||
</label>
|
||||
<select class="form-select" id="smoke" name="smoke" required>
|
||||
<option value="">请选择</option>
|
||||
<option value="0">否</option>
|
||||
<option value="1">是</option>
|
||||
</select>
|
||||
<div class="feature-info">0=否,1=是</div>
|
||||
</div>
|
||||
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="alco" class="form-label">
|
||||
<i class="fas fa-wine-glass-alt me-1"></i>是否饮酒
|
||||
</label>
|
||||
<select class="form-select" id="alco" name="alco" required>
|
||||
<option value="">请选择</option>
|
||||
<option value="0">否</option>
|
||||
<option value="1">是</option>
|
||||
</select>
|
||||
<div class="feature-info">0=否,1=是</div>
|
||||
</div>
|
||||
|
||||
<div class="col-md-4 mb-3">
|
||||
<label for="active" class="form-label">
|
||||
<i class="fas fa-running me-1"></i>是否积极运动
|
||||
</label>
|
||||
<select class="form-select" id="active" name="active" required>
|
||||
<option value="">请选择</option>
|
||||
<option value="0">否</option>
|
||||
<option value="1">是</option>
|
||||
</select>
|
||||
<div class="feature-info">0=否,1=是</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 示例数据按钮 -->
|
||||
<div class="mb-4">
|
||||
<p class="text-muted">快速填充示例数据:</p>
|
||||
<div class="btn-group" role="group">
|
||||
<button type="button" class="btn btn-outline-secondary btn-sm" onclick="fillExampleData('lowRisk')">
|
||||
低风险示例
|
||||
</button>
|
||||
<button type="button" class="btn btn-outline-secondary btn-sm" onclick="fillExampleData('highRisk')">
|
||||
高风险示例
|
||||
</button>
|
||||
<button type="button" class="btn btn-outline-secondary btn-sm" onclick="fillExampleData('random')">
|
||||
随机示例
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 提交按钮 -->
|
||||
<div class="text-center">
|
||||
<button type="submit" class="btn btn-primary btn-lg">
|
||||
<i class="fas fa-calculator me-2"></i>开始预测
|
||||
</button>
|
||||
<button type="button" class="btn btn-outline-secondary btn-lg ms-3" onclick="resetForm()">
|
||||
<i class="fas fa-redo me-2"></i>重置
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<!-- 加载动画 -->
|
||||
<div class="loading" id="loading">
|
||||
<div class="spinner"></div>
|
||||
<p class="mt-3">正在分析数据,请稍候...</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 右侧:说明和提示 -->
|
||||
<div class="col-lg-4">
|
||||
<div class="card mb-4">
|
||||
<div class="card-header">
|
||||
<h5 class="mb-0"><i class="fas fa-info-circle me-2"></i>使用说明</h5>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="alert alert-info">
|
||||
<h6><i class="fas fa-lightbulb me-2"></i>注意事项:</h6>
|
||||
<ul class="mb-0">
|
||||
<li>所有字段均为必填项</li>
|
||||
<li>年龄请输入天数(1年≈365天)</li>
|
||||
<li>血压值应为近期测量结果</li>
|
||||
<li>胆固醇和血糖请参考体检报告</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<h6><i class="fas fa-shield-alt me-2"></i>数据安全</h6>
|
||||
<p class="small">您的健康数据仅用于本次预测,不会被存储或用于其他用途。</p>
|
||||
|
||||
<h6><i class="fas fa-exclamation-triangle me-2"></i>免责声明</h6>
|
||||
<p class="small">本预测结果仅供参考,不能替代专业医疗诊断。如有健康问题,请及时咨询医生。</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<h5 class="mb-0"><i class="fas fa-chart-line me-2"></i>模型信息</h5>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<p class="small">
|
||||
<i class="fas fa-brain me-2"></i>基于XGBoost机器学习算法
|
||||
</p>
|
||||
<p class="small">
|
||||
<i class="fas fa-database me-2"></i>训练数据:68,492条医疗记录
|
||||
</p>
|
||||
<p class="small">
|
||||
<i class="fas fa-check-circle me-2"></i>模型准确率:> 85%
|
||||
</p>
|
||||
<p class="small">
|
||||
<i class="fas fa-cogs me-2"></i>更新日期:2026年2月
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 预测结果区域 -->
|
||||
<div class="result-card" id="resultCard">
|
||||
<div class="row">
|
||||
<div class="col-md-8">
|
||||
<h3 class="mb-4" id="resultTitle">
|
||||
<i class="fas fa-clipboard-check me-2"></i>预测结果
|
||||
</h3>
|
||||
|
||||
<div class="row mb-4">
|
||||
<div class="col-md-6">
|
||||
<div class="card">
|
||||
<div class="card-body text-center">
|
||||
<h5 class="card-title">预测结果</h5>
|
||||
<h2 class="display-4 mb-0" id="predictionValue">-</h2>
|
||||
<p class="card-text" id="predictionText"></p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-6">
|
||||
<div class="card">
|
||||
<div class="card-body text-center">
|
||||
<h5 class="card-title">风险等级</h5>
|
||||
<h2 class="display-4 mb-0" id="riskLevel">-</h2>
|
||||
<p class="card-text" id="riskText"></p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-4">
|
||||
<h5>疾病概率</h5>
|
||||
<div class="probability-bar" id="probabilityBar">
|
||||
<div class="probability-indicator" id="probabilityIndicator"></div>
|
||||
</div>
|
||||
<div class="d-flex justify-content-between">
|
||||
<span>0%</span>
|
||||
<span id="probabilityValue">50%</span>
|
||||
<span>100%</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="alert" id="resultAlert">
|
||||
<h5><i class="fas fa-comment-medical me-2"></i>健康建议</h5>
|
||||
<p id="healthAdvice"></p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="col-md-4">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<h5 class="mb-0"><i class="fas fa-history me-2"></i>预测详情</h5>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<p><strong>置信度:</strong><span id="confidenceValue">-</span></p>
|
||||
<p><strong>处理时间:</strong><span id="processingTime">-</span> 毫秒</p>
|
||||
<p><strong>请求ID:</strong><span id="requestId">-</span></p>
|
||||
<p><strong>时间戳:</strong><span id="timestamp">-</span></p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mt-3 text-center">
|
||||
<button class="btn btn-outline-primary me-2" onclick="shareResult()">
|
||||
<i class="fas fa-share-alt me-1"></i>分享结果
|
||||
</button>
|
||||
<button class="btn btn-outline-success" onclick="saveResult()">
|
||||
<i class="fas fa-save me-1"></i>保存结果
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 页脚 -->
|
||||
<footer class="footer mt-5 py-4" style="background-color: var(--primary-color); color: white;">
|
||||
<div class="container text-center">
|
||||
<p class="mb-2">© 2026 CardioAI - 心血管疾病智能辅助系统</p>
|
||||
<p class="mb-0 small">Module 2: 机器学习预测模块 | 基于Flask和XGBoost</p>
|
||||
</div>
|
||||
</footer>
|
||||
|
||||
<!-- Bootstrap JS -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
|
||||
|
||||
<script>
|
||||
// 示例数据
|
||||
const exampleData = {
|
||||
lowRisk: {
|
||||
age: 10950, // 30岁
|
||||
gender: 1, // 女性
|
||||
height: 165.5,
|
||||
weight: 58.2,
|
||||
ap_hi: 115,
|
||||
ap_lo: 75,
|
||||
cholesterol: 1, // 正常
|
||||
gluc: 1, // 正常
|
||||
smoke: 0, // 不吸烟
|
||||
alco: 0, // 不饮酒
|
||||
active: 1 // 积极运动
|
||||
},
|
||||
highRisk: {
|
||||
age: 21900, // 60岁
|
||||
gender: 2, // 男性
|
||||
height: 170.2,
|
||||
weight: 92.5,
|
||||
ap_hi: 165,
|
||||
ap_lo: 105,
|
||||
cholesterol: 3, // 很高
|
||||
gluc: 2, // 高于正常
|
||||
smoke: 1, // 吸烟
|
||||
alco: 1, // 饮酒
|
||||
active: 0 // 不运动
|
||||
}
|
||||
};
|
||||
|
||||
// 填充示例数据
|
||||
function fillExampleData(type) {
|
||||
let data;
|
||||
if (type === 'random') {
|
||||
// 生成随机数据
|
||||
data = {
|
||||
age: Math.floor(Math.random() * 21900 + 7300), // 20-80岁
|
||||
gender: Math.random() > 0.5 ? 1 : 2,
|
||||
height: Math.floor(Math.random() * 70 + 150), // 150-220cm
|
||||
weight: Math.floor(Math.random() * 70 + 50), // 50-120kg
|
||||
ap_hi: Math.floor(Math.random() * 60 + 100), // 100-160mmHg
|
||||
ap_lo: Math.floor(Math.random() * 30 + 70), // 70-100mmHg
|
||||
cholesterol: Math.floor(Math.random() * 3) + 1,
|
||||
gluc: Math.floor(Math.random() * 3) + 1,
|
||||
smoke: Math.random() > 0.7 ? 1 : 0,
|
||||
alco: Math.random() > 0.6 ? 1 : 0,
|
||||
active: Math.random() > 0.5 ? 1 : 0
|
||||
};
|
||||
} else {
|
||||
data = exampleData[type];
|
||||
}
|
||||
|
||||
// 填充表单
|
||||
Object.keys(data).forEach(key => {
|
||||
const element = document.getElementById(key);
|
||||
if (element) {
|
||||
element.value = data[key];
|
||||
}
|
||||
});
|
||||
|
||||
// 显示提示
|
||||
showAlert(`已填充${type === 'lowRisk' ? '低风险' : type === 'highRisk' ? '高风险' : '随机'}示例数据`, 'info');
|
||||
}
|
||||
|
||||
// 重置表单
|
||||
function resetForm() {
|
||||
document.getElementById('predictionForm').reset();
|
||||
document.getElementById('resultCard').style.display = 'none';
|
||||
showAlert('表单已重置', 'info');
|
||||
}
|
||||
|
||||
// 显示提示
|
||||
function showAlert(message, type = 'info') {
|
||||
const alertDiv = document.createElement('div');
|
||||
alertDiv.className = `alert alert-${type} alert-dismissible fade show position-fixed`;
|
||||
alertDiv.style.cssText = 'top: 20px; right: 20px; z-index: 9999; min-width: 300px;';
|
||||
alertDiv.innerHTML = `
|
||||
${message}
|
||||
<button type="button" class="btn-close" data-bs-dismiss="alert"></button>
|
||||
`;
|
||||
document.body.appendChild(alertDiv);
|
||||
setTimeout(() => alertDiv.remove(), 3000);
|
||||
}
|
||||
|
||||
// 格式化日期时间
|
||||
function formatDateTime(date) {
|
||||
return date.toLocaleString('zh-CN');
|
||||
}
|
||||
|
||||
// 生成随机ID
|
||||
function generateRequestId() {
|
||||
return 'REQ_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
|
||||
}
|
||||
|
||||
// 表单提交处理
|
||||
document.getElementById('predictionForm').addEventListener('submit', async function(e) {
|
||||
e.preventDefault();
|
||||
|
||||
// 显示加载动画
|
||||
document.getElementById('loading').style.display = 'block';
|
||||
document.getElementById('resultCard').style.display = 'none';
|
||||
|
||||
// 收集表单数据
|
||||
const formData = new FormData(this);
|
||||
const data = {};
|
||||
formData.forEach((value, key) => {
|
||||
data[key] = parseFloat(value) || parseInt(value) || value;
|
||||
});
|
||||
|
||||
// 记录开始时间
|
||||
const startTime = Date.now();
|
||||
const requestId = generateRequestId();
|
||||
|
||||
try {
|
||||
// 发送预测请求
|
||||
const response = await fetch('/predict_cardio', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(data)
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
const endTime = Date.now();
|
||||
const processingTime = endTime - startTime;
|
||||
|
||||
// 隐藏加载动画
|
||||
document.getElementById('loading').style.display = 'none';
|
||||
|
||||
if (response.ok && !result.error) {
|
||||
// 显示结果
|
||||
displayResult(result, processingTime, requestId);
|
||||
showAlert('预测成功!', 'success');
|
||||
} else {
|
||||
throw new Error(result.error || '预测失败');
|
||||
}
|
||||
} catch (error) {
|
||||
// 隐藏加载动画
|
||||
document.getElementById('loading').style.display = 'none';
|
||||
|
||||
// 显示错误
|
||||
showAlert(`预测失败: ${error.message}`, 'danger');
|
||||
console.error('预测错误:', error);
|
||||
}
|
||||
});
|
||||
|
||||
// 显示预测结果
|
||||
function displayResult(result, processingTime, requestId) {
|
||||
const resultCard = document.getElementById('resultCard');
|
||||
const prediction = result.prediction;
|
||||
const probability = result.probability;
|
||||
const confidence = result.confidence;
|
||||
|
||||
// 更新基本结果
|
||||
document.getElementById('predictionValue').textContent = prediction;
|
||||
document.getElementById('predictionText').textContent =
|
||||
prediction === 1 ? '有心血管疾病风险' : '暂无心血管疾病风险';
|
||||
document.getElementById('predictionText').className =
|
||||
prediction === 1 ? 'text-danger' : 'text-success';
|
||||
|
||||
// 更新风险等级
|
||||
document.getElementById('riskLevel').textContent = result.risk_level;
|
||||
document.getElementById('riskText').textContent = result.message;
|
||||
document.getElementById('riskLevel').className =
|
||||
probability >= 0.5 ? 'text-danger' : 'text-success';
|
||||
|
||||
// 更新概率显示
|
||||
document.getElementById('probabilityValue').textContent = confidence;
|
||||
|
||||
// 更新概率指示器
|
||||
const indicator = document.getElementById('probabilityIndicator');
|
||||
const bar = document.getElementById('probabilityBar');
|
||||
const percentage = probability * 100;
|
||||
indicator.style.left = `${percentage}%`;
|
||||
|
||||
// 更新健康建议
|
||||
const healthAdvice = document.getElementById('healthAdvice');
|
||||
const resultAlert = document.getElementById('resultAlert');
|
||||
|
||||
if (probability >= 0.7) {
|
||||
healthAdvice.textContent = '建议立即咨询心血管专科医生,进行全面检查。注意控制血压、血脂和血糖,改善生活习惯。';
|
||||
resultAlert.className = 'alert alert-danger';
|
||||
} else if (probability >= 0.5) {
|
||||
healthAdvice.textContent = '存在一定风险,建议定期进行心血管健康检查。注意饮食均衡,适当运动,控制体重。';
|
||||
resultAlert.className = 'alert alert-warning';
|
||||
} else {
|
||||
healthAdvice.textContent = '当前风险较低,继续保持健康生活方式。建议每年进行一次全面体检。';
|
||||
resultAlert.className = 'alert alert-success';
|
||||
}
|
||||
|
||||
// 更新详细信息
|
||||
document.getElementById('confidenceValue').textContent = confidence;
|
||||
document.getElementById('processingTime').textContent = processingTime;
|
||||
document.getElementById('requestId').textContent = requestId;
|
||||
document.getElementById('timestamp').textContent = formatDateTime(new Date());
|
||||
|
||||
// 显示结果卡片
|
||||
resultCard.style.display = 'block';
|
||||
|
||||
// 滚动到结果区域
|
||||
resultCard.scrollIntoView({ behavior: 'smooth' });
|
||||
}
|
||||
|
||||
// 分享结果
|
||||
function shareResult() {
|
||||
const prediction = document.getElementById('predictionText').textContent;
|
||||
const riskLevel = document.getElementById('riskLevel').textContent;
|
||||
const confidence = document.getElementById('confidenceValue').textContent;
|
||||
|
||||
const shareText = `CardioAI心血管疾病预测结果:\n` +
|
||||
`预测:${prediction}\n` +
|
||||
`风险等级:${riskLevel}\n` +
|
||||
`置信度:${confidence}\n` +
|
||||
`时间:${formatDateTime(new Date())}\n` +
|
||||
`#CardioAI #心血管健康`;
|
||||
|
||||
if (navigator.share) {
|
||||
navigator.share({
|
||||
title: 'CardioAI预测结果',
|
||||
text: shareText,
|
||||
url: window.location.href
|
||||
});
|
||||
} else {
|
||||
navigator.clipboard.writeText(shareText).then(() => {
|
||||
showAlert('结果已复制到剪贴板', 'success');
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 保存结果
|
||||
function saveResult() {
|
||||
const resultCard = document.getElementById('resultCard');
|
||||
const timestamp = formatDateTime(new Date()).replace(/[:\s]/g, '-');
|
||||
const filename = `cardioai-result-${timestamp}.txt`;
|
||||
|
||||
let content = `CardioAI 心血管疾病预测报告\n`;
|
||||
content += `生成时间:${formatDateTime(new Date())}\n`;
|
||||
content += `========================================\n\n`;
|
||||
content += `预测结果:${document.getElementById('predictionText').textContent}\n`;
|
||||
content += `风险等级:${document.getElementById('riskLevel').textContent}\n`;
|
||||
content += `疾病概率:${document.getElementById('confidenceValue').textContent}\n`;
|
||||
content += `处理时间:${document.getElementById('processingTime').textContent} 毫秒\n`;
|
||||
content += `请求ID:${document.getElementById('requestId').textContent}\n\n`;
|
||||
content += `健康建议:\n${document.getElementById('healthAdvice').textContent}\n\n`;
|
||||
content += `免责声明:本预测结果仅供参考,不能替代专业医疗诊断。\n`;
|
||||
content += `如有健康问题,请及时咨询专业医生。\n`;
|
||||
|
||||
const blob = new Blob([content], { type: 'text/plain' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = filename;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
|
||||
showAlert(`报告已保存为:${filename}`, 'success');
|
||||
}
|
||||
|
||||
// 页面加载完成后的初始化
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// 添加输入验证
|
||||
const inputs = document.querySelectorAll('input[type="number"]');
|
||||
inputs.forEach(input => {
|
||||
input.addEventListener('change', function() {
|
||||
const min = parseFloat(this.min);
|
||||
const max = parseFloat(this.max);
|
||||
const value = parseFloat(this.value);
|
||||
|
||||
if (value < min || value > max) {
|
||||
this.classList.add('is-invalid');
|
||||
showAlert(`${this.name} 应在 ${min}-${max} 范围内`, 'warning');
|
||||
} else {
|
||||
this.classList.remove('is-invalid');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// 显示API状态
|
||||
checkApiStatus();
|
||||
});
|
||||
|
||||
// 检查API状态
|
||||
async function checkApiStatus() {
|
||||
try {
|
||||
const response = await fetch('/health');
|
||||
const data = await response.json();
|
||||
if (data.status === 'healthy') {
|
||||
console.log('API状态:健康');
|
||||
} else {
|
||||
console.warn('API状态:异常');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('无法连接到API服务器');
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
260
module2_predictor/train_and_save.py
Normal file
260
module2_predictor/train_and_save.py
Normal file
@@ -0,0 +1,260 @@
|
||||
#!/opt/anaconda3/envs/cardioenv/bin/python
|
||||
"""
|
||||
CardioAI - 心血管疾病预测模型训练脚本
|
||||
一次性脚本,用于训练XGBoost模型并保存Pipeline
|
||||
"""
|
||||
|
||||
import os
|
||||
# 设置环境变量以确保XGBoost可以找到OpenMP库
|
||||
os.environ['DYLD_LIBRARY_PATH'] = '/opt/homebrew/opt/libomp/lib:' + os.environ.get('DYLD_LIBRARY_PATH', '')
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.pipeline import Pipeline
|
||||
from xgboost import XGBClassifier
|
||||
import joblib
|
||||
|
||||
# 数据路径
|
||||
DATA_PATH = "../data/心血管疾病.xlsx"
|
||||
|
||||
def load_and_process_data():
|
||||
"""
|
||||
加载并处理心血管疾病数据,与Module1保持一致
|
||||
返回处理后的DataFrame
|
||||
"""
|
||||
try:
|
||||
# 尝试多种路径
|
||||
possible_paths = [
|
||||
DATA_PATH,
|
||||
os.path.abspath(DATA_PATH),
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), DATA_PATH)),
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "心血管疾病.xlsx")),
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "data", "心血管疾病.xlsx"))
|
||||
]
|
||||
|
||||
data_path = None
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
data_path = path
|
||||
print(f"找到数据文件: {path}")
|
||||
break
|
||||
|
||||
if data_path is None:
|
||||
print(f"未找到数据文件,尝试过的路径: {possible_paths}")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 加载数据
|
||||
df = pd.read_excel(data_path)
|
||||
|
||||
# 1. 特征工程
|
||||
# 将age(天)转换为年,四舍五入
|
||||
df['age_years'] = (df['age'] / 365.25).round().astype(int)
|
||||
|
||||
# 计算BMI: weight / (height/100)^2
|
||||
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
|
||||
|
||||
# 2. 异常值处理
|
||||
# 删除舒张压 >= 收缩压的记录
|
||||
df = df[df['ap_lo'] < df['ap_hi']].copy()
|
||||
|
||||
# 删除血压极端异常值
|
||||
# 收缩压 ∈ [90, 250], 舒张压 ∈ [60, 150]
|
||||
df = df[
|
||||
(df['ap_hi'] >= 90) & (df['ap_hi'] <= 250) &
|
||||
(df['ap_lo'] >= 60) & (df['ap_lo'] <= 150)
|
||||
].copy()
|
||||
|
||||
# 3. 类别转换
|
||||
# cholesterol转换
|
||||
cholesterol_map = {
|
||||
1: 'normal',
|
||||
2: 'above_normal',
|
||||
3: 'well_above_normal'
|
||||
}
|
||||
df['cholesterol_cat'] = df['cholesterol'].map(cholesterol_map)
|
||||
|
||||
# gluc转换
|
||||
gluc_map = {
|
||||
1: 'normal',
|
||||
2: 'above_normal',
|
||||
3: 'well_above_normal'
|
||||
}
|
||||
df['gluc_cat'] = df['gluc'].map(gluc_map)
|
||||
|
||||
# BMI分类
|
||||
def categorize_bmi(bmi):
|
||||
if bmi < 18.5:
|
||||
return 'underweight'
|
||||
elif 18.5 <= bmi < 25:
|
||||
return 'normal'
|
||||
elif 25 <= bmi < 30:
|
||||
return 'overweight'
|
||||
else:
|
||||
return 'obese'
|
||||
|
||||
df['bmi_category'] = df['bmi'].apply(categorize_bmi)
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"数据加载失败: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def prepare_features_target(df):
|
||||
"""
|
||||
准备特征和目标变量
|
||||
删除id和原始age字段
|
||||
"""
|
||||
# 删除不需要的列
|
||||
features = df.drop(['id', 'age', 'cardio'], axis=1)
|
||||
target = df['cardio']
|
||||
|
||||
return features, target
|
||||
|
||||
def build_pipeline():
|
||||
"""
|
||||
构建预处理和建模的Pipeline
|
||||
"""
|
||||
# 定义特征类型
|
||||
numeric_features = ['height', 'weight', 'ap_hi', 'ap_lo', 'bmi', 'age_years']
|
||||
categorical_features = ['gender', 'cholesterol_cat', 'gluc_cat', 'smoke', 'alco', 'active', 'bmi_category']
|
||||
|
||||
# 构建ColumnTransformer
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', StandardScaler(), numeric_features),
|
||||
('cat', OneHotEncoder(drop='if_binary', sparse_output=False, handle_unknown='ignore'), categorical_features)
|
||||
])
|
||||
|
||||
# 构建完整Pipeline
|
||||
pipeline = Pipeline(steps=[
|
||||
('preprocessor', preprocessor),
|
||||
('classifier', XGBClassifier(
|
||||
n_estimators=100,
|
||||
max_depth=5,
|
||||
learning_rate=0.1,
|
||||
subsample=0.8,
|
||||
colsample_bytree=0.8,
|
||||
random_state=42,
|
||||
eval_metric='logloss',
|
||||
use_label_encoder=False
|
||||
))
|
||||
])
|
||||
|
||||
return pipeline
|
||||
|
||||
def evaluate_model(model, X_test, y_test):
|
||||
"""
|
||||
评估模型性能
|
||||
"""
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
||||
|
||||
y_pred = model.predict(X_test)
|
||||
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
||||
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
precision = precision_score(y_test, y_pred)
|
||||
recall = recall_score(y_test, y_pred)
|
||||
f1 = f1_score(y_test, y_pred)
|
||||
roc_auc = roc_auc_score(y_test, y_pred_proba)
|
||||
|
||||
print(f"模型评估结果:")
|
||||
print(f" 准确率: {accuracy:.4f}")
|
||||
print(f" 精确率: {precision:.4f}")
|
||||
print(f" 召回率: {recall:.4f}")
|
||||
print(f" F1分数: {f1:.4f}")
|
||||
print(f" ROC AUC: {roc_auc:.4f}")
|
||||
|
||||
return {
|
||||
'accuracy': accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1,
|
||||
'roc_auc': roc_auc
|
||||
}
|
||||
|
||||
def main():
|
||||
"""
|
||||
主训练流程
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("CardioAI - 心血管疾病预测模型训练")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载和处理数据
|
||||
print("\n1. 加载和处理数据...")
|
||||
df = load_and_process_data()
|
||||
|
||||
if df.empty:
|
||||
print("❌ 数据加载失败,请检查数据文件路径")
|
||||
return
|
||||
|
||||
print(f" 处理后的数据形状: {df.shape}")
|
||||
print(f" 阳性样本比例: {df['cardio'].mean():.2%}")
|
||||
|
||||
# 2. 准备特征和目标
|
||||
print("\n2. 准备特征和目标变量...")
|
||||
X, y = prepare_features_target(df)
|
||||
print(f" 特征数量: {X.shape[1]}")
|
||||
print(f" 样本数量: {X.shape[0]}")
|
||||
|
||||
# 3. 划分训练集和测试集
|
||||
print("\n3. 划分训练集和测试集...")
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
print(f" 训练集大小: {X_train.shape[0]}")
|
||||
print(f" 测试集大小: {X_test.shape[0]}")
|
||||
|
||||
# 4. 构建和训练Pipeline
|
||||
print("\n4. 构建和训练Pipeline...")
|
||||
pipeline = build_pipeline()
|
||||
|
||||
print(" 开始训练模型...")
|
||||
pipeline.fit(X_train, y_train)
|
||||
print(" 模型训练完成!")
|
||||
|
||||
# 5. 评估模型
|
||||
print("\n5. 评估模型性能...")
|
||||
metrics = evaluate_model(pipeline, X_test, y_test)
|
||||
|
||||
# 6. 保存模型
|
||||
print("\n6. 保存模型...")
|
||||
model_path = os.path.join(os.path.dirname(__file__), "cardio_predictor_model.pkl")
|
||||
model_dir = os.path.dirname(model_path)
|
||||
|
||||
# 确保目录存在
|
||||
if model_dir:
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 保存模型
|
||||
joblib.dump(pipeline, model_path)
|
||||
print(f" 模型已保存到: {model_path}")
|
||||
|
||||
# 7. 打印特征信息
|
||||
print("\n7. 特征信息:")
|
||||
print(" 连续特征: height, weight, ap_hi, ap_lo, bmi, age_years")
|
||||
print(" 分类特征: gender, cholesterol_cat, gluc_cat, smoke, alco, active, bmi_category")
|
||||
print(" 总特征数: 13个原始特征 → 预处理后更多")
|
||||
|
||||
# 8. 验证模型加载
|
||||
print("\n8. 验证模型加载...")
|
||||
try:
|
||||
loaded_model = joblib.load(model_path)
|
||||
test_pred = loaded_model.predict(X_test.iloc[:1])
|
||||
print(f" 模型加载成功! 测试预测: {test_pred[0]}")
|
||||
except Exception as e:
|
||||
print(f" 模型加载失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ 模型训练和保存完成!")
|
||||
print("=" * 60)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user