From 60934347d823971c7a13fa0af2739c1ddbc12d2c Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 29 May 2025 20:26:45 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=B8=B4=E6=97=B6?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E4=B8=8A=E4=BC=A0=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/app.py | 21 ++- .../base/task/worker/acc_testing_worker.py | 2 +- .../base/task/worker/export_dataset_worker.py | 2 +- .../task/worker/generate_dataset_worker.py | 2 +- .../base/task/worker/import_dataset_worker.py | 2 +- .../worker/import_knowledge_base_worker.py | 4 +- .../base/task/worker/parse_document_worker.py | 2 +- data_chain/apps/router/document.py | 34 ++++- data_chain/apps/service/chunk_service.py | 21 +-- data_chain/apps/service/document_service.py | 135 +++++++++++++++++- .../apps/service/knwoledge_base_service.py | 4 +- data_chain/entities/common.py | 3 +- data_chain/entities/request_data.py | 23 ++- data_chain/entities/response_data.py | 17 +++ data_chain/manager/document_manager.py | 16 +++ data_chain/parser/parse_result.py | 2 +- 16 files changed, 259 insertions(+), 31 deletions(-) diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index 64fdf4e..3f28e26 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -8,7 +8,7 @@ import os import shutil from data_chain.config.config import config from data_chain.logger.logger import logger as logging -from data_chain.entities.common import actions, DEFAULt_DOC_TYPE_ID +from data_chain.entities.common import actions, DEFAULT_DOC_TYPE_ID from data_chain.apps.router import ( team, knowledge_base, @@ -55,9 +55,16 @@ from data_chain.rag import ( doc2chunk_bfs_searcher, enhanced_by_llm_searcher ) -from data_chain.stores.database.database import DataBase, ActionEntity, DocumentTypeEntity +from data_chain.stores.database.database import ( + DataBase, + ActionEntity, + KnowledgeBaseEntity, + DocumentTypeEntity +) from data_chain.manager.role_manager import RoleManager +from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_type_manager import DocumentTypeManager +from data_chain.entities.enum import ParseMethod from data_chain.entities.common import ( DOC_PATH_IN_OS, EXPORT_KB_PATH_IN_OS, @@ -96,9 +103,17 @@ async def add_acitons(): await RoleManager.add_action(action_entity) +async def add_knowledge_base(): + knowledge_base_entity = KnowledgeBaseEntity( + id=DEFAULT_DOC_TYPE_ID, + default_parse_method=ParseMethod.OCR.value + ) + await KnowledgeBaseManager.add_knowledge_base(knowledge_base_entity) + + async def add_document_type(): document_type_entity = DocumentTypeEntity( - id=DEFAULt_DOC_TYPE_ID, + id=DEFAULT_DOC_TYPE_ID, name="default", ) await DocumentTypeManager.add_document_type(document_type_entity) diff --git a/data_chain/apps/base/task/worker/acc_testing_worker.py b/data_chain/apps/base/task/worker/acc_testing_worker.py index 8ae3ceb..429ebf0 100644 --- a/data_chain/apps/base/task/worker/acc_testing_worker.py +++ b/data_chain/apps/base/task/worker/acc_testing_worker.py @@ -13,7 +13,7 @@ from data_chain.apps.base.task.worker.base_worker import BaseWorker from data_chain.llm.llm import LLM from data_chain.rag.base_searcher import BaseSearcher from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus, DataSetStatus, QAStatus, TestingStatus, TestCaseStatus -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, TESTING_REPORT_PATH_IN_OS, TESTING_REPORT_PATH_IN_MINIO +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID, TESTING_REPORT_PATH_IN_OS, TESTING_REPORT_PATH_IN_MINIO from data_chain.parser.parse_result import ParseResult, ParseNode from data_chain.parser.tools.token_tool import TokenTool from data_chain.parser.handler.json_parser import JsonParser diff --git a/data_chain/apps/base/task/worker/export_dataset_worker.py b/data_chain/apps/base/task/worker/export_dataset_worker.py index 311d5b5..8d9b081 100644 --- a/data_chain/apps/base/task/worker/export_dataset_worker.py +++ b/data_chain/apps/base/task/worker/export_dataset_worker.py @@ -12,7 +12,7 @@ from data_chain.logger.logger import logger as logging from data_chain.apps.base.task.worker.base_worker import BaseWorker from data_chain.llm.llm import LLM from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus, DataSetStatus, QAStatus -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, EXPORT_DATASET_PATH_IN_OS, EXPORT_DATASET_PATH_IN_MINIO +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID, EXPORT_DATASET_PATH_IN_OS, EXPORT_DATASET_PATH_IN_MINIO from data_chain.parser.parse_result import ParseResult, ParseNode from data_chain.parser.tools.token_tool import TokenTool from data_chain.parser.handler.json_parser import JsonParser diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index b297c5a..530cbaf 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -12,7 +12,7 @@ from data_chain.logger.logger import logger as logging from data_chain.apps.base.task.worker.base_worker import BaseWorker from data_chain.llm.llm import LLM from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus, DataSetStatus, QAStatus -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID from data_chain.parser.tools.token_tool import TokenTool from data_chain.manager.task_manager import TaskManager from data_chain.manager.document_manager import DocumentManager diff --git a/data_chain/apps/base/task/worker/import_dataset_worker.py b/data_chain/apps/base/task/worker/import_dataset_worker.py index 1442c64..02451f5 100644 --- a/data_chain/apps/base/task/worker/import_dataset_worker.py +++ b/data_chain/apps/base/task/worker/import_dataset_worker.py @@ -11,7 +11,7 @@ from data_chain.logger.logger import logger as logging from data_chain.apps.base.task.worker.base_worker import BaseWorker from data_chain.llm.llm import LLM from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus, DataSetStatus, QAStatus, ChunkType -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, IMPORT_DATASET_PATH_IN_OS, IMPORT_DATASET_PATH_IN_MINIO +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID, IMPORT_DATASET_PATH_IN_OS, IMPORT_DATASET_PATH_IN_MINIO from data_chain.parser.parse_result import ParseResult, ParseNode from data_chain.parser.tools.token_tool import TokenTool from data_chain.parser.handler.json_parser import JsonParser diff --git a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py index ac69907..b7dcc5e 100644 --- a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py +++ b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py @@ -9,7 +9,7 @@ from data_chain.logger.logger import logger as logging from data_chain.apps.service.task_queue_service import TaskQueueService from data_chain.apps.base.task.worker.base_worker import BaseWorker from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, IMPORT_KB_PATH_IN_OS, DOC_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID, IMPORT_KB_PATH_IN_OS, DOC_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO from data_chain.manager.task_manager import TaskManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_type_manager import DocumentTypeManager @@ -139,7 +139,7 @@ class ImportKnowledgeBaseWorker(BaseWorker): doc_path = os.path.join(doc_download_path, doc_name) if not os.path.exists(doc_path): continue - doc_type_id = doc_types_old_id_map_to_new_id.get(doc_config.get("type_id"), DEFAULt_DOC_TYPE_ID) + doc_type_id = doc_types_old_id_map_to_new_id.get(doc_config.get("type_id"), DEFAULT_DOC_TYPE_ID) document_entity = DocumentEntity( team_id=kb_entity.team_id, kb_id=kb_entity.id, diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 7b9a75e..731a067 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -20,7 +20,7 @@ from data_chain.config.config import config from data_chain.logger.logger import logger as logging from data_chain.apps.base.task.worker.base_worker import BaseWorker from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, ParseMethod, DocumentStatus, ChunkStatus, ImageStatus, DocParseRelutTopology, ChunkParseTopology, ChunkType -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, REPORT_PATH_IN_MINIO, DOC_PATH_IN_MINIO, DOC_PATH_IN_OS, IMAGE_PATH_IN_MINIO +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID, REPORT_PATH_IN_MINIO, DOC_PATH_IN_MINIO, DOC_PATH_IN_OS, IMAGE_PATH_IN_MINIO from data_chain.manager.task_manager import TaskManager from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_type_manager import DocumentTypeManager diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py index b7e9736..1e98523 100644 --- a/data_chain/apps/router/document.py +++ b/data_chain/apps/router/document.py @@ -10,7 +10,10 @@ from typing import Annotated from uuid import UUID from data_chain.entities.request_data import ( ListDocumentRequest, - UpdateDocumentRequest + UpdateDocumentRequest, + GetTemporaryDocumentStatusRequest, + UploadTemporaryRequest, + DeleteTemporaryDocumentRequest ) from data_chain.entities.response_data import ( @@ -20,7 +23,10 @@ from data_chain.entities.response_data import ( UploadDocumentResponse, ParseDocumentResponse, UpdateDocumentResponse, - DeleteDocumentResponse + DeleteDocumentResponse, + GetTemporaryDocumentStatusResponse, + UploadTemporaryDocumentResponse, + DeleteTemporaryDocumentResponse ) from data_chain.apps.service.session_service import get_user_sub, verify_user from data_chain.apps.service.router_service import get_route_info @@ -152,3 +158,27 @@ async def delete_docs_by_ids( raise Exception("用户没有权限删除该文档") await DocumentService.delete_docs_by_ids(doc_ids) return DeleteDocumentResponse(result=doc_ids) + + +@router.get('/temporary/status', response_class=GetTemporaryDocumentStatusResponse, dependencies=[Depends(verify_user)]) +async def get_temporary_docs_status( + user_sub: Annotated[str, Depends(get_user_sub)], + req: Annotated[GetTemporaryDocumentStatusRequest, Body()]): + doc_status_list = await DocumentService.get_temporary_docs_status(user_sub, req.ids) + return GetTemporaryDocumentStatusResponse(result=doc_status_list) + + +@router.post('/temporary/parser', response_model=UploadTemporaryDocumentResponse, dependencies=[Depends(verify_user)]) +async def upload_temporary_docs( + user_sub: Annotated[str, Depends(get_user_sub)], + req: Annotated[UploadTemporaryRequest, Body()]): + doc_ids = await DocumentService.upload_temporary_docs(user_sub, req) + return UploadTemporaryDocumentResponse(result=doc_ids) + + +@router.get('/temporary/delete', response_model=DeleteTemporaryDocumentResponse, dependencies=[Depends(verify_user)]) +async def delete_temporary_docs( + user_sub: Annotated[str, Depends(get_user_sub)], + req: Annotated[DeleteTemporaryDocumentRequest, Body()]): + doc_ids = await DocumentService.delete_temporary_docs(user_sub, req.ids) + return DeleteTemporaryDocumentResponse(result=doc_ids) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index 5c7fd21..0693c2e 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -30,7 +30,7 @@ from data_chain.manager.task_report_manager import TaskReportManager from data_chain.stores.database.database import DocumentEntity from data_chain.stores.minio.minio import MinIO from data_chain.entities.enum import ParseMethod, DataSetStatus, DocumentStatus, TaskType -from data_chain.entities.common import DOC_PATH_IN_OS, DOC_PATH_IN_MINIO, DEFAULt_DOC_TYPE_ID +from data_chain.entities.common import DOC_PATH_IN_OS, DOC_PATH_IN_MINIO, DEFAULT_KNOWLEDGE_BASE_ID, DEFAULT_DOC_TYPE_ID from data_chain.logger.logger import logger as logging from data_chain.rag.base_searcher import BaseSearcher from data_chain.parser.tools.token_tool import TokenTool @@ -79,16 +79,17 @@ class ChunkService: logging.error("[ChunkService] 搜索分片,查询条件: %s", req) chunk_entities = [] for kb_id in req.kb_ids: - try: - if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base( - user_sub, kb_id, action)): - err = f"用户没有权限访问知识库中的块,知识库ID: {kb_id}" - logging.error("[ChunkService] %s", err) + if kb_id != DEFAULT_KNOWLEDGE_BASE_ID: + try: + if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base( + user_sub, kb_id, action)): + err = f"用户没有权限访问知识库中的块,知识库ID: {kb_id}" + logging.error("[ChunkService] %s", err) + continue + except Exception as e: + err = f"验证用户对知识库的操作权限失败,error: {e}" + logging.exception(err) continue - except Exception as e: - err = f"验证用户对知识库的操作权限失败,error: {e}" - logging.exception(err) - continue try: chunk_entities += await BaseSearcher.search(req.search_method.value, kb_id, req.query, 2*req.top_k, req.doc_ids, req.banned_ids) except Exception as e: diff --git a/data_chain/apps/service/document_service.py b/data_chain/apps/service/document_service.py index 320653b..919ee5f 100644 --- a/data_chain/apps/service/document_service.py +++ b/data_chain/apps/service/document_service.py @@ -7,11 +7,13 @@ import shutil import os from data_chain.entities.request_data import ( ListDocumentRequest, + UploadTemporaryRequest, UpdateDocumentRequest ) from data_chain.entities.response_data import ( Task, Document, + DOC_STATUS, ListDocumentMsg, ListDocumentResponse ) @@ -25,8 +27,8 @@ from data_chain.manager.task_manager import TaskManager from data_chain.manager.task_report_manager import TaskReportManager from data_chain.stores.database.database import DocumentEntity from data_chain.stores.minio.minio import MinIO -from data_chain.entities.enum import ParseMethod, DataSetStatus, DocumentStatus, TaskType -from data_chain.entities.common import DOC_PATH_IN_OS, DOC_PATH_IN_MINIO, REPORT_PATH_IN_MINIO, DEFAULt_DOC_TYPE_ID +from data_chain.entities.enum import ParseMethod, DataSetStatus, DocumentStatus, TaskType, TaskStatus +from data_chain.entities.common import DOC_PATH_IN_OS, DOC_PATH_IN_MINIO, REPORT_PATH_IN_MINIO, DEFAULT_KNOWLEDGE_BASE_ID, DEFAULT_DOC_TYPE_ID from data_chain.logger.logger import logger as logging @@ -146,6 +148,133 @@ class DocumentService: logging.exception("[DocumentService] %s", err) raise e + @staticmethod + async def get_temporary_docs_status(user_sub: str, doc_ids: list[uuid.UUID]) -> list[DOC_STATUS]: + """获取临时文档状态""" + doc_entities = await DocumentManager.list_document_by_doc_ids(doc_ids) + doc_ids = [] + for doc_entity in doc_entities: + if doc_entity.author_id != user_sub: + err = f"用户没有权限访问临时文档, 文档ID: {doc_entity.id}, 用户ID: {user_sub}" + logging.error("[DocumentService] %s", err) + continue + doc_ids.append(doc_entity.id) + task_entities = await TaskManager.list_current_tasks_by_op_ids(doc_ids) + task_dict = {task_entity.op_id: task_entity for task_entity in task_entities} + doc_status_list = [] + for doc_id in doc_ids: + task_entity = task_dict.get(doc_id, None) + if task_entity is not None: + doc_status = DOC_STATUS( + id=doc_id, + status=task_entity.status, + ) + else: + doc_status = DOC_STATUS( + id=doc_id, + status=TaskStatus.PENDING.value + ) + doc_status_list.append(doc_status) + return doc_status_list + + @staticmethod + async def upload_temporary_docs(user_sub: str, req: UploadTemporaryRequest): + """上传临时文档""" + try: + doc_entities = [] + for doc in req.document_list: + id = doc.id + bucket_name = doc.bucket + file_name = doc.name + extension = file_name.split('.')[-1] + if not extension: + extension = doc.type + tmp_path = os.path.join(DOC_PATH_IN_OS, str(id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + os.makedirs(tmp_path) + document_file_path = os.path.join(tmp_path, file_name) + flag = await MinIO.download_object( + bucket_name=bucket_name, + file_index=str(id), + file_path=document_file_path + ) + if not flag: + err = f"下载临时文档失败, 文档ID: {id}, 存储桶: {bucket_name}" + logging.error("[DocumentService] %s", err) + continue + document = await DocumentManager.get_document_by_doc_id(id) + if document is not None: + err = f"文档已存在, 文档ID: {id}" + logging.error("[DocumentService] %s", err) + continue + doc_entity = DocumentEntity( + id=id, + kb_id=DEFAULT_KNOWLEDGE_BASE_ID, + author_id=user_sub, + author_name=user_sub, + name=file_name, + extension=extension, + size=os.path.getsize(document_file_path), + parse_method=ParseMethod.OCR.value, + parse_relut_topology=None, + chunk_size=1024, + type_id=DEFAULT_DOC_TYPE_ID, + enabled=True, + status=DataSetStatus.IDLE.value, + full_text='', + abstract='', + abstract_vector=None + ) + doc_entities.append(doc_entity) + await MinIO.put_object( + bucket_name=DOC_PATH_IN_MINIO, + file_index=str(id), + file_path=document_file_path + ) + except Exception as e: + err = f"上传临时文档失败, 错误信息: {e}" + logging.error("[DocumentService] %s", err) + raise e + index = 0 + while index < len(doc_entities): + try: + await DocumentManager.add_documents(doc_entities[index:index+1024]) + index += 1024 + except Exception as e: + err = f"上传文档失败, 文档名: {doc_entity.name}, 错误信息: {e}" + logging.error("[DocumentService] %s", err) + continue + for doc_entity in doc_entities: + await TaskQueueService.init_task(TaskType.DOC_PARSE.value, doc_entity.id) + doc_ids = [doc_entity.id for doc_entity in doc_entities] + await KnowledgeBaseManager.update_doc_cnt_and_doc_size(kb_id=DEFAULT_KNOWLEDGE_BASE_ID) + return doc_ids + + @staticmethod + async def delete_temporary_docs(user_sub: str, doc_ids: list[uuid.UUID]) -> list[uuid.UUID]: + """删除临时文档""" + try: + doc_entities = await DocumentManager.list_document_by_doc_ids(doc_ids) + doc_ids = [] + for doc_entity in doc_entities: + if doc_entity.author_id != user_sub: + err = f"用户没有权限删除临时文档, 文档ID: {doc_entity.id}, 用户ID: {user_sub}" + logging.error("[DocumentService] %s", err) + continue + doc_ids.append(doc_entity.id) + task_entities = await TaskManager.list_current_tasks_by_op_ids(doc_ids) + for task_entity in task_entities: + await TaskQueueService.stop_task(task_entity.id) + doc_entities = await DocumentManager.update_document_by_doc_ids( + doc_ids, {"status": DocumentStatus.DELETED.value}) + doc_ids = [doc_entity.id for doc_entity in doc_entities] + return doc_ids + except Exception as e: + err = "删除临时文档失败" + logging.exception("[DocumentService] %s", err) + raise e + @staticmethod async def upload_docs(user_sub: str, kb_id: uuid.UUID, docs: list[UploadFile]) -> list[uuid.UUID]: """上传文档""" @@ -193,7 +322,7 @@ class DocumentService: parse_method=kb_entity.default_parse_method, parse_relut_topology=None, chunk_size=kb_entity.default_chunk_size, - type_id=DEFAULt_DOC_TYPE_ID, + type_id=DEFAULT_DOC_TYPE_ID, enabled=True, status=DataSetStatus.IDLE.value, full_text='', diff --git a/data_chain/apps/service/knwoledge_base_service.py b/data_chain/apps/service/knwoledge_base_service.py index 9d3a922..8dcc3da 100644 --- a/data_chain/apps/service/knwoledge_base_service.py +++ b/data_chain/apps/service/knwoledge_base_service.py @@ -23,7 +23,7 @@ from data_chain.entities.response_data import ( from data_chain.apps.base.zip_handler import ZipHandler from data_chain.apps.service.task_queue_service import TaskQueueService from data_chain.entities.enum import Tokenizer, ParseMethod, TeamType, TeamStatus, KnowledgeBaseStatus, TaskType -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, default_roles, IMPORT_KB_PATH_IN_OS, EXPORT_KB_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID, default_roles, IMPORT_KB_PATH_IN_OS, EXPORT_KB_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO from data_chain.stores.database.database import TeamEntity, KnowledgeBaseEntity, DocumentTypeEntity from data_chain.stores.minio.minio import MinIO from data_chain.apps.base.convertor import Convertor @@ -300,7 +300,7 @@ class KnowledgeBaseService: delete_doc_type_ids = old_doc_type_ids - new_doc_type_ids add_doc_type_ids = new_doc_type_ids - old_doc_type_ids update_doc_type_ids = old_doc_type_ids & new_doc_type_ids - await DocumentManager.update_doc_type_by_kb_id(kb_id, delete_doc_type_ids, DEFAULt_DOC_TYPE_ID) + await DocumentManager.update_doc_type_by_kb_id(kb_id, delete_doc_type_ids, DEFAULT_DOC_TYPE_ID) doc_type_entities = [] for doc_type_id in add_doc_type_ids: doc_type_entity = DocumentTypeEntity( diff --git a/data_chain/entities/common.py b/data_chain/entities/common.py index 9d0ec02..7511840 100644 --- a/data_chain/entities/common.py +++ b/data_chain/entities/common.py @@ -1,5 +1,6 @@ import uuid -DEFAULt_DOC_TYPE_ID = uuid.UUID("00000000-0000-0000-0000-000000000000") +DEFAULT_DOC_TYPE_ID = uuid.UUID("00000000-0000-0000-0000-000000000000") +DEFAULT_KNOWLEDGE_BASE_ID = uuid.UUID("00000000-0000-0000-0000-000000000000") DEFAULt_DOC_TYPE_NAME = "default" actions = [ {'type': 'team', diff --git a/data_chain/entities/request_data.py b/data_chain/entities/request_data.py index 71d9dd2..ff39c16 100644 --- a/data_chain/entities/request_data.py +++ b/data_chain/entities/request_data.py @@ -22,7 +22,7 @@ from data_chain.entities.enum import ( TaskType, TaskStatus, OrderType) -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID class ListTeamRequest(BaseModel): @@ -116,13 +116,32 @@ class ListDocumentRequest(BaseModel): class UpdateDocumentRequest(BaseModel): doc_name: str = Field(default='这是一个默认的文档名称', min_length=1, max_length=150, alias="docName") - doc_type_id: uuid.UUID = Field(default=DEFAULt_DOC_TYPE_ID, description="文档类型的id", alias="docTypeId") + doc_type_id: uuid.UUID = Field(default=DEFAULT_DOC_TYPE_ID, description="文档类型的id", alias="docTypeId") parse_method: ParseMethod = Field( default=ParseMethod.GENERAL, description="知识库默认解析方法", alias="parseMethod") chunk_size: int = Field(default=512, description="知识库默认文件分块大小", alias="chunkSize", min=128, max=2048) enabled: bool = Field(default=True, description="文档是否启用") +class GetTemporaryDocumentStatusRequest(BaseModel): + ids: List[uuid.UUID] = Field(default=[], description="临时文档id列表", alias="ids") + + +class TemporaryDocument(BaseModel): + id: uuid.UUID = Field(description="临时文档id", alias="id") + name: str = Field(default='这是一个默认的临时文档名称', min_length=1, max_length=150, alias="name") + bucket: str = Field(default='default', description="临时文档存储的桶名称", alias="bucket") + type: str = Field(default='txt', description="临时文档的类型", alias="type") + + +class UploadTemporaryRequest(BaseModel): + document_list: List[TemporaryDocument] = Field(default=[], description="临时文档列表") + + +class DeleteTemporaryDocumentRequest(BaseModel): + ids: List[uuid.UUID] = Field(default=[], description="临时文档id列表", alias="ids") + + class ListChunkRequest(BaseModel): doc_id: uuid.UUID = Field(description="文档id", alias="docId") text: Optional[str] = Field(default=None, description="分块文本内容", alias="text") diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py index 776fa28..a03f00b 100644 --- a/data_chain/entities/response_data.py +++ b/data_chain/entities/response_data.py @@ -259,6 +259,23 @@ class UploadDocumentResponse(ResponseData): result: list[uuid.UUID] = Field(default=[], description="文档ID列表") +class DOC_STATUS(BaseModel): + id: uuid.UUID = Field(description="文档ID") + status: TaskStatus = Field(description="文档状态", alias="status") + + +class GetTemporaryDocumentStatusResponse(ResponseData): + result: list[DOC_STATUS] = Field(default=[], description="临时文档状态列表", alias="result") + + +class UploadTemporaryDocumentResponse(ResponseData): + """POST /doc/temporary 响应""" + result: list[uuid.UUID] = Field(default=[], description="临时文档ID列表") + + +class DeleteTemporaryDocumentResponse(ResponseData): + """DELETE /doc/temporary 响应""" + result: list[uuid.UUID] = Field(default=[], description="临时文档ID列表") class ParseDocumentResponse(ResponseData): """POST /doc/parse 响应""" result: list[uuid.UUID] = Field(default=[], description="文档ID列表") diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 4c1a540..d0894f9 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -195,6 +195,22 @@ class DocumentManager(): logging.exception("[DocumentManager] %s", err) raise e + @staticmethod + async def list_document_by_doc_ids(doc_ids: list[uuid.UUID]) -> List[DocumentEntity]: + """根据文档ID获取文档列表""" + try: + async with await DataBase.get_session() as session: + stmt = select(DocumentEntity).where( + and_(DocumentEntity.id.in_(doc_ids), + DocumentEntity.status != DocumentStatus.DELETED.value)) + result = await session.execute(stmt) + document_entities = result.scalars().all() + return document_entities + except Exception as e: + err = "获取文档列表失败" + logging.exception("[DocumentManager] %s", err) + raise e + @staticmethod async def get_document_by_doc_id(doc_id: uuid.UUID) -> DocumentEntity: """根据文档ID获取文档""" diff --git a/data_chain/parser/parse_result.py b/data_chain/parser/parse_result.py index 24dee22..965aa92 100644 --- a/data_chain/parser/parse_result.py +++ b/data_chain/parser/parse_result.py @@ -2,7 +2,7 @@ from typing import Any, Optional import uuid from pydantic import BaseModel, Field, validator, constr from data_chain.entities.enum import DocParseRelutTopology, ChunkParseTopology, ChunkType -from data_chain.entities.common import DEFAULt_DOC_TYPE_ID +from data_chain.entities.common import DEFAULT_DOC_TYPE_ID class ParseNode(BaseModel): -- Gitee From 1c84b6c446a492bc8a9f2632ea10bed19a28f1dc Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 29 May 2025 21:01:30 +0800 Subject: [PATCH 2/5] =?UTF-8?q?rag=E8=BF=94=E5=9B=9Echunk=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E8=BF=94=E5=9B=9E=E5=85=B3=E8=81=94=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E7=9A=84=E9=93=BE=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/service/chunk_service.py | 3 +++ data_chain/entities/response_data.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index 0693c2e..f6697bb 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -20,6 +20,7 @@ from data_chain.entities.response_data import ( from data_chain.apps.base.convertor import Convertor from data_chain.apps.service.task_queue_service import TaskQueueService from data_chain.apps.service.knwoledge_base_service import KnowledgeBaseService +from data_chain.apps.service.document_service import DocumentService from data_chain.manager.knowledge_manager import KnowledgeBaseManager from data_chain.manager.document_type_manager import DocumentTypeManager from data_chain.manager.document_manager import DocumentManager @@ -143,6 +144,8 @@ class ChunkService: chunk.text = TokenTool.compress_tokens(chunk.text) dc = DocChunk(docId=chunk_entity.doc_id, docName=chunk_entity.doc_name, chunks=[chunk]) search_chunk_msg.doc_chunks.append(dc) + for doc_chunk in search_chunk_msg.doc_chunks: + doc_chunk.doc_link = await DocumentService.generate_doc_download_url(doc_chunk.doc_id) return search_chunk_msg async def update_chunk_by_id(chunk_id: uuid.UUID, req: UpdateChunkRequest) -> uuid.UUID: diff --git a/data_chain/entities/response_data.py b/data_chain/entities/response_data.py index a03f00b..4b9c478 100644 --- a/data_chain/entities/response_data.py +++ b/data_chain/entities/response_data.py @@ -276,6 +276,8 @@ class UploadTemporaryDocumentResponse(ResponseData): class DeleteTemporaryDocumentResponse(ResponseData): """DELETE /doc/temporary 响应""" result: list[uuid.UUID] = Field(default=[], description="临时文档ID列表") + + class ParseDocumentResponse(ResponseData): """POST /doc/parse 响应""" result: list[uuid.UUID] = Field(default=[], description="文档ID列表") @@ -324,6 +326,7 @@ class DocChunk(BaseModel): """Post /chunk/search 数据结构""" doc_id: uuid.UUID = Field(description="文档ID", alias="docId") doc_name: str = Field(description="文档名称", alias="docName") + doc_link: str = Field(default="", description="文档链接", alias="docLink") chunks: list[Chunk] = Field(default=[], description="分片列表", alias="chunks") -- Gitee From b210e845dbdf60db38a0ae6ea8ac22a2760ac81b Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 30 May 2025 10:10:10 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E5=AE=8C=E5=96=84=E4=B8=B4=E6=97=B6?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E4=B8=8A=E4=BC=A0=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/router/document.py | 5 +++-- data_chain/apps/service/document_service.py | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py index 1e98523..44075d7 100644 --- a/data_chain/apps/router/document.py +++ b/data_chain/apps/router/document.py @@ -160,7 +160,8 @@ async def delete_docs_by_ids( return DeleteDocumentResponse(result=doc_ids) -@router.get('/temporary/status', response_class=GetTemporaryDocumentStatusResponse, dependencies=[Depends(verify_user)]) +@router.post('/temporary/status', response_class=GetTemporaryDocumentStatusResponse, dependencies=[ + Depends(verify_user)]) async def get_temporary_docs_status( user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[GetTemporaryDocumentStatusRequest, Body()]): @@ -176,7 +177,7 @@ async def upload_temporary_docs( return UploadTemporaryDocumentResponse(result=doc_ids) -@router.get('/temporary/delete', response_model=DeleteTemporaryDocumentResponse, dependencies=[Depends(verify_user)]) +@router.delete('/temporary/delete', response_model=DeleteTemporaryDocumentResponse, dependencies=[Depends(verify_user)]) async def delete_temporary_docs( user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[DeleteTemporaryDocumentRequest, Body()]): diff --git a/data_chain/apps/service/document_service.py b/data_chain/apps/service/document_service.py index 919ee5f..21aa64a 100644 --- a/data_chain/apps/service/document_service.py +++ b/data_chain/apps/service/document_service.py @@ -232,6 +232,8 @@ class DocumentService: file_index=str(id), file_path=document_file_path ) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) except Exception as e: err = f"上传临时文档失败, 错误信息: {e}" logging.error("[DocumentService] %s", err) -- Gitee From 32d2a8388da131e6c1b803d35867a6476776f491 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 30 May 2025 11:05:53 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93url=E6=9E=84=E5=BB=BA=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/config/config.py | 7 ++++++- data_chain/stores/database/database.py | 14 ++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/data_chain/config/config.py b/data_chain/config/config.py index a413651..1eb3c9d 100644 --- a/data_chain/config/config.py +++ b/data_chain/config/config.py @@ -25,7 +25,12 @@ class ConfigModel(DictBaseModel): # LOG METHOD LOG_METHOD: str = Field('stdout', description="日志记录方式") # Database - DATABASE_URL: str = Field(None, description="Database数据库链接url") + DATABASE_TYPE: str = Field(default="postgres", description="数据库类型") + DATABASE_HOST: str = Field(None, description="数据库地址") + DATABASE_PORT: int = Field(None, description="数据库端口") + DATABASE_USER: str = Field(None, description="数据库用户名") + DATABASE_PASSWORD: str = Field(None, description="数据库密码") + DATABASE_DB: str = Field(None, description="数据库名称") # MinIO MINIO_ENDPOINT: str = Field(None, description="MinIO连接地址") MINIO_ACCESS_KEY: str = Field(None, description="Minio认证ak") diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 6a9249a..9e7f2e7 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy import Index from uuid import uuid4 +import urllib.parse from data_chain.logger.logger import logger as logging from pgvector.sqlalchemy import Vector from sqlalchemy import Boolean, Column, ForeignKey, Integer, Float, String, func @@ -536,8 +537,17 @@ class TaskReportEntity(Base): class DataBase: + + # 对密码进行 URL 编码 + password = config['DATABASE_PASSWORD'] + encoded_password = urllib.parse.quote_plus(password) + + if config['DATABASE_TYPE'] == 'opengauss': + database_url = f"opengauss+asyncpg://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}" + else: + database_url = f"postgresql+asyncpg://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}" engine = create_async_engine( - config['DATABASE_URL'], + database_url, echo=False, pool_recycle=300, pool_pre_ping=True @@ -545,7 +555,7 @@ class DataBase: @classmethod async def init_all_table(cls): - if 'opengauss' in config['DATABASE_URL']: + if config['DATABASE_TYPE'] == 'opengauss': from sqlalchemy import event from opengauss_sqlalchemy.register_async import register_vector -- Gitee From 2d7533d70bceaa0c4d4a100d4791de3f42440af4 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 30 May 2025 15:04:27 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E5=AE=8C=E5=96=84=E4=B8=8D=E5=90=8C?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=A8=A1=E5=BC=8F=E4=B8=8B=E5=88=86?= =?UTF-8?q?=E8=AF=8D=E5=99=A8=E7=9A=84=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/manager/chunk_manager.py | 4 ++-- data_chain/manager/document_manager.py | 2 +- data_chain/stores/database/database.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 5387815..88e0dc2 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -208,14 +208,14 @@ class ChunkManager(): async with await DataBase.get_session() as session: kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) if kb_entity.tokenizer == Tokenizer.ZH.value: - if 'opengauss' in config['DATABASE_URL']: + if config['DATABASE_TYPE'].lower() == 'opengauss': tokenizer = 'chparser' else: tokenizer = 'zhparser' elif kb_entity.tokenizer == Tokenizer.EN.value: tokenizer = 'english' else: - if 'opengauss' in config['DATABASE_URL']: + if config['DATABASE_TYPE'].lower() == 'opengauss': tokenizer = 'chparser' else: tokenizer = 'zhparser' diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index d0894f9..242db0c 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -79,7 +79,7 @@ class DocumentManager(): kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) tokenizer = '' if kb_entity.tokenizer == Tokenizer.ZH.value: - if 'opengauss' in config['DATABASE_URL']: + if config['DATABASE_TYPE'].lower() == 'opengauss': tokenizer = 'chparser' else: tokenizer = 'zhparser' diff --git a/data_chain/stores/database/database.py b/data_chain/stores/database/database.py index 9e7f2e7..fae18cf 100644 --- a/data_chain/stores/database/database.py +++ b/data_chain/stores/database/database.py @@ -542,7 +542,7 @@ class DataBase: password = config['DATABASE_PASSWORD'] encoded_password = urllib.parse.quote_plus(password) - if config['DATABASE_TYPE'] == 'opengauss': + if config['DATABASE_TYPE'].lower() == 'opengauss': database_url = f"opengauss+asyncpg://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}" else: database_url = f"postgresql+asyncpg://{config['DATABASE_USER']}:{encoded_password}@{config['DATABASE_HOST']}:{config['DATABASE_PORT']}/{config['DATABASE_DB']}" -- Gitee