113 lines
4.4 KiB
Python
113 lines
4.4 KiB
Python
# 先安装: pip install openai
|
||
import os
|
||
import dotenv
|
||
from sklearn.metrics import accuracy_score
|
||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||
import pandas as pd
|
||
import json
|
||
from langchain_openai import ChatOpenAI
|
||
|
||
|
||
dotenv.load_dotenv()
|
||
|
||
# 获取环境变量
|
||
BASE_URL = os.getenv("BASE_URL")
|
||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
|
||
DEV2FILE = os.getenv("DEV2FILE")
|
||
CLASS_FILE = os.getenv("CLASS_FILE")
|
||
|
||
|
||
# 提前定义提示词
|
||
SYSTEM_PROMPT = """
|
||
你是一个专门用于中文新闻标题分类的助手,能够将给定的新闻标题准确地分类到以下十个预定义的类别之一:
|
||
finance(财经)、realty(房产)、stocks(股市)、education(教育)、science(科技)、society(社会)、politics(政治)、sports(体育)、game(游戏)和entertainment(娱乐)。
|
||
请根据标题内容和以下关键词与示例,匹配最相关的类别。如果标题涉及教育机构但核心是社会贡献,优先归为 society。
|
||
返回 JSON 格式:{"category": "类别", "reason": "分类原因"}
|
||
|
||
例如:
|
||
- 输入:“同步A股首秀:港股缩量回调”,应返回{"category": "stocks", "reason": "股市"}
|
||
- 输入:“布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击”,应返回“sports”。{"category": "sports", "reason": "体育"}
|
||
|
||
请注意,只从上述提供的十个类别中选择一个最合适的进行返回。如果标题与多个类别相关,请选择最相关的一个。请直接返回类别名称,无需额外解释!!!
|
||
"""
|
||
|
||
# 加载配置和类别映射
|
||
id2name = {i: line.strip() for i, line in enumerate(open(CLASS_FILE, encoding="utf-8"))}
|
||
# name2id{标签:索引}
|
||
name2id = {v: k for k, v in id2name.items()}
|
||
|
||
# 创建 LLM 对象
|
||
llm = ChatOpenAI(
|
||
base_url=BASE_URL,
|
||
api_key=DEEPSEEK_API_KEY,
|
||
model="deepseek-chat",
|
||
model_kwargs={"response_format": {"type": "json_object"}}
|
||
)
|
||
|
||
|
||
# 定义带重试机制的 LLM 调用函数
|
||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
||
def invoke_llm(prompt):
|
||
"""调用 LLM
|
||
# stop_after_attempt:若失败则重试最多3次
|
||
# wait_fixed:每次间隔2秒"""
|
||
return llm.invoke(prompt)
|
||
|
||
|
||
# 调用模型进行新闻标题分类获取响应
|
||
def get_deepseek_res(title: str) -> dict:
|
||
"""将新闻标题分类到指定类别,返回分类结果"""
|
||
# 定义提示词,包含类别说明和示例
|
||
prompt = [
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user", "content": f"新闻标题:'{title}',请分类并说明原因。"}
|
||
]
|
||
# print(prompt) # [{"role":'system',"content":提示词},{"role":'user',"content":''体验2D巅峰 倚天屠龙记十大创新概览''}]
|
||
|
||
# 调用 LLM 获取分类结果
|
||
response = invoke_llm(prompt)
|
||
result = json.loads(response.content)
|
||
|
||
# 返回分类结果,包含类别和原因
|
||
return {
|
||
"category": result.get("category", "society"), # 默认 社会society
|
||
"reason": result.get("reason", "未明确分类,归为社会类别")
|
||
}
|
||
|
||
|
||
# 分类并进行评估
|
||
def evaluate_classification():
|
||
"""对新闻标题分类进行评估"""
|
||
# 读取 tab 分隔的新闻标题和标签数据
|
||
df = pd.read_csv(DEV2FILE, sep='\t', header=None, names=['title', 'label'], encoding='utf-8')
|
||
# 拆包
|
||
titles = df['title'].tolist()
|
||
labels = df['label'].tolist()
|
||
print("titles-->", titles)
|
||
print("labels-->", labels)
|
||
# 如果多的话,只测试前 2 个
|
||
# titles = titles[:2]
|
||
# labels = labels[:2]
|
||
# 定义列表用于存储预测结果
|
||
pred_labels = []
|
||
# 遍历所有的标题,依次调用get_deepseek_res
|
||
for title in titles:
|
||
# 调用大模型获取结果
|
||
category_dict = get_deepseek_res(title)
|
||
print('category_dict-->', category_dict) # {'category': 'game', 'reason': '游戏'}...{'category': 'society', 'reason': '社会'}
|
||
# 获取预测类别,如果没有默认设置,则默认为 society社会
|
||
pred_category = category_dict.get("category", "society")
|
||
print('pred_category-->', pred_category) # game...society
|
||
# 添加到预测结果列表中 注意: name2id格式{标签:索引}
|
||
pred_labels.append(name2id[pred_category])
|
||
# 最终打印或者返回预测结果
|
||
print('pred_labels-->', pred_labels) # [8,5]
|
||
# 准确率
|
||
print('accuracy_score-->', accuracy_score(pred_labels, labels))
|
||
|
||
|
||
# 测试
|
||
if __name__ == '__main__':
|
||
# 测试评估
|
||
evaluate_classification()
|