diff --git a/data_chain/apps/service/router_service.py b/data_chain/apps/service/router_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8e4ca626afaca6eceda938a70fcd7a8c91991fa6 --- /dev/null +++ b/data_chain/apps/service/router_service.py @@ -0,0 +1,12 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from fastapi import FastAPI, Request, Depends +from fastapi.routing import APIRoute + + +def get_route_info(request: Request): + route = request.scope.get("route") + if isinstance(route, APIRoute): + request_method = request.method + route_path = route.path + return request_method+' '+route_path + return '' diff --git a/data_chain/apps/service/task_service.py b/data_chain/apps/service/task_service.py index 05c068c5cea8c751d908ab9929c33f7df853290d..e8219dffb9e480f24dac5a7391a3dc8963b4eb33 100644 --- a/data_chain/apps/service/task_service.py +++ b/data_chain/apps/service/task_service.py @@ -1,107 +1,92 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from data_chain.logger.logger import logger as logging -import traceback +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. import uuid -import pytz -from datetime import datetime -from data_chain.apps.base.task.document_task_handler import DocumentTaskHandler -from data_chain.apps.base.task.knowledge_base_task_handler import KnowledgeBaseTaskHandler -from data_chain.apps.base.task.task_handler import TaskRedisHandler, TaskHandler +from data_chain.apps.service.task_queue_service import TaskQueueService +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ( + ListTaskRequest +) +from data_chain.entities.response_data import ( + ListTaskMsg) +from data_chain.entities.enum import TaskType, TaskStatus +from data_chain.entities.common import default_roles +from data_chain.stores.database.database import TeamEntity +from data_chain.apps.base.convertor import Convertor from data_chain.manager.task_manager import TaskManager -from data_chain.models.constant import TaskConstant, TaskActionEnum -from data_chain.exceptions.exception import TaskException -from data_chain.config.config import config - +from data_chain.manager.task_report_manager import TaskReportManager +from data_chain.manager.role_manager import RoleManager -async def _validate_task_belong_to_user(user_id: uuid.UUID, task_id: str) -> bool: - task_entity = await TaskManager.select_by_id(task_id) - if task_entity is None: - raise TaskException("Task not exist") - if task_entity.user_id != user_id: - raise TaskException("Task not belong to user") +class TaskService: + """任务服务""" + @staticmethod + async def validate_user_action_to_task( + user_sub: str, task_id: uuid.UUID, action: str) -> bool: + """验证用户在任务中的操作权限""" + try: + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + logging.exception("[TaskService] 任务不存在") + raise Exception("Task not exist") + action_entity = await RoleManager.get_action_by_team_id_user_sub_and_action( + user_sub, task_entity.team_id, action) + if action_entity is None: + return False + return True + except Exception as e: + err = "验证用户在任务中的操作权限失败" + logging.exception("[TaskService] %s", err) + raise e -async def task_queue_handler(): - # 处理成功的任务 - success_task_ids = TaskRedisHandler.select_all_task(config['REDIS_SUCCESS_TASK_QUEUE_NAME']) - for task_id in success_task_ids: - TaskRedisHandler.remove_task_by_task_id(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], task_id) - TaskHandler.remove_task(uuid.UUID(task_id)) - task_entity = await TaskManager.select_by_id(uuid.UUID(task_id)) - if task_entity is None or task_entity.status != TaskConstant.TASK_STATUS_SUCCESS: - TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], task_id) - # 处理需要重新开始的任务 - restart_task_ids = TaskRedisHandler.select_all_task(config['REDIS_RESTART_TASK_QUEUE_NAME']) - for task_id in restart_task_ids: - TaskRedisHandler.remove_task_by_task_id(config['REDIS_RESTART_TASK_QUEUE_NAME'], task_id) - task_entity = await TaskManager.select_by_id(uuid.UUID(task_id)) - if task_entity is None or task_entity.status != TaskConstant.TASK_STATUS_RUNNING: - TaskHandler.remove_task(uuid.UUID(task_id)) - TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], task_id) - else: - await TaskHandler.restart_or_clear_task(uuid.UUID(task_id)) - # 静默处理状态错误的任务 - silent_error_task_ids = TaskRedisHandler.select_all_task(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME']) - for task_id in silent_error_task_ids: - TaskRedisHandler.remove_task_by_task_id(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], task_id) - task_entity = await TaskManager.select_by_id(task_id) - if task_entity is None: - continue - elif task_entity.status == TaskConstant.TASK_STATUS_DELETED: - await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.DELETE) - elif task_entity.status == TaskConstant.TASK_STATUS_CANCELED: - await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.CANCEL) - # 处理等待的任务 - task_start_cnt = 0 - while True: - task_id = TaskRedisHandler.get_task_by_head(config['REDIS_PENDING_TASK_QUEUE_NAME']) - if task_id is None: - break - task_id = uuid.UUID(task_id) - task_entity = await TaskManager.select_by_id(task_id) - if task_entity is None or task_entity.status != TaskConstant.TASK_STATUS_PENDING: - TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], task_id) - continue + async def list_task(user_sub: str, req: ListTaskRequest) -> ListTaskMsg: + """列出任务""" try: - func = None - # 处理上传资产库任务 - if task_entity.type == TaskConstant.IMPORT_KNOWLEDGE_BASE: - func = KnowledgeBaseTaskHandler.handle_import_knowledge_base_task - # 处理打包资产库任务 - elif task_entity.type == TaskConstant.EXPORT_KNOWLEDGE_BASE: - func = KnowledgeBaseTaskHandler.handle_export_knowledge_base_task - # 处理临时文件解析任务 - elif task_entity.type == TaskConstant.PARSE_TEMPORARY_DOCUMENT: - func = DocumentTaskHandler.handle_parser_temporary_document_task - # 处理文档解析任务 - elif task_entity.type == TaskConstant.PARSE_DOCUMENT: - func = DocumentTaskHandler.handle_parser_document_task - if not TaskHandler.add_task(task_id, target=func, t_id=task_entity.id): - TaskRedisHandler.put_task_by_tail(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_id)) - break - await TaskManager.update(task_id, {'status': TaskConstant.TASK_STATUS_RUNNING}) + total, task_entities = await TaskManager.list_task(user_sub, req) + tasks = [] + task_ids = [task_entity.id for task_entity in task_entities] + task_report_entities = await TaskReportManager.list_current_task_report_by_task_ids(task_ids) + task_report_dict = {task_report_entity.task_id: task_report_entity + for task_report_entity in task_report_entities} + for task_entity in task_entities: + task_report = task_report_dict.get(task_entity.id, None) + task = await Convertor.convert_task_entity_to_task(task_entity, task_report) + tasks.append(task) + return ListTaskMsg(total=total, tasks=tasks) except Exception as e: - logging.error(f"Handle task {task_id} error: {e}") - await TaskHandler.restart_or_clear_task(task_id) - break - task_start_cnt += 1 - if task_start_cnt == 8: - break + err = "列出任务失败" + logging.exception("[TaskService] %s", err) + raise e + @staticmethod + async def delete_task_by_task_id( + task_id: uuid.UUID) -> uuid.UUID: + """根据任务ID删除任务""" + try: + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + err = "任务不存在" + logging.exception("[TaskService] %s", err) + raise Exception(err) + task_id = await TaskQueueService.delete_task(task_id) + return task_id + except Exception as e: + err = "删除任务失败" + logging.exception("[TaskService] %s", err) + raise e -async def monitor_tasks(): - task_id_list = TaskHandler.list_tasks() - for task_id in task_id_list: - task_entity = await TaskManager.select_by_id(task_id) - if task_entity is None: - continue - current_time_utc = datetime.now(pytz.utc) - time_difference = current_time_utc - task_entity.created_time - seconds = time_difference.total_seconds() - if seconds > 12*3600: - if task_entity.status==TaskConstant.TASK_STATUS_DELETED: - await TaskHandler.restart_or_clear_task(task_id,TaskActionEnum.DELETE) - elif task_entity.status==TaskConstant.TASK_STATUS_CANCELED: - await TaskHandler.restart_or_clear_task(task_id,TaskActionEnum.CANCEL) - else: - await TaskHandler.restart_or_clear_task(task_id) + @staticmethod + async def delete_task_by_type( + user_sub: str, team_id: uuid.UUID, task_type: TaskType) -> list[uuid.UUID]: + """根据任务类型删除任务""" + try: + task_entities = await TaskManager.list_task_by_user_sub_and_team_id_and_type( + user_sub, team_id, task_type) + task_ids = [] + for task_entity in task_entities: + task_id = await TaskQueueService.delete_task(task_entity.id) + if task_id is not None: + task_ids.append(task_id) + return task_ids + except Exception as e: + err = "删除任务失败" + logging.exception("[TaskService] %s", err) + raise e diff --git a/data_chain/apps/service/team_service.py b/data_chain/apps/service/team_service.py new file mode 100644 index 0000000000000000000000000000000000000000..9404c6a55033689771caf02e5de0346ee80afabb --- /dev/null +++ b/data_chain/apps/service/team_service.py @@ -0,0 +1,102 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import uuid +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ListTeamRequest, CreateTeamRequest +from data_chain.entities.response_data import ListTeamMsg +from data_chain.entities.enum import TeamType, TeamStatus +from data_chain.entities.common import default_roles +from data_chain.stores.database.database import TeamEntity +from data_chain.apps.base.convertor import Convertor +from data_chain.manager.team_manager import TeamManager +from data_chain.manager.role_manager import RoleManager + + +class TeamService: + """团队服务""" + @staticmethod + async def validate_user_action_in_team( + user_sub: str, team_id: uuid.UUID, action: str) -> bool: + """验证用户在团队中的操作权限""" + try: + action_entity = await RoleManager.get_action_by_team_id_user_sub_and_action(user_sub, team_id, action) + if action_entity is None: + return False + return True + except Exception as e: + err = "验证用户在团队中的操作权限失败" + logging.exception("[TeamService] %s", err) + raise e + + @staticmethod + async def list_teams(user_sub: str, req: ListTeamRequest) -> ListTeamMsg: + """列出团队""" + if req.team_type == TeamType.MYCREATED: + total, team_entities = await TeamManager.list_team_mycreated_user_sub(user_sub, req) + elif req.team_type == TeamType.MYJOINED: + total, team_entities = await TeamManager.list_team_myjoined_by_user_sub(user_sub, req) + else: + total, team_entities = await TeamManager.list_pulic_team(user_sub, req) + teams = [] + for team_entity in team_entities: + team = await Convertor.convert_team_entity_to_team(team_entity) + teams.append(team) + return ListTeamMsg(total=total, teams=teams) + + @staticmethod + async def create_team(user_sub: str, req: CreateTeamRequest) -> uuid.UUID: + """创建团队""" + try: + team_entity = await Convertor.convert_create_team_request_to_team_entity(user_sub, req) + team_entity = await TeamManager.add_team(team_entity) + team_user_entity = await Convertor.convert_user_sub_and_team_id_to_team_user_entity(user_sub, team_entity.id) + await TeamManager.add_team_user(team_user_entity) + become_creator_flag = False + for role_dict in default_roles: + role_entity = await Convertor.convert_default_role_dict_to_role_entity(team_entity.id, role_dict) + role_entity = await RoleManager.add_role(role_entity) + if not become_creator_flag: + user_role_entity = await Convertor.convert_user_sub_role_id_and_team_id_to_user_role_entity( + user_sub, role_entity.id, team_entity.id) + await RoleManager.add_user_role(user_role_entity) + become_creator_flag = True + role_action_entities = await Convertor.convert_default_role_action_dicts_to_role_action_entities(role_entity.id, role_dict['actions']) + await RoleManager.add_role_actions(role_action_entities) + return team_entity.id + except Exception as e: + err = "创建团队失败" + logging.exception("[TeamService] %s", err) + raise e + + @staticmethod + async def update_team_by_team_id( + user_sub: str, team_id: uuid.UUID, req: CreateTeamRequest) -> bool: + """更新团队""" + try: + team_dict = await Convertor.convert_update_team_request_to_dict(req) + team_entity = await TeamManager.update_team_by_id(team_id, team_dict) + if team_entity is None: + err = "更新团队失败" + logging.exception("[TeamService] %s", err) + raise "更新团队失败, 团队不存在" + return team_entity.id + except Exception as e: + err = "更新团队失败" + logging.exception("[TeamService] %s", err) + raise e + + @staticmethod + async def soft_delete_team_by_team_id( + team_id: uuid.UUID) -> bool: + """软删除团队""" + try: + team_entity = await TeamManager.update_team_by_id( + team_id, {"status": TeamStatus.DELETED.value}) + if team_entity is None: + err = "软删除团队失败" + logging.exception("[TeamService] %s", err) + raise "软删除团队失败, 团队不存在" + return team_entity.id + except Exception as e: + err = "软删除团队失败" + logging.exception("[TeamService] %s", err) + raise e diff --git a/data_chain/parser/tools/image_tool.py b/data_chain/parser/tools/image_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..4c566826a3baf25ef4fa6f87fe2d308234752da9 --- /dev/null +++ b/data_chain/parser/tools/image_tool.py @@ -0,0 +1,16 @@ +from data_chain.logger.logger import logger as logging + + +class ImageTool: + @staticmethod + def get_image_type(b): + hex_str = bytes.hex(b).upper() + if "FFD8FF" in hex_str: + return "jpg" + elif "89504E47" in hex_str: + return "png" + elif "47494638" in hex_str: + return "gif" + elif "424D" in hex_str: + return "bmp" + return "jpeg" diff --git a/data_chain/parser/tools/ocr_tool.py b/data_chain/parser/tools/ocr_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..11739632af28721abdc0521a951cb06f2a3b2633 --- /dev/null +++ b/data_chain/parser/tools/ocr_tool.py @@ -0,0 +1,91 @@ +from PIL import Image +import asyncio +import yaml +from paddleocr import PaddleOCR +import numpy as np +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.logger.logger import logger as logging +from data_chain.config.config import config +from data_chain.llm.llm import LLM + + +class OcrTool: + det_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_det_infer' + rec_model_dir = 'data_chain/parser/model/ocr/ch_PP-OCRv4_rec_infer' + cls_model_dir = 'data_chain/parser/model/ocr/ch_ppocr_mobile_v2.0_cls_infer' + model = PaddleOCR( + det_model_dir=det_model_dir, + rec_model_dir=rec_model_dir, + cls_model_dir=cls_model_dir, + use_angle_cls=True, # 是否使用角度分类模型 + use_space_char=True # 是否使用空格字符 + ) + + @staticmethod + async def ocr_from_image(image: np.ndarray) -> list: + try: + ocr_result = OcrTool.model.ocr(image) + if ocr_result is None or ocr_result[0] is None: + return None + return ocr_result + except Exception as e: + err = f"[OCRTool] OCR识别失败 {e}" + logging.exception(err) + return None + + @staticmethod + async def merge_text_from_ocr_result(ocr_result: list) -> str: + text = '' + try: + for _ in ocr_result[0]: + text += str(_[1][0]) + return text + except Exception as e: + err = f"[OCRTool] OCR结果合并失败 {e}" + logging.exception(err) + return '' + + @staticmethod + async def enhance_ocr_result(ocr_result, image_related_text='', llm: LLM = None) -> str: + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt_template = prompt_dict.get('OCR_ENHANCED_PROMPT', '') + pre_part_description = "" + token_limit = llm.max_tokens//2 + image_related_text = TokenTool.get_k_tokens_words_from_content(image_related_text, token_limit) + ocr_result_parts = TokenTool.split_str_with_slide_window(str(ocr_result), token_limit) + user_call = '请详细输出图片的摘要,不要输出其他内容' + for part in ocr_result_parts: + pre_part_description_cp = pre_part_description + try: + prompt = prompt_template.format( + image_related_text=image_related_text, + pre_part_description=pre_part_description, + part=part) + pre_part_description = await llm.nostream([], prompt, user_call) + except Exception as e: + err = f"[OCRTool] OCR增强失败 {e}" + logging.exception(err) + pre_part_description = pre_part_description_cp + return pre_part_description + except Exception as e: + err = f"[OCRTool] OCR增强失败 {e}" + logging.exception(err) + return OcrTool.merge_text_from_ocr_result(ocr_result) + + @staticmethod + async def image_to_text(image: np.ndarray, image_related_text: str = '', llm: LLM = None) -> str: + try: + ocr_result = await OcrTool.ocr_from_image(image) + if ocr_result is None: + return '' + if llm is None: + text = await OcrTool.merge_text_from_ocr_result(ocr_result) + else: + text = await OcrTool.enhance_ocr_result(ocr_result, image_related_text, llm) + return text + except Exception as e: + err = f"[OCRTool] 图片转文本失败 {e}" + logging.exception(err) + return '' diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..348feb5145ec6e30bb3d536d79e8ab7627f57bfa --- /dev/null +++ b/data_chain/parser/tools/token_tool.py @@ -0,0 +1,526 @@ +import asyncio +import tiktoken +import jieba +from jieba.analyse import extract_tags +import yaml +import json +import re +import uuid +import numpy as np +from pydantic import BaseModel, Field +from data_chain.llm.llm import LLM +from data_chain.embedding.embedding import Embedding +from data_chain.config.config import config +from data_chain.logger.logger import logger as logging + + +class Grade(BaseModel): + content_len: int = Field(..., description="内容长度") + tokens: int = Field(..., description="token数") + + +class TokenTool: + stop_words_path = config['STOP_WORDS_PATH'] + with open(stop_words_path, 'r', encoding='utf-8') as f: + stopwords = set(line.strip() for line in f) + + @staticmethod + def get_leave_tokens_from_content_len(content: str) -> int: + """ + 根据内容长度获取留存的token数 + """ + grades = [ + Grade(content_len=0, tokens=0), + Grade(content_len=10, tokens=8), + Grade(content_len=50, tokens=16), + Grade(content_len=250, tokens=32), + Grade(content_len=1250, tokens=64), + Grade(content_len=6250, tokens=128), + Grade(content_len=31250, tokens=256), + Grade(content_len=156250, tokens=512), + Grade(content_len=781250, tokens=1024), + ] + tokens = TokenTool.get_tokens(content) + if tokens >= grades[-1].tokens: + return 1024 + index = 0 + for i in range(len(grades)-1): + if grades[i].content_len <= tokens < grades[i+1].content_len: + index = i + break + leave_tokens = grades[index].tokens+(grades[index+1].tokens-grades[index].tokens)*( + tokens-grades[index].content_len)/(grades[index+1].content_len-grades[index].content_len) + return int(leave_tokens) + + @staticmethod + def get_leave_setences_from_content_len(content: str) -> int: + """ + 根据内容长度获取留存的句子数量 + """ + grades = [ + Grade(content_len=0, tokens=0), + Grade(content_len=10, tokens=4), + Grade(content_len=50, tokens=8), + Grade(content_len=250, tokens=16), + Grade(content_len=1250, tokens=32), + Grade(content_len=6250, tokens=64), + Grade(content_len=31250, tokens=128), + Grade(content_len=156250, tokens=256), + Grade(content_len=781250, tokens=512), + ] + sentences = TokenTool.content_to_sentences(content) + if len(sentences) >= grades[-1].tokens: + return 1024 + index = 0 + for i in range(len(grades)-1): + if grades[i].content_len <= len(sentences) < grades[i+1].content_len: + index = i + break + leave_sentences = grades[index].tokens+(grades[index+1].tokens-grades[index].tokens)*( + len(sentences)-grades[index].content_len)/(grades[index+1].content_len-grades[index].content_len) + return int(leave_sentences) + + @staticmethod + def get_tokens(content: str) -> int: + try: + enc = tiktoken.encoding_for_model("gpt-4") + return len(enc.encode(str(content))) + except Exception as e: + err = f"[TokenTool] 获取token失败 {e}" + logging.exception("[TokenTool] %s", err) + return 0 + + @staticmethod + def get_k_tokens_words_from_content(content: str, k: int = 16) -> list: + try: + if (TokenTool.get_tokens(content) <= k): + return content + l = 0 + r = len(content) + while l+1 < r: + mid = (l+r)//2 + if (TokenTool.get_tokens(content[:mid]) <= k): + l = mid + else: + r = mid + return content[:l] + except Exception as e: + err = f"[TokenTool] 获取k个token的词失败 {e}" + logging.exception("[TokenTool] %s", err) + return "" + + @staticmethod + def split_str_with_slide_window(content: str, slide_window_size: int) -> list: + """ + 将字符串按滑动窗口切割 + """ + result = [] + try: + while len(content) > 0: + sub_content = TokenTool.get_k_tokens_words_from_content(content, slide_window_size) + result.append(sub_content) + content = content[len(sub_content):] + return result + except Exception as e: + err = f"[TokenTool] 滑动窗口切割失败 {e}" + logging.exception("[TokenTool] %s", err) + return [] + + @staticmethod + def split_words(content: str) -> list: + try: + return list(jieba.cut(str(content))) + except Exception as e: + err = f"[TokenTool] 分词失败 {e}" + logging.exception("[TokenTool] %s", err) + return [] + + @staticmethod + def get_top_k_keywords(content: str, k=10) -> list: + try: + # 使用jieba提取关键词 + keywords = extract_tags(content, topK=k, withWeight=True) + return [keyword for keyword, weight in keywords] + except Exception as e: + err = f"[TokenTool] 获取关键词失败 {e}" + logging.exception("[TokenTool] %s", err) + return [] + + @staticmethod + def compress_tokens(content: str, k: int = None) -> str: + try: + words = TokenTool.split_words(content) + # 过滤掉停用词 + filtered_words = [ + word for word in words if word not in TokenTool.stopwords + ] + filtered_content = ''.join(filtered_words) + if k is not None: + # 如果k不为None,则获取k个token的词 + filtered_content = TokenTool.get_k_tokens_words_from_content(filtered_content, k) + return filtered_content + except Exception as e: + err = f"[TokenTool] 压缩token失败 {e}" + logging.exception("[TokenTool] %s", err) + return content + + @staticmethod + def content_to_sentences(content: str) -> list: + """ + 基于特殊符号例如句号 感叹号等将段落分割为句子 + """ + # 常见缩写或不应切断的结构 + protected_phrases = [ + 'e.g.', 'i.e.', 'U.S.', 'U.K.', 'A.M.', 'P.M.', 'a.m.', 'p.m.', 'FY25Q2', + 'KPI', 'CI/CD', 'A/B test', 'PRD', 'PDF', 'API', 'OMG', 'TBD', 'EOM', + 'Inc.', 'Ltd.', 'No.', 'vs.', 'approx.', 'Dr.', 'Mr.', 'Ms.', 'Prof.', + ] + + # 替换为唯一占位符 + placeholder_map = {} + for phrase in protected_phrases: + placeholder = f"__PROTECTED_{uuid.uuid4().hex}__" + placeholder_map[placeholder] = phrase + content = content.replace(phrase, placeholder) + + # 分句正则模式 + pattern = pattern = re.compile( + r'(?<=[。!?!?;;][”’"\')】】》〕〉)\]])' # 标点+右引号/右括号后切 + r'|(?<=[。!?!?;;])(?=[^”’"\')】》〕〉)\]])' # 单个标点后,未跟右引号也可切 + r'|(?<=[\.\?!;])(?=\s|$)' # 英文标点后空格或结尾 + ) + + # 分割并还原 + sentences = [] + for segment in pattern.split(content): + segment = segment.strip() + if not segment: + continue + for placeholder, original in placeholder_map.items(): + segment = segment.replace(placeholder, original) + sentences.append(segment) + + return sentences + + @staticmethod + def get_top_k_keysentence(content: str, k: int = None) -> list: + """ + 获取前k个关键句子 + """ + if k is None: + k = TokenTool.get_leave_setences_from_content_len(content) + leave_tokens = TokenTool.get_leave_tokens_from_content_len(content) + words = TokenTool.split_words(content) + # 过滤掉停用词 + filtered_words = [ + word for word in words if word not in TokenTool.stopwords + ] + keywords = TokenTool.get_top_k_keywords(''.join(filtered_words), leave_tokens) + keywords = set(keywords) + sentences = TokenTool.content_to_sentences(content) + sentence_and_score_list = [] + index = 0 + for sentence in sentences: + score = 0 + words = TokenTool.split_words(sentence) + for word in words: + if word in keywords: + score += 1 + sentence_and_score_list.append((index, sentence, score)) + index += 1 + sentence_and_score_list.sort(key=lambda x: x[1], reverse=True) + top_k_sentence_and_score_list = sentence_and_score_list[:k] + top_k_sentence_and_score_list.sort(key=lambda x: x[0]) + return [sentence for index, sentence, score in top_k_sentence_and_score_list] + + @staticmethod + async def get_abstract_by_llm(content: str, llm: LLM) -> str: + """ + 使用llm进行内容摘要 + """ + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt_template = prompt_dict.get('CONTENT_TO_ABSTRACT_PROMPT', '') + sentences = TokenTool.split_str_with_slide_window(content, llm.max_tokens//3*2) + abstract = '' + for sentence in sentences: + abstract = TokenTool.get_k_tokens_words_from_content(abstract, llm.max_tokens//3) + sys_call = prompt_template.format(content=sentence, abstract=abstract) + user_call = '请结合文本和摘要输出新的摘要' + abstract = await llm.nostream([], sys_call, user_call) + return abstract + except Exception as e: + err = f"[TokenTool] 获取摘要失败 {e}" + logging.exception("[TokenTool] %s", err) + + @staticmethod + async def get_title_by_llm(content: str, llm: LLM) -> str: + """ + 使用llm进行标题生成 + """ + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt_template = prompt_dict.get('CONTENT_TO_TITLE_PROMPT', '') + content = TokenTool.get_k_tokens_words_from_content(content, llm.max_tokens) + sys_call = prompt_template.format(content=content) + user_call = '请结合文本输出标题' + title = await llm.nostream([], sys_call, user_call) + return title + except Exception as e: + err = f"[TokenTool] 获取标题失败 {e}" + logging.exception("[TokenTool] %s", err) + + @staticmethod + async def cal_recall(answer_1: str, answer_2: str, llm: LLM) -> float: + """ + 计算recall + 参数: + answer_1:答案1 + answer_2:答案2 + llm:大模型 + """ + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt_template = prompt_dict.get('ANSWER_TO_ANSWER_PROMPT', '') + answer_1 = TokenTool.get_k_tokens_words_from_content(answer_1, llm.max_tokens//2) + answer_2 = TokenTool.get_k_tokens_words_from_content(answer_2, llm.max_tokens//2) + prompt = prompt_template.format(text_1=answer_1, text_2=answer_2) + sys_call = prompt + user_call = '请输出相似度' + similarity = await llm.nostream([], sys_call, user_call) + return eval(similarity) + except Exception as e: + err = f"[TokenTool] 计算recall失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 + + @staticmethod + async def cal_precision(question: str, content: str, llm: LLM) -> float: + """ + 计算precision + 参数: + question:问题 + content:内容 + """ + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt_template = prompt_dict.get('CONTENT_TO_STATEMENTS_PROMPT', '') + content = TokenTool.compress_tokens(content, llm.max_tokens) + sys_call = prompt_template.format(content=content) + user_call = '请结合文本输出陈诉列表' + statements = await llm.nostream([], sys_call, user_call) + statements = json.loads(statements) + if len(statements) == 0: + return 0 + score = 0 + prompt_template = prompt_dict.get('STATEMENTS_TO_QUESTION_PROMPT', '') + for statement in statements: + statement = TokenTool.get_k_tokens_words_from_content(statement, llm.max_tokens) + prompt = prompt_template.format(statement=statement, question=question) + sys_call = prompt + user_call = '请结合文本输出YES或NO' + yn = await llm.nostream([], sys_call, user_call) + yn = yn.lower() + if yn == 'yes': + score += 1 + return score/len(statements)*100 + except Exception as e: + err = f"[TokenTool] 计算precision失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 + + @staticmethod + async def cal_faithfulness(question: str, answer: str, content: str, llm: LLM) -> float: + """ + 计算faithfulness + 参数: + question:问题 + answer:答案 + """ + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt_template = prompt_dict.get('QA_TO_STATEMENTS_PROMPT', '') + question = TokenTool.get_k_tokens_words_from_content(question, llm.max_tokens//8) + answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens//8*7) + prompt = prompt_template.format(question=question, answer=answer) + sys_call = prompt + user_call = '请结合问题和答案输出陈诉' + statements = await llm.nostream([], sys_call, user_call) + prompt_template = prompt_dict.get('STATEMENTS_TO_FRAGMENT_PROMPT', '') + statements = json.loads(statements) + if len(statements) == 0: + return 0 + score = 0 + content = TokenTool.compress_tokens(content, llm.max_tokens//8*7) + for statement in statements: + statement = TokenTool.get_k_tokens_words_from_content(statement, llm.max_tokens//8) + prompt = prompt_template.format(statement=statement, fragment=content) + sys_call = prompt + user_call = '请输出YES或NO' + user_call = user_call + yn = await llm.nostream([], sys_call, user_call) + yn = yn.lower() + if yn == 'yes': + score += 1 + return score/len(statements)*100 + except Exception as e: + err = f"[TokenTool] 计算faithfulness失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 + + @staticmethod + def cosine_distance_numpy(vector1, vector2): + # 计算向量的点积 + dot_product = np.dot(vector1, vector2) + # 计算向量的 L2 范数 + norm_vector1 = np.linalg.norm(vector1) + norm_vector2 = np.linalg.norm(vector2) + # 计算余弦相似度 + cosine_similarity = dot_product / (norm_vector1 * norm_vector2) + # 计算余弦距离 + cosine_dist = 1 - cosine_similarity + return cosine_dist + + @staticmethod + async def cal_relevance(question: str, answer: str, llm: LLM) -> float: + """ + 计算relevance + 参数: + question:问题 + answer:答案 + """ + try: + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + prompt_template = prompt_dict.get('GENREATE_QUESTION_FROM_CONTENT_PROMPT', '') + answer = TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens) + sys_call = prompt_template.format(k=5, content=answer) + user_call = '请结合文本输出问题列表' + question_vector = await Embedding.vectorize_embedding(question) + qs = await llm.nostream([], sys_call, user_call) + qs = json.loads(qs) + if len(qs) == 0: + return 0 + score = 0 + for q in qs: + q_vector = await Embedding.vectorize_embedding(q) + score += TokenTool.cosine_distance_numpy(question_vector, q_vector) + return (score/len(qs)+1)/2*100 + except Exception as e: + err = f"[TokenTool] 计算relevance失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 + + @staticmethod + def cal_lcs(str1: str, str2: str) -> float: + """ + 计算两个字符串的最长公共子序列长度得分 + """ + try: + words1 = TokenTool.split_words(str1) + words2 = TokenTool.split_words(str2) + new_words1 = [] + new_words2 = [] + for word in words1: + if word not in TokenTool.stopwords: + new_words1.append(word) + for word in words2: + if word not in TokenTool.stopwords: + new_words2.append(word) + if len(new_words1) == 0 and len(new_words2) == 0: + return 100 + if len(new_words1) == 0 or len(new_words2) == 0: + return 0 + m = len(new_words1) + n = len(new_words2) + dp = np.zeros((m+1, n+1)) + for i in range(1, m+1): + for j in range(1, n+1): + if new_words1[i-1] == new_words2[j-1]: + dp[i][j] = dp[i-1][j-1] + 1 + else: + dp[i][j] = max(dp[i-1][j], dp[i][j-1]) + lcs_length = dp[m][n] + score = lcs_length / min(len(new_words1), len(new_words2)) * 100 + return score + except Exception as e: + err = f"[TokenTool] 计算lcs失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 + + @staticmethod + def cal_leve(str1: str, str2: str) -> float: + """ + 计算两个字符串的编辑距离 + """ + try: + words1 = TokenTool.split_words(str1) + words2 = TokenTool.split_words(str2) + new_words1 = [] + new_words2 = [] + for word in words1: + if word not in TokenTool.stopwords: + new_words1.append(word) + for word in words2: + if word not in TokenTool.stopwords: + new_words2.append(word) + if len(new_words1) == 0 and len(new_words2) == 0: + return 100 + if len(new_words1) == 0 or len(new_words2) == 0: + return 0 + m = len(new_words1) + n = len(new_words2) + dp = np.zeros((m+1, n+1)) + for i in range(m+1): + dp[i][0] = i + for j in range(n+1): + dp[0][j] = j + for i in range(1, m+1): + for j in range(1, n+1): + if new_words1[i-1] == new_words2[j-1]: + dp[i][j] = dp[i-1][j-1] + else: + dp[i][j] = min(dp[i-1][j]+1, dp[i][j-1]+1, dp[i-1][j-1]+1) + edit_distance = dp[m][n] + score = (1 - edit_distance / max(len(new_words1), len(new_words2))) * 100 + return score + except Exception as e: + err = f"[TokenTool] 计算leve失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 + + @staticmethod + def cal_jac(str1: str, str2: str) -> float: + """ + 计算两个字符串的Jaccard相似度 + """ + try: + if len(str1) == 0 and len(str2) == 0: + return 100 + words1 = TokenTool.split_words(str1) + words2 = TokenTool.split_words(str2) + new_words1 = [] + new_words2 = [] + for word in words1: + if word not in TokenTool.stopwords: + new_words1.append(word) + for word in words2: + if word not in TokenTool.stopwords: + new_words2.append(word) + if len(new_words1) == 0 or len(new_words2) == 0: + return 0 + set1 = set(new_words1) + set2 = set(new_words2) + intersection = len(set1.intersection(set2)) + union = len(set1.union(set2)) + score = intersection / union * 100 + return score + except Exception as e: + err = f"[TokenTool] 计算jac失败 {e}" + logging.exception("[TokenTool] %s", err) + return -1 \ No newline at end of file