From 7af1805d210e7c7bf85a2a022fa11c45623e1dad Mon Sep 17 00:00:00 2001 From: zxstty Date: Mon, 9 Dec 2024 20:43:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=9C=A8=E5=85=B0=E5=8D=8F?= =?UTF-8?q?=E8=AE=AE=E5=92=8Cfix=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- LICENSE | 194 ++++++++++++ README.md | 37 +-- data_chain/apps/app.py | 4 +- .../apps/base/convertor/chunk_convertor.py | 15 + .../apps/base/convertor/document_convertor.py | 16 +- .../base/convertor/knowledge_convertor.py | 21 +- .../apps/base/convertor/task_convertor.py | 6 +- data_chain/apps/base/model/llm.py | 15 +- .../apps/base/task/document_task_handler.py | 175 ++++++----- .../base/task/knowledge_base_task_handler.py | 142 ++++++--- data_chain/apps/base/task/task_handler.py | 34 +-- data_chain/apps/router/document.py | 65 +++- data_chain/apps/router/knowledge_base.py | 82 +++-- data_chain/apps/router/model.py | 2 +- data_chain/apps/router/user.py | 6 +- data_chain/apps/service/chunk_service.py | 250 ++++++++++------ data_chain/apps/service/document_service.py | 280 ++++++++++++++++-- data_chain/apps/service/embedding_service.py | 2 - .../apps/service/knwoledge_base_service.py | 42 ++- data_chain/apps/service/task_service.py | 53 +++- data_chain/common/prompt.yaml | 3 +- data_chain/config/config.py | 10 +- data_chain/manager/chunk_manager.py | 223 ++++++++++++-- data_chain/manager/document_manager.py | 87 +++++- data_chain/manager/document_type_manager.py | 2 + data_chain/manager/image_manager.py | 22 +- data_chain/manager/task_manager.py | 15 +- data_chain/manager/vector_items_manager.py | 105 ++++++- data_chain/models/api.py | 119 ++++---- data_chain/models/constant.py | 67 ++--- data_chain/models/service.py | 7 +- data_chain/parser/handler/base_parser.py | 21 +- data_chain/parser/handler/doc_parser.py | 2 +- data_chain/parser/handler/docx_parser.py | 16 +- data_chain/parser/handler/pdf_parser.py | 2 +- data_chain/parser/service/parser_service.py | 249 +++++++++------- data_chain/parser/tools/ocr.py | 11 +- data_chain/stores/minio/minio.py | 3 - data_chain/stores/postgres/postgres.py | 77 +++-- requirements.txt | 5 +- 40 files changed, 1760 insertions(+), 727 deletions(-) create mode 100644 LICENSE create mode 100644 data_chain/apps/base/convertor/chunk_convertor.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f63f5a9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,194 @@ +木兰宽松许可证,第2版 + +木兰宽松许可证,第2版 + +2020年1月 http://license.coscl.org.cn/MulanPSL2 + +您对“软件”的复制、使用、修改及分发受木兰宽松许可证,第2版(“本许可证”)的如下条款的约束: + +0. 定义 + +“软件” 是指由“贡献”构成的许可在“本许可证”下的程序和相关文档的集合。 + +“贡献” 是指由任一“贡献者”许可在“本许可证”下的受版权法保护的作品。 + +“贡献者” 是指将受版权法保护的作品许可在“本许可证”下的自然人或“法人实体”。 + +“法人实体” 是指提交贡献的机构及其“关联实体”。 + +“关联实体” 是指,对“本许可证”下的行为方而言,控制、受控制或与其共同受控制的机构,此处的控制是 +指有受控方或共同受控方至少50%直接或间接的投票权、资金或其他有价证券。 + +1. 授予版权许可 + +每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的版权许可,您可 +以复制、使用、修改、分发其“贡献”,不论修改与否。 + +2. 授予专利许可 + +每个“贡献者”根据“本许可证”授予您永久性的、全球性的、免费的、非独占的、不可撤销的(根据本条规定 +撤销除外)专利许可,供您制造、委托制造、使用、许诺销售、销售、进口其“贡献”或以其他方式转移其“贡 +献”。前述专利许可仅限于“贡献者”现在或将来拥有或控制的其“贡献”本身或其“贡献”与许可“贡献”时的“软 +件”结合而将必然会侵犯的专利权利要求,不包括对“贡献”的修改或包含“贡献”的其他结合。如果您或您的“ +关联实体”直接或间接地,就“软件”或其中的“贡献”对任何人发起专利侵权诉讼(包括反诉或交叉诉讼)或 +其他专利维权行动,指控其侵犯专利权,则“本许可证”授予您对“软件”的专利许可自您提起诉讼或发起维权 +行动之日终止。 + +3. 无商标许可 + +“本许可证”不提供对“贡献者”的商品名称、商标、服务标志或产品名称的商标许可,但您为满足第4条规定 +的声明义务而必须使用除外。 + +4. 分发限制 + +您可以在任何媒介中将“软件”以源程序形式或可执行形式重新分发,不论修改与否,但您必须向接收者提供“ +本许可证”的副本,并保留“软件”中的版权、商标、专利及免责声明。 + +5. 免责声明与责任限制 + +“软件”及其中的“贡献”在提供时不带任何明示或默示的担保。在任何情况下,“贡献者”或版权所有者不对 +任何人因使用“软件”或其中的“贡献”而引发的任何直接或间接损失承担责任,不论因何种原因导致或者基于 +何种法律理论,即使其曾被建议有此种损失的可能性。 + +6. 语言 + +“本许可证”以中英文双语表述,中英文版本具有同等法律效力。如果中英文版本存在任何冲突不一致,以中文 +版为准。 + +条款结束 + +如何将木兰宽松许可证,第2版,应用到您的软件 + +如果您希望将木兰宽松许可证,第2版,应用到您的新软件,为了方便接收者查阅,建议您完成如下三步: + +1, 请您补充如下声明中的空白,包括软件名、软件的首次发表年份以及您作为版权人的名字; + +2, 请您在软件包的一级目录下创建以“LICENSE”为名的文件,将整个许可证文本放入该文件中; + +3, 请将如下声明文本放入每个源文件的头部注释中。 + +Copyright (c) [Year] [name of copyright holder] +[Software Name] is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan +PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +See the Mulan PSL v2 for more details. + +Mulan Permissive Software License,Version 2 + +Mulan Permissive Software License,Version 2 (Mulan PSL v2) + +January 2020 http://license.coscl.org.cn/MulanPSL2 + +Your reproduction, use, modification and distribution of the Software shall +be subject to Mulan PSL v2 (this License) with the following terms and +conditions: + +0. Definition + +Software means the program and related documents which are licensed under +this License and comprise all Contribution(s). + +Contribution means the copyrightable work licensed by a particular +Contributor under this License. + +Contributor means the Individual or Legal Entity who licenses its +copyrightable work under this License. + +Legal Entity means the entity making a Contribution and all its +Affiliates. + +Affiliates means entities that control, are controlled by, or are under +common control with the acting entity under this License, ‘control’ means +direct or indirect ownership of at least fifty percent (50%) of the voting +power, capital or other securities of controlled or commonly controlled +entity. + +1. Grant of Copyright License + +Subject to the terms and conditions of this License, each Contributor hereby +grants to you a perpetual, worldwide, royalty-free, non-exclusive, +irrevocable copyright license to reproduce, use, modify, or distribute its +Contribution, with modification or not. + +2. Grant of Patent License + +Subject to the terms and conditions of this License, each Contributor hereby +grants to you a perpetual, worldwide, royalty-free, non-exclusive, +irrevocable (except for revocation under this Section) patent license to +make, have made, use, offer for sale, sell, import or otherwise transfer its +Contribution, where such patent license is only limited to the patent claims +owned or controlled by such Contributor now or in future which will be +necessarily infringed by its Contribution alone, or by combination of the +Contribution with the Software to which the Contribution was contributed. +The patent license shall not apply to any modification of the Contribution, +and any other combination which includes the Contribution. If you or your +Affiliates directly or indirectly institute patent litigation (including a +cross claim or counterclaim in a litigation) or other patent enforcement +activities against any individual or entity by alleging that the Software or +any Contribution in it infringes patents, then any patent license granted to +you under this License for the Software shall terminate as of the date such +litigation or activity is filed or taken. + +3. No Trademark License + +No trademark license is granted to use the trade names, trademarks, service +marks, or product names of Contributor, except as required to fulfill notice +requirements in section 4. + +4. Distribution Restriction + +You may distribute the Software in any medium with or without modification, +whether in source or executable forms, provided that you provide recipients +with a copy of this License and retain copyright, patent, trademark and +disclaimer statements in the Software. + +5. Disclaimer of Warranty and Limitation of Liability + +THE SOFTWARE AND CONTRIBUTION IN IT ARE PROVIDED WITHOUT WARRANTIES OF ANY +KIND, EITHER EXPRESS OR IMPLIED. IN NO EVENT SHALL ANY CONTRIBUTOR OR +COPYRIGHT HOLDER BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT +LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING +FROM YOUR USE OR INABILITY TO USE THE SOFTWARE OR THE CONTRIBUTION IN IT, NO +MATTER HOW IT’S CAUSED OR BASED ON WHICH LEGAL THEORY, EVEN IF ADVISED OF +THE POSSIBILITY OF SUCH DAMAGES. + +6. Language + +THIS LICENSE IS WRITTEN IN BOTH CHINESE AND ENGLISH, AND THE CHINESE VERSION +AND ENGLISH VERSION SHALL HAVE THE SAME LEGAL EFFECT. IN THE CASE OF +DIVERGENCE BETWEEN THE CHINESE AND ENGLISH VERSIONS, THE CHINESE VERSION +SHALL PREVAIL. + +END OF THE TERMS AND CONDITIONS + +How to Apply the Mulan Permissive Software License,Version 2 +(Mulan PSL v2) to Your Software + +To apply the Mulan PSL v2 to your work, for easy identification by +recipients, you are suggested to complete following three steps: + +i. Fill in the blanks in following statement, including insert your software +name, the year of the first publication of your software, and your name +identified as the copyright owner; + +ii. Create a file named "LICENSE" which contains the whole context of this +License in the first directory of your software package; + +iii. Attach the statement to the appropriate annotated syntax at the +beginning of each source file. + +Copyright (c) [Year] [name of copyright holder] +[Software Name] is licensed under Mulan PSL v2. +You can use this software according to the terms and conditions of the Mulan +PSL v2. +You may obtain a copy of Mulan PSL v2 at: + http://license.coscl.org.cn/MulanPSL2 +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY +KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +See the Mulan PSL v2 for more details. diff --git a/README.md b/README.md index 7af1a33..b2bcdb9 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,2 @@ -# euler-copilot-rag +# rag-service -#### 介绍 -A retrieval-augmented service designed for data management and enhancing the accuracy of intelligent question-answering systems. - -#### 软件架构 -软件架构说明 - - -#### 安装教程 - -1. xxxx -2. xxxx -3. xxxx - -#### 使用说明 - -1. xxxx -2. xxxx -3. xxxx - -#### 参与贡献 - -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request - - -#### 特技 - -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 -5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index 1007771..7a610c2 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -10,7 +10,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from data_chain.config.config import config from data_chain.apps.router import chunk,document,health_check,knowledge_base,model,other,user -from data_chain.apps.service.task_service import redis_task, monitor_tasks +from data_chain.apps.service.task_service import task_queue_handler, monitor_tasks from data_chain.apps.service.user_service import UserHTTPException from data_chain.stores.postgres.postgres import PostgresDB, DocumentTypeEntity, DocumentEntity, KnowledgeBaseEntity, TaskEntity,User from data_chain.models.constant import DocumentEmbeddingConstant, KnowledgeStatusEnum, TaskConstant @@ -37,7 +37,7 @@ async def startup_event(): TaskRedisHandler.clear_all_task(config['REDIS_PENDING_TASK_QUEUE_NAME']) TaskRedisHandler.clear_all_task(config['REDIS_SUCCESS_TASK_QUEUE_NAME']) TaskRedisHandler.clear_all_task(config['REDIS_RESTART_TASK_QUEUE_NAME']) - scheduler.add_job(redis_task, 'interval', seconds=5) + scheduler.add_job(task_queue_handler, 'interval', seconds=5) scheduler.add_job(monitor_tasks, 'interval', seconds=25) scheduler.start() logging.info("Application startup complete.") diff --git a/data_chain/apps/base/convertor/chunk_convertor.py b/data_chain/apps/base/convertor/chunk_convertor.py new file mode 100644 index 0000000..7a4a4a6 --- /dev/null +++ b/data_chain/apps/base/convertor/chunk_convertor.py @@ -0,0 +1,15 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from typing import Dict, List +from data_chain.models.service import ChunkDTO +from data_chain.stores.postgres.postgres import ChunkEntity + + +class ChunkConvertor(): + @staticmethod + def convert_entity_to_dto(chunk_entity: ChunkEntity) -> ChunkDTO: + return ChunkDTO( + id=str(chunk_entity.id), + text=chunk_entity.text, + enabled=chunk_entity.enabled, + type=chunk_entity.type.split('.')[1] + ) diff --git a/data_chain/apps/base/convertor/document_convertor.py b/data_chain/apps/base/convertor/document_convertor.py index f0a5e10..a204a3f 100644 --- a/data_chain/apps/base/convertor/document_convertor.py +++ b/data_chain/apps/base/convertor/document_convertor.py @@ -9,8 +9,14 @@ class DocumentConvertor(): def convert_entity_to_dto(document_entity: DocumentEntity, document_type_entity: DocumentTypeEntity ) -> DocumentDTO: document_type_dto = DocumentTypeDTO(id=str(document_type_entity.id), type=document_type_entity.type) - return DocumentDTO(id=str(document_entity.id), name=document_entity.name, - extension=document_entity.extension, document_type=document_type_dto, - chunk_size=document_entity.chunk_size, status=document_entity.status, - enabled=document_entity.enabled, parser_method=document_entity.parser_method, - created_time=document_entity.created_time.strftime('%Y-%m-%d %H:%M')) + return DocumentDTO( + id=str(document_entity.id), + name=document_entity.name, + extension=document_entity.extension, + document_type=document_type_dto, + chunk_size=document_entity.chunk_size, + status=document_entity.status, + enabled=document_entity.enabled, + parser_method=document_entity.parser_method, + created_time=document_entity.created_time.strftime('%Y-%m-%d %H:%M') + ) diff --git a/data_chain/apps/base/convertor/knowledge_convertor.py b/data_chain/apps/base/convertor/knowledge_convertor.py index d814292..313b68f 100644 --- a/data_chain/apps/base/convertor/knowledge_convertor.py +++ b/data_chain/apps/base/convertor/knowledge_convertor.py @@ -8,12 +8,12 @@ from data_chain.models.constant import KnowledgeStatusEnum class KnowledgeConvertor(): @staticmethod - def convert_dict_to_entity(tmp_dict): + def convert_dict_to_entity(tmp_dict:dict): return KnowledgeBaseEntity( name=tmp_dict['name'], user_id=tmp_dict['user_id'], language=tmp_dict['language'], - description=tmp_dict['description'], + description=tmp_dict.get('description',''), embedding_model=tmp_dict['embedding_model'], document_number=0, document_size=0, @@ -30,8 +30,15 @@ class KnowledgeConvertor(): for document_type_entity in document_type_entity_list] return KnowledgeBaseDTO( id=str(entity.id), - name=entity.name, language=entity.language, description=entity.description, - embedding_model=entity.embedding_model, default_parser_method=entity.default_parser_method, - default_chunk_size=entity.default_chunk_size, document_count=entity.document_number,status=entity.status, - document_size=entity.document_size, document_type_list=document_type_dto_list, - created_time=entity.created_time.strftime('%Y-%m-%d %H:%M')) + name=entity.name, + language=entity.language, + description=entity.description, + embedding_model=entity.embedding_model, + default_parser_method=entity.default_parser_method, + default_chunk_size=entity.default_chunk_size, + document_count=entity.document_number, + status=entity.status, + document_size=entity.document_size, + document_type_list=document_type_dto_list, + created_time=entity.created_time.strftime('%Y-%m-%d %H:%M') + ) diff --git a/data_chain/apps/base/convertor/task_convertor.py b/data_chain/apps/base/convertor/task_convertor.py index f908f18..f1f16d2 100644 --- a/data_chain/apps/base/convertor/task_convertor.py +++ b/data_chain/apps/base/convertor/task_convertor.py @@ -6,11 +6,11 @@ from data_chain.stores.postgres.postgres import TaskEntity, TaskStatusReportEnti class TaskConvertor(): @staticmethod - def convert_entity_to_dto(task_entity: TaskEntity, TaskStatusReportEntityList: List[TaskStatusReportEntity] = None, - op_dict: Dict = None) -> TaskDTO: + def convert_entity_to_dto(task_entity: TaskEntity, TaskStatusReportEntityList: List[TaskStatusReportEntity] = None) -> TaskDTO: reports = [] for task_status_report_entity in TaskStatusReportEntityList: - reports.append(TaskReportDTO( + reports.append( + TaskReportDTO( id=task_status_report_entity.id, message=task_status_report_entity.message, current_stage=task_status_report_entity.current_stage, diff --git a/data_chain/apps/base/model/llm.py b/data_chain/apps/base/model/llm.py index bdc701a..a830da2 100644 --- a/data_chain/apps/base/model/llm.py +++ b/data_chain/apps/base/model/llm.py @@ -2,6 +2,7 @@ import asyncio import time import json +import tiktoken from langchain_openai import ChatOpenAI from langchain.schema import SystemMessage, HumanMessage from data_chain.logger.logger import logger as logging @@ -46,6 +47,9 @@ class LLM: # 启动生产者任务 producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call)) first_token_reach = False + enc = tiktoken.encoding_for_model("gpt-4") + input_tokens=len(enc.encode(system_call)) + output_tokens=0 while True: data = await q.get() if data is None: @@ -53,9 +57,14 @@ class LLM: if not first_token_reach: first_token_reach = True logging.info(f"大模型回复第一个字耗时 = {time.time() - st}") - for char in data: - yield "data: " + json.dumps({'content': char}, ensure_ascii=False) + '\n\n' - await asyncio.sleep(0.03) # 使用异步 sleep + output_tokens+=len(enc.encode(data)) + yield "data: " + json.dumps( + {'content': data, + 'input_tokens':input_tokens, + 'output_tokens':output_tokens + }, ensure_ascii=False + ) + '\n\n' + await asyncio.sleep(0.03) # 使用异步 sleep yield "data: [DONE]" logging.info(f"大模型回复耗时 = {time.time() - st}") diff --git a/data_chain/apps/base/task/document_task_handler.py b/data_chain/apps/base/task/document_task_handler.py index b0fc502..32425fb 100644 --- a/data_chain/apps/base/task/document_task_handler.py +++ b/data_chain/apps/base/task/document_task_handler.py @@ -8,7 +8,7 @@ from fastapi import File, UploadFile import aiofiles from typing import List from data_chain.apps.base.task.task_handler import TaskRedisHandler -from data_chain.manager.document_manager import DocumentManager +from data_chain.manager.document_manager import DocumentManager,TemporaryDocumentManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.task_manager import TaskManager, TaskStatusReportManager from data_chain.models.constant import DocumentEmbeddingConstant, OssConstant, TaskConstant @@ -18,70 +18,7 @@ from data_chain.stores.postgres.postgres import DocumentEntity, TaskEntity, Task from data_chain.config.config import config - - class DocumentTaskHandler(): - @staticmethod - async def submit_upload_document_task(user_id: uuid.UUID, kb_id: uuid.UUID, files: List[UploadFile] = File(...)): - target_dir=None - try: - # 创建目标目录 - file_upload_successfully_list = [] - target_dir = os.path.join(OssConstant.UPLOAD_DOCUMENT_SAVE_FOLDER, str(user_id), secrets.token_hex(16)) - if os.path.exists(target_dir): - shutil.rmtree(target_dir) - os.makedirs(target_dir) - knowledge_base_entity = await KnowledgeBaseManager.select_by_id(kb_id) - for file in files: - try: - # 1. 将文件写入本地stash目录 - document_file_path = await DocumentTaskHandler.save_document_file_to_local(target_dir, file) - except Exception as e: - logging.error(f"save_document_file_to_local error: {e}") - continue - kb_entity = await KnowledgeBaseManager.select_by_id(kb_id) - # 2. 更新document记录 - file_name = file.filename - if await DocumentManager.select_by_knowledge_base_id_and_file_name(kb_entity.id, file_name): - name = os.path.splitext(file_name)[0] - extension = os.path.splitext(file_name)[1] - file_name = name[:128]+secrets.token_hex(16)+extension - document_entity = await DocumentManager.insert( - DocumentEntity( - kb_id=kb_id, user_id=user_id, name=file_name, - extension=os.path.splitext(file.filename)[1], - size=os.path.getsize(document_file_path), - parser_method=kb_entity.default_parser_method, - type_id='00000000-0000-0000-0000-000000000000', - chunk_size=kb_entity.default_chunk_size, - enabled=True, - status=DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING) - ) - if not await MinIO.put_object(OssConstant.MINIO_BUCKET_DOCUMENT, str(document_entity.id), document_file_path): - logging.error(f"上传文件到minIO失败,文件名:{file.filename}") - await DocumentManager.delete_by_id(document_entity.id) - continue - # 3. 创建task表记录 - task_entity = await TaskManager.insert(TaskEntity(user_id=user_id, - op_id=document_entity.id, - type=TaskConstant.PARSE_DOCUMENT, - retry=0, - status=TaskConstant.TASK_STATUS_PENDING)) - # 4. 提交redis任务队列 - TaskRedisHandler.put_task_by_tail(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_entity.id)) - file_upload_successfully_list.append(file.filename) - # 5.更新kb的文档数和文档总大小 - total_cnt, total_sz = await DocumentManager.select_cnt_and_sz_by_kb_id(kb_id) - update_dict = {'document_number': total_cnt, - 'document_size': total_sz} - await KnowledgeBaseManager.update(kb_id, update_dict) - except Exception as e: - raise e - return [] - finally: - if target_dir is not None and os.path.exists(target_dir): - shutil.rmtree(target_dir) - return file_upload_successfully_list @staticmethod async def save_document_file_to_local(target_dir: str, file: UploadFile): @@ -92,11 +29,14 @@ class DocumentTaskHandler(): return document_file_path @staticmethod - async def handle_parser_document_task(task_entity: TaskEntity): - target_dir=None + async def handle_parser_document_task(t_id: uuid.UUID): + target_dir = None try: + task_entity = await TaskManager.select_by_id(t_id) + if task_entity.status != TaskConstant.TASK_STATUS_RUNNING: + TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], str(task_entity.id)) + return # 文件解析主入口 - await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_RUNNING}) # 下载文件 document_entity = await DocumentManager.select_by_id(task_entity.op_id) @@ -119,13 +59,25 @@ class DocumentTaskHandler(): answer = await parser.parser(document_entity.id, file_path) chunk_list, chunk_link_list, images = answer['chunk_list'], answer['chunk_link_list'], answer[ 'image_chunks'] + chunk_id_set=set() + for chunk in chunk_list: + chunk_id_set.add(chunk['id']) + new_chunk_link_list = [] + for chunk_link in chunk_link_list: + if chunk_link['chunk_a'] in chunk_id_set: + new_chunk_link_list.append(chunk_link) + chunk_link_list = new_chunk_link_list + new_images = [] + for image in images: + if image['chunk_id'] in chunk_id_set: + new_images.append(image) + images = new_images await TaskStatusReportManager.insert(TaskStatusReportEntity( task_id=task_entity.id, message=f'Parse document {document_entity.name} completed, waiting for uploading', current_stage=2, stage_cnt=7 )) - await parser.upload_chunks_to_pg(chunk_list) await parser.upload_chunk_links_to_pg(chunk_link_list) await TaskStatusReportManager.insert(TaskStatusReportEntity( @@ -134,24 +86,23 @@ class DocumentTaskHandler(): current_stage=3, stage_cnt=7 )) + await parser.upload_images_to_minio(images) await TaskStatusReportManager.insert(TaskStatusReportEntity( task_id=task_entity.id, - message=f'Upload document {document_entity.name} images completed', + message=f'Upload document {document_entity.name} images to minio completed', current_stage=4, stage_cnt=7 )) await parser.upload_images_to_pg(images) await TaskStatusReportManager.insert(TaskStatusReportEntity( task_id=task_entity.id, - message=f'Upload document {document_entity.name} images completed', + message=f'Upload document {document_entity.name} images to pg completed', current_stage=5, stage_cnt=7 )) vectors = await parser.embedding_chunks(chunk_list) - for vector in vectors: - vector['kb_id'] = document_entity.kb_id await parser.insert_vectors_to_pg(vectors) await TaskStatusReportManager.insert(TaskStatusReportEntity( task_id=task_entity.id, @@ -180,3 +131,83 @@ class DocumentTaskHandler(): finally: if target_dir and os.path.exists(target_dir): shutil.rmtree(target_dir) # 清理文件 + @staticmethod + async def handle_parser_temporary_document_task(t_id: uuid.UUID): + target_dir = None + try: + task_entity = await TaskManager.select_by_id(t_id) + if task_entity.status != TaskConstant.TASK_STATUS_RUNNING: + TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], str(task_entity.id)) + return + # 文件解析主入口 + + document_entity = await TemporaryDocumentManager.select_by_id(task_entity.op_id) + target_dir = os.path.join(OssConstant.PARSER_SAVE_FOLDER, str(document_entity.id), secrets.token_hex(16)) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + file_extension = document_entity.extension.lower() + file_path = target_dir + '/' + str(document_entity.id) + file_extension + await MinIO.download_object(document_entity.bucket_name, str(document_entity.id), + file_path) + + await TaskStatusReportManager.insert( + TaskStatusReportEntity( + task_id=task_entity.id, + message=f'Temporary document {document_entity.name} begin to parse', + current_stage=1, + stage_cnt=6 + )) + + parser = ParserService() + parser_result = await parser.parser(document_entity.id, file_path,is_temporary_document=True) + chunk_list, chunk_link_list, images = parser_result['chunk_list'], parser_result['chunk_link_list'], parser_result[ + 'image_chunks'] + await TaskStatusReportManager.insert(TaskStatusReportEntity( + task_id=task_entity.id, + message=f'Parse temporary document {document_entity.name} completed, waiting for uploading', + current_stage=2, + stage_cnt=6 + )) + + await parser.upload_chunks_to_pg(chunk_list,is_temporary_document=True) + await TaskStatusReportManager.insert(TaskStatusReportEntity( + task_id=task_entity.id, + message=f'Upload temporary document {document_entity.name} chunks completed', + current_stage=3, + stage_cnt=6 + )) + await parser.upload_images_to_minio(images,is_temporary_document=True) + await TaskStatusReportManager.insert(TaskStatusReportEntity( + task_id=task_entity.id, + message=f'Upload temporary document {document_entity.name} images completed', + current_stage=4, + stage_cnt=6 + )) + vectors = await parser.embedding_chunks(chunk_list,is_temporary_document=True) + await parser.insert_vectors_to_pg(vectors,is_temporary_document=True) + await TaskStatusReportManager.insert(TaskStatusReportEntity( + task_id=task_entity.id, + message=f'Upload temporary document {document_entity.name} vectors completed', + current_stage=5, + stage_cnt=6 + )) + + await TaskStatusReportManager.insert(TaskStatusReportEntity( + task_id=task_entity.id, + message=f'Parse temporary document {task_entity.id} succcessfully', + current_stage=6, + stage_cnt=6 + )) + await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_SUCCESS}) + TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) + except Exception as e: + TaskRedisHandler.put_task_by_tail(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_entity.id)) + await TaskStatusReportManager.insert(TaskStatusReportEntity( + task_id=task_entity.id, + message=f'Parse temporary document {task_entity.id} failed due to {e}', + current_stage=5, + stage_cnt=5 + )) + finally: + if target_dir and os.path.exists(target_dir): + shutil.rmtree(target_dir) # 清理文件 diff --git a/data_chain/apps/base/task/knowledge_base_task_handler.py b/data_chain/apps/base/task/knowledge_base_task_handler.py index 92fccaa..42bd77b 100644 --- a/data_chain/apps/base/task/knowledge_base_task_handler.py +++ b/data_chain/apps/base/task/knowledge_base_task_handler.py @@ -60,13 +60,23 @@ class KnowledgeBaseTaskHandler(): if parser_method not in parse_methods: parser_method=ParseMethodEnum.GENERAL document_entity = DocumentEntity( - kb_id=knowledge_base_entity.id, user_id=knowledge_base_entity.user_id, name=file_name, - extension=data.get('extension', ''), - size=file_size, parser_method=parser_method, - type_id=document_type_dict.get(data['type'], - uuid.UUID('00000000-0000-0000-0000-000000000000')), - chunk_size=data.get('chunk_size', knowledge_base_entity.default_chunk_size), - enabled=True, status=DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING) + kb_id=knowledge_base_entity.id,\ + user_id=knowledge_base_entity.user_id,\ + name=file_name,\ + extension=data.get('extension', ''),\ + size=file_size,\ + parser_method=parser_method,\ + type_id=document_type_dict.get( + data['type'], + uuid.UUID('00000000-0000-0000-0000-000000000000') + ),\ + chunk_size=data.get( + 'chunk_size', + knowledge_base_entity.default_chunk_size + ),\ + enabled=True,\ + status=DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING + ) document_entity_list.append(document_entity) file_path_list.append(file_path) await DocumentManager.insert_bulk(document_entity_list) @@ -81,9 +91,16 @@ class KnowledgeBaseTaskHandler(): return document_entity_list @staticmethod - async def handle_import_knowledge_base_task(task_entity: TaskEntity): + async def handle_import_knowledge_base_task(t_id: uuid.UUID): unzip_folder_path=None try: + task_entity=await TaskManager.select_by_id(t_id) + if task_entity.status != TaskConstant.TASK_STATUS_RUNNING: + TaskRedisHandler.put_task_by_tail( + config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], + str(task_entity.id) + ) + return # 定义目标目录 kb_id = task_entity.op_id knowledge_base_entity = await KnowledgeBaseManager.select_by_id(kb_id) @@ -97,14 +114,16 @@ class KnowledgeBaseTaskHandler(): zip_file_path = os.path.join(unzip_folder_path, str(kb_id)+'.zip') # todo 下面两个子过程记录task report if not await MinIO.download_object(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE, str(kb_id), zip_file_path): - await TaskStatusReportManager.insert(TaskStatusReportEntity( + await TaskStatusReportManager.insert( + TaskStatusReportEntity( task_id=task_entity.id, message=f'Download knowledge base {kb_id} zip file failed', current_stage=3, stage_cnt=3 - )) + ) + ) await KnowledgeBaseManager.update(kb_id, {'status': KnowledgeStatusEnum.IDLE}) - await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_SUCCESS}) + await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_FAILED}) TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) return await TaskStatusReportManager.insert(TaskStatusReportEntity( @@ -130,7 +149,11 @@ class KnowledgeBaseTaskHandler(): current_stage=2, stage_cnt=3 )) - document_entity_list = await KnowledgeBaseTaskHandler.parse_knowledge_base_document_yaml_file(knowledge_base_entity, unzip_folder_path) + document_entity_list = await KnowledgeBaseTaskHandler.\ + parse_knowledge_base_document_yaml_file( + knowledge_base_entity,\ + unzip_folder_path + ) await TaskStatusReportManager.insert(TaskStatusReportEntity( task_id=task_entity.id, message=f'Save document and parse yaml from knowledge base {kb_id} zip file succcessfully', @@ -162,9 +185,13 @@ class KnowledgeBaseTaskHandler(): shutil.rmtree(unzip_folder_path) @staticmethod - async def handle_export_knowledge_base_task(task_entity: TaskEntity): + async def handle_export_knowledge_base_task(t_id: uuid.UUID): knowledge_yaml_path=None try: + task_entity=await TaskManager.select_by_id(t_id) + if task_entity.status != TaskConstant.TASK_STATUS_RUNNING: + TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], str(task_entity.id)) + return knowledge_base_entity = await KnowledgeBaseManager.select_by_id(task_entity.op_id) user_entity = await UserManager.get_user_info_by_user_id(knowledge_base_entity.user_id) zip_file_path = os.path.join(OssConstant.ZIP_FILE_SAVE_FOLDER, str(user_entity.id)) @@ -192,17 +219,23 @@ class KnowledgeBaseTaskHandler(): 'language': knowledge_base_entity.language, 'name': knowledge_base_entity.name } - knowledge_dict['document_type_list'] = list(set([document_type_entity.type - for document_type_entity in document_type_entity_list])) + knowledge_dict['document_type_list'] = list( + set([document_type_entity.type + for document_type_entity in document_type_entity_list] + ) + ) yaml.dump(knowledge_dict, knowledge_yaml_file) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Save knowledge base yaml from knowledge base {knowledge_base_entity.id} succcessfully', - current_stage=1, - stage_cnt=4 - )) - document_type_entity_map = {str(document_type_entity.id): document_type_entity.type - for document_type_entity in document_type_entity_list} + await TaskStatusReportManager.insert( + TaskStatusReportEntity( + task_id=task_entity.id,\ + message=f'Save knowledge base yaml from knowledge base {knowledge_base_entity.id} succcessfully',\ + current_stage=1,\ + stage_cnt=4\ + ) + ) + document_type_entity_map = { + str(document_type_entity.id): document_type_entity.type + for document_type_entity in document_type_entity_list} document_entity_list = await DocumentManager.select_by_knowledge_base_id(knowledge_base_entity.id) for document_entity in document_entity_list: try: @@ -214,57 +247,78 @@ class KnowledgeBaseTaskHandler(): if str(document_entity.type_id) in document_type_entity_map.keys(): doc_type = document_type_entity_map[str(document_entity.type_id)] yaml.dump( - {'name': document_entity.name, 'extension': document_entity.extension, - 'parser_method': document_entity.parser_method,'chunk_size':document_entity.chunk_size, - 'type': doc_type}, + { + 'name': document_entity.name,\ + 'extension': document_entity.extension,\ + 'parser_method': document_entity.parser_method,\ + 'chunk_size':document_entity.chunk_size,\ + 'type': doc_type}, yaml_file) except Exception as e: - logging.error(f"download document {document_entity.id} failed: {str(e)}") + logging.error(f"Download document {document_entity.id} failed due to: {e}") continue - await TaskStatusReportManager.insert(TaskStatusReportEntity( + await TaskStatusReportManager.insert( + TaskStatusReportEntity( task_id=task_entity.id, message=f'Save document and yaml from knowledge base {knowledge_base_entity.id} succcessfully', current_stage=2, stage_cnt=4 )) - # 最后压缩export目录, 并放入minIO, 然后生成下载链接 + # 最后压缩export目录, 并放入minIO save_knowledge_base_zip_file_name = str(task_entity.id)+".zip" await ZipHandler.zip_dir(knowledge_yaml_path, os.path.join(zip_file_path, save_knowledge_base_zip_file_name)) - await TaskStatusReportManager.insert(TaskStatusReportEntity( + await TaskStatusReportManager.insert( + TaskStatusReportEntity( task_id=task_entity.id, message=f'Zip knowledge base {knowledge_base_entity.id} succcessfully', current_stage=3, stage_cnt=4 )) # 上传到minIO exportzip桶 - res = await MinIO.put_object(OssConstant.MINIO_BUCKET_EXPORTZIP, str(task_entity.id), - os.path.join(zip_file_path, save_knowledge_base_zip_file_name)) + res = await MinIO.put_object( + OssConstant.MINIO_BUCKET_KNOWLEDGEBASE,\ + str(task_entity.id),\ + os.path.join( + zip_file_path,\ + save_knowledge_base_zip_file_name\ + ) + ) if res: - await TaskStatusReportManager.insert(TaskStatusReportEntity( + await TaskStatusReportManager.insert( + TaskStatusReportEntity( task_id=task_entity.id, message=f'Update knowledge base {knowledge_base_entity.id} zip file {save_knowledge_base_zip_file_name} to Minio succcessfully', current_stage=4, stage_cnt=4 - )) + ) + ) else: - await TaskStatusReportManager.insert(TaskStatusReportEntity( + await TaskStatusReportManager.insert( + TaskStatusReportEntity( task_id=task_entity.id, message=f'Update knowledge base {knowledge_base_entity.id} zip file {save_knowledge_base_zip_file_name} to Minio failed', current_stage=4, stage_cnt=4 - )) + ) + ) + raise Exception(f'Update knowledge base {knowledge_base_entity.id} zip file {save_knowledge_base_zip_file_name} to Minio failed') await KnowledgeBaseManager.update(knowledge_base_entity.id, {'status': KnowledgeStatusEnum.IDLE}) await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_SUCCESS}) TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) except Exception as e: - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Import knowledge base {task_entity.op_id} failed due to {e}', - current_stage=0, - stage_cnt=4 - )) - TaskRedisHandler.put_task_by_tail(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_entity.id)) - logging.error("Export knowledge base zip files error: {}".format(e)) + await TaskStatusReportManager.insert( + TaskStatusReportEntity( + task_id=task_entity.id,\ + message=f'Import knowledge base {task_entity.op_id} failed due to {e}',\ + current_stage=0,\ + stage_cnt=4\ + ) + ) + TaskRedisHandler.put_task_by_tail( + config['REDIS_RESTART_TASK_QUEUE_NAME'],\ + str(task_entity.id) + ) + logging.error(f"Export knowledge base zip files errordue to {e}") finally: if knowledge_yaml_path and os.path.exists(knowledge_yaml_path): shutil.rmtree(knowledge_yaml_path) diff --git a/data_chain/apps/base/task/task_handler.py b/data_chain/apps/base/task/task_handler.py index 5c1da02..4ea0c21 100644 --- a/data_chain/apps/base/task/task_handler.py +++ b/data_chain/apps/base/task/task_handler.py @@ -12,27 +12,15 @@ from data_chain.stores.redis.redis import RedisConnectionPool from data_chain.manager.task_manager import TaskManager from data_chain.manager.document_manager import DocumentManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager -from data_chain.manager.chunk_manager import ChunkManager +from data_chain.manager.chunk_manager import ChunkManager,TemporaryChunkManager from data_chain.models.constant import TaskConstant, DocumentEmbeddingConstant, KnowledgeStatusEnum, OssConstant, TaskActionEnum from data_chain.stores.minio.minio import MinIO + multiprocessing = multiprocessing.get_context('spawn') class TaskHandler: tasks = {} # 存储进程的字典 lock = multiprocessing.Lock() # 创建一个锁对象 - max_processes = max((os.cpu_count() or 1)//2, 1) # 获取CPU核心数作为最大进程数,默认为1 - - @staticmethod - async def start_subprocess(target, *args, **kwargs): - process = await asyncio.create_subprocess_exec( - sys.executable, - "-m", "multiprocessing.spawn", - "subprocess_target", - args=(target, *args), - kwargs=kwargs, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE - ) - stdout, stderr = await process.communicate() + max_processes = min(max((os.cpu_count() or 1)//2, 1),config['DOCUMENT_PARSE_USE_CPU_LIMIT']) # 获取CPU核心数作为最大进程数,默认为1 @staticmethod def subprocess_target(target, *args, **kwargs): @@ -46,7 +34,7 @@ class TaskHandler: @staticmethod def add_task(task_id: uuid.UUID, target, *args, **kwargs): with TaskHandler.lock: - if len(TaskHandler.tasks) >= TaskHandler.max_processes: + if len(TaskHandler.tasks)>= TaskHandler.max_processes: logging.info("Reached maximum number of active processes.") return False @@ -60,7 +48,7 @@ class TaskHandler: return True @staticmethod - def remove_task(task_id): + def remove_task(task_id: uuid.UUID): with TaskHandler.lock: if task_id in TaskHandler.tasks.keys(): process = TaskHandler.tasks[task_id] @@ -115,22 +103,32 @@ class TaskHandler: if task_type == TaskConstant.PARSE_DOCUMENT: await DocumentManager.update(op_id, {'status': DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING}) await ChunkManager.delete_by_document_ids([op_id]) + elif task_type == TaskConstant.PARSE_TEMPORARY_DOCUMENT: + await TemporaryChunkManager.delete_by_temporary_document_ids([op_id]) elif task_type == TaskConstant.IMPORT_KNOWLEDGE_BASE: await KnowledgeBaseManager.delete(op_id) elif task_type == TaskConstant.EXPORT_KNOWLEDGE_BASE: await KnowledgeBaseManager.update(op_id, {'status': KnowledgeStatusEnum.EXPROTING}) + await TaskManager.update(task_id, {"status": TaskConstant.TASK_STATUS_PENDING}) TaskRedisHandler.put_task_by_tail(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_id)) elif method == TaskActionEnum.RESTART or method == TaskActionEnum.CANCEL or method == TaskActionEnum.DELETE: + TaskRedisHandler.remove_task_by_task_id(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_id)) + TaskRedisHandler.remove_task_by_task_id(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_id)) + TaskRedisHandler.remove_task_by_task_id(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_id)) if method == TaskActionEnum.CANCEL: await TaskManager.update(task_id, {"status": TaskConstant.TASK_STATUS_CANCELED}) + elif method == TaskActionEnum.DELETE: + await TaskManager.update(task_id, {"status": TaskConstant.TASK_STATUS_DELETED}) else: await TaskManager.update(task_id, {"status": TaskConstant.TASK_STATUS_FAILED}) if task_type == TaskConstant.PARSE_DOCUMENT: await DocumentManager.update(op_id, {'status': DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_PENDING}) await ChunkManager.delete_by_document_ids([op_id]) + elif task_type == TaskConstant.PARSE_TEMPORARY_DOCUMENT: + await TemporaryChunkManager.delete_by_temporary_document_ids([op_id]) elif task_type == TaskConstant.IMPORT_KNOWLEDGE_BASE: await KnowledgeBaseManager.delete(op_id) - await MinIO.delete_object(OssConstant.MINIO_BUCKET_EXPORTZIP, str(task_entity.op_id)) + await MinIO.delete_object(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE, str(task_entity.op_id)) elif task_type == TaskConstant.EXPORT_KNOWLEDGE_BASE: await KnowledgeBaseManager.update(op_id, {'status': KnowledgeStatusEnum.IDLE}) await MinIO.delete_object(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE, str(task_id)) diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py index bbb09db..5c84cea 100644 --- a/data_chain/apps/router/document.py +++ b/data_chain/apps/router/document.py @@ -2,19 +2,21 @@ import urllib from typing import Dict, List import uuid -from fastapi import HTTPException -from data_chain.models.service import DocumentDTO +from fastapi import HTTPException,status +from data_chain.models.service import DocumentDTO, TemporaryDocumentDTO from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, verify_user from data_chain.exceptions.err_code import ErrorCode from data_chain.exceptions.exception import DocumentException from data_chain.models.api import BaseResponse, Page from data_chain.models.api import DeleteDocumentRequest, ListDocumentRequest, UpdateDocumentRequest, \ - RunDocumentEmbeddingRequest, SwitchDocumentRequest + RunDocumentRequest, SwitchDocumentRequest, ParserTemporaryDocumenRequest, GetTemporaryDocumentStatusRequest, \ + DeleteTemporaryDocumentRequest, RelatedTemporaryDocumenRequest from data_chain.apps.service.knwoledge_base_service import _validate_knowledge_base_belong_to_user from data_chain.apps.service.document_service import _validate_doucument_belong_to_user, delete_document, \ generate_document_download_link, \ list_documents_by_knowledgebase_id, run_document, submit_upload_document_task, switch_document, update_document, \ - get_file_name_and_extension + get_file_name_and_extension, init_temporary_document_parse_task, delete_temporary_document, get_temporary_document_parse_status, \ + get_related_document from httpx import AsyncClient from fastapi import Depends @@ -58,7 +60,7 @@ async def update(req: UpdateDocumentRequest, user_id=Depends(get_user_id)): @router.post('/run', response_model=BaseResponse[List[DocumentDTO]], dependencies=[Depends(verify_user), Depends(verify_csrf_token)]) -async def run(reqs: RunDocumentEmbeddingRequest, user_id=Depends(get_user_id)): +async def run(reqs: RunDocumentRequest, user_id=Depends(get_user_id)): try: run = reqs.run ids = reqs.ids @@ -149,3 +151,56 @@ async def download(id: uuid.UUID, user_id=Depends(get_user_id)): retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, retmsg="Failed to retrieve the file.", data=None) except DocumentException as e: return BaseResponse(retcode=ErrorCode.DOWNLOAD_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) + + +@router.post('/temporary/related', response_model=BaseResponse[List[uuid.UUID]]) +async def related_temporary_doc(req: RelatedTemporaryDocumenRequest): + try: + results = await get_related_document(req.content,req.top_k, req.document_ids, req.kb_sn) + return BaseResponse(data=results) + except Exception as e: + return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, retmsg=str(e), data=None) + +@router.post('/temporary/parser', response_model=BaseResponse[List[uuid.UUID]]) +async def parser_temporary_doc(req: ParserTemporaryDocumenRequest): + try: + temporary_document_list = [] + for i in range(len(req.document_list)): + tmp_dict=dict(req.document_list[i]) + if tmp_dict['type']=='application/pdf': + tmp_dict['type']='.pdf' + elif tmp_dict['type']=='text/html': + tmp_dict['type']='.html' + elif tmp_dict['type']=='text/plain': + tmp_dict['type']='.txt' + elif tmp_dict['type']=='application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': + tmp_dict['type']='.xlsx' + elif tmp_dict['type']=='text/x-markdown': + tmp_dict['type']='.md' + elif tmp_dict['type']=='application/vnd.openxmlformats-officedocument.wordprocessingml.document': + tmp_dict['type']='.docx' + elif tmp_dict['type']=='application/msword': + tmp_dict['type']='.doc' + temporary_document_list.append(tmp_dict) + result = await init_temporary_document_parse_task(temporary_document_list) + return BaseResponse(data=result) + except Exception as e: + return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, retmsg=str(e), data=None) + + +@router.post('/temporary/status', response_model=BaseResponse[List[TemporaryDocumentDTO]]) +async def get_temporary_doc_parse_status(req: GetTemporaryDocumentStatusRequest): + try: + result = await get_temporary_document_parse_status(req.ids) + return BaseResponse(data=result) + except Exception as e: + return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, retmsg=str(e), data=None) + + +@router.post('/temporary/delete', response_model=BaseResponse[List[uuid.UUID]]) +async def delete_temporary_doc(req: DeleteTemporaryDocumentRequest): + try: + result = await delete_temporary_document(req.ids) + return BaseResponse(data=result) + except Exception as e: + return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, retmsg=str(e), data=None) diff --git a/data_chain/apps/router/knowledge_base.py b/data_chain/apps/router/knowledge_base.py index caf2d2f..9ae06b9 100644 --- a/data_chain/apps/router/knowledge_base.py +++ b/data_chain/apps/router/knowledge_base.py @@ -207,39 +207,31 @@ async def rm_kb_task(req: RmoveTaskRequest, user_id=Depends(get_user_id)): @router.post('/get_stream_answer', response_class=HTMLResponse) async def get_stream_answer(req: QueryRequest, response: Response): - def is_valid_uuid4(uuid_str): - try: - uuid.UUID(uuid_str, version=4) - return True - except Exception as e: - logging.error(f"Invalid UUID string: {e}") - return False try: question = await question_rewrite(req.history, req.question) max_tokens = config['MAX_TOKENS']//3 bac_info = '' - if req.kb_sn and is_valid_uuid4(req.kb_sn): - document_chunk_list = await get_similar_chunks(uuid.UUID(req.kb_sn), question, config['MAX_TOKENS']//2, req.top_k) + document_chunk_list = await get_similar_chunks(content=question,kb_id=req.kb_sn,temporary_document_ids=req.document_ids, max_tokens=config['MAX_TOKENS']//2, topk=req.top_k) + for i in range(len(document_chunk_list)): + document_name = document_chunk_list[i]['document_name'] + chunk_list = document_chunk_list[i]['chunk_list'] + bac_info += '文档名称:'+document_name+':\n\n' + for j in range(len(chunk_list)): + bac_info += '段落'+str(j)+':\n\n' + bac_info += chunk_list[j]+'\n\n' + bac_info = split_chunk(bac_info) + if len(bac_info) > max_tokens: + bac_info = '' for i in range(len(document_chunk_list)): document_name = document_chunk_list[i]['document_name'] chunk_list = document_chunk_list[i]['chunk_list'] bac_info += '文档名称:'+document_name+':\n\n' for j in range(len(chunk_list)): bac_info += '段落'+str(j)+':\n\n' - bac_info += chunk_list[j]+'\n\n' + bac_info += ''.join(filter_stopwords(chunk_list[j]))+'\n\n' bac_info = split_chunk(bac_info) - if len(bac_info) > max_tokens: - bac_info = '' - for i in range(len(document_chunk_list)): - document_name = document_chunk_list[i]['document_name'] - chunk_list = document_chunk_list[i]['chunk_list'] - bac_info += '文档名称:'+document_name+':\n\n' - for j in range(len(chunk_list)): - bac_info += '段落'+str(j)+':\n\n' - bac_info += ''.join(filter_stopwords(chunk_list[j]))+'\n\n' - bac_info = split_chunk(bac_info) - bac_info = bac_info[:max_tokens] - bac_info = ' '.join(bac_info) + bac_info = bac_info[:max_tokens] + bac_info = ' '.join(bac_info) except Exception as e: bac_info = '' logging.error(f"get bac info failed due to: {e}") @@ -258,39 +250,31 @@ async def get_stream_answer(req: QueryRequest, response: Response): @router.post('/get_answer', response_model=BaseResponse[dict]) async def get_answer(req: QueryRequest): - def is_valid_uuid4(uuid_str): - try: - uuid.UUID(uuid_str, version=4) - return True - except Exception as e: - logging.error(f"Invalid UUID string: {e}") - return False try: question = await question_rewrite(req.history, req.question) max_tokens = config['MAX_TOKENS']//3 bac_info = '' - if req.kb_sn and is_valid_uuid4(req.kb_sn): - document_chunk_list = await get_similar_chunks(uuid.UUID(req.kb_sn), question, config['MAX_TOKENS']//2, req.top_k) + document_chunk_list = await get_similar_chunks(content=question,kb_id=req.kb_sn,temporary_document_ids=req.document_ids, max_tokens=config['MAX_TOKENS']//2, topk=req.top_k) + for i in range(len(document_chunk_list)): + document_name = document_chunk_list[i]['document_name'] + chunk_list = document_chunk_list[i]['chunk_list'] + bac_info += '文档名称:'+document_name+':\n\n' + for j in range(len(chunk_list)): + bac_info += '段落'+str(j)+':\n\n' + bac_info += chunk_list[j]+'\n\n' + bac_info = split_chunk(bac_info) + if len(bac_info) > max_tokens: + bac_info = '' for i in range(len(document_chunk_list)): document_name = document_chunk_list[i]['document_name'] chunk_list = document_chunk_list[i]['chunk_list'] bac_info += '文档名称:'+document_name+':\n\n' for j in range(len(chunk_list)): bac_info += '段落'+str(j)+':\n\n' - bac_info += chunk_list[j]+'\n\n' + bac_info += ''.join(filter_stopwords(chunk_list[j]))+'\n\n' bac_info = split_chunk(bac_info) - if len(bac_info) > max_tokens: - bac_info = '' - for i in range(len(document_chunk_list)): - document_name = document_chunk_list[i]['document_name'] - chunk_list = document_chunk_list[i]['chunk_list'] - bac_info += '文档名称:'+document_name+':\n\n' - for j in range(len(chunk_list)): - bac_info += '段落'+str(j)+':\n\n' - bac_info += ''.join(filter_stopwords(chunk_list[j]))+'\n\n' - bac_info = split_chunk(bac_info) - bac_info = bac_info[:max_tokens] - bac_info = ' '.join(bac_info) + bac_info = bac_info[:max_tokens] + bac_info = ' '.join(bac_info) except Exception as e: bac_info = '' logging.error(f"get bac info failed due to: {e}") @@ -301,11 +285,11 @@ async def get_answer(req: QueryRequest): } if req.fetch_source: tmp_dict['source'] = [] - for i in range(len(document_chunk_list)): - document_name = document_chunk_list[i]['document_name'] - chunk_list = document_chunk_list[i]['chunk_list'] - for j in range(len(chunk_list)): - tmp_dict['source'].append({'document_name': document_name, 'chunk': chunk_list[j]}) + for i in range(len(document_chunk_list)): + document_name = document_chunk_list[i]['document_name'] + chunk_list = document_chunk_list[i]['chunk_list'] + for j in range(len(chunk_list)): + tmp_dict['source'].append({'document_name': document_name, 'chunk': chunk_list[j]}) return BaseResponse(data=tmp_dict) except Exception as e: logging.error(f"Get stream answer failed due to: {e}") diff --git a/data_chain/apps/router/model.py b/data_chain/apps/router/model.py index 6c6af5b..c90eabb 100644 --- a/data_chain/apps/router/model.py +++ b/data_chain/apps/router/model.py @@ -7,7 +7,7 @@ from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, from data_chain.exceptions.err_code import ErrorCode from data_chain.exceptions.exception import DocumentException from data_chain.models.api import BaseResponse -from data_chain.models.api import CreateModelRequest, UpdateModelRequest +from data_chain.models.api import UpdateModelRequest from data_chain.apps.service.model_service import get_model, update_model diff --git a/data_chain/apps/router/user.py b/data_chain/apps/router/user.py index cde6b8a..39352fc 100644 --- a/data_chain/apps/router/user.py +++ b/data_chain/apps/router/user.py @@ -9,7 +9,7 @@ from data_chain.apps.service.user_service import verify_csrf_token, verify_passw from data_chain.apps.base.session.session import SessionManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.user_manager import UserManager -from data_chain.models.api import BaseResponse, UserAddRequest, UserUpdateRequest +from data_chain.models.api import BaseResponse, AddUserRequest, UpdateUserRequest router = APIRouter( @@ -19,7 +19,7 @@ router = APIRouter( @router.post("/add", response_model=BaseResponse) -async def add_user(request: UserAddRequest): +async def add_user(request: AddUserRequest): name = request.name account = request.account passwd = request.passwd @@ -143,7 +143,7 @@ async def logout(request: Request, response: Response, user_id=Depends(get_user_ @router.post("/update", response_model=BaseResponse, dependencies=[Depends(verify_user), Depends(verify_csrf_token)]) -async def switch(req: UserUpdateRequest, user_id=Depends(get_user_id)): +async def switch(req: UpdateUserRequest, user_id=Depends(get_user_id)): user_info = UserManager.get_user_info_by_user_id(user_id) if user_info is None: return BaseResponse( diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index 7a26d1e..a1c352d 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -3,18 +3,20 @@ import uuid import random import time import jieba +import traceback from data_chain.logger.logger import logger as logging -from data_chain.apps.service.llm_service import get_question_chunk_relation +from data_chain.apps.service.llm_service import get_question_chunk_relation from data_chain.models.constant import ChunkRelevance -from data_chain.manager.document_manager import DocumentManager +from data_chain.manager.document_manager import DocumentManager,TemporaryDocumentManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager -from data_chain.manager.chunk_manager import ChunkManager -from data_chain.manager.vector_items_manager import VectorItemsManager +from data_chain.manager.chunk_manager import ChunkManager, TemporaryChunkManager +from data_chain.manager.vector_items_manager import VectorItemsManager, TemporaryVectorItemsManager from data_chain.exceptions.exception import ChunkException from data_chain.stores.postgres.postgres import PostgresDB from data_chain.models.constant import embedding_model_out_dimensions, DocumentEmbeddingConstant from data_chain.apps.service.embedding_service import Vectorize from data_chain.config.config import config +from data_chain.apps.base.convertor.chunk_convertor import ChunkConvertor async def _validate_chunk_belong_to_user(user_id: uuid.UUID, chunk_id: uuid.UUID) -> bool: chunk_entity = await ChunkManager.select_by_chunk_id(chunk_id) @@ -28,8 +30,13 @@ async def list_chunk(params, page_number, page_size): doc_entity = await DocumentManager.select_by_id(params['document_id']) if doc_entity is None or doc_entity.status == DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING: return [], 0 - return await ChunkManager.select_by_page(params, page_number, page_size) - + + chunk_entity_list,total= await ChunkManager.select_by_page(params, page_number, page_size) + chunk_dto_list=[] + for chunk_entity in chunk_entity_list: + chunk_dto = ChunkConvertor.convert_entity_to_dto(chunk_entity) + chunk_dto_list.append(chunk_dto) + return (chunk_dto_list,total) async def switch_chunk(id, enabled): await ChunkManager.update(id, {'enabled': enabled}) @@ -47,112 +54,149 @@ async def switch_chunk(id, enabled): return await VectorItemsManager.update_by_chunk_id(VectorItems, id, {'enabled': enabled}) -async def expand_chunk(document_id,global_offset,expand_method='all',max_tokens=1024): - ex_chunk_tuple_list = await ChunkManager.fetch_surrounding_context(document_id, global_offset,expand_method=expand_method,max_tokens=max_tokens) + +async def expand_chunk(document_id, global_offset, expand_method='all', max_tokens=1024, is_temporary_document=False): + # + # 这里返回的ex_chunk_tuple_list是个n*5二维列表 + # 内部的每个列表内容: [id, document_id, global_offset, tokens, text] + # + if is_temporary_document: + ex_chunk_tuple_list = await TemporaryChunkManager.fetch_surrounding_temporary_context(document_id, global_offset, expand_method=expand_method, max_tokens=max_tokens) + else: + ex_chunk_tuple_list = await ChunkManager.fetch_surrounding_context(document_id, global_offset, expand_method=expand_method, max_tokens=max_tokens) return ex_chunk_tuple_list -async def filter_or_expand_chunk_by_llm(kb_id,content,document_para_dict,maxtokens): - exist_chunk_id_set=set() - new_document_para_dict={} + +async def filter_or_expand_chunk_by_llm(kb_id, content, document_para_dict, maxtokens): + exist_chunk_id_set = set() + new_document_para_dict = {} for document_id in document_para_dict.keys(): for chunk_tuple in document_para_dict[document_id]: - chunk_id=chunk_tuple[0] + chunk_id = chunk_tuple[0] exist_chunk_id_set.add(chunk_id) for document_id in document_para_dict.keys(): - chunk_tuple_list=document_para_dict[document_id] - new_document_para_dict[document_id]=[] - st=0 - en=0 - while st bool: document_entity = await DocumentManager.select_by_id(document_id) if document_entity is None: @@ -95,23 +100,14 @@ async def update_document(tmp_dict) -> DocumentDTO: async def stop_document_parse_task(doc_id: uuid.UUID) -> TaskDTO: try: - await DocumentManager.update(doc_id, { - 'status': DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_PENDING}) - task_entity = await TaskManager.select_by_op_id(doc_id) - if task_entity is None or task_entity.status == TaskConstant.TASK_STATUS_DELETED \ - or task_entity.status == TaskConstant.TASK_STATUS_FAILED \ - or task_entity.status == TaskConstant.TASK_STATUS_CANCELED \ - or task_entity.status == TaskConstant.TASK_STATUS_SUCCESS: + doc_entity = await DocumentManager.select_by_id(doc_id) + if doc_entity.status == DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_PENDING: return + task_entity = await TaskManager.select_by_op_id(doc_id) task_id = task_entity.id - if task_entity.status == TaskConstant.TASK_STATUS_RUNNING: - await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.CANCEL) - TaskRedisHandler.remove_task_by_task_id(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_id)) - TaskRedisHandler.remove_task_by_task_id(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_id)) - await TaskManager.update(task_id, {'status': TaskConstant.TASK_STATUS_CANCELED}) + await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.CANCEL) except Exception as e: - logging.error("Stop knowledge base task={} error: {}".format(task_id, e)) - raise KnowledgeBaseException(f"Stop knowledge base task={task_id} error.") + logging.error("Stop docuemnt parse task={} error: {}".format(task_id, e)) async def init_document_parse_task(doc_id): @@ -145,10 +141,8 @@ async def run_document(update_dict) -> DocumentDTO: updated_document_entity = await DocumentManager.select_by_id(doc_id) document_type_entity = await DocumentTypeManager.select_by_id(updated_document_entity.type_id) return DocumentConvertor.convert_entity_to_dto(updated_document_entity, document_type_entity) - except DocumentException as e: - raise e - except Exception: - logging.error("Embedding document ({}) error: {}".format(update_dict['run'], traceback.format_exc())) + except Exception as e: + logging.error("Embedding document ({}) error: {}".format(update_dict['run'], e)) raise DocumentException(f"Embedding document ({update_dict['run']}) error.") @@ -159,8 +153,8 @@ async def switch_document(document_id, enabled) -> DocumentDTO: updated_document_entity = await DocumentManager.select_by_id(document_id) document_type_entity = await DocumentTypeManager.select_by_id(updated_document_entity.type_id) return DocumentConvertor.convert_entity_to_dto(updated_document_entity, document_type_entity) - except Exception: - logging.error("Switch document status ({}) error: {}".format(enabled, traceback.format_exc())) + except Exception as e: + logging.error("Switch document status ({}) error: {}".format(enabled, e)) raise DocumentException(f"Switch document status ({enabled}) error.") @@ -186,14 +180,70 @@ async def delete_document(ids: List[uuid.UUID]) -> int: 'document_size': total_sz} await KnowledgeBaseManager.update(knowledge_base_entity.id, update_dict) return deleted_cnt - except Exception: - logging.error("Delete document ({}) error: {}".format(ids, traceback.format_exc())) + except Exception as e: + logging.error(f"Delete document ({ids}) error: {e}") raise DocumentException(f"Delete document ({ids}) error.") async def submit_upload_document_task( user_id: uuid.UUID, kb_id: uuid.UUID, files: List[UploadFile] = File(...)) -> bool: - return await DocumentTaskHandler.submit_upload_document_task(user_id, kb_id, files) + target_dir = None + try: + # 创建目标目录 + file_upload_successfully_list = [] + target_dir = os.path.join(OssConstant.UPLOAD_DOCUMENT_SAVE_FOLDER, str(user_id), secrets.token_hex(16)) + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir) + for file in files: + try: + # 1. 将文件写入本地stash目录 + document_file_path = await DocumentTaskHandler.save_document_file_to_local(target_dir, file) + except Exception as e: + logging.error(f"save_document_file_to_local error: {e}") + continue + kb_entity = await KnowledgeBaseManager.select_by_id(kb_id) + # 2. 更新document记录 + file_name = file.filename + if await DocumentManager.select_by_knowledge_base_id_and_file_name(kb_entity.id, file_name): + name = os.path.splitext(file_name)[0] + extension = os.path.splitext(file_name)[1] + file_name = name[:128]+'_'+secrets.token_hex(16)+extension + document_entity = await DocumentManager.insert( + DocumentEntity( + kb_id=kb_id, user_id=user_id, name=file_name, + extension=os.path.splitext(file.filename)[1], + size=os.path.getsize(document_file_path), + parser_method=kb_entity.default_parser_method, + type_id='00000000-0000-0000-0000-000000000000', + chunk_size=kb_entity.default_chunk_size, + enabled=True, + status=DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING) + ) + if not await MinIO.put_object(OssConstant.MINIO_BUCKET_DOCUMENT, str(document_entity.id), document_file_path): + logging.error(f"上传文件到minIO失败,文件名:{file.filename}") + await DocumentManager.delete_by_id(document_entity.id) + continue + # 3. 创建task表记录 + task_entity = await TaskManager.insert(TaskEntity(user_id=user_id, + op_id=document_entity.id, + type=TaskConstant.PARSE_DOCUMENT, + retry=0, + status=TaskConstant.TASK_STATUS_PENDING)) + # 4. 提交redis任务队列 + TaskRedisHandler.put_task_by_tail(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_entity.id)) + file_upload_successfully_list.append(file.filename) + # 5.更新kb的文档数和文档总大小 + total_cnt, total_sz = await DocumentManager.select_cnt_and_sz_by_kb_id(kb_id) + update_dict = {'document_number': total_cnt, + 'document_size': total_sz} + await KnowledgeBaseManager.update(kb_id, update_dict) + except Exception as e: + raise e + finally: + if target_dir is not None and os.path.exists(target_dir): + shutil.rmtree(target_dir) + return file_upload_successfully_list async def generate_document_download_link(id) -> List[Dict]: @@ -204,6 +254,172 @@ async def get_file_name_and_extension(document_id): try: document_entity = await DocumentManager.select_by_id(document_id) return document_entity.name, document_entity.extension.replace('.', '') - except Exception: - logging.error("Get ({}) file name and extension error: {}".format(document_id, traceback.format_exc())) + except Exception as e: + logging.error("Get ({}) file name and extension error: {}".format(document_id, e)) raise DocumentException(f"Get ({document_id}) file name and extension error.") + + +async def init_temporary_document_parse_task( + temporary_document_list: List[Dict]) -> List[uuid.UUID]: + try: + results = [] + ids=[] + for temporary_document in temporary_document_list: + ids.append(temporary_document['id']) + tmp_dict=await get_temporary_document_parse_status(ids) + doc_status_dict={} + for i in range(len(tmp_dict)): + doc_status_dict[tmp_dict[i]['id']]=tmp_dict[i]['status'] + for temporary_document in temporary_document_list: + if temporary_document['id'] in doc_status_dict and \ + ( + doc_status_dict[temporary_document['id']] ==TaskConstant.TASK_STATUS_PENDING or \ + doc_status_dict[temporary_document['id']] ==TaskConstant.TASK_STATUS_RUNNING + ): + continue + temporary_entity=await TemporaryDocumentManager.select_by_id(temporary_document['id']) + if temporary_entity is None: + temporary_entity=await TemporaryDocumentManager.insert( + TemporaryDocumentEntity( + id=temporary_document['id'], + name=temporary_document['name'], + extension=temporary_document['type'], + bucket_name=temporary_document['bucket_name'], + parser_method=temporary_document['parser_method'], + chunk_size=temporary_document['chunk_size'], + status=TemporaryDocumentStatusEnum.EXIST + ) + ) + else: + temporary_document['extension']=temporary_document['type'] + del temporary_document['type'] + temporary_document['status']=TemporaryDocumentStatusEnum.EXIST + temporary_entity=await TemporaryDocumentManager.update( + temporary_document['id'], + temporary_document + ) + if temporary_entity is None: + continue + task_entity = await TaskManager.insert( + TaskEntity( + op_id=temporary_document['id'], + type=TaskConstant.PARSE_TEMPORARY_DOCUMENT, + retry=0, + status=TaskConstant.TASK_STATUS_PENDING + ) + ) + TaskRedisHandler.put_task_by_tail(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_entity.id)) + results.append(temporary_document['id']) + return results + except Exception as e: + raise DocumentException("Init temporary docuemnt parse task={} error: {}".format(temporary_document_list, e)) + + +async def get_related_document( + content: str, + top_k: int, + temporary_document_ids: List[uuid.UUID] = None, + kb_id: uuid.UUID = None + ) -> List[uuid.UUID]: + if top_k==0: + return [] + results = [] + try: + if temporary_document_ids: + chunk_tuple_list = TemporaryChunkManager.find_top_k_similar_chunks( + temporary_document_ids, content, max(top_k // 2, 1)) + elif kb_id: + chunk_tuple_list = await ChunkManager.find_top_k_similar_chunks(kb_id, content, max(top_k//2, 1)) + else: + return [] + for chunk_tuple in chunk_tuple_list: + results.append(chunk_tuple[0]) + except Exception as e: + logging.error(f"Failed to find similar chunks by keywords due to: {e}") + return [] + try: + target_vector = await Vectorize.vectorize_embedding(content) + if target_vector is not None: + chunk_entity_list = [] + if temporary_document_ids: + chunk_id_list = await TemporaryVectorItemsManager.find_top_k_similar_temporary_vectors(target_vector, temporary_document_ids, top_k-len(chunk_tuple_list)) + chunk_entity_list = await TemporaryChunkManager.select_by_temporary_chunk_ids(chunk_id_list) + elif kb_id: + kb_entity = await KnowledgeBaseManager.select_by_id(kb_id) + if kb_entity is None: + return [] + embedding_model = kb_entity.embedding_model + vector_items_id = kb_entity.vector_items_id + dim = embedding_model_out_dimensions[embedding_model] + vector_items_table = await PostgresDB.get_dynamic_vector_items_table(vector_items_id, dim) + chunk_id_list = await VectorItemsManager.find_top_k_similar_vectors(vector_items_table, target_vector, kb_id, top_k-len(chunk_tuple_list)) + chunk_entity_list = await ChunkManager.select_by_chunk_ids(chunk_id_list) + for chunk_entity in chunk_entity_list: + results.append(chunk_entity.id) + except Exception as e: + logging.error(f"Failed to find similar chunks by vecrot due to: {e}") + return results + + +async def stop_temporary_document_parse_task(doc_id): + try: + task_entity = await TaskManager.select_by_op_id(doc_id) + task_id = task_entity.id + await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.CANCEL) + except Exception as e: + logging.error("Stop temporary docuemnt parse task={} error: {}".format(task_id, e)) + + +async def delete_temporary_document(doc_ids) -> List[TemporaryDocumentDTO]: + if len(doc_ids) == 0: + return 0 + try: + # 删除document表的记录 + for doc_id in doc_ids: + await stop_temporary_document_parse_task(doc_id) + tmp_list=await TemporaryDocumentManager.select_by_ids(doc_ids) + doc_ids=[] + for tmp in tmp_list: + doc_ids.append(tmp.id) + await TemporaryDocumentManager.update_all(doc_ids, {"status": TemporaryDocumentStatusEnum.DELETED}) + tmp_list = await TemporaryDocumentManager.select_by_ids(doc_ids) + tmp_set = set() + for tmp in tmp_list: + tmp_set.add(tmp.id) + results = [] + for doc_id in doc_ids: + if doc_id not in tmp_set: + results.append(doc_id) + return results + except Exception as e: + logging.error("Delete temporary document ({}) error: {}".format(doc_ids, e)) + raise DocumentException(f"Delete temporary document ({doc_ids}) error.") + + +async def get_temporary_document_parse_status(doc_ids) -> List[TemporaryDocumentDTO]: + try: + results = [] + temporary_document_list = await TemporaryDocumentManager.select_by_ids(doc_ids) + doc_ids=[] + for temporary_document in temporary_document_list: + doc_ids.append(temporary_document.id) + task_entity_list=await TaskManager.select_latest_task_by_op_ids(doc_ids) + task_entity_dict={} + for task_entity in task_entity_list: + task_entity_dict[task_entity.op_id]=task_entity + for i in range(len(temporary_document_list)): + task_entity = task_entity_dict.get(temporary_document_list[i].id,None) + if task_entity is None: + task_status = TaskConstant.TASK_STATUS_FAILED + else: + task_status = task_entity.status + results.append( + TemporaryDocumentDTO( + id=temporary_document_list[i].id, + status=task_status + ) + ) + return results + except Exception as e: + logging.error(f"Get temporary documents ({doc_ids}) parser status error due to: {e}") + return [] diff --git a/data_chain/apps/service/embedding_service.py b/data_chain/apps/service/embedding_service.py index df9510f..5edd909 100644 --- a/data_chain/apps/service/embedding_service.py +++ b/data_chain/apps/service/embedding_service.py @@ -14,8 +14,6 @@ class Vectorize(): "texts": [text] } try: - logging.info(f'向量化内容为:{data}') - logging.info(f'向量化请求地址为:{config["REMOTE_EMBEDDING_ENDPOINT"]}') res = requests.post(url=config["REMOTE_EMBEDDING_ENDPOINT"], json=data, verify=False) if res.status_code != 200: return None diff --git a/data_chain/apps/service/knwoledge_base_service.py b/data_chain/apps/service/knwoledge_base_service.py index 6fe67ed..5eb603a 100644 --- a/data_chain/apps/service/knwoledge_base_service.py +++ b/data_chain/apps/service/knwoledge_base_service.py @@ -52,9 +52,9 @@ async def list_knowledge_base_task(page_number, page_size, params) -> List[TaskD document_type_entity_list = await DocumentTypeManager.select_by_knowledge_base_id(str(knowledge_base_entity.id)) knowledge_base_dto = KnowledgeConvertor.convert_entity_to_dto( knowledge_base_entity, document_type_entity_list) - latest_task_status_report_entity = await TaskStatusReportManager.select_latest_report_by_task_id(task_entity.id) - if latest_task_status_report_entity is not None: - task_dto = TaskConvertor.convert_entity_to_dto(task_entity, [latest_task_status_report_entity]) + latest_task_status_report_entity_list = await TaskStatusReportManager.select_latest_report_by_task_ids([task_entity.id]) + if len(latest_task_status_report_entity_list)>1: + task_dto = TaskConvertor.convert_entity_to_dto(task_entity, latest_task_status_report_entity_list) else: task_dto = TaskConvertor.convert_entity_to_dto(task_entity, []) knowledge_base_dto.task = task_dto @@ -91,21 +91,8 @@ async def rm_knowledge_base_task(task_id: uuid.UUID) -> TaskDTO: async def stop_knowledge_base_task(task_id: uuid.UUID) -> TaskDTO: try: task_entity = await TaskManager.select_by_id(task_id) - if task_entity is None or task_entity.status == TaskConstant.TASK_STATUS_DELETED \ - or task_entity.status == TaskConstant.TASK_STATUS_FAILED \ - or task_entity.status == TaskConstant.TASK_STATUS_CANCELED \ - or task_entity.status == TaskConstant.TASK_STATUS_SUCCESS: - return task_id = task_entity.id - TaskRedisHandler.remove_task_by_task_id(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_id)) - TaskRedisHandler.remove_task_by_task_id(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_id)) - TaskRedisHandler.remove_task_by_task_id(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_id)) - if task_entity.status == TaskConstant.TASK_STATUS_RUNNING: - await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.CANCEL) - if task_entity.type == TaskConstant.IMPORT_KNOWLEDGE_BASE: - await KnowledgeBaseManager.delete(task_entity.op_id) - else: - await KnowledgeBaseManager.update(task_entity.op_id, {'status': KnowledgeStatusEnum.IDLE}) + await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.CANCEL) except Exception as e: logging.error("Stop knowledge base task={} error={}".format(task_id, e)) raise KnowledgeBaseException(f"Stop knowledge base task={task_id} error.") @@ -129,10 +116,13 @@ async def create_knowledge_base(tmp_dict) -> KnowledgeBaseDTO: raise KnowledgeBaseException("Create knowledge base error.") -async def update_knowledge_base(update_dict) -> KnowledgeBaseDTO: +async def update_knowledge_base(update_dict:dict) -> KnowledgeBaseDTO: kb_id = update_dict['id'] - document_type_list = update_dict['document_type_list'] - del update_dict['document_type_list'] + document_type_list = update_dict.get('document_type_list',None) + if document_type_list is not None: + del update_dict['document_type_list'] + else: + document_type_list=[] knowledge_base_entity = await KnowledgeBaseManager.select_by_user_id_and_kb_name( update_dict['user_id'], update_dict['name']) if knowledge_base_entity and knowledge_base_entity.id != kb_id: @@ -152,8 +142,12 @@ async def list_knowledge_base(params, page_number=1, page_size=1) -> Tuple[List[ knowledge_base_dto_list = [] for knowledge_base_entity in knowledge_base_entity_list: document_type_entity_list = await DocumentTypeManager.select_by_knowledge_base_id(knowledge_base_entity.id) - knowledge_base_dto_list.append(KnowledgeConvertor.convert_entity_to_dto( - knowledge_base_entity, document_type_entity_list)) + knowledge_base_dto_list.append( + KnowledgeConvertor.convert_entity_to_dto( + knowledge_base_entity, + document_type_entity_list + ) + ) return (knowledge_base_dto_list, total) except Exception as e: logging.error("List knowledge base error: {}".format(e)) @@ -181,7 +175,7 @@ async def rm_knowledge_base(kb_id: str) -> bool: await KnowledgeBaseManager.delete(kb_id) return True except Exception as e: - logging.error("Delete knowledge base error: {}".format(e)) + logging.error(f"Delete knowledge base error: {e}") raise KnowledgeBaseException("Delete knowledge base error.") @@ -297,7 +291,7 @@ async def generate_knowledge_base_download_link(task_id) -> str: task_entity = await TaskManager.select_by_id(task_id) if task_entity.status != TaskConstant.TASK_STATUS_SUCCESS: return "" - return await MinIO.generate_download_link(OssConstant.MINIO_BUCKET_EXPORTZIP, str(task_id)) + return await MinIO.generate_download_link(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE, str(task_id)) except Exception as e: logging.error("Export knowledge base zip files error: {}".format(e)) raise KnowledgeBaseException("Export knowledge base zip files error.") diff --git a/data_chain/apps/service/task_service.py b/data_chain/apps/service/task_service.py index 55c1ad4..05c068c 100644 --- a/data_chain/apps/service/task_service.py +++ b/data_chain/apps/service/task_service.py @@ -13,7 +13,6 @@ from data_chain.exceptions.exception import TaskException from data_chain.config.config import config - async def _validate_task_belong_to_user(user_id: uuid.UUID, task_id: str) -> bool: task_entity = await TaskManager.select_by_id(task_id) if task_entity is None: @@ -22,23 +21,46 @@ async def _validate_task_belong_to_user(user_id: uuid.UUID, task_id: str) -> boo raise TaskException("Task not belong to user") -async def redis_task(): - task_start_cnt = 0 +async def task_queue_handler(): + # 处理成功的任务 success_task_ids = TaskRedisHandler.select_all_task(config['REDIS_SUCCESS_TASK_QUEUE_NAME']) for task_id in success_task_ids: TaskRedisHandler.remove_task_by_task_id(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], task_id) TaskHandler.remove_task(uuid.UUID(task_id)) + task_entity = await TaskManager.select_by_id(uuid.UUID(task_id)) + if task_entity is None or task_entity.status != TaskConstant.TASK_STATUS_SUCCESS: + TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], task_id) + # 处理需要重新开始的任务 restart_task_ids = TaskRedisHandler.select_all_task(config['REDIS_RESTART_TASK_QUEUE_NAME']) for task_id in restart_task_ids: TaskRedisHandler.remove_task_by_task_id(config['REDIS_RESTART_TASK_QUEUE_NAME'], task_id) - await TaskHandler.restart_or_clear_task(uuid.UUID(task_id)) + task_entity = await TaskManager.select_by_id(uuid.UUID(task_id)) + if task_entity is None or task_entity.status != TaskConstant.TASK_STATUS_RUNNING: + TaskHandler.remove_task(uuid.UUID(task_id)) + TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], task_id) + else: + await TaskHandler.restart_or_clear_task(uuid.UUID(task_id)) + # 静默处理状态错误的任务 + silent_error_task_ids = TaskRedisHandler.select_all_task(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME']) + for task_id in silent_error_task_ids: + TaskRedisHandler.remove_task_by_task_id(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], task_id) + task_entity = await TaskManager.select_by_id(task_id) + if task_entity is None: + continue + elif task_entity.status == TaskConstant.TASK_STATUS_DELETED: + await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.DELETE) + elif task_entity.status == TaskConstant.TASK_STATUS_CANCELED: + await TaskHandler.restart_or_clear_task(task_id, TaskActionEnum.CANCEL) + # 处理等待的任务 + task_start_cnt = 0 while True: task_id = TaskRedisHandler.get_task_by_head(config['REDIS_PENDING_TASK_QUEUE_NAME']) if task_id is None: break task_id = uuid.UUID(task_id) task_entity = await TaskManager.select_by_id(task_id) - if task_entity is None: + if task_entity is None or task_entity.status != TaskConstant.TASK_STATUS_PENDING: + TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], task_id) continue try: func = None @@ -48,15 +70,18 @@ async def redis_task(): # 处理打包资产库任务 elif task_entity.type == TaskConstant.EXPORT_KNOWLEDGE_BASE: func = KnowledgeBaseTaskHandler.handle_export_knowledge_base_task - # 处理上传文档任务 + # 处理临时文件解析任务 + elif task_entity.type == TaskConstant.PARSE_TEMPORARY_DOCUMENT: + func = DocumentTaskHandler.handle_parser_temporary_document_task + # 处理文档解析任务 elif task_entity.type == TaskConstant.PARSE_DOCUMENT: func = DocumentTaskHandler.handle_parser_document_task - if not TaskHandler.add_task(task_id, target=func, task_entity=task_entity): + if not TaskHandler.add_task(task_id, target=func, t_id=task_entity.id): TaskRedisHandler.put_task_by_tail(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_id)) break await TaskManager.update(task_id, {'status': TaskConstant.TASK_STATUS_RUNNING}) - except Exception: - logging.error("Handle redis task error: {}".format(traceback.format_exc())) + except Exception as e: + logging.error(f"Handle task {task_id} error: {e}") await TaskHandler.restart_or_clear_task(task_id) break task_start_cnt += 1 @@ -68,9 +93,15 @@ async def monitor_tasks(): task_id_list = TaskHandler.list_tasks() for task_id in task_id_list: task_entity = await TaskManager.select_by_id(task_id) + if task_entity is None: + continue current_time_utc = datetime.now(pytz.utc) time_difference = current_time_utc - task_entity.created_time seconds = time_difference.total_seconds() - if seconds > 12*3600 : - if task_entity.status == TaskConstant.TASK_STATUS_RUNNING: + if seconds > 12*3600: + if task_entity.status==TaskConstant.TASK_STATUS_DELETED: + await TaskHandler.restart_or_clear_task(task_id,TaskActionEnum.DELETE) + elif task_entity.status==TaskConstant.TASK_STATUS_CANCELED: + await TaskHandler.restart_or_clear_task(task_id,TaskActionEnum.CANCEL) + else: await TaskHandler.restart_or_clear_task(task_id) diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 2b34aeb..62a88a6 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -35,7 +35,8 @@ LLM_PROMPT_TEMPLATE: "你是由openEuler社区构建的大型语言AI助手。 3.请使用markdown格式输出回答。 4.仅输出回答即可,不要输出其他无关内容。 5.若非必要,请用中文回答。 -6.对于无法使用你认知中以及背景信息进行回答的问题,请回答“您好,换个问题试试,您这个问题难住我了”。 +6.对于背景信息缺失的内容(命令、文件路径、文件名和和后缀之间的分隔符),请补全再输出回答 +7.对于无法使用你认知中以及背景信息进行回答的问题,请回答“您好,换个问题试试,您这个问题难住我了”。 下面是一组背景信息: diff --git a/data_chain/config/config.py b/data_chain/config/config.py index 0f0f883..28d574c 100644 --- a/data_chain/config/config.py +++ b/data_chain/config/config.py @@ -16,7 +16,6 @@ class ConfigModel(BaseModel): LOG_METHOD:str = Field('stdout', description="日志记录方式") # Postgres DATABASE_URL: str = Field(None, description="Postgres数据库链接url") - # MinIO MINIO_ENDPOINT: str = Field(None, description="MinIO连接地址") MINIO_ACCESS_KEY: str = Field(None, description="Minio认证ak") @@ -29,23 +28,20 @@ class ConfigModel(BaseModel): REDIS_PENDING_TASK_QUEUE_NAME: str = Field(default='rag_pending_task_queue', description="redis等待开始任务队列名称") REDIS_SUCCESS_TASK_QUEUE_NAME: str = Field(default='rag_success_task_queue', description="redis已经完成任务队列名称") REDIS_RESTART_TASK_QUEUE_NAME: str = Field(default='rag_restart_task_queue', description="redis等待重启任务队列名称") + REDIS_SILENT_ERROR_TASK_QUEUE_NAME: str = Field(default='rag_silent_error_task_queue', description="redis等待重启任务队列名称") # Task TASK_RETRY_TIME: int = Field(None, description="任务重试次数") # Embedding REMOTE_EMBEDDING_ENDPOINT: str = Field(None, description="远程embedding服务url地址") - # Token SESSION_TTL: int = Field(None, description="用户session过期时间") CSRF_KEY: str = Field(None, description="csrf的密钥") - # Security HALF_KEY1: str = Field(None, description="两层密钥管理组件1") HALF_KEY2: str = Field(None, description="两层密钥管理组件2") HALF_KEY3: str = Field(None, description="两层密钥管理组件3") - # Prompt file PROMPT_PATH: str = Field(None, description="prompt路径") - # Stop Words PATH STOP_WORDS_PATH: str = Field(None, description="停用词表存放位置") # LLM config @@ -55,11 +51,13 @@ class ConfigModel(BaseModel): REQUEST_TIMEOUT: int = Field(None, description="大模型请求超时时间") MAX_TOKENS: int = Field(None, description="单次请求中允许的最大Token数") MODEL_ENH: bool = Field(None, description="是否使用大模型能力增强") - #DEFAULT USER + # DEFAULT USER DEFAULT_USER_ACCOUNT: str = Field(default='admin', description="默认用户账号") DEFAULT_USER_PASSWD: str = Field(default='123456', description="默认用户密码") DEFAULT_USER_NAME: str = Field(default='admin', description="默认用户名称") DEFAULT_USER_LANGUAGE: str = Field(default='zh', description="默认用户语言") + # DOCUMENT PARSER + DOCUMENT_PARSE_USE_CPU_LIMIT:int=Field(default=4,description="文档解析器使用CPU核数") class Config: config: ConfigModel diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index c289e26..65dc980 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -4,20 +4,30 @@ from typing import List, Tuple, Dict, Optional import uuid from data_chain.logger.logger import logger as logging -from data_chain.stores.postgres.postgres import PostgresDB, ChunkEntity, ChunkLinkEntity, DocumentEntity +from data_chain.stores.postgres.postgres import PostgresDB, ChunkEntity, ChunkLinkEntity, DocumentEntity, TemporaryChunkEntity from data_chain.models.service import ChunkDTO from data_chain.exceptions.exception import ChunkException - class ChunkManager(): @staticmethod - async def insert_chunk(chunk_entity: ChunkEntity) -> uuid.UUID: + async def insert_chunk(chunk_entity: ChunkEntity) -> ChunkEntity: async with await PostgresDB.get_session() as session: session.add(chunk_entity) await session.commit() - return chunk_entity.id - + await session.refresh(chunk_entity) + return chunk_entity + @staticmethod + async def insert_chunks(chunk_entity_list: List[ChunkEntity]) -> List[ChunkEntity]: + async with await PostgresDB.get_session() as session: + try: + session.add_all(chunk_entity_list) + await session.commit() + for chunk_entity in chunk_entity_list: + await session.refresh(chunk_entity) + return chunk_entity_list + except Exception as e: + logging.error(f'Error saving chunk entities due to: {e}') @staticmethod async def select_by_chunk_id(chunk_id: uuid.UUID) -> Optional[ChunkEntity]: async with await PostgresDB.get_session() as session: @@ -27,7 +37,7 @@ class ChunkManager(): return chunk_entity @staticmethod - async def select_by_chunk_ids(chunk_ids: List[uuid.UUID]) -> Optional[ChunkEntity]: + async def select_by_chunk_ids(chunk_ids: List[uuid.UUID]) -> Optional[List[ChunkEntity]]: async with await PostgresDB.get_session() as session: stmt = select(ChunkEntity).where(ChunkEntity.id.in_(chunk_ids)) result = await session.execute(stmt) @@ -146,16 +156,8 @@ class ChunkManager(): stmt = stmt.order_by(ChunkEntity.global_offset).offset((page_number - 1) * page_size).limit(page_size) results = await session.execute(stmt) - chunk_list = results.scalars().all() - return ( - [ChunkDTO( - id=str(chunk_entity.id), - text=chunk_entity.text, - enabled=chunk_entity.enabled, - type=chunk_entity.type.split('.')[1] - ) for chunk_entity in chunk_list], - total - ) + chunk_entity_list = results.scalars().all() + return (chunk_entity_list,total) except Exception as e: logging.error(f"Select by page error: {e}") raise ChunkException(f"Select by page ({params}) error.") @@ -175,10 +177,11 @@ class ChunkManager(): except Exception as e: logging.error(f"Update chunk status ({update_dict}) error: {e}") raise ChunkException(f"Update chunk status ({update_dict}) error.") + @staticmethod - async def find_top_k_similar_chunks(kb_id, content,topk=3, banned_ids=[]): + async def find_top_k_similar_chunks(kb_id, content, topk=3, banned_ids=[]): try: - if topk<=0: + if topk <= 0: return [] async with await PostgresDB.get_session() as session: # 构建SQL查询语句 @@ -195,7 +198,7 @@ class ChunkManager(): JOIN document d ON c.document_id = d.id WHERE - c.id NOT IN :banned_ids AND + c.id NOT!=ANY(:banned_ids) AND c.kb_id = :kb_id AND c.enabled = true AND d.enabled = true AND @@ -243,8 +246,186 @@ class ChunkManager(): class ChunkLinkManager(): @staticmethod - async def insert_chunk_link(chunk_link_entity: ChunkLinkEntity) -> uuid.UUID: + async def insert_chunk_link(chunk_link_entity: ChunkLinkEntity) -> ChunkLinkEntity: async with await PostgresDB.get_session() as session: session.add(chunk_link_entity) await session.commit() - return chunk_link_entity.id + await session.refresh(chunk_link_entity) + return chunk_link_entity + + @staticmethod + async def insert_chunk_links(chunk_link_entity_list: List[ChunkLinkEntity]) -> List[ChunkLinkEntity]: + async with await PostgresDB.get_session() as session: + try: + session.add_all(chunk_link_entity_list) + await session.commit() + for chunk_link_entity in chunk_link_entity_list: + await session.refresh(chunk_link_entity) + return chunk_link_entity + + except Exception as e: + logging.error(f'Error saving chunk entities due to: {e}') +class TemporaryChunkManager(): + @staticmethod + async def insert_temprorary_chunk(temprorary_chunk_entity: TemporaryChunkEntity) -> TemporaryChunkEntity: + async with await PostgresDB.get_session() as session: + session.add(temprorary_chunk_entity) + await session.commit() + await session.refresh(temprorary_chunk_entity) + return temprorary_chunk_entity + @staticmethod + async def insert_temprorary_chunks(temprorary_chunk_entity_list: List[TemporaryChunkEntity]) -> List[TemporaryChunkEntity]: + async with await PostgresDB.get_session() as session: + session.add_all(temprorary_chunk_entity_list) + await session.commit() + for temprorary_chunk_entity in temprorary_chunk_entity_list: + await session.refresh(temprorary_chunk_entity) + return temprorary_chunk_entity + @staticmethod + async def delete_by_temporary_document_ids(document_ids: List[str]) -> None: + async with await PostgresDB.get_session() as session: + stmt = await session.execute( + select(TemporaryChunkEntity).where(ChunkEntity.document_id.in_(document_ids)) + ) + entities = stmt.scalars().all() + for entity in entities: + await session.delete(entity) + await session.commit() + + @staticmethod + async def select_by_temporary_chunk_ids(temporary_chunk_ids: List[uuid.UUID]) -> Optional[List[TemporaryChunkEntity]]: + async with await PostgresDB.get_session() as session: + stmt = select(TemporaryChunkEntity).where(TemporaryChunkEntity.id.in_(temporary_chunk_ids)) + result = await session.execute(stmt) + temporary_chunk_entity_list = result.scalars().all() + return temporary_chunk_entity_list + + @staticmethod + async def find_top_k_similar_chunks(document_ids: List[uuid.UUID], content: str, topk=3): + try: + if topk <= 0: + return [] + async with await PostgresDB.get_session() as session: + # 构建SQL查询语句 + if document_ids: + query = text(""" + SELECT + c.id, + c.document_id, + c.global_offset, + c.tokens, + c.text + FROM + temporary_chunk c + JOIN + temporary_document d ON c.document_id = d.id + WHERE + c.document_id=ANY(:document_ids) AND + d.status!='deleted' AND + to_tsvector(:language, c.text) @@ plainto_tsquery(:language, :content) + ORDER BY + ts_rank_cd(to_tsvector(:language, c.text), plainto_tsquery(:language, :content)) DESC + LIMIT :topk; + """) + else: + return [] + # 安全地绑定参数 + params = { + 'document_ids': document_ids, + 'language': 'zhparser', + 'content': content, + 'topk': topk, + } + result = await session.execute(query, params) + return result.all() + except Exception as e: + logging.error(f"Find top k similar temporary chunks failed due to: {e}") + return [] + + @staticmethod + async def fetch_surrounding_temporary_context( + document_id: uuid.UUID, global_offset: int, expand_method='all', max_tokens: int = 1024, max_rd_cnt: int = 50): + try: + if max_tokens <= 0: + return [] + results = [] + async with await PostgresDB.get_session() as session: + tokens = 0 + para_cnt = 0 + global_offset_set = set([global_offset]) + result = await session.execute( + select(func.min(TemporaryChunkEntity.global_offset), func.max(TemporaryChunkEntity.global_offset)). + where( + TemporaryChunkEntity.document_id == document_id + ) + ) + min_global_offset, max_global_offset = result.one() + if expand_method == 'nex': + min_global_offset = global_offset + if expand_method == 'pre': + max_global_offset = global_offset + tokens_sub = 0 + mv_flag = None + rd_it = 0 + temporary_chunk_entity_list = ( + await session.execute( + select(TemporaryChunkEntity). + where( + and_( + TemporaryChunkEntity.document_id == document_id, + TemporaryChunkEntity.global_offset>=global_offset-max_rd_cnt, + TemporaryChunkEntity.global_offset<=global_offset+max_rd_cnt + ) + ) + ) + ).scalars().all() + global_offset_set_dict={} + for temporary_chunk_entity in temporary_chunk_entity_list: + global_offset_set_dict[temporary_chunk_entity.global_offset]=( + temporary_chunk_entity.id, + temporary_chunk_entity.document_id, + temporary_chunk_entity.global_offset, + temporary_chunk_entity.tokens, + temporary_chunk_entity.text) + while tokens < max_tokens and (min(global_offset_set) > min_global_offset or max(global_offset_set) < + max_global_offset) and rd_it < max_rd_cnt: + result = None + new_global_offset = None + if tokens_sub <= 0 and min(global_offset_set) > min_global_offset: + mv_flag = True + new_global_offset = min(global_offset_set)-1 + elif tokens_sub > 0 and max(global_offset_set) < max_global_offset: + mv_flag = False + new_global_offset = max(global_offset_set)+1 + elif rd_it % 2 == 0: + if min(global_offset_set) > min_global_offset: + mv_flag = True + new_global_offset = min(global_offset_set)-1 + elif max(global_offset_set) < max_global_offset: + mv_flag = False + new_global_offset = max(global_offset_set)+1 + + elif rd_it % 2 != 0: + if max(global_offset_set) < max_global_offset: + mv_flag = False + new_global_offset = max(global_offset_set)+1 + elif min(global_offset_set) > min_global_offset: + mv_flag = True + new_global_offset = min(global_offset_set)-1 + else: + break + result = global_offset_set_dict.get(new_global_offset,None) + global_offset_set.add(new_global_offset) + if result: + tokens += result[3] + para_cnt += 1 + results.append(result) + if mv_flag: + tokens_sub += result[3] + else: + tokens_sub -= result[3] + rd_it += 1 + return results + except Exception as e: + logging.error(f"Fetch surrounding temporary context failed due to: {e}") + return [] diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 0b31085..7c4636b 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -5,8 +5,8 @@ import uuid from data_chain.logger.logger import logger as logging from typing import Dict, List, Tuple -from data_chain.stores.postgres.postgres import PostgresDB, DocumentEntity, TaskEntity -from data_chain.models.constant import DocumentEmbeddingConstant +from data_chain.stores.postgres.postgres import PostgresDB, DocumentEntity, TaskEntity,TemporaryDocumentEntity +from data_chain.models.constant import DocumentEmbeddingConstant,TemporaryDocumentStatusEnum @@ -19,7 +19,7 @@ class DocumentManager(): async with await PostgresDB.get_session() as session: session.add(entity) await session.commit() - await session.refresh(entity) # 刷新实体以确保获取到所有数据(如自增ID) + await session.refresh(entity) return entity except Exception as e: logging.error(f"Failed to insert entity: {e}") @@ -77,7 +77,8 @@ class DocumentManager(): try: async with await PostgresDB.get_session() as session: stmt = select(DocumentEntity).where( - (DocumentEntity.kb_id == kb_id) & (DocumentEntity.name == file_name)) + and_(DocumentEntity.kb_id == kb_id, + DocumentEntity.name == file_name)) result = await session.execute(stmt) return result.scalars().first() except Exception as e: @@ -216,3 +217,81 @@ class DocumentManager(): result = await session.execute(stmt) await session.commit() return result.rowcount + +class TemporaryDocumentManager(): + @staticmethod + async def insert(entity: TemporaryDocumentEntity) -> TemporaryDocumentEntity: + try: + async with await PostgresDB.get_session() as session: + session.add(entity) + await session.commit() + await session.refresh(entity) + return entity + except Exception as e: + logging.error(f"Failed to insert temporary document: {e}") + return None + @staticmethod + async def delete_by_ids(ids: List[uuid.UUID]) -> int: + try: + async with await PostgresDB.get_session() as session: + stmt = delete(TemporaryDocumentEntity).where(TemporaryDocumentEntity.id.in_(ids)) + result = await session.execute(stmt) + await session.commit() + return result.rowcount + except Exception as e: + logging.error(f"Failed to delete temporary documents: {e}") + return 0 + @staticmethod + async def select_by_ids(ids: List[uuid.UUID]) -> List[TemporaryDocumentEntity]: + try: + async with await PostgresDB.get_session() as session: + stmt = select(TemporaryDocumentEntity).where( + and_( + TemporaryDocumentEntity.id.in_(ids), + TemporaryDocumentEntity.status !=TemporaryDocumentStatusEnum.DELETED + ) + ) + result = await session.execute(stmt) + tmp_list = result.scalars().all() + return tmp_list + except Exception as e: + logging.error(f"Failed to select temporary documents by page: {e}") + return (0, []) + @staticmethod + async def select_by_id(id: uuid.UUID) -> TemporaryDocumentEntity: + try: + async with await PostgresDB.get_session() as session: + stmt = select(TemporaryDocumentEntity).where(TemporaryDocumentEntity.id==id) + result = await session.execute(stmt) + return result.scalars().first() + except Exception as e: + logging.error(f"Failed to select temporary document by id: {e}") + return None + @staticmethod + async def update(id: uuid.UUID, update_dict: dict): + try: + async with await PostgresDB.get_session() as session: + await session.execute( + update(TemporaryDocumentEntity).where(TemporaryDocumentEntity.id == id).values(**update_dict) + ) + await session.commit() + entity=await session.execute(select(TemporaryDocumentEntity).where( + TemporaryDocumentEntity.id==id, + ) + ) + return entity + except Exception as e: + logging.error(f"Failed to update temporary document by id: {e}") + return None + @staticmethod + async def update_all(ids: List[uuid.UUID], update_dict: dict): + try: + async with await PostgresDB.get_session() as session: + await session.execute( + update(TemporaryDocumentEntity).where(TemporaryDocumentEntity.id.in_(ids)).values(**update_dict) + ) + await session.commit() + return True + except Exception as e: + logging.error(f"Failed to update temporary document by ids: {e}") + return False \ No newline at end of file diff --git a/data_chain/manager/document_type_manager.py b/data_chain/manager/document_type_manager.py index 7bc5886..03e988a 100644 --- a/data_chain/manager/document_type_manager.py +++ b/data_chain/manager/document_type_manager.py @@ -33,6 +33,8 @@ class DocumentTypeManager(): @staticmethod async def insert_bulk(kb_id: uuid.UUID, types: List[str]) -> List[DocumentTypeEntity]: + if types is None or len(types) == 0: + return [] async with await PostgresDB.get_session()as session: document_type_entity_list = [ DocumentTypeEntity(kb_id=kb_id, type=type) for type in types diff --git a/data_chain/manager/image_manager.py b/data_chain/manager/image_manager.py index 38f5133..ed091aa 100644 --- a/data_chain/manager/image_manager.py +++ b/data_chain/manager/image_manager.py @@ -1,5 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from sqlalchemy import select +from typing import List import uuid from data_chain.logger.logger import logger as logging @@ -11,17 +12,28 @@ from data_chain.stores.postgres.postgres import PostgresDB, ImageEntity class ImageManager: @staticmethod - async def add_image(image_slice: ImageEntity): + async def add_image(image_entity: ImageEntity): try: async with await PostgresDB.get_session() as session: - session.add(image_slice) + session.add(image_entity) await session.commit() - await session.refresh(image_slice) + await session.refresh(image_entity) except Exception as e: logging.error(f"Add image failed due to error: {e}") return None - return image_slice - + return image_entity + @staticmethod + async def add_images(image_entity_list: List[ImageEntity]): + try: + async with await PostgresDB.get_session() as session: + session.add_all(image_entity_list) + await session.commit() + for image_entity in image_entity_list: + await session.refresh(image_entity) + except Exception as e: + logging.error(f"Add image failed due to error: {e}") + return None + return image_entity_list @staticmethod async def del_image_by_image_id(image_id: uuid.UUID): try: diff --git a/data_chain/manager/task_manager.py b/data_chain/manager/task_manager.py index 8b3a41e..80060a5 100644 --- a/data_chain/manager/task_manager.py +++ b/data_chain/manager/task_manager.py @@ -160,7 +160,7 @@ class TaskManager(): result = await session.execute(stmt) return result.scalars().all() - + @staticmethod async def delete_by_op_id(op_id: uuid.UUID) -> TaskEntity: async with await PostgresDB.get_session() as session: @@ -265,16 +265,6 @@ class TaskStatusReportManager(): return result.scalars().all() @staticmethod - async def select_latest_report_by_task_id(task_id: uuid.UUID): - async with await PostgresDB.get_session() as session: - stmt = ( - select(TaskStatusReportEntity) - .where(TaskStatusReportEntity.task_id == task_id) - .order_by(desc(TaskStatusReportEntity.created_time)) - ).limit(1) - result = await session.execute(stmt) - return result.scalars().first() - @staticmethod async def select_latest_report_by_task_ids(task_ids: List[uuid.UUID]) -> List[TaskStatusReportEntity]: async with await PostgresDB.get_session() as session: # 创建一个别名用于子查询 @@ -292,7 +282,6 @@ class TaskStatusReportManager(): .where(report_alias.task_id.in_(task_ids)) .subquery() ) - # 主查询选择row_num为1的记录,即每个op_id的最新任务 stmt = ( select(TaskStatusReportEntity) @@ -301,4 +290,4 @@ class TaskStatusReportManager(): ) result = await session.execute(stmt) - return result.scalars().all() \ No newline at end of file + return result.scalars().all() diff --git a/data_chain/manager/vector_items_manager.py b/data_chain/manager/vector_items_manager.py index 368d976..22f4a4c 100644 --- a/data_chain/manager/vector_items_manager.py +++ b/data_chain/manager/vector_items_manager.py @@ -1,23 +1,25 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from data_chain.logger.logger import logger as logging +import uuid +from typing import List from sqlalchemy import insert, delete, update, text import traceback -from data_chain.stores.postgres.postgres import PostgresDB +from data_chain.logger.logger import logger as logging +from data_chain.stores.postgres.postgres import PostgresDB,TemporaryVectorItemstEntity class VectorItemsManager: @staticmethod - async def add(VectorItems, params): + async def add(VectorItems, vector): # 构建插入语句 insert_stmt = insert(VectorItems).values( - user_id=params['user_id'], - chunk_id=params['chunk_id'], - kb_id=params['kb_id'], - document_id=params['doc_id'], - vector=params['vector'], - enabled=params['enabled'] + user_id=vector['user_id'], + chunk_id=vector['chunk_id'], + kb_id=vector['kb_id'], + document_id=vector['doc_id'], + vector=vector['vector'], + enabled=vector['enabled'] ).returning(VectorItems.c.id) # 获取会话 @@ -26,7 +28,31 @@ class VectorItemsManager: inserted_id = result.scalar() await session.commit() return inserted_id + @staticmethod + async def add_all(VectorItems, vector_list): + # 构建插入语句 + insert_stmt = ( + insert(VectorItems) + .values([ + { + "user_id": vector['user_id'], + "chunk_id": vector['chunk_id'], + "kb_id": vector['kb_id'], + "document_id": vector['doc_id'], + "vector": vector['vector'], + "enabled": vector['enabled'] + } + for vector in vector_list + ]) + .returning(VectorItems.c.id) # 假设VectorItems有id字段 + ) + # 获取会话 + async with await PostgresDB.get_session() as session: + result = await session.execute(insert_stmt) + inserted_ids = result.scalars().all() + await session.commit() + return inserted_ids @staticmethod async def del_by_id(VectorItems, id): try: @@ -88,7 +114,7 @@ class VectorItemsManager: f"SELECT v.chunk_id " f"FROM \"{VectorItems.name}\" AS v " f"INNER JOIN document ON v.document_id = document.id " - f"WHERE v.kb_id = :kb_id AND v.chunk_id not in :banned_ids AND v.enabled = true AND document.enabled = true " + f"WHERE v.kb_id = :kb_id AND v.chunk_id!=ANY(:banned_ids) AND v.enabled = true AND document.enabled = true " f"ORDER BY v.vector <=> :target_vector " f"LIMIT :topk") else: @@ -117,3 +143,62 @@ class VectorItemsManager: logging.error(f"Query for similar vectors failed due to error: {e}") logging.error(f"Error details: {traceback.format_exc()}") return [] +class TemporaryVectorItemsManager: + @staticmethod + async def add(temporary_vector_items_entity:TemporaryVectorItemstEntity)->TemporaryVectorItemstEntity: + try: + async with await PostgresDB.get_session() as session: + session.add(temporary_vector_items_entity) + await session.commit() + await session.refresh(temporary_vector_items_entity) + return temporary_vector_items_entity + except Exception as e: + logging.error(f"Add temporary vector items failed due to error: {e}") + return None + @staticmethod + async def add_all(temporary_vector_items_entity_list:List[TemporaryVectorItemstEntity])->List[TemporaryVectorItemstEntity]: + try: + async with await PostgresDB.get_session() as session: + session.add_all(temporary_vector_items_entity_list) + await session.commit() + for temporary_vector_items_entity in temporary_vector_items_entity_list: + await session.refresh(temporary_vector_items_entity) + return temporary_vector_items_entity + except Exception as e: + logging.error(f"Add temporary vector items failed due to error: {e}") + return None + @staticmethod + async def find_top_k_similar_temporary_vectors(target_vector, document_ids:List[uuid.UUID],topk=3)->List[uuid.UUID]: + try: + if topk<=0: + return [] + # 构造查询 + if document_ids: + query_sql = ( + f"SELECT v.chunk_id " + f"FROM temporary_vector_items AS v " + f"INNER JOIN temporary_document ON v.document_id = temporary_document.id " + f"WHERE v.document_id= ANY(:document_ids) AND temporary_document.status!='deleted'" + f"ORDER BY v.vector <=> :target_vector " + f"LIMIT :topk" + ) + else: + return [] + # 获取会话并执行查询 + async with await PostgresDB.get_session() as session: + # 使用execute执行原始SQL语句,并传递参数 + result = await session.execute( + text(query_sql), + { + "document_ids": document_ids, + "target_vector": str(target_vector), + "topk": topk + } + ) + result = result.scalars().all() + print(result) + return result + except Exception as e: + logging.error(f"Query for similar temporary vectors failed due to error: {e}") + logging.error(f"Error details: {traceback.format_exc()}") + return [] \ No newline at end of file diff --git a/data_chain/models/api.py b/data_chain/models/api.py index 2b82ffe..8ad9550 100644 --- a/data_chain/models/api.py +++ b/data_chain/models/api.py @@ -1,12 +1,12 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. # TODO: 给其中某些属性加上参数约束, 例如page或者count之类的 -from enum import Enum +import re import uuid from typing import Dict, Generic, List, Optional, TypeVar from data_chain.models.service import DocumentTypeDTO -from pydantic import BaseModel, Field,validator +from pydantic import BaseModel, Field, validator,constr T = TypeVar('T') @@ -20,7 +20,7 @@ class DictionaryBaseModel(BaseModel): class BaseResponse(BaseModel, Generic[T]): - retcode: int = 0 + retcode: int = 200 retmsg: str = "ok" data: Optional[T] @@ -31,25 +31,23 @@ class Page(DictionaryBaseModel, Generic[T]): total: int data_list: Optional[List[T]] - class CreateKnowledgeBaseRequest(DictionaryBaseModel): - name: str - language: str - description: Optional[str] - embedding_model: str + name: str=Field(...,min_length=1, max_length=150) + language: str=Field(...,pattern=r"^(zh|en)$") + description: Optional[str]=Field(None, max_length=150) + embedding_model: str=Field(...,pattern=r"^(bge_large_zh|bge_large_en)$") default_parser_method: str - default_chunk_size: int= Field(..., gt=127, lt=1025) + default_chunk_size: int = Field(1024, ge=128, le=1024) document_type_list: Optional[List[str]] - class UpdateKnowledgeBaseRequest(DictionaryBaseModel): id: uuid.UUID - name: str - language: str - description: str - embedding_model: str - default_parser_method: str - default_chunk_size: int= Field(..., gt=127, lt=1025) + name: Optional[str]=Field(None,min_length=1, max_length=150) + language: Optional[str]=Field(None,pattern=r"^(zh|en)$") + description: Optional[str] + embedding_model: Optional[str]=Field(None,pattern=r"^(bge_large_zh|bge_large_en)$") + default_parser_method: Optional[str]=None + default_chunk_size: Optional[int] = Field(None, ge=128, le=1024) document_type_list: Optional[List[DocumentTypeDTO]] = None @@ -94,7 +92,6 @@ class ListTaskRequest(DictionaryBaseModel): page_size: int = 10 created_time_order: Optional[str] = 'desc' # 取值desc降序, asc升序 - class ListDocumentRequest(DictionaryBaseModel): kb_id: Optional[uuid.UUID] = None id: Optional[uuid.UUID] = None @@ -108,24 +105,26 @@ class ListDocumentRequest(DictionaryBaseModel): parser_method: Optional[List[str]] = None page_number: int = 1 page_size: int = 10 - + @validator('status', each_item=True) + def check_types(cls, v): + # 定义允许的类型正则表达式 + allowed_type_pattern = r"^(pending|success|failed|running|canceled)$" + if not re.match(allowed_type_pattern, v): + raise ValueError(f'Invalid type value "{v}". Must match pattern {allowed_type_pattern}.') + return v class UpdateDocumentRequest(DictionaryBaseModel): id: uuid.UUID - name: Optional[str] = None - parser_method: Optional[str] = None # TODO:可以编辑parser_method,和tantan对齐 + name: Optional[str] = Field(None,min_length=1, max_length=128) + parser_method: Optional[str] = Field(None,pattern=r"^(general|ocr|enhanced)$") type_id: Optional[uuid.UUID] = None - chunk_size: Optional[int] = Field(None, gt=127, lt=1025) - + chunk_size: Optional[int] = Field(None, gt=127, lt=1025) -class Action(str, Enum): - RUN = "run" - CANCEL = "cancel" -class RunDocumentEmbeddingRequest(DictionaryBaseModel): +class RunDocumentRequest(DictionaryBaseModel): ids: List[uuid.UUID] - run: Action # run运行或者cancel取消 + run: str=Field(...,pattern=r"^(run|cancel)$")# run运行或者cancel取消 class SwitchDocumentRequest(DictionaryBaseModel): @@ -140,16 +139,25 @@ class DeleteDocumentRequest(DictionaryBaseModel): class DownloadDocumentRequest(DictionaryBaseModel): ids: List[uuid.UUID] +class GetTemporaryDocumentStatusRequest(DictionaryBaseModel): + ids: List[uuid.UUID] -class CreateChunkRequest(DictionaryBaseModel): +class TemporaryDocumentInParserRequest(DictionaryBaseModel): id: uuid.UUID - text: str - type: str # 标明是段落text还是说图谱的节点亦或者其他数据类型 - enabled: bool - - -ALLOWED_TYPES = ["para", "table", "image"] - + name:str=Field(...,min_length=1, max_length=128) + type:str=Field(...,min_length=1, max_length=128) + bucket_name:str=Field(...,min_length=1, max_length=128) + parser_method:str=Field("ocr",pattern=r"^(general|ocr)$") + chunk_size:int=Field(1024,ge=128,le=1024) +class ParserTemporaryDocumenRequest(DictionaryBaseModel): + document_list:List[TemporaryDocumentInParserRequest] +class DeleteTemporaryDocumentRequest(DictionaryBaseModel): + ids: List[uuid.UUID] +class RelatedTemporaryDocumenRequest(DictionaryBaseModel): + content:str + top_k:int=Field(5, ge=0, le=10) + kb_sn: Optional[uuid.UUID] = None + document_ids: Optional[List[uuid.UUID]] = None class ListChunkRequest(DictionaryBaseModel): document_id: uuid.UUID @@ -157,27 +165,26 @@ class ListChunkRequest(DictionaryBaseModel): page_number: int = 1 types: Optional[List[str]] = None page_size: int = 10 - - # 定义一个验证器来检查 'types' 是否只包含允许的类型 - @validator('types', pre=True, always=False) - def validate_types(cls, v): - if not set(v).issubset(set(ALLOWED_TYPES)): - raise ValueError(f'Invalid type(s) found. Allowed types are {ALLOWED_TYPES}') + # 定义一个验证器来确保types中的每个元素都符合正则表达式 + @validator('types', each_item=True) + def check_types(cls, v): + # 定义允许的类型正则表达式 + allowed_type_pattern = r"^(para|table|image)$" # 替换为你需要的正则表达式 + if not re.match(allowed_type_pattern, v): + raise ValueError(f'Invalid type value "{v}". Must match pattern {allowed_type_pattern}.') return v - - class SwitchChunkRequest(DictionaryBaseModel): ids: List[uuid.UUID] # 支持批量操作 enabled: bool # True启用, False未启用 -class UserAddRequest(DictionaryBaseModel): +class AddUserRequest(DictionaryBaseModel): name: str account: str passwd: str -class UserUpdateRequest(DictionaryBaseModel): +class UpdateUserRequest(DictionaryBaseModel): name: Optional[str] = None account: Optional[str] = None passwd: Optional[str] = None @@ -185,23 +192,17 @@ class UserUpdateRequest(DictionaryBaseModel): status: Optional[str] = None language: Optional[str] = None - -class CreateModelRequest(DictionaryBaseModel): - model_name: str - openai_api_base: str - openai_api_key: str - max_tokens: int - - class UpdateModelRequest(DictionaryBaseModel): - model_name: str - openai_api_base: str - openai_api_key: str - max_tokens: int + model_name: str=Field(...,min_length=1, max_length=128) + openai_api_base: str=Field(...,min_length=1, max_length=128) + openai_api_key: str=Field(...,min_length=1, max_length=128) + max_tokens: int=Field(1024, ge=1024, le=8192) + class QueryRequest(BaseModel): question: str - kb_sn: Optional[str]=None + kb_sn: Optional[uuid.UUID] = None + document_ids : Optional[List[uuid.UUID]] = None top_k: int = Field(5, ge=0, le=10) fetch_source: bool = False - history: Optional[List] = [] \ No newline at end of file + history: Optional[List] = [] diff --git a/data_chain/models/constant.py b/data_chain/models/constant.py index 6ab4df7..80dc6d6 100644 --- a/data_chain/models/constant.py +++ b/data_chain/models/constant.py @@ -1,7 +1,11 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - -class OssConstant(): +class BaseConstant(): + @classmethod + def get_all_values(cls): + return [value for name, value in cls.__dict__.items() + if not name.startswith("__") and isinstance(value, str)] +class OssConstant(BaseConstant): IMPORT_FILE_SAVE_FOLDER = "./stash" EXPORT_FILE_SAVE_FOLDER = "./export" UPLOAD_DOCUMENT_SAVE_FOLDER = "./document" @@ -10,18 +14,24 @@ class OssConstant(): MINIO_BUCKET_DOCUMENT = "document" MINIO_BUCKET_KNOWLEDGEBASE = "knowledgebase" - MINIO_BUCKET_EXPORTZIP = "exportzip" MINIO_BUCKET_PICTURE = "picture" -class DocumentEmbeddingConstant(): +class DocumentEmbeddingConstant(BaseConstant): DOCUMENT_EMBEDDING_RUN = 'run' DOCUMENT_EMBEDDING_CANCEL = 'cancel' DOCUMENT_EMBEDDING_STATUS_PENDING = "pending" DOCUMENT_EMBEDDING_STATUS_RUNNING = "running" +class DocumentStatusEnum(BaseConstant): + PENDIND='pending' + RUNNING='running' + DELETED='deleted' -class TaskConstant(): +class TemporaryDocumentStatusEnum(BaseConstant): + EXIST='exist' + DELETED='deleted' +class TaskConstant(BaseConstant): TASK_REDIS_QUEUE_KEY = "TASK_QUEUE" TASK_STATUS_PENDING = "pending" @@ -34,54 +44,33 @@ class TaskConstant(): IMPORT_KNOWLEDGE_BASE = "import_knowledge_base" EXPORT_KNOWLEDGE_BASE = "export_knowledge_base" PARSE_DOCUMENT = "parse_document" + PARSE_TEMPORARY_DOCUMENT = "parse_temporary_document" - -class KnowledgeStatusEnum(): +class KnowledgeStatusEnum(BaseConstant): IMPORTING = "importing" EXPROTING = "exporting" IDLE = 'idle' DELETE = 'delete' - -class TaskActionEnum(): +class TaskActionEnum(BaseConstant): CANCEL = "cancel" RESTART = "restart" DELETE = "delete" - -class KnowledgeLanguageEnum(): +class KnowledgeLanguageEnum(BaseConstant): ZH = "简体中文" EN = "English" - @classmethod - def get_all_values(cls): - return [value for name, value in cls.__dict__.items() - if not name.startswith("__") and isinstance(value, str)] -class EmbeddingModelEnum(): +class EmbeddingModelEnum(BaseConstant): BGE_LARGE_ZH = "bge_large_zh" BGE_LARGE_EN = "bge_large_en" - @classmethod - def get_all_values(cls): - return [value for name, value in cls.__dict__.items() - if not name.startswith("__") and isinstance(value, str)] - -embedding_model_out_dimensions = { - 'bge_large_zh': 1024, - 'bge_large_en': 1024 -} - - -class ParseMethodEnum(): +class ParseMethodEnum(BaseConstant): GENERAL = "general" OCR = "ocr" ENHANCED = "enhanced" - @classmethod - def get_all_values(cls): - return [value for name, value in cls.__dict__.items() - if not name.startswith("__") and isinstance(value, str)] -class ParseExtensionEnum(): +class ParseExtensionEnum(BaseConstant): PDF = ".pdf" DOCX = ".docx" DOC = ".doc" @@ -89,11 +78,7 @@ class ParseExtensionEnum(): XLSX = ".xlsx" HTML = ".html" MD = ".md" - @classmethod - def get_all_values(cls): - return [value for name, value in cls.__dict__.items() - if not name.startswith("__") and isinstance(value, str)] -class ChunkRelevance(): +class ChunkRelevance(BaseConstant): IRRELEVANT = 1 WEAKLY_RELEVANT = 2 RELEVANT_BUT_LACKS_PREVIOUS_CONTEXT = 3 @@ -101,4 +86,8 @@ class ChunkRelevance(): RELEVANT_BUT_LACKS_BOTH_CONTEXTS = 5 RELEVANT_AND_COMPLETE = 6 -default_document_type_id = '00000000-0000-0000-0000-000000000000' \ No newline at end of file +default_document_type_id = '00000000-0000-0000-0000-000000000000' +embedding_model_out_dimensions = { + 'bge_large_zh': 1024, + 'bge_large_en': 1024 +} \ No newline at end of file diff --git a/data_chain/models/service.py b/data_chain/models/service.py index 3389e5c..8cf437b 100644 --- a/data_chain/models/service.py +++ b/data_chain/models/service.py @@ -66,6 +66,11 @@ class DocumentDTO(DictionaryBaseModelDTO): parser_method: str +class TemporaryDocumentDTO(DictionaryBaseModelDTO): + id: uuid.UUID + status: str + + class ChunkDTO(DictionaryBaseModelDTO): id: str text: str @@ -77,4 +82,4 @@ class ModelDTO(DictionaryBaseModelDTO): id: Optional[str] = None model_name: Optional[str] = None openai_api_base: Optional[str] = None - max_tokens : Optional[int] =None + max_tokens: Optional[int] = None diff --git a/data_chain/parser/handler/base_parser.py b/data_chain/parser/handler/base_parser.py index b8da432..6de4b39 100644 --- a/data_chain/parser/handler/base_parser.py +++ b/data_chain/parser/handler/base_parser.py @@ -31,14 +31,16 @@ class BaseService: async def init_service(self, llm_entity, tokens, parser_method): self.parser_method = parser_method - if llm_entity is None: - self.llm = None - self.llm_max_tokens = None - else: - self.llm = LLM(model_name=llm_entity.model_name, openai_api_base=llm_entity.openai_api_base, - openai_api_key=Security.decrypt(llm_entity.encrypted_openai_api_key, - json.loads(llm_entity.encrypted_config)), - max_tokens=llm_entity.max_tokens, ) + if llm_entity is not None: + self.llm = LLM( + model_name=llm_entity.model_name, + openai_api_base=llm_entity.openai_api_base, + openai_api_key=Security.decrypt( + llm_entity.encrypted_openai_api_key, + json.loads(llm_entity.encrypted_config) + ), + max_tokens=llm_entity.max_tokens, + ) self.llm_max_tokens = llm_entity.max_tokens self.tokens = tokens self.vectorizer = TfidfVectorizer() @@ -66,8 +68,7 @@ class BaseService: if cosine_sim > 0.85: return True except Exception as e: - logging.error(f"Check similarity error due to: {e}") - return False + logging.error(f'Check_similarity error due to: {e}') return False def merge_texts(self, texts): diff --git a/data_chain/parser/handler/doc_parser.py b/data_chain/parser/handler/doc_parser.py index cabc3b4..fdb1de6 100644 --- a/data_chain/parser/handler/doc_parser.py +++ b/data_chain/parser/handler/doc_parser.py @@ -27,7 +27,7 @@ class DocService(BaseService): except Exception as e: logging.error(f"Error opening file {file_path} :{e}") raise e - content=js['content'] + content=js.get('content','') paragraphs = content.split('\n') sentences = [] for paragraph in paragraphs: diff --git a/data_chain/parser/handler/docx_parser.py b/data_chain/parser/handler/docx_parser.py index fe7a18f..efff7ed 100644 --- a/data_chain/parser/handler/docx_parser.py +++ b/data_chain/parser/handler/docx_parser.py @@ -78,12 +78,8 @@ class DocxService(BaseService): text_part = '' # 处理图片 for image_part in image_parts: - try: - image_blob = image_part.image.blob - content_type = image_part.content_type - except Exception as e: - logging.error(f"Error getting image blob and content type due to: {e}") - continue + image_blob = image_part.image.blob + content_type = image_part.content_type extension = mimetypes.guess_extension(content_type).replace('.', '') lines.append(([(Image.open(BytesIO(image_blob)), extension)], 'image')) else: @@ -104,8 +100,12 @@ class DocxService(BaseService): img_id = child.xpath('.//a:blip/@r:embed')[0] part = parent.part.related_parts[img_id] if isinstance(part, ImagePart): - image_blob = part.image.blob - content_type = part.content_type + try: + image_blob = part.image.blob + content_type = part.content_type + except: + logging.error(f'Error get image blob and content type due to: {img_id}') + continue extension = mimetypes.guess_extension(content_type).replace('.', '') lines.append(([(Image.open(BytesIO(image_blob)), extension)], 'image')) new_lines = [] diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index 9bc79eb..a08f059 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -94,7 +94,7 @@ class PdfService(BaseService): try: img_np = np.array(image) except Exception as e: - logging.error(f"Error converting image to numpy array due to {e}") + logging.error(f"Error converting image to numpy array: {e}") continue ocr_results = await self.image_model.run(img_np, text=near) diff --git a/data_chain/parser/service/parser_service.py b/data_chain/parser/service/parser_service.py index 091fc26..24c88bc 100644 --- a/data_chain/parser/service/parser_service.py +++ b/data_chain/parser/service/parser_service.py @@ -1,11 +1,11 @@ import shutil import os +from typing import List, Dict import traceback -from data_chain.logger.logger import logger as logging from data_chain.apps.service.embedding_service import Vectorize from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.model_manager import ModelManager -from data_chain.models.constant import OssConstant, embedding_model_out_dimensions +from data_chain.models.constant import OssConstant, embedding_model_out_dimensions,ParseMethodEnum from data_chain.parser.handler.docx_parser import DocxService from data_chain.parser.handler.html_parser import HtmlService from data_chain.parser.handler.xlsx_parser import XlsxService @@ -13,12 +13,13 @@ from data_chain.parser.handler.txt_parser import TxtService from data_chain.parser.handler.pdf_parser import PdfService from data_chain.parser.handler.md_parser import MdService from data_chain.parser.handler.doc_parser import DocService -from data_chain.stores.postgres.postgres import DocumentEntity, ChunkEntity, ChunkLinkEntity, PostgresDB, ImageEntity -from data_chain.manager.document_manager import DocumentManager -from data_chain.manager.chunk_manager import ChunkManager, ChunkLinkManager +from data_chain.stores.postgres.postgres import ChunkEntity, TemporaryChunkEntity, ChunkLinkEntity, PostgresDB, ImageEntity, TemporaryVectorItemstEntity +from data_chain.manager.document_manager import DocumentManager, TemporaryDocumentManager +from data_chain.manager.chunk_manager import ChunkManager, ChunkLinkManager, TemporaryChunkManager from data_chain.manager.image_manager import ImageManager from data_chain.stores.minio.minio import MinIO -from data_chain.manager.vector_items_manager import VectorItemsManager +from data_chain.manager.vector_items_manager import VectorItemsManager, TemporaryVectorItemsManager +from data_chain.logger.logger import logger as logging class ParserService: @@ -26,7 +27,7 @@ class ParserService: def __init__(self): self.doc = None - async def parser(self, doc_id, file_path): + async def parser(self, doc_id, file_path, is_temporary_document=False): model_map = { ".docx": DocxService, ".doc": DocService, @@ -36,149 +37,197 @@ class ParserService: ".md": MdService, ".html": HtmlService, } - self.doc = await DocumentManager.select_by_id(doc_id) + if not is_temporary_document: + self.doc = await DocumentManager.select_by_id(doc_id) + else: + self.doc = await TemporaryDocumentManager.select_by_id(doc_id) file_extension = self.doc.extension try: if file_extension in model_map: model = model_map[file_extension]() # 判断文件类型 - llm_entity = await ModelManager.select_by_user_id(self.doc.user_id) - await model.init_service(llm_entity=llm_entity, - tokens=self.doc.chunk_size, - parser_method=self.doc.parser_method) + if not is_temporary_document: + llm_entity = await ModelManager.select_by_user_id(self.doc.user_id) + await model.init_service(llm_entity=llm_entity, + tokens=self.doc.chunk_size, + parser_method=self.doc.parser_method) + else: + await model.init_service(llm_entity=None, + tokens=self.doc.chunk_size, + parser_method=self.doc.parser_method) chunk_list, chunk_link_list, image_chunks = await model.parser(file_path) - for chunk in chunk_list: - chunk['doc_id'] = doc_id - chunk['user_id'] = self.doc.user_id - chunk['kb_id'] = self.doc.kb_id - for image_chunk in image_chunks: - image_chunk['doc_id'] = doc_id - image_chunk['user_id'] = self.doc.user_id + if not is_temporary_document: + for chunk in chunk_list: + chunk['doc_id'] = doc_id + chunk['user_id'] = self.doc.user_id + chunk['kb_id'] = self.doc.kb_id + for image_chunk in image_chunks: + image_chunk['doc_id'] = doc_id + image_chunk['user_id'] = self.doc.user_id + else: + for chunk in chunk_list: + chunk['doc_id'] = doc_id + for image_chunk in image_chunks: + image_chunk['doc_id'] = doc_id else: logging.error(f"No service available for file type: {file_extension}") return {"chunk_list": [], "chunk_link_list": [], "image_chunks": []} except Exception as e: - logging.error(f'Fail with exception:{e}') - logging.error(f'Fail with exception traceback:{traceback.format_exc()}') + logging.error(f'fail with exception:{e}') + logging.error(f'fail with exception:{traceback.format_exc()}') raise e return {"chunk_list": chunk_list, "chunk_link_list": chunk_link_list, "image_chunks": image_chunks} @staticmethod - async def upload_chunks_to_pg(chunks): - if len(chunks)==0: + async def upload_chunks_to_pg(chunks, is_temporary_document=False): + if len(chunks) == 0: return try: - image_entity_list=await ImageManager.query_image_by_doc_id(chunks[0]['doc_id']) - for image_entity in image_entity_list: - MinIO.delete_object(OssConstant.MINIO_BUCKET_PICTURE, str(image_entity.id)) - await ChunkManager.delete_by_document_ids([chunks[0]['doc_id']]) + if not is_temporary_document: + image_entity_list=await ImageManager.query_image_by_doc_id(chunks[0]['doc_id']) + for image_entity in image_entity_list: + MinIO.delete_object(OssConstant.MINIO_BUCKET_PICTURE, str(image_entity.id)) + await ChunkManager.delete_by_document_ids([chunks[0]['doc_id']]) + else: + await TemporaryChunkManager.delete_by_temporary_document_ids([chunks[0]['doc_id']]) except Exception as e: logging.error(f"Failed to delete chunk: {e}") raise e try: - for chunk in chunks: - chunk_entity = ChunkEntity( - id=chunk['id'], - kb_id=chunk['kb_id'], - user_id=chunk['user_id'], - document_id=chunk['doc_id'], - text=chunk['text'], - tokens=chunk['tokens'], - type=chunk['type'], - global_offset=chunk['global_offset'], - local_offset=chunk['local_offset'], - enabled=chunk['enabled'], - status=chunk['status'] - ) - try: - await ChunkManager.insert_chunk(chunk_entity) - except Exception as e: - logging.error(f"Failed to upload chunk: {e}") - raise e + if not is_temporary_document: + chunk_entity_list = [] + for chunk in chunks: + chunk_entity = ChunkEntity( + id=chunk['id'], + kb_id=chunk['kb_id'], + user_id=chunk['user_id'], + document_id=chunk['doc_id'], + text=chunk['text'], + tokens=chunk['tokens'], + type=chunk['type'], + global_offset=chunk['global_offset'], + local_offset=chunk['local_offset'], + enabled=chunk['enabled'], + status=chunk['status'] + ) + chunk_entity_list.append(chunk_entity) + await ChunkManager.insert_chunks(chunk_entity_list) + else: + chunk_entity_list = [] + for chunk in chunks: + chunk_entity = TemporaryChunkEntity( + id=chunk['id'], + document_id=chunk['doc_id'], + text=chunk['text'], + tokens=chunk['tokens'], + type=chunk['type'], + global_offset=chunk['global_offset'], + ) + chunk_entity_list.append(chunk_entity) + await TemporaryChunkManager.insert_temprorary_chunks(chunk_entity_list) except Exception as e: logging.error(f"Failed to upload chunk: {e}") + raise e @staticmethod - async def upload_chunk_links_to_pg(chunk_links: dict): + async def upload_chunk_links_to_pg(chunk_links: List[Dict], is_temporary_document: bool = False): try: - for chunk_link in chunk_links: - chunk_link_entity = ChunkLinkEntity( - id=chunk_link['id'], - chunk_a_id=chunk_link['chunk_a'], - chunk_b_id=chunk_link['chunk_b'], - type=chunk_link['type'], - ) - try: - # 插入数据库 - await ChunkLinkManager.insert_chunk_link(chunk_link_entity) - except Exception as e: - logging.error(f"Failed to upload chunk: {e}") - raise e + chunk_link_entity_list=[] + if not is_temporary_document: + for chunk_link in chunk_links: + chunk_link_entity = ChunkLinkEntity( + id=chunk_link['id'], + chunk_a_id=chunk_link['chunk_a'], + chunk_b_id=chunk_link['chunk_b'], + type=chunk_link['type'], + ) + chunk_link_entity_list.append(chunk_link_entity) + await ChunkLinkManager.insert_chunk_links(chunk_link_entity_list) except Exception as e: logging.error(f"Failed to upload chunk: {e}") + raise e @staticmethod - async def upload_images_to_minio(images): + async def upload_images_to_minio(images, is_temporary_document: bool = False): output_dir = None try: - for image in images: - output_dir = os.path.join(OssConstant.PARSER_SAVE_FOLDER, str(image['id'])) - output_path = os.path.join(output_dir, str(image['id'])+'.'+image['extension']) - await MinIO.put_object(OssConstant.MINIO_BUCKET_PICTURE, str(image['id']), output_path) + if not is_temporary_document: + for image in images: + output_dir = os.path.join(OssConstant.PARSER_SAVE_FOLDER, str(image['id'])) + output_path = os.path.join(output_dir, str(image['id'])+'.'+image['extension']) + await MinIO.put_object(OssConstant.MINIO_BUCKET_PICTURE, str(image['id']), output_path) except Exception as e: logging.error(f"Failed to upload image: {e}") finally: - if output_dir and os.path.exists(output_dir): - shutil.rmtree(output_dir) + for image in images: + output_dir = os.path.join(OssConstant.PARSER_SAVE_FOLDER, str(image['id'])) + if output_dir and os.path.exists(output_dir): + shutil.rmtree(output_dir) @staticmethod - async def upload_images_to_pg(images): + async def upload_images_to_pg(images, is_temporary_document: bool = False): try: - for image in images: - image_entity = ImageEntity(id=image['id'], - chunk_id=image['chunk_id'], - document_id=image['doc_id'], - user_id=image['user_id'], - ) - try: - await ImageManager.add_image(image_entity) - except Exception as e: - logging.error(f"Failed to upload image: {e}") - raise e + image_entity_list=[] + if not is_temporary_document: + for image in images: + image_entity = ImageEntity(id=image['id'], + chunk_id=image['chunk_id'], + document_id=image['doc_id'], + user_id=image['user_id'], + ) + image_entity_list.append(image_entity) + await ImageManager.add_images(image_entity_list) except Exception as e: logging.error(f"Failed to upload image: {e}") + raise e @staticmethod - async def embedding_chunks(chunks): + async def embedding_chunks(chunks, is_temporary_document: bool = False): try: vectors = [] - for chunk in chunks: - vectors.append({'chunk_id': chunk['id'], - 'doc_id': chunk['doc_id'], - 'kb_id': chunk['kb_id'], - 'user_id': chunk['user_id'], - 'vector': await Vectorize.vectorize_embedding(chunk['text']), - 'enabled': chunk['enabled'], - }) + if not is_temporary_document: + for chunk in chunks: + vectors.append({'chunk_id': chunk['id'], + 'doc_id': chunk['doc_id'], + 'kb_id': chunk['kb_id'], + 'user_id': chunk['user_id'], + 'vector': await Vectorize.vectorize_embedding(chunk['text']), + 'enabled': chunk['enabled'], + }) + else: + for chunk in chunks: + vectors.append({'chunk_id': chunk['id'], + 'doc_id': chunk['doc_id'], + 'vector': await Vectorize.vectorize_embedding(chunk['text']), + }) return vectors except Exception as e: logging.error(f"Failed to embedding chunk: {e}") raise e @staticmethod - async def insert_vectors_to_pg(vectors): - if len(vectors)==0: + async def insert_vectors_to_pg(vectors, is_temporary_document=False): + if len(vectors) == 0: return try: - # if True: - doc = await DocumentManager.select_by_id(vectors[0]['doc_id']) - kb = await KnowledgeBaseManager.select_by_id(doc.kb_id) - vector_items_table = await PostgresDB.get_dynamic_vector_items_table( - str(kb.vector_items_id), - embedding_model_out_dimensions[kb.embedding_model] - ) - await PostgresDB.create_table(vector_items_table) - for vect in vectors: - await VectorItemsManager.add(vector_items_table, vect) + if not is_temporary_document: + doc = await DocumentManager.select_by_id(vectors[0]['doc_id']) + kb = await KnowledgeBaseManager.select_by_id(doc.kb_id) + vector_items_table = await PostgresDB.get_dynamic_vector_items_table( + str(kb.vector_items_id), + embedding_model_out_dimensions[kb.embedding_model] + ) + await PostgresDB.create_table(vector_items_table) + await VectorItemsManager.add_all(vector_items_table, vectors) + else: + vector_entity_list=[] + for vector in vectors: + vector_entity_list.append( + TemporaryVectorItemstEntity( + document_id=vector['doc_id'], + chunk_id=vector['chunk_id'], + vector=vector['vector']) + ) + await TemporaryVectorItemsManager.add_all(vector_entity_list) except Exception as e: logging.error(f"Failed to upload chunk: {e}") raise e diff --git a/data_chain/parser/tools/ocr.py b/data_chain/parser/tools/ocr.py index ddf5111..f553008 100644 --- a/data_chain/parser/tools/ocr.py +++ b/data_chain/parser/tools/ocr.py @@ -40,18 +40,17 @@ class BaseOCR: try: # get config results = self.model.ocr(image) - logging.info(f"OCR job down {results}") return results except Exception as e: logging.error(f"OCR job error {e}") - raise [[None]] + return [[None]] @staticmethod def get_text_from_ocr_results(ocr_results): results = '' if ocr_results[0] is None or len(ocr_results)==0: return '' - if ocr_results[0][0] is None: + if ocr_results[0][0] is None or len(ocr_results[0][0])==0: return '' try: for result in ocr_results[0][0]: @@ -101,9 +100,9 @@ class BaseOCR: - text:图片组对应的前后文 """ try: - user_call = '请详细输出图片的总结,不要输出其他内容' + user_call = '请详细输出图片的摘要,不要输出其他内容' split_images = [] - max_tokens = self.max_tokens // 2 + max_tokens = self.max_tokens // 5*2 for image in image_results: split_result = self.split_list(image, max_tokens) split_images.append(split_result) @@ -127,7 +126,7 @@ class BaseOCR: answer = front_image_description return answer except Exception as e: - logging.info(f"Imnage ocr result improve error due to: {e}") + logging.error(f'OCR result improve error due to: {e}') return "" async def run(self, image, text): diff --git a/data_chain/stores/minio/minio.py b/data_chain/stores/minio/minio.py index 6cd805c..ab35459 100644 --- a/data_chain/stores/minio/minio.py +++ b/data_chain/stores/minio/minio.py @@ -19,9 +19,6 @@ class MinIO(): found = client.bucket_exists(OssConstant.MINIO_BUCKET_DOCUMENT) if not found: client.make_bucket(OssConstant.MINIO_BUCKET_DOCUMENT) - found = client.bucket_exists(OssConstant.MINIO_BUCKET_EXPORTZIP) - if not found: - client.make_bucket(OssConstant.MINIO_BUCKET_EXPORTZIP) found = client.bucket_exists(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE) if not found: client.make_bucket(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE) diff --git a/data_chain/stores/postgres/postgres.py b/data_chain/stores/postgres/postgres.py index fd8b4f9..16afba1 100644 --- a/data_chain/stores/postgres/postgres.py +++ b/data_chain/stores/postgres/postgres.py @@ -12,7 +12,7 @@ from sqlalchemy.orm import declarative_base, relationship from data_chain.config.config import config from data_chain.models.api import CreateKnowledgeBaseRequest - +from data_chain.models.constant import KnowledgeStatusEnum,ParseMethodEnum Base = declarative_base() @@ -46,8 +46,8 @@ class ModelEntity(Base): encrypted_openai_api_key = Column(String) encrypted_config = Column(String) max_tokens = Column(Integer) - create_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) - update_time = Column( + created_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) + updated_time = Column( TIMESTAMP(timezone=True), server_default=func.current_timestamp(), onupdate=func.current_timestamp()) @@ -58,16 +58,16 @@ class KnowledgeBaseEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) user_id = Column(UUID, ForeignKey('users.id', ondelete="CASCADE")) # 用户id - name = Column(String) # 资产名 + name = Column(String,default='') # 知识库名资产名 language = Column(String, default='zh') # 资产文档语言 - description = Column(String) # 资产描述 + description = Column(String,default='') # 资产描述 embedding_model = Column(String) # 资产向量化模型 - document_number = Column(Integer) # 资产文档个数 - document_size = Column(Integer) # 资产下所有文档大小(TODO: 单位kb或者字节) - default_parser_method = Column(String) # 默认解析方法 - default_chunk_size = Column(Integer) # 默认分块大小 + document_number = Column(Integer,default=0) # 资产文档个数 + document_size = Column(Integer,default=0) # 资产下所有文档大小(TODO: 单位kb或者字节) + default_parser_method = Column(String,default=ParseMethodEnum.GENERAL) # 默认解析方法 + default_chunk_size = Column(Integer,default=1024) # 默认分块大小 vector_items_id = Column(UUID, default=uuid4) # 向量表id - status = Column(String) + status = Column(String,default=KnowledgeStatusEnum.IDLE) created_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) updated_time = Column( TIMESTAMP(timezone=True), @@ -138,10 +138,10 @@ class ImageEntity(Base): id = Column(UUID, default=uuid4, primary_key=True) user_id = Column(UUID) document_id = Column(UUID) - chunk_id = Column(UUID,ForeignKey('chunk.id', ondelete="CASCADE")) + chunk_id = Column(UUID, ForeignKey('chunk.id', ondelete="CASCADE")) extension = Column(String) # 文件后缀 - create_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) - update_time = Column( + created_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) + updated_time = Column( TIMESTAMP(timezone=True), server_default=func.current_timestamp(), onupdate=func.current_timestamp()) @@ -213,6 +213,37 @@ class TaskStatusReportEntity(Base): server_default=func.current_timestamp(), onupdate=func.current_timestamp()) + +class TemporaryDocumentEntity(Base): + __tablename__ = 'temporary_document' + id = Column(UUID, default=uuid4, primary_key=True) + name = Column(String) + extension = Column(String) + bucket_name=Column(String) + parser_method=Column(String) + chunk_size = Column(Integer) # 文档分块大小 + status=Column(String) + created_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) + + +class TemporaryChunkEntity(Base): + __tablename__ = 'temporary_chunk' + id = Column(UUID, default=uuid4, primary_key=True) + document_id = Column(UUID, ForeignKey('temporary_document.id', ondelete="CASCADE")) + text = Column(String) + tokens = Column(Integer) + type = Column(String) + global_offset = Column(Integer) + created_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) + + +class TemporaryVectorItemstEntity(Base): + __tablename__ = 'temporary_vector_items' + id = Column(UUID, default=uuid4, primary_key=True) + document_id = Column(UUID) + chunk_id = Column(UUID, ForeignKey('temporary_chunk.id', ondelete="CASCADE")) + vector = Column(Vector(1024)) + created_time = Column(TIMESTAMP(timezone=True), nullable=True, server_default=func.current_timestamp()) # 反射加载 chunk 表的元数据 @@ -237,21 +268,16 @@ class PostgresDB: @classmethod async def get_dynamic_vector_items_table(cls, uuid_str, vector_dim): - st = time.time() # 使用同步引擎 - sync_engine = create_engine(config['DATABASE_URL'].replace("postgresql+asyncpg", "postgresql"), echo=False, pool_recycle=300, pool_pre_ping=True) - logging.info(f'数据库同步引擎创建耗时:{time.time()-st}') + sync_engine = create_engine( + config['DATABASE_URL'].replace("postgresql+asyncpg", "postgresql"), + echo=False, pool_recycle=300, pool_pre_ping=True) # 反射加载 chunk 表的元数据 - st = time.time() chunk_table = reflect_chunk_table(sync_engine) - logging.info(f'反射表格耗时:{time.time()-st}') - st = time.time() metadata = MetaData() - logging.info(f'关联原始数据耗时:{time.time()-st}') - st = time.time() # 动态创建表 vector_items_table = Table( f'vector_items_{uuid_str}', metadata, @@ -267,9 +293,7 @@ class PostgresDB: server_default=func.current_timestamp(), onupdate=func.current_timestamp()), ) - logging.info(f'动态构建向量存储表耗时:{time.time()-st}') - st = time.time() # 动态创建索引 index = Index( f'general_text_vector_index_{uuid_str}', @@ -278,19 +302,14 @@ class PostgresDB: postgresql_with={'m': 16, 'ef_construction': 200}, postgresql_ops={'vector': 'vector_cosine_ops'} ) - logging.info(f'创建向量化索引耗时:{time.time()-st}') - st = time.time() # 将索引添加到表定义中 vector_items_table.append_constraint(index) - logging.info(f'添加向量化索引耗时:{time.time()-st}') - st = time.time() # 创建表 with sync_engine.begin() as conn: metadata.create_all(conn) - logging.info(f'同步创建表格耗时:{time.time()-st}') - + return vector_items_table @classmethod diff --git a/requirements.txt b/requirements.txt index e893031..4897c54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,4 +33,7 @@ chardet==5.2.0 PyMuPDF==1.24.13 python-multipart==0.0.9 asyncpg==0.29.0 -psycopg2-binary==2.9.9 \ No newline at end of file +psycopg2-binary==2.9.9 +openpyxl==3.1.2 +beautifulsoup4==4.12.3 +tiktoken==0.8.0 \ No newline at end of file -- Gitee