Files
111/aicodes/module2_predictor/train_and_save.py
2026-01-30 20:40:57 +08:00

212 lines
6.0 KiB
Python

# -*- coding: utf-8 -*-
"""
CardioAI - Module 2: 模型训练与保存
一次性脚本 - 训练XGBoost模型并保存Pipeline
"""
import pandas as pd
import numpy as np
import joblib
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import xgboost as xgb
# ============================================
# 配置与常量
# ============================================
CODE_ROOT = Path(r"E:\project_ai\claude_project1\aicodes")
DATA_PATH = Path(r"E:\project_ai\claude_project1\data\心血管疾病.xlsx")
MODEL_SAVE_PATH = CODE_ROOT / "module2_predictor" / "cardio_predictor_model.pkl"
# 类别映射字典
CHOLESTEROL_MAP = {1: "正常", 2: "高于正常", 3: "远高于正常"}
GLUC_MAP = {1: "正常", 2: "高于正常", 3: "远高于正常"}
GENDER_MAP = {1: "女性", 2: "男性"}
# ============================================
# 数据加载与清洗函数
# ============================================
def load_data():
"""加载心血管疾病数据"""
print("正在加载数据...")
df = pd.read_excel(DATA_PATH, engine='openpyxl')
print(f"数据加载完成: {df.shape[0]} 条记录, {df.shape[1]} 个特征")
return df
def clean_and_process_data(df):
"""
数据清洗与特征工程
处理步骤:
1. 将age(天)转换为年
2. 计算BMI
3. 删除血压异常值
4. 类别转换
5. 删除id和原始age字段
"""
print("\n开始数据清洗与特征工程...")
df_clean = df.copy()
# 1. 年龄转换: 天 -> 年 (四舍五入)
df_clean['age_years'] = (df_clean['age'] / 365.25).round().astype(int)
# 2. 计算BMI: weight / (height/100)^2
df_clean['bmi'] = df_clean['weight'] / ((df_clean['height'] / 100) ** 2)
# 3. 异常值处理
original_count = len(df_clean)
# 删除舒张压 >= 收缩压的记录
df_clean = df_clean[df_clean['ap_lo'] < df_clean['ap_hi']]
# 删除血压极端异常值 (收缩压范围: [90, 250], 舒张压范围: [60, 150])
df_clean = df_clean[
(df_clean['ap_hi'] >= 90) & (df_clean['ap_hi'] <= 250) &
(df_clean['ap_lo'] >= 60) & (df_clean['ap_lo'] <= 150)
]
removed_count = original_count - len(df_clean)
print(f" - 删除异常值: {removed_count} 条记录")
# 4. 删除id和原始age字段
df_clean = df_clean.drop(columns=['id', 'age'], errors='ignore')
print(f"数据清洗完成: {len(df_clean)} 条有效记录")
return df_clean
# ============================================
# 模型训练函数
# ============================================
def train_model(df):
"""
训练XGBoost分类模型
Args:
df (pd.DataFrame): 清洗后的数据
Returns:
Pipeline: 包含预处理器和模型的完整Pipeline
"""
print("\n开始模型训练...")
# 定义特征列
# 连续特征
numeric_features = ['age_years', 'height', 'weight', 'bmi', 'ap_hi', 'ap_lo']
# 分类特征
categorical_features = ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']
# 特征与目标变量
X = df[numeric_features + categorical_features]
y = df['cardio']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
print(f" - 训练集: {X_train.shape[0]} 条记录")
print(f" - 测试集: {X_test.shape[0]} 条记录")
# 构建预处理器
preprocessor = ColumnTransformer(
transformers=[
('num', StandardScaler(), numeric_features),
('cat', OneHotEncoder(drop='first', handle_unknown='ignore'), categorical_features)
],
remainder='passthrough'
)
# 构建完整Pipeline
pipeline = Pipeline(steps=[
('preprocessor', preprocessor),
('classifier', xgb.XGBClassifier(
n_estimators=100,
max_depth=6,
learning_rate=0.1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42,
eval_metric='logloss',
use_label_encoder=False
))
])
# 训练模型
print(" - 正在训练 XGBoost 模型...")
pipeline.fit(X_train, y_train)
# 评估模型
print("\n模型评估:")
y_pred = pipeline.predict(X_test)
y_pred_proba = pipeline.predict_proba(X_test)[:, 1]
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=['无疾病', '有疾病']))
print("\n混淆矩阵:")
cm = confusion_matrix(y_test, y_pred)
print(f" 真阴性(TN): {cm[0,0]} 假阳性(FP): {cm[0,1]}")
print(f" 假阴性(FN): {cm[1,0]} 真阳性(TP): {cm[1,1]}")
auc_score = roc_auc_score(y_test, y_pred_proba)
print(f"\nAUC-ROC: {auc_score:.4f}")
return pipeline
# ============================================
# 模型保存函数
# ============================================
def save_model(pipeline, save_path):
"""
保存训练好的模型
Args:
pipeline (Pipeline): 训练好的Pipeline
save_path (Path): 模型保存路径
"""
print(f"\n正在保存模型到: {save_path}")
joblib.dump(pipeline, save_path)
print("模型保存完成!")
# ============================================
# 主程序
# ============================================
def main():
"""主程序入口"""
print("=" * 60)
print("CardioAI - Module 2: XGBoost模型训练与保存")
print("=" * 60)
# 1. 加载数据
df = load_data()
# 2. 数据清洗与特征工程
df_clean = clean_and_process_data(df)
# 3. 训练模型
pipeline = train_model(df_clean)
# 4. 保存模型
save_model(pipeline, MODEL_SAVE_PATH)
print("\n" + "=" * 60)
print("训练完成! 模型已保存至:")
print(f" {MODEL_SAVE_PATH}")
print("=" * 60)
if __name__ == "__main__":
main()