113 lines
4.4 KiB
Python
113 lines
4.4 KiB
Python
|
|
# 先安装: pip install openai
|
|||
|
|
import os
|
|||
|
|
import dotenv
|
|||
|
|
from sklearn.metrics import accuracy_score
|
|||
|
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
|||
|
|
import pandas as pd
|
|||
|
|
import json
|
|||
|
|
from langchain_openai import ChatOpenAI
|
|||
|
|
|
|||
|
|
|
|||
|
|
dotenv.load_dotenv()
|
|||
|
|
|
|||
|
|
# 获取环境变量
|
|||
|
|
BASE_URL = os.getenv("BASE_URL")
|
|||
|
|
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
|
|||
|
|
DEV2FILE = os.getenv("DEV2FILE")
|
|||
|
|
CLASS_FILE = os.getenv("CLASS_FILE")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 提前定义提示词
|
|||
|
|
SYSTEM_PROMPT = """
|
|||
|
|
你是一个专门用于中文新闻标题分类的助手,能够将给定的新闻标题准确地分类到以下十个预定义的类别之一:
|
|||
|
|
finance(财经)、realty(房产)、stocks(股市)、education(教育)、science(科技)、society(社会)、politics(政治)、sports(体育)、game(游戏)和entertainment(娱乐)。
|
|||
|
|
请根据标题内容和以下关键词与示例,匹配最相关的类别。如果标题涉及教育机构但核心是社会贡献,优先归为 society。
|
|||
|
|
返回 JSON 格式:{"category": "类别", "reason": "分类原因"}
|
|||
|
|
|
|||
|
|
例如:
|
|||
|
|
- 输入:“同步A股首秀:港股缩量回调”,应返回{"category": "stocks", "reason": "股市"}
|
|||
|
|
- 输入:“布拉特:放球员一条生路吧 FIFA能消化俱乐部的攻击”,应返回“sports”。{"category": "sports", "reason": "体育"}
|
|||
|
|
|
|||
|
|
请注意,只从上述提供的十个类别中选择一个最合适的进行返回。如果标题与多个类别相关,请选择最相关的一个。请直接返回类别名称,无需额外解释!!!
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 加载配置和类别映射
|
|||
|
|
id2name = {i: line.strip() for i, line in enumerate(open(CLASS_FILE, encoding="utf-8"))}
|
|||
|
|
# name2id{标签:索引}
|
|||
|
|
name2id = {v: k for k, v in id2name.items()}
|
|||
|
|
|
|||
|
|
# 创建 LLM 对象
|
|||
|
|
llm = ChatOpenAI(
|
|||
|
|
base_url=BASE_URL,
|
|||
|
|
api_key=DEEPSEEK_API_KEY,
|
|||
|
|
model="deepseek-chat",
|
|||
|
|
model_kwargs={"response_format": {"type": "json_object"}}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 定义带重试机制的 LLM 调用函数
|
|||
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
|||
|
|
def invoke_llm(prompt):
|
|||
|
|
"""调用 LLM
|
|||
|
|
# stop_after_attempt:若失败则重试最多3次
|
|||
|
|
# wait_fixed:每次间隔2秒"""
|
|||
|
|
return llm.invoke(prompt)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 调用模型进行新闻标题分类获取响应
|
|||
|
|
def get_deepseek_res(title: str) -> dict:
|
|||
|
|
"""将新闻标题分类到指定类别,返回分类结果"""
|
|||
|
|
# 定义提示词,包含类别说明和示例
|
|||
|
|
prompt = [
|
|||
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|||
|
|
{"role": "user", "content": f"新闻标题:'{title}',请分类并说明原因。"}
|
|||
|
|
]
|
|||
|
|
# print(prompt) # [{"role":'system',"content":提示词},{"role":'user',"content":''体验2D巅峰 倚天屠龙记十大创新概览''}]
|
|||
|
|
|
|||
|
|
# 调用 LLM 获取分类结果
|
|||
|
|
response = invoke_llm(prompt)
|
|||
|
|
result = json.loads(response.content)
|
|||
|
|
|
|||
|
|
# 返回分类结果,包含类别和原因
|
|||
|
|
return {
|
|||
|
|
"category": result.get("category", "society"), # 默认 社会society
|
|||
|
|
"reason": result.get("reason", "未明确分类,归为社会类别")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 分类并进行评估
|
|||
|
|
def evaluate_classification():
|
|||
|
|
"""对新闻标题分类进行评估"""
|
|||
|
|
# 读取 tab 分隔的新闻标题和标签数据
|
|||
|
|
df = pd.read_csv(DEV2FILE, sep='\t', header=None, names=['title', 'label'], encoding='utf-8')
|
|||
|
|
# 拆包
|
|||
|
|
titles = df['title'].tolist()
|
|||
|
|
labels = df['label'].tolist()
|
|||
|
|
print("titles-->", titles)
|
|||
|
|
print("labels-->", labels)
|
|||
|
|
# 如果多的话,只测试前 2 个
|
|||
|
|
# titles = titles[:2]
|
|||
|
|
# labels = labels[:2]
|
|||
|
|
# 定义列表用于存储预测结果
|
|||
|
|
pred_labels = []
|
|||
|
|
# 遍历所有的标题,依次调用get_deepseek_res
|
|||
|
|
for title in titles:
|
|||
|
|
# 调用大模型获取结果
|
|||
|
|
category_dict = get_deepseek_res(title)
|
|||
|
|
print('category_dict-->', category_dict) # {'category': 'game', 'reason': '游戏'}...{'category': 'society', 'reason': '社会'}
|
|||
|
|
# 获取预测类别,如果没有默认设置,则默认为 society社会
|
|||
|
|
pred_category = category_dict.get("category", "society")
|
|||
|
|
print('pred_category-->', pred_category) # game...society
|
|||
|
|
# 添加到预测结果列表中 注意: name2id格式{标签:索引}
|
|||
|
|
pred_labels.append(name2id[pred_category])
|
|||
|
|
# 最终打印或者返回预测结果
|
|||
|
|
print('pred_labels-->', pred_labels) # [8,5]
|
|||
|
|
# 准确率
|
|||
|
|
print('accuracy_score-->', accuracy_score(pred_labels, labels))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 测试
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
# 测试评估
|
|||
|
|
evaluate_classification()
|