Add CardioAI project with usage instructions
- Add comprehensive README.md with setup and usage instructions - Add .env.example template (sanitized, no real API keys) - Add root-level .gitignore to exclude .env and generated files - Add all project modules (dashboard, predictor) - Add data file and requirements.txt
This commit is contained in:
11
CardioAI/.env.example
Normal file
11
CardioAI/.env.example
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# DeepSeek API 配置
|
||||||
|
BASE_URL=https://api.deepseek.com/v1
|
||||||
|
DEEPSEEK_API_KEY=your_api_key_here
|
||||||
|
MODEL_NAME=deepseek-reasoner
|
||||||
|
|
||||||
|
# Flask 配置
|
||||||
|
FLASK_ENV=development
|
||||||
|
FLASK_DEBUG=True
|
||||||
|
|
||||||
|
# 数据文件路径
|
||||||
|
DATA_PATH=./data/心血管疾病.xlsx
|
||||||
13
CardioAI/.gitignore
vendored
Normal file
13
CardioAI/.gitignore
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
*.pyc
|
||||||
|
__pycache__/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Model files (generated)
|
||||||
|
*.pkl
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
102
CardioAI/README.md
Normal file
102
CardioAI/README.md
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# CardioAI - 心血管疾病智能辅助系统
|
||||||
|
|
||||||
|
多模块应用,集成了数据可视化、机器学习预测和AI问答功能。
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
CardioAI/
|
||||||
|
├── data/ # 数据文件目录
|
||||||
|
│ └── 心血管疾病.xlsx # 心血管疾病数据集
|
||||||
|
├── module1_dashboard/ # 模块1: Streamlit 交互式仪表盘
|
||||||
|
│ └── cardio_dashboard.py
|
||||||
|
├── module2_predictor/ # 模块2: XGBoost 风险预测模型
|
||||||
|
│ ├── train_and_save.py # 模型训练脚本
|
||||||
|
│ ├── cardio_predictor_model.pkl # 训练好的模型
|
||||||
|
│ ├── app.py # Flask API 服务
|
||||||
|
│ └── templates/
|
||||||
|
│ └── index.html # 预测前端页面
|
||||||
|
├── requirements.txt # 项目依赖
|
||||||
|
├── .env.example # 环境变量模板
|
||||||
|
└── .gitignore # Git 忽略文件
|
||||||
|
```
|
||||||
|
|
||||||
|
## 环境配置
|
||||||
|
|
||||||
|
### 1. 创建 conda 环境
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n cardioenv python=3.10
|
||||||
|
conda activate cardioenv
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd F:\My_Git_Project\CardioAI
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 配置环境变量
|
||||||
|
|
||||||
|
复制 `.env.example` 为 `.env`,并填入您的 API Key:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
copy .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
编辑 `.env` 文件,填入您的 DeepSeek API Key。
|
||||||
|
|
||||||
|
## 模块说明
|
||||||
|
|
||||||
|
### 模块1: 交互式仪表盘 (Streamlit)
|
||||||
|
|
||||||
|
心血管数据的交互式可视化界面。
|
||||||
|
|
||||||
|
**启动命令:**
|
||||||
|
```bash
|
||||||
|
cd F:\My_Git_Project\CardioAI
|
||||||
|
streamlit run module1_dashboard/cardio_dashboard.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**功能特性:**
|
||||||
|
- 年龄范围筛选
|
||||||
|
- 性别和心血管疾病状态筛选
|
||||||
|
- 统计数据展示
|
||||||
|
- BMI分布可视化
|
||||||
|
|
||||||
|
### 模块2: 心血管风险预测模型 (Flask + XGBoost)
|
||||||
|
|
||||||
|
基于 XGBoost 的心血管疾病风险预测 API。
|
||||||
|
|
||||||
|
**训练模型:**
|
||||||
|
```bash
|
||||||
|
cd F:\My_Git_Project\CardioAI
|
||||||
|
python module2_predictor/train_and_save.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**启动预测服务:**
|
||||||
|
```bash
|
||||||
|
cd F:\My_Git_Project\CardioAI\module2_predictor
|
||||||
|
set FLASK_APP=app.py
|
||||||
|
flask run --host=0.0.0.0 --port=5000
|
||||||
|
```
|
||||||
|
|
||||||
|
**API 接口:**
|
||||||
|
- `POST /predict_cardio` - 提交11个特征值,返回预测概率和结果
|
||||||
|
|
||||||
|
## 依赖说明
|
||||||
|
|
||||||
|
- pandas, openpyxl - 数据处理
|
||||||
|
- numpy, scikit-learn - 数值计算
|
||||||
|
- xgboost, joblib - 机器学习
|
||||||
|
- streamlit, plotly - 数据可视化
|
||||||
|
- Flask - Web 服务
|
||||||
|
- python-dotenv - 环境变量
|
||||||
|
- langchain-openai, dashscope, requests - AI 集成
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. 数据文件路径可通过 `.env` 中的 `DATA_PATH` 配置
|
||||||
|
2. 确保 `.env` 文件不提交到版本库(已加入 .gitignore)
|
||||||
|
3. 使用前请确保已正确配置 DeepSeek API Key
|
||||||
BIN
CardioAI/data/心血管疾病.xlsx
Normal file
BIN
CardioAI/data/心血管疾病.xlsx
Normal file
Binary file not shown.
684
CardioAI/module1_dashboard/cardio_dashboard.py
Normal file
684
CardioAI/module1_dashboard/cardio_dashboard.py
Normal file
@@ -0,0 +1,684 @@
|
|||||||
|
"""
|
||||||
|
CardioAI 模块1: 交互式仪表盘
|
||||||
|
心血管疾病数据可视化系统 - 美化版
|
||||||
|
"""
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import plotly.express as px
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# ==================== 页面配置 ====================
|
||||||
|
st.set_page_config(
|
||||||
|
page_title="CardioAI 心血管疾病分析",
|
||||||
|
page_icon="❤️",
|
||||||
|
layout="wide",
|
||||||
|
initial_sidebar_state="expanded"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 自定义CSS样式 ====================
|
||||||
|
st.markdown("""
|
||||||
|
<style>
|
||||||
|
/* 全局样式 */
|
||||||
|
.main {
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
padding: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 标题样式 */
|
||||||
|
.main-title {
|
||||||
|
text-align: center;
|
||||||
|
background: linear-gradient(90deg, #ff6b6b, #feca57);
|
||||||
|
-webkit-background-clip: text;
|
||||||
|
-webkit-text-fill-color: transparent;
|
||||||
|
font-size: 3rem !important;
|
||||||
|
font-weight: 800 !important;
|
||||||
|
margin-bottom: 0.5rem !important;
|
||||||
|
text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.sub-title {
|
||||||
|
text-align: center;
|
||||||
|
color: #666;
|
||||||
|
font-size: 1.1rem;
|
||||||
|
margin-bottom: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 侧边栏样式 */
|
||||||
|
.sidebar .sidebar-content {
|
||||||
|
background: linear-gradient(180deg, #1e3c72 0%, #2a5298 100%);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 指标卡片样式 */
|
||||||
|
.metric-card {
|
||||||
|
background: white;
|
||||||
|
border-radius: 16px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
box-shadow: 0 10px 40px rgba(0,0,0,0.1);
|
||||||
|
transition: transform 0.3s ease, box-shadow 0.3s ease;
|
||||||
|
border: 1px solid rgba(255,255,255,0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-card:hover {
|
||||||
|
transform: translateY(-5px);
|
||||||
|
box-shadow: 0 20px 60px rgba(0,0,0,0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-value {
|
||||||
|
font-size: 2.5rem;
|
||||||
|
font-weight: 700;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
-webkit-background-clip: text;
|
||||||
|
-webkit-text-fill-color: transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
.metric-label {
|
||||||
|
color: #666;
|
||||||
|
font-size: 0.9rem;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 图表容器样式 */
|
||||||
|
.chart-container {
|
||||||
|
background: white;
|
||||||
|
border-radius: 16px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 区块标题样式 */
|
||||||
|
.section-title {
|
||||||
|
background: linear-gradient(90deg, #667eea, #764ba2);
|
||||||
|
color: white !important;
|
||||||
|
padding: 0.8rem 1.5rem;
|
||||||
|
border-radius: 12px;
|
||||||
|
font-size: 1.2rem;
|
||||||
|
font-weight: 600;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 数据表格样式 */
|
||||||
|
.dataframe {
|
||||||
|
font-size: 0.85rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dataframe th {
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
color: white;
|
||||||
|
font-weight: 600;
|
||||||
|
padding: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dataframe td {
|
||||||
|
padding: 10px;
|
||||||
|
border-bottom: 1px solid #eee;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dataframe tr:hover {
|
||||||
|
background-color: #f5f7fa;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 状态标签 */
|
||||||
|
.status-safe {
|
||||||
|
background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%);
|
||||||
|
color: white;
|
||||||
|
padding: 0.3rem 0.8rem;
|
||||||
|
border-radius: 20px;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-risk {
|
||||||
|
background: linear-gradient(135deg, #ff416c 0%, #ff4b2b 100%);
|
||||||
|
color: white;
|
||||||
|
padding: 0.3rem 0.8rem;
|
||||||
|
border-radius: 20px;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 进度条样式 */
|
||||||
|
.stProgress > div > div {
|
||||||
|
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
|
||||||
|
border-radius: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 侧边栏筛选器样式 */
|
||||||
|
.sidebar-filter {
|
||||||
|
background: rgba(255,255,255,0.1);
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 12px;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 分隔线样式 */
|
||||||
|
hr {
|
||||||
|
border: none;
|
||||||
|
height: 2px;
|
||||||
|
background: linear-gradient(90deg, transparent, #667eea, transparent);
|
||||||
|
margin: 2rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 信息卡片 */
|
||||||
|
.info-box {
|
||||||
|
background: linear-gradient(135deg, #667eea15 0%, #764ba215 100%);
|
||||||
|
border-left: 4px solid #667eea;
|
||||||
|
padding: 1rem 1.5rem;
|
||||||
|
border-radius: 0 12px 12px 0;
|
||||||
|
margin: 1rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 按钮样式 */
|
||||||
|
.stButton > button {
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
border-radius: 25px;
|
||||||
|
padding: 0.6rem 2rem;
|
||||||
|
font-weight: 600;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.stButton > button:hover {
|
||||||
|
transform: scale(1.05);
|
||||||
|
box-shadow: 0 8px 25px rgba(102, 126, 234, 0.4);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# ==================== 常量定义 ====================
|
||||||
|
CODE_ROOT = Path(r"F:\My_Git_Project\CardioAI")
|
||||||
|
DATA_PATH = CODE_ROOT / "data" / "心血管疾病.xlsx"
|
||||||
|
|
||||||
|
# 配色方案
|
||||||
|
COLORS = {
|
||||||
|
'primary': ['#667eea', '#764ba2', '#f093fb', '#f5576c'],
|
||||||
|
'safe': '#2ecc71',
|
||||||
|
'risk': '#e74c3c',
|
||||||
|
'gradient': ['#667eea', '#764ba2'],
|
||||||
|
'bmi': ['#3498db', '#2ecc71', '#f39c12', '#e74c3c']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 数据加载函数 ====================
|
||||||
|
@st.cache_data(show_spinner=False)
|
||||||
|
def load_data(file_path: Path) -> pd.DataFrame:
|
||||||
|
"""加载数据,支持Excel格式"""
|
||||||
|
try:
|
||||||
|
df = pd.read_excel(file_path, engine='openpyxl')
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"数据加载失败: {e}")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_data(show_spinner=False)
|
||||||
|
def clean_and_engineer_features(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""数据清洗和特征工程"""
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# 1. 特征工程
|
||||||
|
df['age_years'] = (df['age'] / 365).round().astype(int)
|
||||||
|
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
|
||||||
|
|
||||||
|
# 2. 异常值处理
|
||||||
|
df = df[df['ap_lo'] < df['ap_hi']]
|
||||||
|
df = df[(df['ap_hi'] >= 90) & (df['ap_hi'] <= 250)]
|
||||||
|
df = df[(df['ap_lo'] >= 60) & (df['ap_lo'] <= 150)]
|
||||||
|
|
||||||
|
# 3. 类别转换
|
||||||
|
cholesterol_map = {1: '正常', 2: '偏高', 3: '非常高'}
|
||||||
|
gluc_map = {1: '正常', 2: '偏高', 3: '非常高'}
|
||||||
|
df['cholesterol_cat'] = df['cholesterol'].map(cholesterol_map)
|
||||||
|
df['gluc_cat'] = df['gluc'].map(gluc_map)
|
||||||
|
|
||||||
|
# 4. BMI分类
|
||||||
|
def categorize_bmi(bmi):
|
||||||
|
if bmi < 18.5:
|
||||||
|
return '体重过低'
|
||||||
|
elif bmi < 25:
|
||||||
|
return '体重正常'
|
||||||
|
elif bmi < 30:
|
||||||
|
return '超重'
|
||||||
|
else:
|
||||||
|
return '肥胖'
|
||||||
|
|
||||||
|
df['bmi_category'] = df['bmi'].apply(categorize_bmi)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== UI组件 ====================
|
||||||
|
def render_header():
|
||||||
|
"""渲染页面头部"""
|
||||||
|
st.markdown('<h1 class="main-title">❤️ CardioAI</h1>', unsafe_allow_html=True)
|
||||||
|
st.markdown('<p class="sub-title">心血管疾病智能分析系统 | 数据驱动的健康洞察</p>', unsafe_allow_html=True)
|
||||||
|
st.markdown("<hr>", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sidebar(df: pd.DataFrame) -> dict:
|
||||||
|
"""创建美观的侧边栏"""
|
||||||
|
with st.sidebar:
|
||||||
|
st.markdown("### 🎛️ 数据筛选器")
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# 年龄范围
|
||||||
|
with st.container():
|
||||||
|
st.markdown("**📅 年龄范围**")
|
||||||
|
age_range = st.slider(
|
||||||
|
"",
|
||||||
|
min_value=int(df['age_years'].min()),
|
||||||
|
max_value=int(df['age_years'].max()),
|
||||||
|
value=(int(df['age_years'].min()), int(df['age_years'].max())),
|
||||||
|
key="age_slider"
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown("<br>", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# 性别选择
|
||||||
|
with st.container():
|
||||||
|
st.markdown("**👤 性别**")
|
||||||
|
gender_options = st.multiselect(
|
||||||
|
"",
|
||||||
|
options=[1, 2],
|
||||||
|
default=[1, 2],
|
||||||
|
format_func=lambda x: "👩 女性" if x == 1 else "👨 男性",
|
||||||
|
key="gender_select"
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown("<br>", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# 心血管疾病状态
|
||||||
|
with st.container():
|
||||||
|
st.markdown("**🏥 心血管健康状态**")
|
||||||
|
cardio_options = st.multiselect(
|
||||||
|
"",
|
||||||
|
options=[0, 1],
|
||||||
|
default=[0, 1],
|
||||||
|
format_func=lambda x: "✅ 健康" if x == 0 else "⚠️ 有风险",
|
||||||
|
key="cardio_select"
|
||||||
|
)
|
||||||
|
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# 数据统计
|
||||||
|
st.markdown("### 📊 数据概览")
|
||||||
|
st.metric("总记录数", f"{len(df):,}")
|
||||||
|
st.metric("平均BMI", f"{df['bmi'].mean():.1f}")
|
||||||
|
st.metric("平均年龄", f"{df['age_years'].mean():.1f} 岁")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'age_range': age_range,
|
||||||
|
'gender': gender_options,
|
||||||
|
'cardio': cardio_options
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def apply_filters(df: pd.DataFrame, filters: dict) -> pd.DataFrame:
|
||||||
|
"""应用筛选条件"""
|
||||||
|
return df[
|
||||||
|
(df['age_years'] >= filters['age_range'][0]) &
|
||||||
|
(df['age_years'] <= filters['age_range'][1]) &
|
||||||
|
(df['gender'].isin(filters['gender'])) &
|
||||||
|
(df['cardio'].isin(filters['cardio']))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def render_metrics(filtered_df: pd.DataFrame, total_count: int):
|
||||||
|
"""渲染指标卡片"""
|
||||||
|
st.markdown('<div class="section-title">📊 关键指标</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
col1, col2, col3, col4 = st.columns(4)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
st.markdown(f"""
|
||||||
|
<div class="metric-card">
|
||||||
|
<div class="metric-label">📋 筛选记录数</div>
|
||||||
|
<div class="metric-value">{len(filtered_df):,}</div>
|
||||||
|
<div style="color: #999; font-size: 0.8rem;">占比 {(len(filtered_df)/total_count*100):.1f}%</div>
|
||||||
|
</div>
|
||||||
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
risk_rate = (filtered_df['cardio'].sum() / len(filtered_df) * 100) if len(filtered_df) > 0 else 0
|
||||||
|
st.markdown(f"""
|
||||||
|
<div class="metric-card">
|
||||||
|
<div class="metric-label">⚠️ 风险率</div>
|
||||||
|
<div class="metric-value">{risk_rate:.1f}%</div>
|
||||||
|
<div style="color: #999; font-size: 0.8rem;">心血管疾病患者占比</div>
|
||||||
|
</div>
|
||||||
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
with col3:
|
||||||
|
avg_age = filtered_df['age_years'].mean() if len(filtered_df) > 0 else 0
|
||||||
|
st.markdown(f"""
|
||||||
|
<div class="metric-card">
|
||||||
|
<div class="metric-label">📅 平均年龄</div>
|
||||||
|
<div class="metric-value">{avg_age:.1f}</div>
|
||||||
|
<div style="color: #999; font-size: 0.8rem;">岁</div>
|
||||||
|
</div>
|
||||||
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
avg_bmi = filtered_df['bmi'].mean() if len(filtered_df) > 0 else 0
|
||||||
|
st.markdown(f"""
|
||||||
|
<div class="metric-card">
|
||||||
|
<div class="metric-label">⚖️ 平均BMI</div>
|
||||||
|
<div class="metric-value">{avg_bmi:.1f}</div>
|
||||||
|
<div style="color: #999; font-size: 0.8rem;">{get_bmi_status(avg_bmi)}</div>
|
||||||
|
</div>
|
||||||
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bmi_status(bmi: float) -> str:
|
||||||
|
"""获取BMI状态"""
|
||||||
|
if bmi < 18.5:
|
||||||
|
return "体重过低"
|
||||||
|
elif bmi < 25:
|
||||||
|
return "体重正常"
|
||||||
|
elif bmi < 30:
|
||||||
|
return "超重"
|
||||||
|
return "肥胖"
|
||||||
|
|
||||||
|
|
||||||
|
def plot_age_distribution(df: pd.DataFrame):
|
||||||
|
"""年龄分布图 - 美化版"""
|
||||||
|
fig = px.histogram(
|
||||||
|
df,
|
||||||
|
x='age_years',
|
||||||
|
color='cardio',
|
||||||
|
nbins=30,
|
||||||
|
title="年龄分布趋势",
|
||||||
|
labels={'age_years': '年龄', 'count': '人数'},
|
||||||
|
color_discrete_map={0: '#2ecc71', 1: '#e74c3c'},
|
||||||
|
barmode='overlay',
|
||||||
|
opacity=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
template='plotly_white',
|
||||||
|
title_font_size=18,
|
||||||
|
title_x=0.5,
|
||||||
|
legend_title_text="",
|
||||||
|
legend=dict(
|
||||||
|
orientation="h",
|
||||||
|
yanchor="bottom",
|
||||||
|
y=1.02,
|
||||||
|
xanchor="right",
|
||||||
|
x=1
|
||||||
|
),
|
||||||
|
plot_bgcolor='rgba(0,0,0,0)',
|
||||||
|
paper_bgcolor='rgba(0,0,0,0)',
|
||||||
|
font=dict(family="Arial, sans-serif", size=12),
|
||||||
|
margin=dict(t=60, b=40, l=40, r=40)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.for_each_trace(lambda t: t.update(
|
||||||
|
name="✅ 健康" if t.name == "0" else "⚠️ 有风险",
|
||||||
|
marker_line_width=1,
|
||||||
|
marker_line_color='white'
|
||||||
|
))
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_bmi_distribution(df: pd.DataFrame):
|
||||||
|
"""BMI分布饼图"""
|
||||||
|
bmi_counts = df['bmi_category'].value_counts().reindex(['体重过低', '体重正常', '超重', '肥胖'])
|
||||||
|
|
||||||
|
fig = go.Figure(data=[go.Pie(
|
||||||
|
labels=bmi_counts.index,
|
||||||
|
values=bmi_counts.values,
|
||||||
|
hole=0.5,
|
||||||
|
marker=dict(
|
||||||
|
colors=COLORS['bmi'],
|
||||||
|
line=dict(color='white', width=2)
|
||||||
|
),
|
||||||
|
textinfo='label+percent',
|
||||||
|
textposition='outside',
|
||||||
|
textfont=dict(size=12)
|
||||||
|
)])
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=dict(text="BMI分布", font=dict(size=18), x=0.5),
|
||||||
|
template='plotly_white',
|
||||||
|
showlegend=False,
|
||||||
|
plot_bgcolor='rgba(0,0,0,0)',
|
||||||
|
paper_bgcolor='rgba(0,0,0,0)',
|
||||||
|
margin=dict(t=60, b=40, l=40, r=40),
|
||||||
|
annotations=[dict(text='BMI', x=0.5, y=0.5, font_size=20, showarrow=False)]
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_bmi_cardio_relation(df: pd.DataFrame):
|
||||||
|
"""BMI与心血管疾病关系 - 堆叠柱状图"""
|
||||||
|
bmi_cardio = df.groupby(['bmi_category', 'cardio']).size().unstack(fill_value=0)
|
||||||
|
bmi_order = ['体重过低', '体重正常', '超重', '肥胖']
|
||||||
|
bmi_cardio = bmi_cardio.reindex(bmi_order)
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
fig.add_trace(go.Bar(
|
||||||
|
name='✅ 健康',
|
||||||
|
x=bmi_cardio.index,
|
||||||
|
y=bmi_cardio[0],
|
||||||
|
marker_color='#2ecc71',
|
||||||
|
marker_line=dict(color='white', width=1)
|
||||||
|
))
|
||||||
|
|
||||||
|
fig.add_trace(go.Bar(
|
||||||
|
name='⚠️ 有风险',
|
||||||
|
x=bmi_cardio.index,
|
||||||
|
y=bmi_cardio[1],
|
||||||
|
marker_color='#e74c3c',
|
||||||
|
marker_line=dict(color='white', width=1)
|
||||||
|
))
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=dict(text="BMI与心血管疾病关联分析", font=dict(size=18), x=0.5),
|
||||||
|
xaxis_title="BMI类别",
|
||||||
|
yaxis_title="人数",
|
||||||
|
barmode='stack',
|
||||||
|
template='plotly_white',
|
||||||
|
legend=dict(
|
||||||
|
orientation="h",
|
||||||
|
yanchor="bottom",
|
||||||
|
y=1.02,
|
||||||
|
xanchor="right",
|
||||||
|
x=1
|
||||||
|
),
|
||||||
|
plot_bgcolor='rgba(0,0,0,0)',
|
||||||
|
paper_bgcolor='rgba(0,0,0,0)',
|
||||||
|
margin=dict(t=60, b=40, l=40, r=40)
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_blood_pressure_scatter(df: pd.DataFrame):
|
||||||
|
"""血压散点图"""
|
||||||
|
sample_df = df.sample(min(2000, len(df))) # 采样避免过多点
|
||||||
|
|
||||||
|
fig = px.scatter(
|
||||||
|
sample_df,
|
||||||
|
x='ap_lo',
|
||||||
|
y='ap_hi',
|
||||||
|
color='cardio',
|
||||||
|
color_discrete_map={0: '#2ecc71', 1: '#e74c3c'},
|
||||||
|
opacity=0.6,
|
||||||
|
title="血压分布散点图",
|
||||||
|
labels={'ap_lo': '舒张压 (mmHg)', 'ap_hi': '收缩压 (mmHg)'}
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
template='plotly_white',
|
||||||
|
title_font_size=18,
|
||||||
|
title_x=0.5,
|
||||||
|
legend_title_text="",
|
||||||
|
legend=dict(
|
||||||
|
orientation="h",
|
||||||
|
yanchor="bottom",
|
||||||
|
y=1.02,
|
||||||
|
xanchor="right",
|
||||||
|
x=1
|
||||||
|
),
|
||||||
|
plot_bgcolor='rgba(0,0,0,0)',
|
||||||
|
paper_bgcolor='rgba(0,0,0,0)',
|
||||||
|
margin=dict(t=60, b=40, l=40, r=40)
|
||||||
|
)
|
||||||
|
|
||||||
|
fig.for_each_trace(lambda t: t.update(
|
||||||
|
name="✅ 健康" if t.name == "0" else "⚠️ 有风险"
|
||||||
|
))
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_cholesterol_analysis(df: pd.DataFrame):
|
||||||
|
"""胆固醇与心血管疾病关系"""
|
||||||
|
chol_cardio = df.groupby(['cholesterol_cat', 'cardio']).size().unstack(fill_value=0)
|
||||||
|
chol_order = ['正常', '偏高', '非常高']
|
||||||
|
chol_cardio = chol_cardio.reindex(chol_order)
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
fig.add_trace(go.Bar(
|
||||||
|
name='✅ 健康',
|
||||||
|
x=chol_cardio.index,
|
||||||
|
y=chol_cardio[0] if 0 in chol_cardio.columns else [0, 0, 0],
|
||||||
|
marker_color='#2ecc71'
|
||||||
|
))
|
||||||
|
|
||||||
|
fig.add_trace(go.Bar(
|
||||||
|
name='⚠️ 有风险',
|
||||||
|
x=chol_cardio.index,
|
||||||
|
y=chol_cardio[1] if 1 in chol_cardio.columns else [0, 0, 0],
|
||||||
|
marker_color='#e74c3c'
|
||||||
|
))
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=dict(text="胆固醇水平与心血管疾病", font=dict(size=18), x=0.5),
|
||||||
|
xaxis_title="胆固醇水平",
|
||||||
|
yaxis_title="人数",
|
||||||
|
barmode='group',
|
||||||
|
template='plotly_white',
|
||||||
|
legend=dict(
|
||||||
|
orientation="h",
|
||||||
|
yanchor="bottom",
|
||||||
|
y=1.02,
|
||||||
|
xanchor="right",
|
||||||
|
x=1
|
||||||
|
),
|
||||||
|
plot_bgcolor='rgba(0,0,0,0)',
|
||||||
|
paper_bgcolor='rgba(0,0,0,0)',
|
||||||
|
margin=dict(t=60, b=40, l=40, r=40)
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def render_data_table(df: pd.DataFrame):
|
||||||
|
"""渲染数据表格"""
|
||||||
|
st.markdown('<div class="section-title">📋 数据明细</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
display_cols = ['id', 'age_years', 'gender', 'height', 'weight', 'bmi', 'bmi_category',
|
||||||
|
'ap_hi', 'ap_lo', 'cholesterol_cat', 'gluc_cat', 'cardio']
|
||||||
|
|
||||||
|
display_df = df[display_cols].copy()
|
||||||
|
display_df['gender'] = display_df['gender'].map({1: '女性', 2: '男性'})
|
||||||
|
display_df['cardio'] = display_df['cardio'].map({0: '✅ 健康', 1: '⚠️ 有风险'})
|
||||||
|
|
||||||
|
display_df.columns = ['ID', '年龄', '性别', '身高(cm)', '体重(kg)', 'BMI', 'BMI分类',
|
||||||
|
'收缩压', '舒张压', '胆固醇', '血糖', '心血管状态']
|
||||||
|
|
||||||
|
st.dataframe(
|
||||||
|
display_df,
|
||||||
|
use_container_width=True,
|
||||||
|
height=400,
|
||||||
|
column_config={
|
||||||
|
"ID": st.column_config.NumberColumn(width="small"),
|
||||||
|
"年龄": st.column_config.NumberColumn(width="small"),
|
||||||
|
"心血管状态": st.column_config.TextColumn(width="medium")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
render_header()
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
with st.spinner("正在加载数据..."):
|
||||||
|
raw_df = load_data(DATA_PATH)
|
||||||
|
|
||||||
|
if raw_df.empty:
|
||||||
|
st.error("❌ 数据加载失败,请检查数据文件路径!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 数据处理
|
||||||
|
with st.spinner("正在进行数据清洗..."):
|
||||||
|
df = clean_and_engineer_features(raw_df)
|
||||||
|
|
||||||
|
total_count = len(df)
|
||||||
|
|
||||||
|
# 侧边栏筛选
|
||||||
|
filters = create_sidebar(df)
|
||||||
|
|
||||||
|
# 应用筛选
|
||||||
|
filtered_df = apply_filters(df, filters)
|
||||||
|
|
||||||
|
# 指标卡片
|
||||||
|
render_metrics(filtered_df, total_count)
|
||||||
|
|
||||||
|
st.markdown("<br>", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# 图表区域
|
||||||
|
st.markdown('<div class="section-title">📈 可视化分析</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# 第一行图表
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
with col1:
|
||||||
|
fig1 = plot_age_distribution(filtered_df)
|
||||||
|
st.plotly_chart(fig1, use_container_width=True, key="age_chart")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
fig2 = plot_bmi_distribution(filtered_df)
|
||||||
|
st.plotly_chart(fig2, use_container_width=True, key="bmi_pie")
|
||||||
|
|
||||||
|
# 第二行图表
|
||||||
|
col3, col4 = st.columns(2)
|
||||||
|
with col3:
|
||||||
|
fig3 = plot_bmi_cardio_relation(filtered_df)
|
||||||
|
st.plotly_chart(fig3, use_container_width=True, key="bmi_cardio")
|
||||||
|
|
||||||
|
with col4:
|
||||||
|
fig4 = plot_cholesterol_analysis(filtered_df)
|
||||||
|
st.plotly_chart(fig4, use_container_width=True, key="chol_chart")
|
||||||
|
|
||||||
|
# 第三行图表
|
||||||
|
col5, _ = st.columns([1, 1])
|
||||||
|
with col5:
|
||||||
|
fig5 = plot_blood_pressure_scatter(filtered_df)
|
||||||
|
st.plotly_chart(fig5, use_container_width=True, key="bp_scatter")
|
||||||
|
|
||||||
|
st.markdown("<br>", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# 数据表格
|
||||||
|
render_data_table(filtered_df)
|
||||||
|
|
||||||
|
# 页脚
|
||||||
|
st.markdown("<hr>", unsafe_allow_html=True)
|
||||||
|
st.markdown(
|
||||||
|
"<p style='text-align: center; color: #999; font-size: 0.9rem;'>❤️ CardioAI © 2024 | 心血管疾病智能分析系统</p>",
|
||||||
|
unsafe_allow_html=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
187
CardioAI/module2_predictor/app.py
Normal file
187
CardioAI/module2_predictor/app.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""
|
||||||
|
CardioAI 模块2: Flask API服务
|
||||||
|
心血管疾病风险预测 - 后端接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from flask import Flask, request, jsonify, render_template
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# ==================== 常量定义 ====================
|
||||||
|
CODE_ROOT = Path(r"F:\My_Git_Project\CardioAI")
|
||||||
|
MODEL_PATH = CODE_ROOT / "module2_predictor" / "cardio_predictor_model.pkl"
|
||||||
|
|
||||||
|
# ==================== Flask应用 ====================
|
||||||
|
app = Flask(__name__,
|
||||||
|
template_folder='templates',
|
||||||
|
static_folder='static')
|
||||||
|
|
||||||
|
# 全局变量存储模型
|
||||||
|
model = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
"""加载模型"""
|
||||||
|
global model
|
||||||
|
if model is None:
|
||||||
|
print("📂 正在加载模型...")
|
||||||
|
model = joblib.load(MODEL_PATH)
|
||||||
|
print("✅ 模型加载成功!")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 路由定义 ====================
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
"""渲染前端页面"""
|
||||||
|
return render_template('index.html')
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/predict_cardio', methods=['POST'])
|
||||||
|
def predict_cardio():
|
||||||
|
"""
|
||||||
|
心血管疾病风险预测接口
|
||||||
|
接收11个原始特征值的JSON POST请求
|
||||||
|
返回预测概率和结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取JSON数据
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': '未收到数据'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 定义特征列顺序(与训练时一致)
|
||||||
|
feature_names = [
|
||||||
|
'age_years', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
||||||
|
'cholesterol', 'gluc', 'smoke', 'alco', 'active'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 从请求中提取特征值
|
||||||
|
features = []
|
||||||
|
missing_fields = []
|
||||||
|
|
||||||
|
for col in feature_names:
|
||||||
|
if col in data:
|
||||||
|
features.append(float(data[col]))
|
||||||
|
else:
|
||||||
|
missing_fields.append(col)
|
||||||
|
features.append(0.0) # 默认值
|
||||||
|
|
||||||
|
# 计算BMI: weight / (height/100)^2
|
||||||
|
weight = float(data.get('weight', 0))
|
||||||
|
height = float(data.get('height', 0))
|
||||||
|
if height > 0:
|
||||||
|
bmi = weight / ((height / 100) ** 2)
|
||||||
|
features.append(bmi)
|
||||||
|
else:
|
||||||
|
features.append(0.0)
|
||||||
|
|
||||||
|
if missing_fields:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': f'缺少必要字段: {", ".join(missing_fields)}'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 定义特征列名(与训练时一致)
|
||||||
|
feature_columns = [
|
||||||
|
'age_years', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
||||||
|
'cholesterol', 'gluc', 'smoke', 'alco', 'active', 'bmi'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 转换为DataFrame格式
|
||||||
|
X_input = pd.DataFrame([features], columns=feature_columns)
|
||||||
|
|
||||||
|
# 加载模型(如果尚未加载)
|
||||||
|
predictor = load_model()
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
prediction = int(predictor.predict(X_input)[0])
|
||||||
|
prob_risk = float(predictor.predict_proba(X_input)[0][1])
|
||||||
|
prob_healthy = float(predictor.predict_proba(X_input)[0][0])
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
result = {
|
||||||
|
'success': True,
|
||||||
|
'prediction': prediction,
|
||||||
|
'prediction_label': '有风险' if prediction == 1 else '健康',
|
||||||
|
'probability': {
|
||||||
|
'健康': round(prob_healthy * 100, 2),
|
||||||
|
'有风险': round(prob_risk * 100, 2)
|
||||||
|
},
|
||||||
|
'risk_level': get_risk_level(prob_risk),
|
||||||
|
'recommendation': get_recommendation(prob_risk, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonify(result)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': f'数据格式错误: {str(e)}'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': f'预测失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
|
def get_risk_level(probability: float) -> str:
|
||||||
|
"""根据概率返回风险等级"""
|
||||||
|
if probability < 0.3:
|
||||||
|
return '🟢 低风险'
|
||||||
|
elif probability < 0.5:
|
||||||
|
return '🟡 中低风险'
|
||||||
|
elif probability < 0.7:
|
||||||
|
return '🟠 中高风险'
|
||||||
|
else:
|
||||||
|
return '🔴 高风险'
|
||||||
|
|
||||||
|
|
||||||
|
def get_recommendation(probability: float, data: dict) -> str:
|
||||||
|
"""根据预测结果给出建议"""
|
||||||
|
if probability < 0.3:
|
||||||
|
return '继续保持健康的生活方式,定期体检。'
|
||||||
|
elif probability < 0.5:
|
||||||
|
return '建议适当增加运动,注意饮食均衡。'
|
||||||
|
elif probability < 0.7:
|
||||||
|
return '建议咨询医生,制定健康管理计划。'
|
||||||
|
else:
|
||||||
|
return '⚠️ 风险较高,请尽快就医检查。'
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
"""健康检查接口"""
|
||||||
|
return jsonify({
|
||||||
|
'status': 'healthy',
|
||||||
|
'service': 'CardioAI Cardiovascular Prediction API',
|
||||||
|
'version': '1.0.0'
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 启动应用 ====================
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("❤️ CardioAI 心血管疾病风险预测 API")
|
||||||
|
print("="*60)
|
||||||
|
print(f"📂 模型路径: {MODEL_PATH}")
|
||||||
|
print(f"🌐 启动地址: http://localhost:5001")
|
||||||
|
print("="*60 + "\n")
|
||||||
|
|
||||||
|
# 预加载模型
|
||||||
|
load_model()
|
||||||
|
|
||||||
|
# 启动Flask应用
|
||||||
|
app.run(
|
||||||
|
host='0.0.0.0',
|
||||||
|
port=5001,
|
||||||
|
debug=True
|
||||||
|
)
|
||||||
1060
CardioAI/module2_predictor/templates/index.html
Normal file
1060
CardioAI/module2_predictor/templates/index.html
Normal file
File diff suppressed because it is too large
Load Diff
199
CardioAI/module2_predictor/train_and_save.py
Normal file
199
CardioAI/module2_predictor/train_and_save.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
CardioAI 模块2: 模型训练脚本
|
||||||
|
心血管疾病风险预测模型 - 训练与保存
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import joblib
|
||||||
|
from pathlib import Path
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
from sklearn.compose import ColumnTransformer
|
||||||
|
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score
|
||||||
|
from xgboost import XGBClassifier
|
||||||
|
|
||||||
|
# ==================== 常量定义 ====================
|
||||||
|
CODE_ROOT = Path(r"F:\My_Git_Project\CardioAI")
|
||||||
|
DATA_PATH = CODE_ROOT / "data" / "心血管疾病.xlsx"
|
||||||
|
MODEL_PATH = CODE_ROOT / "module2_predictor" / "cardio_predictor_model.pkl"
|
||||||
|
|
||||||
|
# 特征列定义
|
||||||
|
CONTINUOUS_FEATURES = ['age', 'height', 'weight', 'ap_hi', 'ap_lo', 'bmi']
|
||||||
|
CATEGORICAL_FEATURES = ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_clean_data(file_path: Path) -> pd.DataFrame:
|
||||||
|
"""加载并清洗数据"""
|
||||||
|
print(f"📂 正在加载数据: {file_path}")
|
||||||
|
|
||||||
|
# 加载Excel数据
|
||||||
|
df = pd.read_excel(file_path, engine='openpyxl')
|
||||||
|
print(f"✅ 数据加载成功,共 {len(df)} 条记录")
|
||||||
|
|
||||||
|
# 复制数据
|
||||||
|
df = df.copy()
|
||||||
|
|
||||||
|
# 特征工程: age(天) -> age_years
|
||||||
|
df['age_years'] = (df['age'] / 365).round().astype(int)
|
||||||
|
|
||||||
|
# 计算BMI
|
||||||
|
df['bmi'] = df['weight'] / ((df['height'] / 100) ** 2)
|
||||||
|
|
||||||
|
# 异常值处理: 删除舒张压 >= 收缩压的记录
|
||||||
|
initial_count = len(df)
|
||||||
|
df = df[df['ap_lo'] < df['ap_hi']]
|
||||||
|
print(f"🗑️ 删除舒张压>=收缩压的记录: {initial_count - len(df)} 条")
|
||||||
|
|
||||||
|
# 删除血压极端异常值
|
||||||
|
# 收缩压 ∈ [90, 250]
|
||||||
|
initial_count = len(df)
|
||||||
|
df = df[(df['ap_hi'] >= 90) & (df['ap_hi'] <= 250)]
|
||||||
|
removed_hy = initial_count - len(df)
|
||||||
|
|
||||||
|
# 舒张压 ∈ [60, 150]
|
||||||
|
initial_count = len(df)
|
||||||
|
df = df[(df['ap_lo'] >= 60) & (df['ap_lo'] <= 150)]
|
||||||
|
removed_lo = initial_count - len(df)
|
||||||
|
print(f"🗑️ 删除血压异常值: 收缩压 {removed_hy} 条, 舒张压 {removed_lo} 条")
|
||||||
|
|
||||||
|
print(f"✅ 数据清洗完成,剩余 {len(df)} 条记录")
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_features(df: pd.DataFrame) -> tuple:
|
||||||
|
"""
|
||||||
|
准备特征和标签
|
||||||
|
删除id和原始age字段,保留处理后的特征
|
||||||
|
"""
|
||||||
|
# 定义要使用的特征(删除id和原始age,保留age_years)
|
||||||
|
feature_columns = ['age_years', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
||||||
|
'cholesterol', 'gluc', 'smoke', 'alco', 'active', 'bmi']
|
||||||
|
|
||||||
|
X = df[feature_columns].copy()
|
||||||
|
y = df['cardio'].copy()
|
||||||
|
|
||||||
|
print(f"📊 特征数量: {len(feature_columns)}")
|
||||||
|
print(f"📊 特征列: {feature_columns}")
|
||||||
|
|
||||||
|
return X, y, feature_columns
|
||||||
|
|
||||||
|
|
||||||
|
def build_pipeline() -> Pipeline:
|
||||||
|
"""构建包含预处理器和分类器的Pipeline"""
|
||||||
|
print("🔧 正在构建Pipeline...")
|
||||||
|
|
||||||
|
# 连续特征列
|
||||||
|
continuous_cols = ['age_years', 'height', 'weight', 'ap_hi', 'ap_lo', 'bmi']
|
||||||
|
|
||||||
|
# 分类特征列
|
||||||
|
categorical_cols = ['gender', 'cholesterol', 'gluc', 'smoke', 'alco', 'active']
|
||||||
|
|
||||||
|
# 预处理器
|
||||||
|
preprocessor = ColumnTransformer(
|
||||||
|
transformers=[
|
||||||
|
('num', StandardScaler(), continuous_cols),
|
||||||
|
('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), categorical_cols)
|
||||||
|
],
|
||||||
|
remainder='drop'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 完整Pipeline: 预处理 + XGBoost分类器
|
||||||
|
pipeline = Pipeline([
|
||||||
|
('preprocessor', preprocessor),
|
||||||
|
('classifier', XGBClassifier(
|
||||||
|
n_estimators=100,
|
||||||
|
max_depth=6,
|
||||||
|
learning_rate=0.1,
|
||||||
|
random_state=42,
|
||||||
|
use_label_encoder=False,
|
||||||
|
eval_metric='logloss',
|
||||||
|
n_jobs=-1
|
||||||
|
))
|
||||||
|
])
|
||||||
|
|
||||||
|
print("✅ Pipeline构建完成")
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def train_and_evaluate(X: pd.DataFrame, y: pd.Series, pipeline: Pipeline):
|
||||||
|
"""训练模型并评估"""
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("📈 开始模型训练...")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
# 划分训练集和测试集
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
|
X, y, test_size=0.2, random_state=42, stratify=y
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"📊 训练集大小: {len(X_train)}")
|
||||||
|
print(f"📊 测试集大小: {len(X_test)}")
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
print("🏋️ 正在训练XGBoost模型...")
|
||||||
|
pipeline.fit(X_train, y_train)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
y_pred = pipeline.predict(X_test)
|
||||||
|
y_pred_proba = pipeline.predict_proba(X_test)[:, 1]
|
||||||
|
|
||||||
|
# 评估指标
|
||||||
|
accuracy = accuracy_score(y_test, y_pred)
|
||||||
|
roc_auc = roc_auc_score(y_test, y_pred_proba)
|
||||||
|
|
||||||
|
print("\n" + "="*50)
|
||||||
|
print("📊 模型评估结果:")
|
||||||
|
print("="*50)
|
||||||
|
print(f"✅ 准确率 (Accuracy): {accuracy:.4f}")
|
||||||
|
print(f"✅ ROC-AUC 分数: {roc_auc:.4f}")
|
||||||
|
print("\n📋 分类报告:")
|
||||||
|
print(classification_report(y_test, y_pred, target_names=['健康', '有风险']))
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(pipeline: Pipeline, model_path: Path):
|
||||||
|
"""保存模型"""
|
||||||
|
print(f"\n💾 正在保存模型到: {model_path}")
|
||||||
|
joblib.dump(pipeline, model_path)
|
||||||
|
print(f"✅ 模型保存成功!")
|
||||||
|
|
||||||
|
# 验证模型文件
|
||||||
|
file_size = model_path.stat().st_size / (1024 * 1024)
|
||||||
|
print(f"📦 模型文件大小: {file_size:.2f} MB")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("❤️ CardioAI 模块2: 心血管疾病风险预测模型训练")
|
||||||
|
print("="*60 + "\n")
|
||||||
|
|
||||||
|
# 1. 加载并清洗数据
|
||||||
|
df = load_and_clean_data(DATA_PATH)
|
||||||
|
|
||||||
|
# 2. 准备特征
|
||||||
|
X, y, feature_columns = prepare_features(df)
|
||||||
|
|
||||||
|
# 3. 构建Pipeline
|
||||||
|
pipeline = build_pipeline()
|
||||||
|
|
||||||
|
# 4. 训练并评估模型
|
||||||
|
trained_pipeline = train_and_evaluate(X, y, pipeline)
|
||||||
|
|
||||||
|
# 5. 保存模型
|
||||||
|
save_model(trained_pipeline, MODEL_PATH)
|
||||||
|
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("🎉 模型训练完成!")
|
||||||
|
print("="*60)
|
||||||
|
print(f"\n📌 模型使用说明:")
|
||||||
|
print(f" 1. 启动Flask API: python app.py")
|
||||||
|
print(f" 2. 访问 http://localhost:5001 查看预测界面")
|
||||||
|
print(f" 3. 输入11个特征值进行预测")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
21
CardioAI/requirements.txt
Normal file
21
CardioAI/requirements.txt
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# CardioAI 项目依赖
|
||||||
|
# 创建并激活 conda 环境:
|
||||||
|
# conda create -n cardioenv python=3.10
|
||||||
|
# conda activate cardioenv
|
||||||
|
|
||||||
|
# 然后安装依赖:
|
||||||
|
# pip install -r requirements.txt
|
||||||
|
|
||||||
|
pandas
|
||||||
|
openpyxl
|
||||||
|
numpy
|
||||||
|
scikit-learn
|
||||||
|
xgboost
|
||||||
|
joblib
|
||||||
|
streamlit
|
||||||
|
plotly
|
||||||
|
Flask
|
||||||
|
python-dotenv
|
||||||
|
langchain-openai
|
||||||
|
dashscope
|
||||||
|
requests
|
||||||
Reference in New Issue
Block a user