feat: 添加心血管疾病预测模型和语音助手模块
module2_predictor: - 逻辑回归、随机森林、梯度提升三种模型 - 模型性能对比 (准确率、精确率、召回率、F1、ROC-AUC) - 交互式预测界面 - 混淆矩阵、ROC曲线、特征重要性可视化 module3_voice_assistant: - 语音交互式心血管疾病风险评估 - 患者信息管理 - 风险仪表盘可视化 - 快速指令和自定义问答 - 对话历史记录 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
467
ai_code/aicodes/module3_voice_assistant/voice_assistant.py
Normal file
467
ai_code/aicodes/module3_voice_assistant/voice_assistant.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
CardioAI 语音助手
|
||||
提供语音交互式心血管疾病风险评估
|
||||
"""
|
||||
|
||||
import streamlit as st
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
from sklearn.metrics import roc_auc_score
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
# ============================================
|
||||
# 数据加载函数 (带缓存)
|
||||
# ============================================
|
||||
@st.cache_data
|
||||
def load_data(file_path: str) -> pd.DataFrame:
|
||||
"""加载 Excel 数据文件"""
|
||||
df = pd.read_excel(file_path)
|
||||
return df
|
||||
|
||||
|
||||
# ============================================
|
||||
# 数据预处理函数
|
||||
# ============================================
|
||||
@st.cache_data
|
||||
def preprocess_data(df: pd.DataFrame) -> tuple:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Returns:
|
||||
(model, scaler, feature_columns)
|
||||
"""
|
||||
df_clean = df.copy()
|
||||
|
||||
# 年龄转换
|
||||
df_clean['age_years'] = (df_clean['age'] / 365).round().astype(int)
|
||||
|
||||
# 计算 BMI
|
||||
df_clean['bmi'] = df_clean['weight'] / ((df_clean['height'] / 100) ** 2)
|
||||
|
||||
# 删除异常值
|
||||
df_clean = df_clean[df_clean['ap_hi'] > df_clean['ap_lo']]
|
||||
df_clean = df_clean[(df_clean['ap_hi'] >= 90) & (df_clean['ap_hi'] <= 250)]
|
||||
df_clean = df_clean[(df_clean['ap_lo'] >= 60) & (df_clean['ap_lo'] <= 150)]
|
||||
|
||||
# 特征工程
|
||||
df_clean['bp_diff'] = df_clean['ap_hi'] - df_clean['ap_lo']
|
||||
df_clean['map'] = df_clean['ap_lo'] + (df_clean['bp_diff'] / 3)
|
||||
|
||||
# 特征列
|
||||
feature_columns = [
|
||||
'age_years', 'gender', 'ap_hi', 'ap_lo',
|
||||
'cholesterol', 'gluc', 'bmi', 'smoke', 'alco', 'active'
|
||||
]
|
||||
|
||||
X = df_clean[feature_columns]
|
||||
y = df_clean['cardio']
|
||||
|
||||
# 分割数据
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
|
||||
# 标准化
|
||||
scaler = StandardScaler()
|
||||
X_train_scaled = scaler.fit_transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
|
||||
# 训练模型
|
||||
model = GradientBoostingClassifier(
|
||||
n_estimators=100,
|
||||
max_depth=5,
|
||||
learning_rate=0.1,
|
||||
random_state=42
|
||||
)
|
||||
model.fit(X_train_scaled, y_train)
|
||||
|
||||
# 验证模型
|
||||
y_prob = model.predict_proba(X_test_scaled)[:, 1]
|
||||
roc_auc = roc_auc_score(y_test, y_prob)
|
||||
|
||||
return model, scaler, feature_columns, roc_auc
|
||||
|
||||
|
||||
# ============================================
|
||||
# 预测函数
|
||||
# ============================================
|
||||
def predict_risk(model, scaler, input_data, feature_columns):
|
||||
"""预测心血管疾病风险"""
|
||||
input_df = pd.DataFrame([input_data])[feature_columns]
|
||||
input_scaled = scaler.transform(input_df)
|
||||
probability = model.predict_proba(input_scaled)[0]
|
||||
return probability
|
||||
|
||||
|
||||
# ============================================
|
||||
# 语音指令处理
|
||||
# ============================================
|
||||
def process_voice_command(command: str, patient_data: dict, model, scaler, feature_columns) -> str:
|
||||
"""
|
||||
处理语音指令
|
||||
|
||||
Args:
|
||||
command: 用户语音指令
|
||||
patient_data: 患者数据
|
||||
model: 训练好的模型
|
||||
scaler: 标准化器
|
||||
feature_columns: 特征列
|
||||
|
||||
Returns:
|
||||
响应文本
|
||||
"""
|
||||
command = command.lower()
|
||||
|
||||
# 预测风险
|
||||
probability = predict_risk(model, scaler, patient_data, feature_columns)
|
||||
risk_percent = probability[1] * 100
|
||||
no_risk_percent = probability[0] * 100
|
||||
|
||||
if "风险" in command or "患病" in command or "probability" in command:
|
||||
if risk_percent > 50:
|
||||
return f"根据您提供的信息,您患有心血管疾病的风险为 {risk_percent:.1f}%,风险较高,建议尽快就医检查。"
|
||||
else:
|
||||
return f"根据您提供的信息,您患有心血管疾病的风险为 {risk_percent:.1f}%,风险较低,请继续保持健康的生活方式。"
|
||||
|
||||
elif "正常" in command or "健康" in command or "healthy" in command:
|
||||
return f"您的健康概率为 {no_risk_percent:.1f}%,患病风险为 {risk_percent:.1f}%。"
|
||||
|
||||
elif "年龄" in command or "age" in command:
|
||||
return f"您的年龄是 {patient_data['age_years']} 岁。"
|
||||
|
||||
elif "血压" in command or "blood pressure" in command:
|
||||
return f"您的血压信息:收缩压 {patient_data['ap_hi']} mmHg,舒张压 {patient_data['ap_lo']} mmHg。"
|
||||
|
||||
elif "bmi" in command or "体重" in command:
|
||||
return f"您的 BMI 为 {patient_data['bmi']:.1f}。"
|
||||
|
||||
elif "胆固醇" in command or "cholesterol" in command:
|
||||
chol_map = {1: "正常", 2: "高于正常", 3: "远高于正常"}
|
||||
return f"您的胆固醇水平:{chol_map.get(patient_data['cholesterol'], '未知')}。"
|
||||
|
||||
elif "血糖" in command or "glucose" in command:
|
||||
gluc_map = {1: "正常", 2: "高于正常", 3: "远高于正常"}
|
||||
return f"您的血糖水平:{gluc_map.get(patient_data['gluc'], '未知')}。"
|
||||
|
||||
elif "建议" in command or "advice" in command or "recommend" in command:
|
||||
if risk_percent > 60:
|
||||
return "建议您:1. 立即就医进行全面检查 2. 控制饮食,减少高盐高脂食物 3. 适度运动 4. 戒烟限酒 5. 保持规律作息。"
|
||||
elif risk_percent > 30:
|
||||
return "建议您:1. 定期体检监测 2. 保持健康饮食 3. 坚持适度运动 4. 控制体重 5. 避免吸烟饮酒。"
|
||||
else:
|
||||
return "建议您:1. 保持当前健康生活方式 2. 均衡饮食 3. 适量运动 4. 定期体检。"
|
||||
|
||||
elif "帮助" in command or "help" in command:
|
||||
return """您可以询问以下信息:
|
||||
- 患病风险
|
||||
- 血压情况
|
||||
- BMI 数值
|
||||
- 胆固醇水平
|
||||
- 血糖水平
|
||||
- 健康建议
|
||||
- 我的年龄
|
||||
|
||||
请告诉我您想了解什么?"""
|
||||
|
||||
else:
|
||||
return f"抱歉,我没有理解您的指令。您可以问我关于患病风险、血压、BMI等问题。或者直接点击「开始语音评估」按钮进行心血管疾病风险评估。"
|
||||
|
||||
|
||||
# ============================================
|
||||
# Streamlit 页面配置
|
||||
# ============================================
|
||||
st.set_page_config(
|
||||
page_title="CardioAI 语音助手",
|
||||
page_icon="🎤",
|
||||
layout="wide"
|
||||
)
|
||||
|
||||
|
||||
# ============================================
|
||||
# 初始化会话状态
|
||||
# ============================================
|
||||
if 'patient_data' not in st.session_state:
|
||||
st.session_state.patient_data = None
|
||||
if 'voice_history' not in st.session_state:
|
||||
st.session_state.voice_history = []
|
||||
if 'model_trained' not in st.session_state:
|
||||
st.session_state.model_trained = False
|
||||
|
||||
|
||||
# ============================================
|
||||
# 主程序
|
||||
# ============================================
|
||||
def main():
|
||||
"""主程序入口"""
|
||||
|
||||
# 页面标题
|
||||
st.title("🎤 CardioAI 智能语音助手")
|
||||
st.markdown("---")
|
||||
st.markdown("**您的私人心血管健康顾问**")
|
||||
|
||||
# 数据路径
|
||||
DATA_PATH = "C:/Users/SAM/Desktop/sam_test/ai_code/aicodes/data/心血管疾病.xlsx"
|
||||
|
||||
# 加载和训练模型
|
||||
try:
|
||||
df = load_data(DATA_PATH)
|
||||
model, scaler, feature_columns, roc_auc = preprocess_data(df)
|
||||
st.session_state.model_trained = True
|
||||
st.session_state.model = model
|
||||
st.session_state.scaler = scaler
|
||||
st.session_state.feature_columns = feature_columns
|
||||
except Exception as e:
|
||||
st.error(f"❌ 系统初始化失败: {e}")
|
||||
return
|
||||
|
||||
# ============================================
|
||||
# 布局:侧边栏 + 主内容
|
||||
# ============================================
|
||||
st.sidebar.header("👤 患者信息录入")
|
||||
|
||||
# 患者信息表单
|
||||
with st.sidebar.form("patient_form"):
|
||||
st.subheader("基本信息")
|
||||
|
||||
age = st.number_input("年龄 (岁)", 20, 100, 50)
|
||||
gender = st.selectbox("性别", ["女性", "男性"])
|
||||
|
||||
st.subheader("身体指标")
|
||||
|
||||
ap_hi = st.number_input("收缩压 (mmHg)", 90, 250, 120)
|
||||
ap_lo = st.number_input("舒张压 (mmHg)", 60, 150, 80)
|
||||
height = st.number_input("身高 (cm)", 100, 220, 170)
|
||||
weight = st.number_input("体重 (kg)", 30, 200, 70)
|
||||
|
||||
st.subheader("生化指标")
|
||||
|
||||
cholesterol = st.selectbox("胆固醇水平", ["正常", "高于正常", "远高于正常"])
|
||||
gluc = st.selectbox("血糖水平", ["正常", "高于正常", "远高于正常"])
|
||||
|
||||
st.subheader("生活方式")
|
||||
|
||||
smoke = st.checkbox("吸烟")
|
||||
alco = st.checkbox("饮酒")
|
||||
active = st.checkbox("运动")
|
||||
|
||||
submit = st.form_submit_button("确认信息", type="primary")
|
||||
|
||||
if submit:
|
||||
# 构建患者数据
|
||||
gender_val = 1 if gender == "女性" else 2
|
||||
cholesterol_map = {"正常": 1, "高于正常": 2, "远高于正常": 3}
|
||||
gluc_map = {"正常": 1, "高于正常": 2, "远高于正常": 3}
|
||||
|
||||
bmi = weight / ((height / 100) ** 2)
|
||||
|
||||
st.session_state.patient_data = {
|
||||
'age_years': age,
|
||||
'gender': gender_val,
|
||||
'ap_hi': ap_hi,
|
||||
'ap_lo': ap_lo,
|
||||
'cholesterol': cholesterol_map[cholesterol],
|
||||
'gluc': gluc_map[gluc],
|
||||
'bmi': bmi,
|
||||
'smoke': 1 if smoke else 0,
|
||||
'alco': 1 if alco else 0,
|
||||
'active': 1 if active else 0
|
||||
}
|
||||
|
||||
# 初始风险评估
|
||||
probability = predict_risk(
|
||||
st.session_state.model,
|
||||
st.session_state.scaler,
|
||||
st.session_state.patient_data,
|
||||
st.session_state.feature_columns
|
||||
)
|
||||
|
||||
risk_msg = f"您好!我是 CardioAI 语音助手。根据您提供的信息,您的心血管疾病风险评估如下:患病风险为 {probability[1]*100:.1f}%,健康概率为 {probability[0]*100:.1f}%。您可以询问我任何关于您心血管健康的问题。"
|
||||
st.session_state.voice_history = [
|
||||
{"role": "assistant", "content": risk_msg}
|
||||
]
|
||||
|
||||
st.success("✅ 患者信息已录入,可以开始语音交互了!")
|
||||
|
||||
# ============================================
|
||||
# 主内容区域
|
||||
# ============================================
|
||||
|
||||
# 欢迎信息和风险评估展示
|
||||
st.markdown("### 🎯 心血管疾病风险评估")
|
||||
|
||||
if st.session_state.patient_data:
|
||||
# 预测风险
|
||||
probability = predict_risk(
|
||||
st.session_state.model,
|
||||
st.session_state.scaler,
|
||||
st.session_state.patient_data,
|
||||
st.session_state.feature_columns
|
||||
)
|
||||
|
||||
risk_percent = probability[1] * 100
|
||||
|
||||
# 风险仪表盘
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
# 风险仪表图
|
||||
fig_gauge = go.Figure(go.Indicator(
|
||||
mode="gauge+number",
|
||||
value=risk_percent,
|
||||
domain={'x': [0, 1], 'y': [0, 1]},
|
||||
title={'text': "心血管疾病风险率"},
|
||||
gauge={
|
||||
'axis': {'range': [0, 100]},
|
||||
'bar': {'color': "#e74c3c" if risk_percent > 50 else "#2ecc71"},
|
||||
'steps': [
|
||||
{'range': [0, 30], 'color': "#2ecc71"},
|
||||
{'range': [30, 60], 'color': "#f1c40f"},
|
||||
{'range': [60, 100], 'color': "#e74c3c"}
|
||||
],
|
||||
'threshold': {
|
||||
'line': {'color': "black", 'width': 4},
|
||||
'thickness': 0.75,
|
||||
'value': risk_percent
|
||||
}
|
||||
}
|
||||
))
|
||||
fig_gauge.update_layout(height=200)
|
||||
st.plotly_chart(fig_gauge, use_container_width=True)
|
||||
|
||||
with col2:
|
||||
# 风险等级和建议
|
||||
if risk_percent < 30:
|
||||
st.success("🟢 **低风险** - 请继续保持健康的生活方式")
|
||||
st.info("建议:均衡饮食、适量运动、定期体检")
|
||||
elif risk_percent < 60:
|
||||
st.warning("🟡 **中风险** - 建议关注心血管健康")
|
||||
st.info("建议:控制饮食、适度运动、定期监测血压")
|
||||
else:
|
||||
st.error("🔴 **高风险** - 建议及时就医")
|
||||
st.info("建议:尽快就医、严格控制饮食、戒烟限酒")
|
||||
|
||||
st.markdown(f"**模型 ROC-AUC**: {roc_auc:.4f}")
|
||||
else:
|
||||
st.info("👈 请在左侧填写患者信息后开始评估")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# ============================================
|
||||
# 语音交互区域
|
||||
# ============================================
|
||||
st.markdown("### 🎤 语音交互")
|
||||
|
||||
# 快速指令按钮
|
||||
st.markdown("#### 快速指令")
|
||||
quick_commands = [
|
||||
("查询风险", "我的患病风险是多少?"),
|
||||
("查询血压", "我的血压情况如何?"),
|
||||
("查询BMI", "我的BMI是多少?"),
|
||||
("健康建议", "请给我一些健康建议"),
|
||||
("帮助", "帮助")
|
||||
]
|
||||
|
||||
cols = st.columns(5)
|
||||
for i, (label, cmd) in enumerate(quick_commands):
|
||||
if cols[i].button(label):
|
||||
if st.session_state.patient_data:
|
||||
response = process_voice_command(
|
||||
cmd,
|
||||
st.session_state.patient_data,
|
||||
st.session_state.model,
|
||||
st.session_state.scaler,
|
||||
st.session_state.feature_columns
|
||||
)
|
||||
st.session_state.voice_history.append(
|
||||
{"role": "user", "content": cmd}
|
||||
)
|
||||
st.session_state.voice_history.append(
|
||||
{"role": "assistant", "content": response}
|
||||
)
|
||||
else:
|
||||
st.warning("请先录入患者信息")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# 自定义语音输入
|
||||
st.markdown("#### 自定义语音输入")
|
||||
|
||||
# 模拟语音输入 (实际项目中可接入语音识别API)
|
||||
voice_input = st.text_input(
|
||||
"输入您的问题:",
|
||||
placeholder="例如:我的胆固醇水平如何?",
|
||||
key="voice_input"
|
||||
)
|
||||
|
||||
if st.button("发送", type="primary") and voice_input:
|
||||
if st.session_state.patient_data:
|
||||
response = process_voice_command(
|
||||
voice_input,
|
||||
st.session_state.patient_data,
|
||||
st.session_state.model,
|
||||
st.session_state.scaler,
|
||||
st.session_state.feature_columns
|
||||
)
|
||||
st.session_state.voice_history.append(
|
||||
{"role": "user", "content": voice_input}
|
||||
)
|
||||
st.session_state.voice_history.append(
|
||||
{"role": "assistant", "content": response}
|
||||
)
|
||||
else:
|
||||
st.warning("请先录入患者信息")
|
||||
|
||||
# ============================================
|
||||
# 对话历史
|
||||
# ============================================
|
||||
st.markdown("---")
|
||||
st.markdown("### 💬 对话历史")
|
||||
|
||||
for msg in st.session_state.voice_history:
|
||||
if msg["role"] == "user":
|
||||
st.markdown(f"**👤 您**: {msg['content']}")
|
||||
else:
|
||||
st.markdown(f"**🤖 CardioAI**: {msg['content']}")
|
||||
st.markdown("")
|
||||
|
||||
# 清空对话
|
||||
if st.button("清空对话历史"):
|
||||
st.session_state.voice_history = []
|
||||
st.rerun()
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# ============================================
|
||||
# 使用说明
|
||||
# ============================================
|
||||
with st.expander("📖 使用说明"):
|
||||
st.markdown("""
|
||||
### CardioAI 语音助手使用指南
|
||||
|
||||
1. **录入患者信息**: 在左侧边栏填写患者的基本信息、身体指标、生化指标和生活方式。
|
||||
|
||||
2. **开始评估**: 点击「确认信息」按钮后,系统会自动进行心血管疾病风险评估。
|
||||
|
||||
3. **语音交互**: 您可以:
|
||||
- 使用快速指令按钮快速查询
|
||||
- 输入自定义问题
|
||||
|
||||
4. **支持的查询**:
|
||||
- 患病风险和健康概率
|
||||
- 血压、BMI、胆固醇、血糖等指标
|
||||
- 健康建议
|
||||
|
||||
### 风险等级说明
|
||||
- 🟢 **低风险** (0-30%): 继续保持健康生活方式
|
||||
- 🟡 **中风险** (30-60%): 建议关注并改善生活方式
|
||||
- 🔴 **高风险** (>60%): 建议及时就医检查
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user