diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..2dc0be7be245949f71f62fa15409d0401957ecca 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -0,0 +1,318 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import uuid +import os +import shutil +import yaml +import json +import random +from pydantic import BaseModel, Field +from data_chain.apps.base.zip_handler import ZipHandler +from data_chain.config.config import config +from data_chain.logger.logger import logger as logging +from data_chain.apps.base.task.worker.base_worker import BaseWorker +from data_chain.llm.llm import LLM +from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus, DataSetStatus, QAStatus +from data_chain.entities.common import DEFAULt_DOC_TYPE_ID +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.manager.task_manager import TaskManager +from data_chain.manager.document_manager import DocumentManager +from data_chain.manager.chunk_manager import ChunkManager +from data_chain.manager.dataset_manager import DatasetManager +from data_chain.manager.qa_manager import QAManager +from data_chain.manager.task_queue_mamanger import TaskQueueManager +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, QAEntity, DataSetEntity, DataSetDocEntity +from data_chain.stores.minio.minio import MinIO +from data_chain.stores.mongodb.mongodb import Task + + +class Chunk(BaseModel): + text: str + type: str + + +class DocChunk(BaseModel): + doc_id: uuid.UUID + doc_name: str + chunks: list[Chunk] + + +class GenerateDataSetWorker(BaseWorker): + """ + GenerateDataSetWorker + """ + name = TaskType.DATASET_GENERATE.value + + @staticmethod + async def init(dataset_id: uuid.UUID) -> uuid.UUID: + '''初始化任务''' + dataset_entity = await DatasetManager.get_dataset_by_dataset_id(dataset_id) + if dataset_entity is None: + err = f"[GenerateDataSetWorker] 数据集不存在,数据集ID: {dataset_id}" + logging.exception(err) + return None + await DatasetManager.update_dataset_by_dataset_id(dataset_id, {"score": -1}) + dataset_entity = await DatasetManager.update_dataset_by_dataset_id(dataset_id, {"status": DataSetStatus.PENDING.value}) + await QAManager.update_qa_by_dataset_id(dataset_id, {"status": QAStatus.DELETED.value}) + task_entity = TaskEntity( + team_id=dataset_entity.team_id, + user_id=dataset_entity.author_id, + op_id=dataset_entity.id, + op_name=dataset_entity.name, + type=TaskType.DATASET_GENERATE.value, + retry=0, + status=TaskStatus.PENDING.value) + task_entity = await TaskManager.add_task(task_entity) + return task_entity.id + + @staticmethod + async def reinit(task_id: uuid.UUID) -> bool: + '''重新初始化任务''' + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + err = f"[GenerateDataSetWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return False + await DatasetManager.update_dataset_by_dataset_id(task_entity.op_id, {"score": -1}) + await QAManager.update_qa_by_dataset_id(task_entity.op_id, {"status": QAStatus.DELETED.value}) + if task_entity.retry < config['TASK_RETRY_TIME_LIMIT']: + await DatasetManager.update_dataset_by_dataset_id(task_entity.op_id, {"status": DataSetStatus.PENDING.value}) + return True + else: + await DatasetManager.update_dataset_by_dataset_id(task_entity.op_id, {"status": DataSetStatus.IDLE.value}) + return False + + @staticmethod + async def deinit(task_id: uuid.UUID) -> uuid.UUID: + '''析构任务''' + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + err = f"[GenerateDataSetWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await DatasetManager.update_dataset_by_dataset_id(task_entity.op_id, {"status": DataSetStatus.IDLE.value}) + return task_id + + @staticmethod + async def get_chunks(dataset_entity: DataSetEntity) -> list[DocChunk]: + '''获取文档的分块信息''' + dataset_doc_entities = await DatasetManager.list_dataset_document_by_dataset_id(dataset_entity.id) + doc_chunks = [] + for dataset_doc_entity in dataset_doc_entities: + doc_entity = await DocumentManager.get_document_by_doc_id(dataset_doc_entity.doc_id) + chunk_entities = await ChunkManager.list_all_chunk_by_doc_id(dataset_doc_entity.doc_id) + chunks = [] + for chunk_entity in chunk_entities: + chunks.append(Chunk( + text=chunk_entity.text, + type=chunk_entity.type + )) + doc_chunk = DocChunk( + doc_id=doc_entity.id, + doc_name=doc_entity.name, + chunks=[] + ) + doc_chunk.chunks = chunks + doc_chunks.append(doc_chunk) + return doc_chunks + + @staticmethod + async def generate_qa(dataset_entity: DataSetEntity, doc_chunks: list[DocChunk], llm: LLM) -> list[QAEntity]: + chunk_cnt = 0 + for doc_chunk in doc_chunks: + chunk_cnt += len(doc_chunk.chunks) + if chunk_cnt == 0: + return [] + qa_entities = [] + data_cnt = dataset_entity.data_cnt + division = data_cnt // chunk_cnt + remainder = data_cnt % chunk_cnt + index = 0 + random.shuffle(doc_chunks) + with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: + prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) + q_generate_prompt_template = prompt_dict.get('GENREATE_QUESTION_FROM_CONTENT_PROMPT', '') + answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', '') + cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', '') + dataset_score = 0 + for i in range(len(doc_chunks)): + doc_chunk = doc_chunks[i] + for j in range(len(doc_chunk.chunks)): + chunk = doc_chunk.chunks[j].text + if dataset_entity.is_chunk_related: + l = i-1 + r = i+1 + tokens_sub = 0 + while TokenTool.get_tokens(chunk) < llm.max_tokens: + if l < 0 and r >= len(doc_chunks): + break + if tokens_sub > 0: + if l >= 0: + tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[l].text) + chunk += doc_chunks[i].chunks[l].text + l -= 1 + else: + tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[r].text) + chunk += doc_chunks[i].chunks[r].text + r += 1 + else: + if r < len(doc_chunks): + tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[r].text) + chunk += doc_chunks[i].chunks[r].text + r += 1 + else: + tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[l].text) + chunk += doc_chunks[i].chunks[l].text + l -= 1 + if i > 0: + chunk = doc_chunk.chunks[i-1].text + chunk + if i < len(doc_chunk.chunks) - 1: + chunk = chunk + doc_chunk.chunks[i+1].text + qa_cnt = random.randint(0, 2*(division+(index <= remainder))) + if i == len(doc_chunks)-1 and j == len(doc_chunk.chunks)-1: + qa_cnt = data_cnt + qa_cnt = min(qa_cnt, data_cnt) + if qa_cnt == 0: + continue + data_cnt -= qa_cnt + data_cnt = max(data_cnt, 0) + qs = [] + answers = [] + rd = 5 + while len(qs) < qa_cnt and rd > 0: + rd -= 1 + try: + sys_call = q_generate_prompt_template.format( + k=qa_cnt, + content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens) + ) + usr_call = '请输出问题的列表' + sub_qs = await llm.nostream([], sys_call, usr_call) + sub_qs = json.loads(sub_qs) + except Exception as e: + err = f"[GenerateDataSetWorker] 生成问题失败,错误信息: {e}" + logging.exception(err) + continue + sub_answers = [] + try: + for q in sub_qs: + sys_call = answer_generate_prompt_template.format( + content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens//8*7), + question=TokenTool.get_k_tokens_words_from_content(q, llm.max_tokens//8) + ) + usr_call = '请输出答案' + sub_answer = await llm.nostream([], sys_call, usr_call) + sub_answers.append(sub_answer) + except Exception as e: + err = f"[GenerateDataSetWorker] 生成答案失败,错误信息: {e}" + logging.exception(err) + continue + for q, answer in zip(sub_qs, sub_answers): + if len(qa_entities) + len(qs) >= dataset_entity.data_cnt: + break + try: + if dataset_entity.is_data_cleared: + sys_call = cal_qa_score_prompt_template.format( + fragment=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens//9*4), + question=TokenTool.get_k_tokens_words_from_content(q, llm.max_tokens//9), + answer=TokenTool.get_k_tokens_words_from_content(answer, llm.max_tokens//9*4) + ) + usr_call = '请输出分数' + score = await llm.nostream([], sys_call, usr_call) + score = eval(score) + score = max(0, min(100, score)) + else: + score = 100 + if score > 60: + qs.append(q) + answers.append(answer) + dataset_score += score + except Exception as e: + err = f"[GenerateDataSetWorker] 计算分数失败,错误信息: {e}" + logging.exception(err) + continue + for q, ans in zip(qs, answers): + qa_entity = QAEntity( + dataset_id=dataset_entity.id, + doc_id=doc_chunk.doc_id, + doc_name=doc_chunk.doc_name, + question=q, + answer=ans, + chunk=chunk, + chunk_type=doc_chunk.chunks[i].type) + qa_entities.append(qa_entity) + if len(qa_entities) >= dataset_entity.data_cnt: + break + index += 1 + if len(qa_entities) >= dataset_entity.data_cnt: + break + if len(qa_entities) > 0: + dataset_score = dataset_score / len(qa_entities) + await DatasetManager.update_dataset_by_dataset_id( + dataset_entity.id, {'score': dataset_score}) + return qa_entities + + @staticmethod + async def add_qa_to_db(qa_entities: list[QAEntity]) -> None: + '''添加QA到数据库''' + index = 0 + while index < len(qa_entities): + try: + await QAManager.add_qas(qa_entities[index:index+1024]) + except Exception as e: + err = f"[GenerateDataSetWorker] 添加QA到数据库失败,错误信息: {e}" + logging.exception(err) + index += 1024 + + @staticmethod + async def run(task_id: uuid.UUID) -> None: + '''运行任务''' + try: + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + err = f"[GenerateDataSetWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + raise Exception(err) + llm = LLM( + openai_api_key=config['OPENAI_API_KEY'], + openai_api_base=config['OPENAI_API_BASE'], + model_name=config['MODEL_NAME'], + max_tokens=config['MAX_TOKENS'], + ) + dataset_entity = await DatasetManager.get_dataset_by_dataset_id(task_entity.op_id) + if dataset_entity is None: + err = f"[GenerateDataSetWorker] 数据集不存在,数据集ID: {task_entity.op_id}" + logging.exception(err) + raise Exception(err) + await DatasetManager.update_dataset_by_dataset_id(dataset_entity.id, {"status": DataSetStatus.GENERATING.value}) + current_stage = 0 + stage_cnt = 3 + doc_chunks = await GenerateDataSetWorker.get_chunks(dataset_entity) + current_stage += 1 + await GenerateDataSetWorker.report(task_id, "获取文档分块信息", current_stage, stage_cnt) + qa_entities = await GenerateDataSetWorker.generate_qa( + dataset_entity, doc_chunks, llm) + current_stage += 1 + await GenerateDataSetWorker.report(task_id, "生成QA", current_stage, stage_cnt) + await GenerateDataSetWorker.add_qa_to_db(qa_entities) + current_stage += 1 + await GenerateDataSetWorker.report(task_id, "添加QA到数据库", current_stage, stage_cnt) + await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + except Exception as e: + err = f"[GenerateDataSetWorker] 任务失败,task_id: {task_id},错误信息: {e}" + logging.exception(err) + await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await GenerateDataSetWorker.report(task_id, err, 0, 1) + + @staticmethod + async def stop(task_id: uuid.UUID) -> uuid.UUID: + '''停止任务''' + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await DatasetManager.update_dataset_by_dataset_id(task_entity.op_id, {"status": DataSetStatus.IDLE.value}) + if task_entity.status == TaskStatus.PENDING.value or task_entity.status == TaskStatus.RUNNING.value or task_entity.status == TaskStatus.FAILED.value: + await DatasetManager.update_dataset_by_dataset_id(task_entity.op_id, {"score": -1}) + await QAManager.update_qa_by_dataset_id(task_entity.op_id, {"status": QAStatus.DELETED.value}) + return task_id diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 98f41601a7df9f4b9294615eb3819ac45d3ffe9b..e6cad5b9a8c83baf09a224048e5d813dc4a3bc19 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -1,120 +1,578 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from typing import Any import uuid import os import shutil import yaml +import random +import io +import numpy as np +from PIL import Image +from data_chain.parser.tools.ocr_tool import OcrTool +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.parser.tools.image_tool import ImageTool +from data_chain.parser.handler.base_parser import BaseParser from data_chain.apps.base.zip_handler import ZipHandler +from data_chain.parser.parse_result import ParseNode, ParseResult +from data_chain.llm.llm import LLM +from data_chain.embedding.embedding import Embedding from data_chain.config.config import config from data_chain.logger.logger import logger as logging from data_chain.apps.base.task.worker.base_worker import BaseWorker -from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, IMPORT_KB_PATH_IN_OS, DOC_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO +from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, ParseMethod, DocumentStatus, ChunkStatus, ImageStatus, DocParseRelutTopology, ChunkParseTopology, ChunkType +from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, DOC_PATH_IN_MINIO, DOC_PATH_IN_OS, IMAGE_PATH_IN_MINIO from data_chain.manager.task_manager import TaskManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_type_manager import DocumentTypeManager from data_chain.manager.document_manager import DocumentManager +from data_chain.manager.chunk_manager import ChunkManager +from data_chain.manager.image_manager import ImageManager from data_chain.manager.task_queue_mamanger import TaskQueueManager -from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity, ChunkEntity, ImageEntity from data_chain.stores.minio.minio import MinIO from data_chain.stores.mongodb.mongodb import Task class ParseDocumentWorker(BaseWorker): - name = TaskType.DOC_PARSE + name = TaskType.DOC_PARSE.value @staticmethod async def init(doc_id: uuid.UUID) -> uuid.UUID: '''初始化任务''' - pass + doc_entity = await DocumentManager.get_document_by_doc_id(doc_id) + if doc_entity is None: + err = f"[ParseDocumentWorker] 文档不存在,doc_id: {doc_id}" + logging.exception(err) + raise None + await DocumentManager.update_document_by_doc_id(doc_id, {"status": DocumentStatus.PENDING.value, "abstract": "", "abstract_vector": None}) + await ImageManager.update_images_by_doc_id(doc_id, {"status": ImageStatus.DELETED.value}) + await ChunkManager.update_chunk_by_doc_id(doc_id, {"status": ChunkStatus.DELETED.value}) + task_entity = TaskEntity( + team_id=doc_entity.team_id, + user_id=doc_entity.author_id, + op_id=doc_entity.id, + op_name=doc_entity.name, + type=TaskType.DOC_PARSE.value, + retry=0, + status=TaskStatus.PENDING.value) + task_entity = await TaskManager.add_task(task_entity) + return task_entity.id @staticmethod async def reinit(task_id: uuid.UUID) -> bool: '''重新初始化任务''' - pass + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + err = f"[ImportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return False + doc_id = task_entity.op_id + await DocumentManager.update_document_by_doc_id(doc_id, {"abstract": "", "abstract_vector": None}) + await ImageManager.update_images_by_doc_id(doc_id, {"status": ImageStatus.DELETED.value}) + await ChunkManager.update_chunk_by_doc_id(doc_id, {"status": ChunkStatus.DELETED.value}) + tmp_path = os.path.join(DOC_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + if task_entity.retry < config['TASK_RETRY_TIME_LIMIT']: + await DocumentManager.update_document_by_doc_id(task_entity.op_id, {"status": DocumentStatus.PENDING.value}) + return True + else: + await DocumentManager.update_document_by_doc_id(task_entity.op_id, {"status": DocumentStatus.IDLE.value}) + return False @staticmethod async def deinit(task_id: uuid.UUID) -> uuid.UUID: '''析构任务''' - pass + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + err = f"[ParseDocumentWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await DocumentManager.update_document_by_doc_id(task_entity.op_id, {"status": DocumentStatus.IDLE.value}) + tmp_path = os.path.join(DOC_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + return task_id @staticmethod async def init_path(task_id: uuid.UUID) -> tuple: '''初始化存放配置文件和文档的路径''' - pass + tmp_path = os.path.join(DOC_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + os.makedirs(tmp_path) + image_path = os.path.join(tmp_path, "images") + os.makedirs(image_path) + return tmp_path, image_path @staticmethod - async def download_doc_from_minio(doc_id: uuid.UUID) -> str: + async def download_doc_from_minio(doc_id: uuid.UUID, tmp_path: str) -> str: '''下载文档''' - pass + doc_entity = await DocumentManager.get_document_by_doc_id(doc_id) + if doc_entity is None: + err = f"[ParseDocumentWorker] 文档不存在,doc_id: {doc_id}" + logging.exception(err) + raise Exception(err) + file_path = os.path.join(tmp_path, str(doc_id)+'.'+doc_entity.extension) + await MinIO.download_object( + DOC_PATH_IN_MINIO, + str(doc_entity.id), + file_path, + ) + return file_path @staticmethod - async def parse_doc(download_path: str): + async def parse_doc(doc_entity: DocumentEntity, file_path: str) -> ParseResult: '''解析文档''' - pass + parse_result = await BaseParser.parser(doc_entity.extension, file_path) + return parse_result @staticmethod - async def get_doc_abstrace(): - '''获取文档摘要''' - pass + async def get_content_from_json(js: Any) -> str: + '''获取json内容''' + if isinstance(js, dict): + content = '' + for key, value in js.items(): + content += str(key) + ': ' + if isinstance(value, (dict, list)): + content += await ParseDocumentWorker.get_content_from_json(value) + else: + content += str(value) + '\n' + return content + elif isinstance(js, list): + for item in js: + content += await ParseDocumentWorker.get_content_from_json(item) + content += ' ' + content += '\n' + return content + else: + return str(js) @staticmethod - async def ocr_from_parse_image(): - '''从解析图片中获取ocr''' - pass + async def handle_parse_result(parse_result: ParseResult, doc_entity: DocumentEntity, llm: LLM = None) -> None: + '''处理解析结果''' + if doc_entity.parse_method != ParseMethod.OCR.value and doc_entity.parse_method != ParseMethod.EHANCED: + nodes = [] + for node in parse_result.nodes: + if node.type != ChunkType.IMAGE: + nodes.append(node) + parse_result.nodes = nodes + if doc_entity.parse_method == ParseMethod.QA: + if doc_entity.extension == 'xlsx' or doc_entity.extension == 'csv': + for node in parse_result.nodes: + node.type = ChunkType.QA + try: + question = '' + if len(node.content) > 0: + question = str(node.content[0]) + answer = '' + if len(node.content) > 1: + answer = str(node.content[1]) + except Exception as e: + question = '' + answer = '' + warning = f"[ParseDocumentWorker] 解析问题和答案失败,doc_id: {doc_entity.id}, error: {e}" + logging.warning(warning) + node.text_feature = question + node.content = 'question: ' + question + '\n' + 'answer: ' + answer + elif doc_entity.extension == 'json' or doc_entity.extension == 'yaml': + qa_list = parse_result.nodes[0].content + parse_result.nodes = [] + for qa in qa_list: + question = qa.get('question') + answer = qa.get('answer') + if question is None or answer is None: + warning = f"[ParseDocumentWorker] 解析问题和答案失败,doc_id: {doc_entity.id}, error: {e}" + logging.warning(warning) + continue + node = ParseNode( + id=uuid.uuid4(), + lv=0, + parse_topology_type=ChunkParseTopology.GERNERAL, + text_feature=question, + content='question: ' + question + '\n' + 'answer: ' + answer, + type=ChunkType.QA, + link_nodes=[] + ) + parse_result.nodes.append(node) + else: + if doc_entity.extension == 'xlsx' or doc_entity.extension == 'csv': + for node in parse_result.nodes: + node.content = '|'.join(node.content) + node.text_feature = node.content + elif doc_entity.extension == 'json' or doc_entity.extension == 'yaml': + parse_result.nodes[0].content = await ParseDocumentWorker.get_content_from_json(parse_result.nodes[0].content) + parse_result.nodes[0].text_feature = parse_result.nodes[0].content + else: + for node in parse_result.nodes: + if node.type == ChunkType.TEXT or node.type == ChunkType.LINK: + node.text_feature = node.content + elif node.type == ChunkType.CODE: + if llm is not None: + node.text_feature = await TokenTool.get_abstract_by_llm(node.content, llm) + if node.text_feature is None: + node.text_feature = TokenTool.get_top_k_keywords(node.content) + elif node.type == ChunkType.TABLE: + node.content = '|'.join(node.content) + node.text_feature = node.content @staticmethod - async def upload_parse_image_to_minio(): + async def upload_parse_image_to_minio_and_postgres( + parse_result: ParseResult, doc_entity: DocumentEntity, image_path: str) -> None: '''上传解析图片到minio''' - pass + image_entities = [] + for node in parse_result.nodes: + if node.type == ChunkType.IMAGE: + try: + extension = ImageTool.get_image_type(node.content) + image_entity = ImageEntity( + id=uuid.uuid4(), + team_id=doc_entity.team_id, + doc_id=doc_entity.id, + chunk_id=node.id, + extension=extension, + ) + image_entities.append(image_entity) + image_blob = node.content + image = Image.open(io.BytesIO(image_blob)) + image_file_path = os.path.join(image_path, str(node.id) + '.' + extension) + with open(image_file_path, 'wb') as f: + f.write(image) + await MinIO.upload_object( + bucket_name=IMAGE_PATH_IN_MINIO, + object_name=str(node.id), + file_path=image_file_path + ) + except Exception as e: + err = f"[ParseDocumentWorker] 上传解析图片到minio失败,doc_id: {doc_entity.id}, image_path: {image_path}, error: {e}" + logging.exception(err) + continue + image_entities.append(image_entity) + index = 0 + while index < len(image_entities): + try: + await ImageManager.add_images(image_entities[index:index+1024]) + except Exception as e: + err = f"[ParseDocumentWorker] 上传解析图片到postgres失败,doc_id: {doc_entity.id}, image_path: {image_path}, error: {e}" + logging.exception(err) + index += 1024 @staticmethod - async def push_down_words_feature(): - '''下推words特征''' - pass + async def ocr_from_parse_image(parse_result: ParseResult, llm: LLM = None) -> list: + '''从解析图片中获取ocr''' + for node in parse_result.nodes: + if node.type == ChunkType.IMAGE: + try: + image_blob = node.content + image = Image.open(io.BytesIO(image_blob)) + img_np = np.array(image) + image_related_text = '' + for related_node in node.link_nodes: + if related_node.type != ChunkType.IMAGE: + image_related_text += related_node.content + node.content = OcrTool.image_to_text(img_np, image_related_text, llm) + node.text_feature = node.content + except Exception as e: + err = f"[ParseDocumentWorker] OCR失败,doc_id: {node.doc_id}, error: {e}" + logging.exception(err) + continue @staticmethod - async def handle_doc_parse_result(): - '''处理解析结果''' - pass + async def merge_and_split_text(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: + '''合并和拆分内容''' + if doc_entity.parse_method == ParseMethod.QA or doc_entity.parse_relut_topology == DocParseRelutTopology.TREE: + return + nodes = [] + for node in parse_result.nodes: + if node.type == ChunkType.TEXT: + tokens = TokenTool.get_tokens(node.content) + if len(nodes) == 0 or ( + len(nodes) + and ( + nodes[-1].type != ChunkType.TEXT or TokenTool.get_tokens(nodes[-1].content) + tokens > doc_entity. + chunk_size)): + nodes.append(node) + else: + nodes[-1].content += node.content + else: + nodes.append(node) + parse_result.nodes = nodes + nodes = [] + tmp = '' + for node in parse_result.nodes: + if node.type == ChunkType.TEXT: + sentences = TokenTool.content_to_sentences(node.content) + new_sentences = [] + for sentence in sentences: + if TokenTool.get_tokens(sentence) > doc_entity.chunk_size: + tmp = sentence + while len(tmp) > 0: + sub_sentence = TokenTool.get_k_tokens_words_from_content(tmp, doc_entity.chunk_size) + new_sentences.append(sub_sentence) + tmp = tmp[len(sub_sentence):] + else: + new_sentences.append(sentence) + sentences = new_sentences + for sentence in sentences: + if len(tmp) == 0: + tmp = sentence + else: + if TokenTool.get_tokens(tmp) + TokenTool.get_tokens(sentence) > doc_entity.chunk_size: + tmp_node = ParseNode( + id=uuid.uuid4(), + lv=node.lv, + parse_topology_type=ChunkParseTopology.GERNERAL, + text_feature=tmp, + content=tmp, + type=ChunkType.TEXT, + link_nodes=[] + ) + nodes.append(tmp_node) + tmp = '' + else: + tmp += sentence + else: + if len(tmp) > 0: + tmp_node = ParseNode( + id=uuid.uuid4(), + lv=node.lv, + parse_topology_type=ChunkParseTopology.GERNERAL, + text_feature=tmp, + content=tmp, + type=ChunkType.TEXT, + link_nodes=[] + ) + nodes.append(tmp_node) + tmp = '' + nodes.append(node) + if len(tmp) > 0: + tmp_node = ParseNode( + id=uuid.uuid4(), + lv=node.lv, + parse_topology_type=ChunkParseTopology.GERNERAL, + text_feature=tmp, + content=tmp, + type=ChunkType.TEXT, + link_nodes=[] + ) + nodes.append(tmp_node) + parse_result.nodes = nodes + + @staticmethod + async def push_up_words_feature(parse_result: ParseResult, llm: LLM = None) -> None: + '''推送上层词特征''' + async def dfs(node: ParseNode, parent_node: ParseNode, llm) -> None: + if parent_node is None: + node.pre_id = parent_node.id + for child_node in node.link_nodes: + await dfs(child_node, node) + if node.title is not None: + if len(node.title) == 0: + if llm is not None: + content = '父标题\n' + if parent_node and parent_node.title: + if len(parent_node.title) > 0: + content += parent_node.title + '\n' + else: + sentences = TokenTool.get_top_k_keysentence(parent_node.content, 1) + if sentences: + content += sentences[0] + '\n' + index = 0 + for node in node.link_nodes: + content += '子标题'+str(index) + '\n' + if node.title: + content += node.title + '\n' + else: + sentences = TokenTool.get_top_k_keysentence(node.content, 1) + if sentences: + content += sentences[0] + '\n' + index += 1 + title = await TokenTool.get_title_by_llm(content, llm) + if not title: + sentences = TokenTool.get_top_k_keysentence(content, 1) + if sentences: + title = sentences[0] + node.text_feature = title + node.content = node.text_feature + else: + if parent_node and parent_node.title: + if len(parent_node.title) > 0: + content += parent_node.title + '\n' + else: + sentences = TokenTool.get_top_k_keysentence(parent_node.content, 1) + if sentences: + content += sentences[0] + '\n' + for node in node.link_nodes: + if node.title: + content += node.title + '\n' + else: + sentences = TokenTool.get_top_k_keysentence(node.content, 1) + if sentences: + content += sentences[0] + '\n' + sentences = TokenTool.get_top_k_keysentence(content, 1) + if sentences: + node.text_feature = sentences[0] + else: + node.text_feature = '' + node.content = node.text_feature + else: + node.text_feature = node.title + node.content = node.text_feature + if parse_result.parse_topology_type == DocParseRelutTopology.TREE: + await dfs(parse_result.nodes, None, llm) @staticmethod - async def embedding_chunk(): + async def update_doc_abstract(doc_id: uuid.UUID, parse_result: ParseResult, llm: LLM = None) -> str: + '''获取文档摘要''' + abstract = "" + for node in parse_result.nodes: + abstract += node.content + if llm is not None: + abstract = await TokenTool.get_abstract_by_llm(abstract, llm) + else: + sentences = TokenTool.get_top_k_keysentence(abstract, 1) + if sentences: + abstract = sentences[0] + else: + abstract = '' + abstract_vector = await Embedding.vectorize_embedding(abstract) + await DocumentManager.update_document_by_doc_id( + doc_id, + { + "abstract": abstract, + "abstract_vector": abstract_vector + } + ) + return abstract + + @staticmethod + async def embedding_chunk(parse_result: ParseResult) -> None: '''嵌入chunk''' - pass + for node in parse_result.nodes: + node.vector = await Embedding.vectorize_embedding(node.text_feature) @staticmethod - async def add_parse_result_to_db(): + async def add_parse_result_to_db(parse_result: ParseResult, doc_entity: DocumentEntity) -> None: '''添加解析结果到数据库''' - pass + chunk_entities = [] + global_offset = 0 + local_offset = 0 + for node in parse_result.nodes: + if not node.content: + continue + chunk_entity = ChunkEntity( + id=node.id, + team_id=doc_entity.team_id, + kb_id=doc_entity.kb_id, + doc_id=doc_entity.id, + doc_name=doc_entity.name, + text=node.content, + text_vector=node.vector, + tokens=TokenTool.get_tokens(node.content), + type=node.type, + pre_id_in_parse_topology=node.pre_id, + parse_topology_type=node.parse_topology_type, + global_offset=global_offset, + local_offset=local_offset, + enabled=True, + status=ChunkStatus.EXISTED.value + ) + chunk_entities.append(chunk_entity) + if global_offset and parse_result.nodes[global_offset].type != parse_result.nodes[global_offset-1].type: + local_offset = 0 + local_offset += 1 + global_offset += 1 + index = 0 + while index < len(chunk_entities): + try: + await ChunkManager.add_chunks(chunk_entities[index:index+1024]) + except Exception as e: + err = f"[ParseDocumentWorker] 添加解析结果到数据库失败,doc_id: {doc_entity.id}, error: {e}" + logging.exception(err) + index += 1024 @staticmethod async def run(task_id: uuid.UUID) -> None: '''运行任务''' - pass + task_entity = await TaskManager.get_task_by_task_id(task_id) + if task_entity is None: + err = f"[ParseDocumentWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + raise Exception(err) + doc_entity = await DocumentManager.get_document_by_doc_id(task_entity.op_id) + if doc_entity is None: + err = f"[ParseDocumentWorker] 文档不存在,doc_id: {task_entity.op_id}" + logging.exception(err) + raise Exception(err) + await DocumentManager.update_document_by_doc_id(task_entity.op_id, {"status": DocumentStatus.RUNNING.value}) + try: + if doc_entity.parse_method == ParseMethod.EHANCED: + llm = LLM( + openai_api_key=config['OPENAI_API_KEY'], + openai_api_base=config['OPENAI_API_BASE'], + model_name=config['MODEL_NAME'], + max_tokens=config['MAX_TOKENS'], + ) + else: + llm = None + tmp_path, image_path = await ParseDocumentWorker.init_path(task_id) + current_stage = 0 + stage_cnt = 10 + await ParseDocumentWorker.download_doc_from_minio(task_entity.op_id, tmp_path) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '下载文档', current_stage, stage_cnt) + file_path = os.path.join(tmp_path, str(task_entity.op_id)+'.'+doc_entity.extension) + parse_result = await ParseDocumentWorker.parse_doc(doc_entity, file_path) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '解析文档', current_stage, stage_cnt) + await ParseDocumentWorker.handle_parse_result(parse_result, doc_entity, llm) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '处理解析结果', current_stage, stage_cnt) + await ParseDocumentWorker.upload_parse_image_to_minio_and_postgres(parse_result, doc_entity, image_path) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '上传解析图片', current_stage, stage_cnt) + await ParseDocumentWorker.ocr_from_parse_image(parse_result, llm) + current_stage += 1 + await ParseDocumentWorker.report(task_id, 'OCR图片', current_stage, stage_cnt) + await ParseDocumentWorker.merge_and_split_text(parse_result, doc_entity) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '合并和拆分文本', current_stage, stage_cnt) + await ParseDocumentWorker.push_up_words_feature(parse_result, llm) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '推送上层词特征', current_stage, stage_cnt) + await ParseDocumentWorker.embedding_chunk(parse_result) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '嵌入chunk', current_stage, stage_cnt) + await ParseDocumentWorker.update_doc_abstract(doc_entity.id, parse_result, llm) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '更新文档摘要', current_stage, stage_cnt) + await ParseDocumentWorker.add_parse_result_to_db(parse_result, doc_entity) + current_stage += 1 + await ParseDocumentWorker.report(task_id, '添加解析结果到数据库', current_stage, stage_cnt) + await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + except Exception as e: + err = f"[DocParseWorker] 任务失败,task_id: {task_id},错误信息: {e}" + logging.exception(err) + await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await ParseDocumentWorker.report(task_id, err, 0, 1) + return None @staticmethod async def stop(task_id: uuid.UUID) -> uuid.UUID: '''停止任务''' - task_entity = await TaskManager.get_task_by_id(task_id) + task_entity = await TaskManager.get_task_by_task_id(task_id) if task_entity is None: - err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + err = f"[ParseDocumentWorker] 任务不存在,task_id: {task_id}" logging.exception(err) return None - await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.IDLE.value}) - tmp_path = os.path.join(IMPORT_KB_PATH_IN_OS, str(task_id)) + await DocumentManager.update_document_by_doc_id(task_entity.op_id, {"status": DocumentStatus.IDLE.value}) + if task_entity.status == TaskStatus.PENDING.value or task_entity.status == TaskStatus.RUNNING.value or task_entity.status == TaskStatus.FAILED.value: + if task_entity.status == TaskStatus.RUNNING.value or task_entity.status == TaskStatus.FAILED.value: + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.CANCLED.value}) + await DocumentManager.update_document_by_doc_id(task_entity.op_id, {"abstract": "", "abstract_vector": None}) + await ImageManager.update_images_by_doc_id(task_entity.op_id, {"status": ImageStatus.DELETED.value}) + await ChunkManager.update_chunk_by_doc_id(task_entity.op_id, {"status": ChunkStatus.DELETED.value}) + tmp_path = os.path.join(DOC_PATH_IN_OS, str(task_id)) if os.path.exists(tmp_path): shutil.rmtree(tmp_path) return task_id - - @staticmethod - async def delete(task_id) -> uuid.UUID: - '''删除任务''' - task_entity = await TaskManager.get_task_by_id(task_id) - if task_entity is None: - err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" - logging.exception(err) - return None - if task_entity.status == TaskStatus.CANCLED or TaskStatus.FAILED.value: - await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.DELTED.value}) - await MinIO.delete_object(IMPORT_KB_PATH_IN_OS, str(task_entity.op_id)) - return task_id