feat: 添加心血管疾病预测模型和语音助手模块
module2_predictor: - 逻辑回归、随机森林、梯度提升三种模型 - 模型性能对比 (准确率、精确率、召回率、F1、ROC-AUC) - 交互式预测界面 - 混淆矩阵、ROC曲线、特征重要性可视化 module3_voice_assistant: - 语音交互式心血管疾病风险评估 - 患者信息管理 - 风险仪表盘可视化 - 快速指令和自定义问答 - 对话历史记录 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
500
ai_code/aicodes/module2_predictor/cardio_predictor.py
Normal file
500
ai_code/aicodes/module2_predictor/cardio_predictor.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
CardioAI 心血管疾病预测模型
|
||||
使用机器学习模型进行心血管疾病风险预测
|
||||
"""
|
||||
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import (
|
||||
accuracy_score, precision_score, recall_score,
|
||||
f1_score, roc_auc_score, confusion_matrix,
|
||||
classification_report, roc_curve
|
||||
)
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
import joblib
|
||||
import os
|
||||
|
||||
|
||||
# ============================================
|
||||
# 数据加载函数 (带缓存)
|
||||
# ============================================
|
||||
@st.cache_data
|
||||
def load_data(file_path: str) -> pd.DataFrame:
|
||||
"""
|
||||
加载 Excel 数据文件
|
||||
|
||||
Args:
|
||||
file_path: Excel 文件路径
|
||||
|
||||
Returns:
|
||||
加载的 DataFrame
|
||||
"""
|
||||
df = pd.read_excel(file_path)
|
||||
return df
|
||||
|
||||
|
||||
# ============================================
|
||||
# 数据预处理函数
|
||||
# ============================================
|
||||
@st.cache_data
|
||||
def preprocess_data(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
数据清洗和特征工程
|
||||
|
||||
Args:
|
||||
df: 原始 DataFrame
|
||||
|
||||
Returns:
|
||||
预处理后的 DataFrame
|
||||
"""
|
||||
df_clean = df.copy()
|
||||
|
||||
# 年龄转换: 天 -> 年
|
||||
df_clean['age_years'] = (df_clean['age'] / 365).round().astype(int)
|
||||
|
||||
# 计算 BMI
|
||||
df_clean['bmi'] = df_clean['weight'] / ((df_clean['height'] / 100) ** 2)
|
||||
|
||||
# 删除舒张压 >= 收缩压的记录
|
||||
df_clean = df_clean[df_clean['ap_hi'] > df_clean['ap_lo']]
|
||||
|
||||
# 删除血压极端异常值
|
||||
df_clean = df_clean[(df_clean['ap_hi'] >= 90) & (df_clean['ap_hi'] <= 250)]
|
||||
df_clean = df_clean[(df_clean['ap_lo'] >= 60) & (df_clean['ap_lo'] <= 150)]
|
||||
|
||||
# BMI 分类
|
||||
def categorize_bmi(bmi):
|
||||
if bmi < 18.5:
|
||||
return 0 # 偏瘦
|
||||
elif bmi < 24:
|
||||
return 1 # 正常
|
||||
elif bmi < 28:
|
||||
return 2 # 超重
|
||||
else:
|
||||
return 3 # 肥胖
|
||||
|
||||
df_clean['bmi_category'] = df_clean['bmi'].apply(categorize_bmi)
|
||||
|
||||
# 特征工程: 添加更多有用的特征
|
||||
# 血压差值
|
||||
df_clean['bp_diff'] = df_clean['ap_hi'] - df_clean['ap_lo']
|
||||
# 平均动脉压
|
||||
df_clean['map'] = df_clean['ap_lo'] + (df_clean['bp_diff'] / 3)
|
||||
|
||||
return df_clean
|
||||
|
||||
|
||||
# ============================================
|
||||
# 模型训练函数
|
||||
# ============================================
|
||||
def train_models(X_train, y_train, X_test, y_test):
|
||||
"""
|
||||
训练多个模型并返回结果
|
||||
|
||||
Args:
|
||||
X_train, y_train: 训练数据
|
||||
X_test, y_test: 测试数据
|
||||
|
||||
Returns:
|
||||
训练好的模型和评估结果
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# 1. 逻辑回归
|
||||
lr = LogisticRegression(max_iter=1000, random_state=42)
|
||||
lr.fit(X_train, y_train)
|
||||
y_pred_lr = lr.predict(X_test)
|
||||
y_prob_lr = lr.predict_proba(X_test)[:, 1]
|
||||
|
||||
results['Logistic Regression'] = {
|
||||
'model': lr,
|
||||
'predictions': y_pred_lr,
|
||||
'probabilities': y_prob_lr,
|
||||
'accuracy': accuracy_score(y_test, y_pred_lr),
|
||||
'precision': precision_score(y_test, y_pred_lr),
|
||||
'recall': recall_score(y_test, y_pred_lr),
|
||||
'f1': f1_score(y_test, y_pred_lr),
|
||||
'roc_auc': roc_auc_score(y_test, y_prob_lr)
|
||||
}
|
||||
|
||||
# 2. 随机森林
|
||||
rf = RandomForestClassifier(
|
||||
n_estimators=100,
|
||||
max_depth=10,
|
||||
random_state=42,
|
||||
n_jobs=-1
|
||||
)
|
||||
rf.fit(X_train, y_train)
|
||||
y_pred_rf = rf.predict(X_test)
|
||||
y_prob_rf = rf.predict_proba(X_test)[:, 1]
|
||||
|
||||
results['Random Forest'] = {
|
||||
'model': rf,
|
||||
'predictions': y_pred_rf,
|
||||
'probabilities': y_prob_rf,
|
||||
'accuracy': accuracy_score(y_test, y_pred_rf),
|
||||
'precision': precision_score(y_test, y_pred_rf),
|
||||
'recall': recall_score(y_test, y_pred_rf),
|
||||
'f1': f1_score(y_test, y_pred_rf),
|
||||
'roc_auc': roc_auc_score(y_test, y_prob_rf)
|
||||
}
|
||||
|
||||
# 3. 梯度提升
|
||||
gb = GradientBoostingClassifier(
|
||||
n_estimators=100,
|
||||
max_depth=5,
|
||||
learning_rate=0.1,
|
||||
random_state=42
|
||||
)
|
||||
gb.fit(X_train, y_train)
|
||||
y_pred_gb = gb.predict(X_test)
|
||||
y_prob_gb = gb.predict_proba(X_test)[:, 1]
|
||||
|
||||
results['Gradient Boosting'] = {
|
||||
'model': gb,
|
||||
'predictions': y_pred_gb,
|
||||
'probabilities': y_prob_gb,
|
||||
'accuracy': accuracy_score(y_test, y_pred_gb),
|
||||
'precision': precision_score(y_test, y_pred_gb),
|
||||
'recall': recall_score(y_test, y_pred_gb),
|
||||
'f1': f1_score(y_test, y_pred_gb),
|
||||
'roc_auc': roc_auc_score(y_test, y_prob_gb)
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ============================================
|
||||
# 预测函数
|
||||
# ============================================
|
||||
def predict_cardio(model, input_data, scaler):
|
||||
"""
|
||||
使用模型进行预测
|
||||
|
||||
Args:
|
||||
model: 训练好的模型
|
||||
input_data: 输入特征
|
||||
scaler: 数据标准化器
|
||||
|
||||
Returns:
|
||||
预测结果和概率
|
||||
"""
|
||||
# 标准化输入
|
||||
input_scaled = scaler.transform(input_data)
|
||||
# 预测
|
||||
prediction = model.predict(input_scaled)[0]
|
||||
probability = model.predict_proba(input_scaled)[0]
|
||||
|
||||
return prediction, probability
|
||||
|
||||
|
||||
# ============================================
|
||||
# Streamlit 页面配置
|
||||
# ============================================
|
||||
st.set_page_config(
|
||||
page_title="CardioAI 心血管疾病预测",
|
||||
page_icon="🔬",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
|
||||
# ============================================
|
||||
# 主程序
|
||||
# ============================================
|
||||
def main():
|
||||
"""主程序入口"""
|
||||
|
||||
# 页面标题
|
||||
st.title("🔬 CardioAI 心血管疾病预测模型")
|
||||
st.markdown("---")
|
||||
|
||||
# 数据路径
|
||||
DATA_PATH = "C:/Users/SAM/Desktop/sam_test/ai_code/aicodes/data/心血管疾病.xlsx"
|
||||
|
||||
# 加载数据
|
||||
try:
|
||||
df = load_data(DATA_PATH)
|
||||
except Exception as e:
|
||||
st.error(f"❌ 数据加载失败: {e}")
|
||||
return
|
||||
|
||||
# 数据预处理
|
||||
df_processed = preprocess_data(df)
|
||||
st.success(f"✅ 数据加载成功,共 {len(df_processed)} 条有效记录")
|
||||
|
||||
# ============================================
|
||||
# 侧边栏 - 模型选择
|
||||
# ============================================
|
||||
st.sidebar.header("⚙️ 模型设置")
|
||||
|
||||
# 特征选择
|
||||
feature_options = [
|
||||
'age_years', 'gender', 'height', 'weight', 'ap_hi', 'ap_lo',
|
||||
'cholesterol', 'gluc', 'smoke', 'alco', 'active', 'bmi', 'bmi_category',
|
||||
'bp_diff', 'map'
|
||||
]
|
||||
|
||||
selected_features = st.sidebar.multiselect(
|
||||
"选择特征",
|
||||
options=feature_options,
|
||||
default=['age_years', 'gender', 'ap_hi', 'ap_lo', 'cholesterol', 'gluc', 'bmi']
|
||||
)
|
||||
|
||||
if not selected_features:
|
||||
st.warning("请至少选择一个特征")
|
||||
return
|
||||
|
||||
# 测试集比例
|
||||
test_size = st.sidebar.slider("测试集比例", 0.1, 0.4, 0.2)
|
||||
|
||||
# 分割数据
|
||||
X = df_processed[selected_features]
|
||||
y = df_processed['cardio']
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=test_size, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
# 标准化
|
||||
scaler = StandardScaler()
|
||||
X_train_scaled = scaler.fit_transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# ============================================
|
||||
# Tab 布局
|
||||
# ============================================
|
||||
tab1, tab2, tab3 = st.tabs(["📊 模型训练", "🔮 疾病预测", "📈 模型评估"])
|
||||
|
||||
# ==================== Tab 1: 模型训练 ====================
|
||||
with tab1:
|
||||
st.header("模型训练")
|
||||
|
||||
# 训练模型
|
||||
with st.spinner("正在训练模型..."):
|
||||
results = train_models(X_train_scaled, y_train, X_test_scaled, y_test)
|
||||
|
||||
st.success("✅ 模型训练完成")
|
||||
|
||||
# 显示模型性能对比
|
||||
st.subheader("模型性能对比")
|
||||
|
||||
# 创建性能对比表
|
||||
performance_data = []
|
||||
for model_name, result in results.items():
|
||||
performance_data.append({
|
||||
'模型': model_name,
|
||||
'准确率': f"{result['accuracy']:.4f}",
|
||||
'精确率': f"{result['precision']:.4f}",
|
||||
'召回率': f"{result['recall']:.4f}",
|
||||
'F1分数': f"{result['f1']:.4f}",
|
||||
'ROC-AUC': f"{result['roc_auc']:.4f}"
|
||||
})
|
||||
|
||||
performance_df = pd.DataFrame(performance_data)
|
||||
st.table(performance_df)
|
||||
|
||||
# 选择最佳模型
|
||||
best_model_name = max(results, key=lambda x: results[x]['roc_auc'])
|
||||
best_result = results[best_model_name]
|
||||
|
||||
st.info(f"🏆 最佳模型: {best_model_name} (ROC-AUC: {best_result['roc_auc']:.4f})")
|
||||
|
||||
# ==================== Tab 2: 疾病预测 ====================
|
||||
with tab2:
|
||||
st.header("心血管疾病风险预测")
|
||||
|
||||
st.markdown("### 输入患者信息")
|
||||
|
||||
# 创建输入表单
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
age = st.number_input("年龄 (岁)", 20, 100, 50)
|
||||
gender = st.selectbox("性别", ["女性", "男性"])
|
||||
height = st.number_input("身高 (cm)", 100, 220, 170)
|
||||
weight = st.number_input("体重 (kg)", 30, 200, 70)
|
||||
|
||||
with col2:
|
||||
ap_hi = st.number_input("收缩压 (mmHg)", 90, 250, 120)
|
||||
ap_lo = st.number_input("舒张压 (mmHg)", 60, 150, 80)
|
||||
cholesterol = st.selectbox("胆固醇", ["正常", "高于正常", "远高于正常"])
|
||||
gluc = st.selectbox("血糖", ["正常", "高于正常", "远高于正常"])
|
||||
|
||||
# 生活方式
|
||||
col3, col4, col5 = st.columns(3)
|
||||
with col3:
|
||||
smoke = st.checkbox("吸烟")
|
||||
with col4:
|
||||
alco = st.checkbox("饮酒")
|
||||
with col5:
|
||||
active = st.checkbox("运动")
|
||||
|
||||
# 转换输入数据
|
||||
gender_val = 1 if gender == "女性" else 2
|
||||
cholesterol_map = {"正常": 1, "高于正常": 2, "远高于正常": 3}
|
||||
gluc_map = {"正常": 1, "高于正常": 2, "远高于正常": 3}
|
||||
cholesterol_val = cholesterol_map[cholesterol]
|
||||
gluc_val = gluc_map[gluc]
|
||||
smoke_val = 1 if smoke else 0
|
||||
alco_val = 1 if alco else 0
|
||||
active_val = 1 if active else 0
|
||||
|
||||
# 计算 BMI
|
||||
bmi = weight / ((height / 100) ** 2)
|
||||
|
||||
# 计算派生特征
|
||||
bp_diff = ap_hi - ap_lo
|
||||
map_val = ap_lo + (bp_diff / 3)
|
||||
|
||||
# BMI 分类
|
||||
if bmi < 18.5:
|
||||
bmi_cat = 0
|
||||
elif bmi < 24:
|
||||
bmi_cat = 1
|
||||
elif bmi < 28:
|
||||
bmi_cat = 2
|
||||
else:
|
||||
bmi_cat = 3
|
||||
|
||||
# 构建输入特征
|
||||
input_data_dict = {
|
||||
'age_years': age,
|
||||
'gender': gender_val,
|
||||
'height': height,
|
||||
'weight': weight,
|
||||
'ap_hi': ap_hi,
|
||||
'ap_lo': ap_lo,
|
||||
'cholesterol': cholesterol_val,
|
||||
'gluc': gluc_val,
|
||||
'smoke': smoke_val,
|
||||
'alco': alco_val,
|
||||
'active': active_val,
|
||||
'bmi': bmi,
|
||||
'bmi_category': bmi_cat,
|
||||
'bp_diff': bp_diff,
|
||||
'map': map_val
|
||||
}
|
||||
|
||||
# 只使用选中的特征
|
||||
input_features = {k: input_data_dict[k] for k in selected_features}
|
||||
input_df = pd.DataFrame([input_features])
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# 预测按钮
|
||||
if st.button("🔍 进行预测", type="primary"):
|
||||
with st.spinner("预测中..."):
|
||||
# 使用最佳模型预测
|
||||
model = best_result['model']
|
||||
prediction, probability = predict_cardio(model, input_df, scaler)
|
||||
|
||||
st.markdown("### 预测结果")
|
||||
|
||||
col_r1, col_r2 = st.columns(2)
|
||||
|
||||
with col_r1:
|
||||
if prediction == 1:
|
||||
st.error("⚠️ 存在心血管疾病风险")
|
||||
else:
|
||||
st.success("✅ 无心血管疾病风险")
|
||||
|
||||
with col_r2:
|
||||
risk_prob = probability[1] * 100
|
||||
st.metric(
|
||||
"风险概率",
|
||||
f"{risk_prob:.1f}%",
|
||||
delta=f"患病概率" if risk_prob > 50 else "健康概率"
|
||||
)
|
||||
|
||||
# 风险等级
|
||||
st.markdown("### 风险评估")
|
||||
if risk_prob < 30:
|
||||
st.info("🟢 低风险 - 保持健康生活方式")
|
||||
elif risk_prob < 60:
|
||||
st.warning("🟡 中风险 - 建议定期体检")
|
||||
else:
|
||||
st.error("🔴 高风险 - 建议及时就医")
|
||||
|
||||
# ==================== Tab 3: 模型评估 ====================
|
||||
with tab3:
|
||||
st.header("模型评估分析")
|
||||
|
||||
# 选择要评估的模型
|
||||
model_to_eval = st.selectbox(
|
||||
"选择模型",
|
||||
list(results.keys())
|
||||
)
|
||||
|
||||
result = results[model_to_eval]
|
||||
|
||||
# 1. 混淆矩阵
|
||||
st.subheader("混淆矩阵")
|
||||
cm = confusion_matrix(y_test, result['predictions'])
|
||||
|
||||
fig_cm = px.imshow(
|
||||
cm,
|
||||
labels=dict(x="预测值", y="实际值"),
|
||||
x=['无心血管疾病', '有心血管疾病'],
|
||||
y=['无心血管疾病', '有心血管疾病'],
|
||||
text_auto=True,
|
||||
color_continuous_scale='Blues'
|
||||
)
|
||||
st.plotly_chart(fig_cm, use_container_width=True)
|
||||
|
||||
# 2. ROC 曲线
|
||||
st.subheader("ROC 曲线")
|
||||
fpr, tpr, _ = roc_curve(y_test, result['probabilities'])
|
||||
|
||||
fig_roc = go.Figure()
|
||||
fig_roc.add_trace(go.Scatter(
|
||||
x=fpr, y=tpr,
|
||||
mode='lines',
|
||||
name=f'ROC (AUC = {result["roc_auc"]:.4f})',
|
||||
line=dict(color='#e74c3c', width=2)
|
||||
))
|
||||
fig_roc.add_trace(go.Scatter(
|
||||
x=[0, 1], y=[0, 1],
|
||||
mode='lines',
|
||||
name='随机猜测',
|
||||
line=dict(color='gray', width=1, dash='dash')
|
||||
))
|
||||
fig_roc.update_layout(
|
||||
xaxis_title='假阳性率 (FPR)',
|
||||
yaxis_title='真阳性率 (TPR)',
|
||||
title='ROC 曲线',
|
||||
showlegend=True
|
||||
)
|
||||
st.plotly_chart(fig_roc, use_container_width=True)
|
||||
|
||||
# 3. 特征重要性 (仅对树模型)
|
||||
if model_to_eval in ['Random Forest', 'Gradient Boosting']:
|
||||
st.subheader("特征重要性")
|
||||
model = result['model']
|
||||
importance = model.feature_importances_
|
||||
importance_df = pd.DataFrame({
|
||||
'特征': selected_features,
|
||||
'重要性': importance
|
||||
}).sort_values('重要性', ascending=True)
|
||||
|
||||
fig_imp = px.bar(
|
||||
importance_df,
|
||||
x='重要性',
|
||||
y='特征',
|
||||
orientation='h',
|
||||
title='特征重要性',
|
||||
color='重要性',
|
||||
color_continuous_scale='Viridis'
|
||||
)
|
||||
st.plotly_chart(fig_imp, use_container_width=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
467
ai_code/aicodes/module3_voice_assistant/voice_assistant.py
Normal file
467
ai_code/aicodes/module3_voice_assistant/voice_assistant.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
CardioAI 语音助手
|
||||
提供语音交互式心血管疾病风险评估
|
||||
"""
|
||||
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
from sklearn.metrics import roc_auc_score
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
# ============================================
|
||||
# 数据加载函数 (带缓存)
|
||||
# ============================================
|
||||
@st.cache_data
|
||||
def load_data(file_path: str) -> pd.DataFrame:
|
||||
"""加载 Excel 数据文件"""
|
||||
df = pd.read_excel(file_path)
|
||||
return df
|
||||
|
||||
|
||||
# ============================================
|
||||
# 数据预处理函数
|
||||
# ============================================
|
||||
@st.cache_data
|
||||
def preprocess_data(df: pd.DataFrame) -> tuple:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Returns:
|
||||
(model, scaler, feature_columns)
|
||||
"""
|
||||
df_clean = df.copy()
|
||||
|
||||
# 年龄转换
|
||||
df_clean['age_years'] = (df_clean['age'] / 365).round().astype(int)
|
||||
|
||||
# 计算 BMI
|
||||
df_clean['bmi'] = df_clean['weight'] / ((df_clean['height'] / 100) ** 2)
|
||||
|
||||
# 删除异常值
|
||||
df_clean = df_clean[df_clean['ap_hi'] > df_clean['ap_lo']]
|
||||
df_clean = df_clean[(df_clean['ap_hi'] >= 90) & (df_clean['ap_hi'] <= 250)]
|
||||
df_clean = df_clean[(df_clean['ap_lo'] >= 60) & (df_clean['ap_lo'] <= 150)]
|
||||
|
||||
# 特征工程
|
||||
df_clean['bp_diff'] = df_clean['ap_hi'] - df_clean['ap_lo']
|
||||
df_clean['map'] = df_clean['ap_lo'] + (df_clean['bp_diff'] / 3)
|
||||
|
||||
# 特征列
|
||||
feature_columns = [
|
||||
'age_years', 'gender', 'ap_hi', 'ap_lo',
|
||||
'cholesterol', 'gluc', 'bmi', 'smoke', 'alco', 'active'
|
||||
]
|
||||
|
||||
X = df_clean[feature_columns]
|
||||
y = df_clean['cardio']
|
||||
|
||||
# 分割数据
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
# 标准化
|
||||
scaler = StandardScaler()
|
||||
X_train_scaled = scaler.fit_transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
|
||||
# 训练模型
|
||||
model = GradientBoostingClassifier(
|
||||
n_estimators=100,
|
||||
max_depth=5,
|
||||
learning_rate=0.1,
|
||||
random_state=42
|
||||
)
|
||||
model.fit(X_train_scaled, y_train)
|
||||
|
||||
# 验证模型
|
||||
y_prob = model.predict_proba(X_test_scaled)[:, 1]
|
||||
roc_auc = roc_auc_score(y_test, y_prob)
|
||||
|
||||
return model, scaler, feature_columns, roc_auc
|
||||
|
||||
|
||||
# ============================================
|
||||
# 预测函数
|
||||
# ============================================
|
||||
def predict_risk(model, scaler, input_data, feature_columns):
|
||||
"""预测心血管疾病风险"""
|
||||
input_df = pd.DataFrame([input_data])[feature_columns]
|
||||
input_scaled = scaler.transform(input_df)
|
||||
probability = model.predict_proba(input_scaled)[0]
|
||||
return probability
|
||||
|
||||
|
||||
# ============================================
|
||||
# 语音指令处理
|
||||
# ============================================
|
||||
def process_voice_command(command: str, patient_data: dict, model, scaler, feature_columns) -> str:
|
||||
"""
|
||||
处理语音指令
|
||||
|
||||
Args:
|
||||
command: 用户语音指令
|
||||
patient_data: 患者数据
|
||||
model: 训练好的模型
|
||||
scaler: 标准化器
|
||||
feature_columns: 特征列
|
||||
|
||||
Returns:
|
||||
响应文本
|
||||
"""
|
||||
command = command.lower()
|
||||
|
||||
# 预测风险
|
||||
probability = predict_risk(model, scaler, patient_data, feature_columns)
|
||||
risk_percent = probability[1] * 100
|
||||
no_risk_percent = probability[0] * 100
|
||||
|
||||
if "风险" in command or "患病" in command or "probability" in command:
|
||||
if risk_percent > 50:
|
||||
return f"根据您提供的信息,您患有心血管疾病的风险为 {risk_percent:.1f}%,风险较高,建议尽快就医检查。"
|
||||
else:
|
||||
return f"根据您提供的信息,您患有心血管疾病的风险为 {risk_percent:.1f}%,风险较低,请继续保持健康的生活方式。"
|
||||
|
||||
elif "正常" in command or "健康" in command or "healthy" in command:
|
||||
return f"您的健康概率为 {no_risk_percent:.1f}%,患病风险为 {risk_percent:.1f}%。"
|
||||
|
||||
elif "年龄" in command or "age" in command:
|
||||
return f"您的年龄是 {patient_data['age_years']} 岁。"
|
||||
|
||||
elif "血压" in command or "blood pressure" in command:
|
||||
return f"您的血压信息:收缩压 {patient_data['ap_hi']} mmHg,舒张压 {patient_data['ap_lo']} mmHg。"
|
||||
|
||||
elif "bmi" in command or "体重" in command:
|
||||
return f"您的 BMI 为 {patient_data['bmi']:.1f}。"
|
||||
|
||||
elif "胆固醇" in command or "cholesterol" in command:
|
||||
chol_map = {1: "正常", 2: "高于正常", 3: "远高于正常"}
|
||||
return f"您的胆固醇水平:{chol_map.get(patient_data['cholesterol'], '未知')}。"
|
||||
|
||||
elif "血糖" in command or "glucose" in command:
|
||||
gluc_map = {1: "正常", 2: "高于正常", 3: "远高于正常"}
|
||||
return f"您的血糖水平:{gluc_map.get(patient_data['gluc'], '未知')}。"
|
||||
|
||||
elif "建议" in command or "advice" in command or "recommend" in command:
|
||||
if risk_percent > 60:
|
||||
return "建议您:1. 立即就医进行全面检查 2. 控制饮食,减少高盐高脂食物 3. 适度运动 4. 戒烟限酒 5. 保持规律作息。"
|
||||
elif risk_percent > 30:
|
||||
return "建议您:1. 定期体检监测 2. 保持健康饮食 3. 坚持适度运动 4. 控制体重 5. 避免吸烟饮酒。"
|
||||
else:
|
||||
return "建议您:1. 保持当前健康生活方式 2. 均衡饮食 3. 适量运动 4. 定期体检。"
|
||||
|
||||
elif "帮助" in command or "help" in command:
|
||||
return """您可以询问以下信息:
|
||||
- 患病风险
|
||||
- 血压情况
|
||||
- BMI 数值
|
||||
- 胆固醇水平
|
||||
- 血糖水平
|
||||
- 健康建议
|
||||
- 我的年龄
|
||||
|
||||
请告诉我您想了解什么?"""
|
||||
|
||||
else:
|
||||
return f"抱歉,我没有理解您的指令。您可以问我关于患病风险、血压、BMI等问题。或者直接点击「开始语音评估」按钮进行心血管疾病风险评估。"
|
||||
|
||||
|
||||
# ============================================
|
||||
# Streamlit 页面配置
|
||||
# ============================================
|
||||
st.set_page_config(
|
||||
page_title="CardioAI 语音助手",
|
||||
page_icon="🎤",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
|
||||
# ============================================
|
||||
# 初始化会话状态
|
||||
# ============================================
|
||||
if 'patient_data' not in st.session_state:
|
||||
st.session_state.patient_data = None
|
||||
if 'voice_history' not in st.session_state:
|
||||
st.session_state.voice_history = []
|
||||
if 'model_trained' not in st.session_state:
|
||||
st.session_state.model_trained = False
|
||||
|
||||
|
||||
# ============================================
|
||||
# 主程序
|
||||
# ============================================
|
||||
def main():
|
||||
"""主程序入口"""
|
||||
|
||||
# 页面标题
|
||||
st.title("🎤 CardioAI 智能语音助手")
|
||||
st.markdown("---")
|
||||
st.markdown("**您的私人心血管健康顾问**")
|
||||
|
||||
# 数据路径
|
||||
DATA_PATH = "C:/Users/SAM/Desktop/sam_test/ai_code/aicodes/data/心血管疾病.xlsx"
|
||||
|
||||
# 加载和训练模型
|
||||
try:
|
||||
df = load_data(DATA_PATH)
|
||||
model, scaler, feature_columns, roc_auc = preprocess_data(df)
|
||||
st.session_state.model_trained = True
|
||||
st.session_state.model = model
|
||||
st.session_state.scaler = scaler
|
||||
st.session_state.feature_columns = feature_columns
|
||||
except Exception as e:
|
||||
st.error(f"❌ 系统初始化失败: {e}")
|
||||
return
|
||||
|
||||
# ============================================
|
||||
# 布局:侧边栏 + 主内容
|
||||
# ============================================
|
||||
st.sidebar.header("👤 患者信息录入")
|
||||
|
||||
# 患者信息表单
|
||||
with st.sidebar.form("patient_form"):
|
||||
st.subheader("基本信息")
|
||||
|
||||
age = st.number_input("年龄 (岁)", 20, 100, 50)
|
||||
gender = st.selectbox("性别", ["女性", "男性"])
|
||||
|
||||
st.subheader("身体指标")
|
||||
|
||||
ap_hi = st.number_input("收缩压 (mmHg)", 90, 250, 120)
|
||||
ap_lo = st.number_input("舒张压 (mmHg)", 60, 150, 80)
|
||||
height = st.number_input("身高 (cm)", 100, 220, 170)
|
||||
weight = st.number_input("体重 (kg)", 30, 200, 70)
|
||||
|
||||
st.subheader("生化指标")
|
||||
|
||||
cholesterol = st.selectbox("胆固醇水平", ["正常", "高于正常", "远高于正常"])
|
||||
gluc = st.selectbox("血糖水平", ["正常", "高于正常", "远高于正常"])
|
||||
|
||||
st.subheader("生活方式")
|
||||
|
||||
smoke = st.checkbox("吸烟")
|
||||
alco = st.checkbox("饮酒")
|
||||
active = st.checkbox("运动")
|
||||
|
||||
submit = st.form_submit_button("确认信息", type="primary")
|
||||
|
||||
if submit:
|
||||
# 构建患者数据
|
||||
gender_val = 1 if gender == "女性" else 2
|
||||
cholesterol_map = {"正常": 1, "高于正常": 2, "远高于正常": 3}
|
||||
gluc_map = {"正常": 1, "高于正常": 2, "远高于正常": 3}
|
||||
|
||||
bmi = weight / ((height / 100) ** 2)
|
||||
|
||||
st.session_state.patient_data = {
|
||||
'age_years': age,
|
||||
'gender': gender_val,
|
||||
'ap_hi': ap_hi,
|
||||
'ap_lo': ap_lo,
|
||||
'cholesterol': cholesterol_map[cholesterol],
|
||||
'gluc': gluc_map[gluc],
|
||||
'bmi': bmi,
|
||||
'smoke': 1 if smoke else 0,
|
||||
'alco': 1 if alco else 0,
|
||||
'active': 1 if active else 0
|
||||
}
|
||||
|
||||
# 初始风险评估
|
||||
probability = predict_risk(
|
||||
st.session_state.model,
|
||||
st.session_state.scaler,
|
||||
st.session_state.patient_data,
|
||||
st.session_state.feature_columns
|
||||
)
|
||||
|
||||
risk_msg = f"您好!我是 CardioAI 语音助手。根据您提供的信息,您的心血管疾病风险评估如下:患病风险为 {probability[1]*100:.1f}%,健康概率为 {probability[0]*100:.1f}%。您可以询问我任何关于您心血管健康的问题。"
|
||||
st.session_state.voice_history = [
|
||||
{"role": "assistant", "content": risk_msg}
|
||||
]
|
||||
|
||||
st.success("✅ 患者信息已录入,可以开始语音交互了!")
|
||||
|
||||
# ============================================
|
||||
# 主内容区域
|
||||
# ============================================
|
||||
|
||||
# 欢迎信息和风险评估展示
|
||||
st.markdown("### 🎯 心血管疾病风险评估")
|
||||
|
||||
if st.session_state.patient_data:
|
||||
# 预测风险
|
||||
probability = predict_risk(
|
||||
st.session_state.model,
|
||||
st.session_state.scaler,
|
||||
st.session_state.patient_data,
|
||||
st.session_state.feature_columns
|
||||
)
|
||||
|
||||
risk_percent = probability[1] * 100
|
||||
|
||||
# 风险仪表盘
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
# 风险仪表图
|
||||
fig_gauge = go.Figure(go.Indicator(
|
||||
mode="gauge+number",
|
||||
value=risk_percent,
|
||||
domain={'x': [0, 1], 'y': [0, 1]},
|
||||
title={'text': "心血管疾病风险率"},
|
||||
gauge={
|
||||
'axis': {'range': [0, 100]},
|
||||
'bar': {'color': "#e74c3c" if risk_percent > 50 else "#2ecc71"},
|
||||
'steps': [
|
||||
{'range': [0, 30], 'color': "#2ecc71"},
|
||||
{'range': [30, 60], 'color': "#f1c40f"},
|
||||
{'range': [60, 100], 'color': "#e74c3c"}
|
||||
],
|
||||
'threshold': {
|
||||
'line': {'color': "black", 'width': 4},
|
||||
'thickness': 0.75,
|
||||
'value': risk_percent
|
||||
}
|
||||
}
|
||||
))
|
||||
fig_gauge.update_layout(height=200)
|
||||
st.plotly_chart(fig_gauge, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
# 风险等级和建议
|
||||
if risk_percent < 30:
|
||||
st.success("🟢 **低风险** - 请继续保持健康的生活方式")
|
||||
st.info("建议:均衡饮食、适量运动、定期体检")
|
||||
elif risk_percent < 60:
|
||||
st.warning("🟡 **中风险** - 建议关注心血管健康")
|
||||
st.info("建议:控制饮食、适度运动、定期监测血压")
|
||||
else:
|
||||
st.error("🔴 **高风险** - 建议及时就医")
|
||||
st.info("建议:尽快就医、严格控制饮食、戒烟限酒")
|
||||
|
||||
st.markdown(f"**模型 ROC-AUC**: {roc_auc:.4f}")
|
||||
else:
|
||||
st.info("👈 请在左侧填写患者信息后开始评估")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# ============================================
|
||||
# 语音交互区域
|
||||
# ============================================
|
||||
st.markdown("### 🎤 语音交互")
|
||||
|
||||
# 快速指令按钮
|
||||
st.markdown("#### 快速指令")
|
||||
quick_commands = [
|
||||
("查询风险", "我的患病风险是多少?"),
|
||||
("查询血压", "我的血压情况如何?"),
|
||||
("查询BMI", "我的BMI是多少?"),
|
||||
("健康建议", "请给我一些健康建议"),
|
||||
("帮助", "帮助")
|
||||
]
|
||||
|
||||
cols = st.columns(5)
|
||||
for i, (label, cmd) in enumerate(quick_commands):
|
||||
if cols[i].button(label):
|
||||
if st.session_state.patient_data:
|
||||
response = process_voice_command(
|
||||
cmd,
|
||||
st.session_state.patient_data,
|
||||
st.session_state.model,
|
||||
st.session_state.scaler,
|
||||
st.session_state.feature_columns
|
||||
)
|
||||
st.session_state.voice_history.append(
|
||||
{"role": "user", "content": cmd}
|
||||
)
|
||||
st.session_state.voice_history.append(
|
||||
{"role": "assistant", "content": response}
|
||||
)
|
||||
else:
|
||||
st.warning("请先录入患者信息")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# 自定义语音输入
|
||||
st.markdown("#### 自定义语音输入")
|
||||
|
||||
# 模拟语音输入 (实际项目中可接入语音识别API)
|
||||
voice_input = st.text_input(
|
||||
"输入您的问题:",
|
||||
placeholder="例如:我的胆固醇水平如何?",
|
||||
key="voice_input"
|
||||
)
|
||||
|
||||
if st.button("发送", type="primary") and voice_input:
|
||||
if st.session_state.patient_data:
|
||||
response = process_voice_command(
|
||||
voice_input,
|
||||
st.session_state.patient_data,
|
||||
st.session_state.model,
|
||||
st.session_state.scaler,
|
||||
st.session_state.feature_columns
|
||||
)
|
||||
st.session_state.voice_history.append(
|
||||
{"role": "user", "content": voice_input}
|
||||
)
|
||||
st.session_state.voice_history.append(
|
||||
{"role": "assistant", "content": response}
|
||||
)
|
||||
else:
|
||||
st.warning("请先录入患者信息")
|
||||
|
||||
# ============================================
|
||||
# 对话历史
|
||||
# ============================================
|
||||
st.markdown("---")
|
||||
st.markdown("### 💬 对话历史")
|
||||
|
||||
for msg in st.session_state.voice_history:
|
||||
if msg["role"] == "user":
|
||||
st.markdown(f"**👤 您**: {msg['content']}")
|
||||
else:
|
||||
st.markdown(f"**🤖 CardioAI**: {msg['content']}")
|
||||
st.markdown("")
|
||||
|
||||
# 清空对话
|
||||
if st.button("清空对话历史"):
|
||||
st.session_state.voice_history = []
|
||||
st.rerun()
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# ============================================
|
||||
# 使用说明
|
||||
# ============================================
|
||||
with st.expander("📖 使用说明"):
|
||||
st.markdown("""
|
||||
### CardioAI 语音助手使用指南
|
||||
|
||||
1. **录入患者信息**: 在左侧边栏填写患者的基本信息、身体指标、生化指标和生活方式。
|
||||
|
||||
2. **开始评估**: 点击「确认信息」按钮后,系统会自动进行心血管疾病风险评估。
|
||||
|
||||
3. **语音交互**: 您可以:
|
||||
- 使用快速指令按钮快速查询
|
||||
- 输入自定义问题
|
||||
|
||||
4. **支持的查询**:
|
||||
- 患病风险和健康概率
|
||||
- 血压、BMI、胆固醇、血糖等指标
|
||||
- 健康建议
|
||||
|
||||
### 风险等级说明
|
||||
- 🟢 **低风险** (0-30%): 继续保持健康生活方式
|
||||
- 🟡 **中风险** (30-60%): 建议关注并改善生活方式
|
||||
- 🔴 **高风险** (>60%): 建议及时就医检查
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user