diff --git a/data_chain/apps/base/convertor/chunk_convertor.py b/data_chain/apps/base/convertor/chunk_convertor.py deleted file mode 100644 index 7a4a4a610c4f4ed29643c12a5fc900cd12c16695..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/convertor/chunk_convertor.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from typing import Dict, List -from data_chain.models.service import ChunkDTO -from data_chain.stores.postgres.postgres import ChunkEntity - - -class ChunkConvertor(): - @staticmethod - def convert_entity_to_dto(chunk_entity: ChunkEntity) -> ChunkDTO: - return ChunkDTO( - id=str(chunk_entity.id), - text=chunk_entity.text, - enabled=chunk_entity.enabled, - type=chunk_entity.type.split('.')[1] - ) diff --git a/data_chain/apps/base/convertor/document_convertor.py b/data_chain/apps/base/convertor/document_convertor.py deleted file mode 100644 index a204a3fc09887e45561ae17b89b2ac5e73155e18..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/convertor/document_convertor.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from data_chain.models.service import DocumentDTO, DocumentTypeDTO -from data_chain.stores.postgres.postgres import DocumentEntity, DocumentTypeEntity - - -class DocumentConvertor(): - - @staticmethod - def convert_entity_to_dto(document_entity: DocumentEntity, document_type_entity: DocumentTypeEntity - ) -> DocumentDTO: - document_type_dto = DocumentTypeDTO(id=str(document_type_entity.id), type=document_type_entity.type) - return DocumentDTO( - id=str(document_entity.id), - name=document_entity.name, - extension=document_entity.extension, - document_type=document_type_dto, - chunk_size=document_entity.chunk_size, - status=document_entity.status, - enabled=document_entity.enabled, - parser_method=document_entity.parser_method, - created_time=document_entity.created_time.strftime('%Y-%m-%d %H:%M') - ) diff --git a/data_chain/apps/base/convertor/knowledge_convertor.py b/data_chain/apps/base/convertor/knowledge_convertor.py deleted file mode 100644 index 313b68f2000b50290bb2cc01807df6c8c1ba3ac8..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/convertor/knowledge_convertor.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from typing import List -import uuid -from data_chain.models.service import DocumentTypeDTO, KnowledgeBaseDTO -from data_chain.stores.postgres.postgres import DocumentTypeEntity, KnowledgeBaseEntity -from data_chain.models.constant import KnowledgeStatusEnum - -class KnowledgeConvertor(): - - @staticmethod - def convert_dict_to_entity(tmp_dict:dict): - return KnowledgeBaseEntity( - name=tmp_dict['name'], - user_id=tmp_dict['user_id'], - language=tmp_dict['language'], - description=tmp_dict.get('description',''), - embedding_model=tmp_dict['embedding_model'], - document_number=0, - document_size=0, - vector_items_id=uuid.uuid4(), - status=KnowledgeStatusEnum.IDLE, - default_parser_method=tmp_dict['default_parser_method'], - default_chunk_size=tmp_dict['default_chunk_size']) - - @staticmethod - def convert_entity_to_dto(entity: KnowledgeBaseEntity, - document_type_entity_list: List[DocumentTypeEntity] - ) -> KnowledgeBaseDTO: - document_type_dto_list = [DocumentTypeDTO(id=str(document_type_entity.id), type=document_type_entity.type) - for document_type_entity in document_type_entity_list] - return KnowledgeBaseDTO( - id=str(entity.id), - name=entity.name, - language=entity.language, - description=entity.description, - embedding_model=entity.embedding_model, - default_parser_method=entity.default_parser_method, - default_chunk_size=entity.default_chunk_size, - document_count=entity.document_number, - status=entity.status, - document_size=entity.document_size, - document_type_list=document_type_dto_list, - created_time=entity.created_time.strftime('%Y-%m-%d %H:%M') - ) diff --git a/data_chain/apps/base/convertor/model_convertor.py b/data_chain/apps/base/convertor/model_convertor.py deleted file mode 100644 index 23a60e5c6ea189d8d0d7d233d64a2f4ed78f39ed..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/convertor/model_convertor.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from typing import Optional -import json -from data_chain.models.service import ModelDTO -from data_chain.stores.postgres.postgres import ModelEntity -from data_chain.apps.base.security.security import Security -from data_chain.config.config import ModelConfig - - -class ModelConvertor(): - @staticmethod - def convert_entity_to_dto(model_entity: Optional[ModelEntity] = None) -> ModelDTO: - if model_entity is None: - return ModelDTO() - return ModelDTO( - id=str(model_entity.id), - model_name=model_entity.model_name, - model_type=model_entity.model_type, - openai_api_base=model_entity.openai_api_base, - openai_api_key=Security.decrypt( - model_entity.encrypted_openai_api_key, - json.loads(model_entity.encrypted_config) - ), - max_tokens=model_entity.max_tokens, - is_online=model_entity.is_online - ) - - @staticmethod - def convert_config_to_entity(model_config: ModelConfig) -> ModelDTO: - if model_config is None: - return ModelEntity() - return ModelDTO( - id=str(model_config['MODEL_ID']), - model_name=model_config['MODEL_NAME'], - model_type=model_config['MODEL_TYPE'], - ) diff --git a/data_chain/apps/base/convertor/task_convertor.py b/data_chain/apps/base/convertor/task_convertor.py deleted file mode 100644 index f1f16d245ce5bc7cd0c6db2606e1b1750fb47bb2..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/convertor/task_convertor.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from typing import Dict, List -from data_chain.models.service import TaskDTO, TaskReportDTO -from data_chain.stores.postgres.postgres import TaskEntity, TaskStatusReportEntity - - -class TaskConvertor(): - @staticmethod - def convert_entity_to_dto(task_entity: TaskEntity, TaskStatusReportEntityList: List[TaskStatusReportEntity] = None) -> TaskDTO: - reports = [] - for task_status_report_entity in TaskStatusReportEntityList: - reports.append( - TaskReportDTO( - id=task_status_report_entity.id, - message=task_status_report_entity.message, - current_stage=task_status_report_entity.current_stage, - stage_cnt=task_status_report_entity.stage_cnt, - create_time=task_status_report_entity.created_time.strftime('%Y-%m-%d %H:%M') - )) - return TaskDTO( - id=task_entity.id, - type=task_entity.type, - retry=task_entity.retry, - status=task_entity.status, - reports=reports, - create_time=task_entity.created_time.strftime('%Y-%m-%d %H:%M') - ) diff --git a/data_chain/apps/base/document/zip_handler.py b/data_chain/apps/base/document/zip_handler.py deleted file mode 100644 index 2dab9103c24ca0601ef7ed6290de284f1fcfaa5f..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/document/zip_handler.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import zipfile -import os -import concurrent -from data_chain.logger.logger import logger as logging -import chardet - - - -class ZipHandler(): - - @staticmethod - def check_zip_file(zip_file_path,max_file_num=2048,max_file_size = 2*1024 * 1024 * 1024): - total_size = 0 - try: - to_zip_file = zipfile.ZipFile(zip_file_path) - if len(to_zip_file.filelist) > max_file_num: - logging.error(f"压缩文件{zip_file_path}的数量超过了上限") - return False - for file in to_zip_file.filelist: - total_size += file.file_size - if total_size > max_file_size: - logging.error(f"压缩文件{zip_file_path}的尺寸超过了上限") - return False - to_zip_file.namelist() - for member in to_zip_file.infolist(): - to_zip_file.open(member) - return True - except zipfile.BadZipFile: - logging.error(f"文件 {zip_file_path} 可能不是有效的ZIP文件.") - return False - except Exception as e: - logging.error(f"处理文件 {zip_file_path} 时出错: {e}") - return False - - @staticmethod - async def zip_dir(start_dir, zip_name): - def zip_dir_excutor(start_dir, zip_name): - with zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED) as zipf: - for root, dirs, files in os.walk(start_dir): - for file in files: - file_path = os.path.join(root, file) - file_path_in_zip = os.path.relpath(file_path, start_dir) - zipf.write(file_path, file_path_in_zip) - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(zip_dir_excutor, start_dir, zip_name) - return future.result() - - @staticmethod - async def unzip_file(zip_file_path, target_dir, files_to_extract=None): - def unzip_file_executor(zip_file_path, target_dir, files_to_extract=None): - if not os.path.exists(target_dir): - os.makedirs(target_dir) - - with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: - # 尝试自动检测文件名的编码 - sample_data = zip_ref.read(zip_ref.namelist()[0]) - detected_encoding = chardet.detect(sample_data)['encoding'] - - if files_to_extract is None: - zip_ref.extractall(target_dir) - else: - files_to_extract = set(files_to_extract) - for file_name in files_to_extract: - if file_name in zip_ref.namelist(): - zip_ref.extract(file_name, path=target_dir) - try: - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(unzip_file_executor, zip_file_path, target_dir, files_to_extract) - # 阻塞等待结果 - future.result() - except Exception as e: - logging.error(f"Error occurred while extracting files from {zip_file_path} due to {e}") - return False - - return True diff --git a/data_chain/apps/base/model/llm.py b/data_chain/apps/base/model/llm.py deleted file mode 100644 index eccc32c003e46ede5a750b37d3bd306d78b64518..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/model/llm.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import asyncio -import time -import re -import json -import tiktoken -from langchain_openai import ChatOpenAI -from langchain.schema import SystemMessage, HumanMessage -from data_chain.logger.logger import logger as logging - - -class LLM: - def __init__(self, openai_api_key, openai_api_base, model_name, max_tokens, request_timeout=60, temperature=0.1): - self.client = ChatOpenAI(model_name=model_name, - openai_api_base=openai_api_base, - openai_api_key=openai_api_key, - request_timeout=request_timeout, - max_tokens=max_tokens, - temperature=temperature) - - def assemble_chat(self, chat=None, system_call='', user_call=''): - if chat is None: - chat = [] - chat.append(SystemMessage(content=system_call)) - chat.append(HumanMessage(content=user_call)) - return chat - - async def nostream(self, chat, system_call, user_call): - chat = self.assemble_chat(chat, system_call, user_call) - response = await self.client.ainvoke(chat) - content = re.sub(r'.*?\n\n', '', response.content, flags=re.DOTALL) - return content - - async def data_producer(self, q: asyncio.Queue, history, system_call, user_call): - message = self.assemble_chat(history, system_call, user_call) - try: - async for frame in self.client.astream(message): - await q.put(frame.content) - except Exception as e: - await q.put(None) - import traceback - logging.error(f"Error in data producer due to: {traceback.format_exc()}") - logging.error(f"Error in data producer due to: {e}") - return - await q.put(None) - - async def stream(self, chat, system_call, user_call): - st = time.time() - q = asyncio.Queue(maxsize=10) - - # 启动生产者任务 - producer_task = asyncio.create_task(self.data_producer(q, chat, system_call, user_call)) - first_token_reach = False - enc = tiktoken.encoding_for_model("gpt-4") - input_tokens=len(enc.encode(system_call)) - output_tokens=0 - while True: - data = await q.get() - if data is None: - break - if not first_token_reach: - first_token_reach = True - logging.info(f"大模型回复第一个字耗时 = {time.time() - st}") - output_tokens+=len(enc.encode(data)) - yield "data: " + json.dumps( - {'content': data, - 'input_tokens':input_tokens, - 'output_tokens':output_tokens - }, ensure_ascii=False - ) + '\n\n' - await asyncio.sleep(0.03) # 使用异步 sleep - - yield "data: [DONE]" - logging.info(f"大模型回复耗时 = {time.time() - st}") diff --git a/data_chain/apps/base/security/security.py b/data_chain/apps/base/security/security.py deleted file mode 100644 index f6a8122a647f53878660f2dd03dc4cd3c5d50676..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/security/security.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - -import base64 -import binascii -import hashlib -import secrets - -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - -from data_chain.config.config import config - - -class Security: - - @staticmethod - def encrypt(plaintext: str) -> tuple[str, dict]: - """ - 加密公共方法 - :param plaintext: - :return: - """ - half_key1 = config['HALF_KEY1'] - - encrypted_work_key, encrypted_work_key_iv = Security._generate_encrypted_work_key( - half_key1) - encrypted_plaintext, encrypted_iv = Security._encrypt_plaintext(half_key1, encrypted_work_key, - encrypted_work_key_iv, plaintext) - del plaintext - secret_dict = { - "encrypted_work_key": encrypted_work_key, - "encrypted_work_key_iv": encrypted_work_key_iv, - "encrypted_iv": encrypted_iv, - "half_key1": half_key1 - } - return encrypted_plaintext, secret_dict - - @staticmethod - def decrypt(encrypted_plaintext: str, secret_dict: dict): - """ - 解密公共方法 - :param encrypted_plaintext: 待解密的字符串 - :param secret_dict: 存放工作密钥的dict - :return: - """ - plaintext = Security._decrypt_plaintext(half_key1=secret_dict.get("half_key1"), - encrypted_work_key=secret_dict.get( - "encrypted_work_key"), - encrypted_work_key_iv=secret_dict.get( - "encrypted_work_key_iv"), - encrypted_iv=secret_dict.get( - "encrypted_iv"), - encrypted_plaintext=encrypted_plaintext) - return plaintext - - @staticmethod - def _get_root_key(half_key1: str) -> bytes: - half_key2 = config['HALF_KEY2'] - key = (half_key1 + half_key2).encode("utf-8") - half_key3 = config['HALF_KEY3'].encode("utf-8") - hash_key = hashlib.pbkdf2_hmac("sha256", key, half_key3, 10000) - return binascii.hexlify(hash_key)[13:45] - - @staticmethod - def _generate_encrypted_work_key(half_key1: str) -> tuple[str, str]: - bin_root_key = Security._get_root_key(half_key1) - bin_work_key = secrets.token_bytes(32) - bin_encrypted_work_key_iv = secrets.token_bytes(16) - bin_encrypted_work_key = Security._root_encrypt(bin_root_key, bin_encrypted_work_key_iv, bin_work_key) - encrypted_work_key = base64.b64encode(bin_encrypted_work_key).decode("ascii") - encrypted_work_key_iv = base64.b64encode(bin_encrypted_work_key_iv).decode("ascii") - return encrypted_work_key, encrypted_work_key_iv - - @staticmethod - def _get_work_key(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str) -> bytes: - bin_root_key = Security._get_root_key(half_key1) - bin_encrypted_work_key = base64.b64decode(encrypted_work_key.encode("ascii")) - bin_encrypted_work_key_iv = base64.b64decode(encrypted_work_key_iv.encode("ascii")) - return Security._root_decrypt(bin_root_key, bin_encrypted_work_key_iv, bin_encrypted_work_key) - - @staticmethod - def _root_encrypt(key: bytes, encrypted_iv: bytes, plaintext: bytes) -> bytes: - encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor() - encrypted = encryptor.update(plaintext) + encryptor.finalize() - return encrypted - - @staticmethod - def _root_decrypt(key: bytes, encrypted_iv: bytes, encrypted: bytes) -> bytes: - encryptor = Cipher(algorithms.AES(key), modes.GCM(encrypted_iv), default_backend()).encryptor() - plaintext = encryptor.update(encrypted) - return plaintext - - @staticmethod - def _encrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str, - plaintext: str) -> tuple[str, str]: - bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv) - salt = f"{half_key1}{plaintext}" - plaintext_temp = salt.encode("utf-8") - del plaintext - del salt - bin_encrypted_iv = secrets.token_bytes(16) - bin_encrypted_plaintext = Security._root_encrypt(bin_work_key, bin_encrypted_iv, plaintext_temp) - encrypted_plaintext = base64.b64encode(bin_encrypted_plaintext).decode("ascii") - encrypted_iv = base64.b64encode(bin_encrypted_iv).decode("ascii") - return encrypted_plaintext, encrypted_iv - - @staticmethod - def _decrypt_plaintext(half_key1: str, encrypted_work_key: str, encrypted_work_key_iv: str, - encrypted_plaintext: str, encrypted_iv) -> str: - bin_work_key = Security._get_work_key(half_key1, encrypted_work_key, encrypted_work_key_iv) - bin_encrypted_plaintext = base64.b64decode(encrypted_plaintext.encode("ascii")) - bin_encrypted_iv = base64.b64decode(encrypted_iv.encode("ascii")) - plaintext_temp = Security._root_decrypt(bin_work_key, bin_encrypted_iv, bin_encrypted_plaintext) - plaintext_salt = plaintext_temp.decode("utf-8") - plaintext = plaintext_salt[len(half_key1):] - return plaintext \ No newline at end of file diff --git a/data_chain/apps/base/session/session.py b/data_chain/apps/base/session/session.py deleted file mode 100644 index 10e096fda3acab9ba4f393769e7caaf8d91a5288..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/session/session.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from __future__ import annotations - -import uuid -import base64 -import hashlib -import hmac -from data_chain.logger.logger import logger as logging -import secrets - -from data_chain.config.config import config -from data_chain.stores.redis.redis import RedisConnectionPool - - - - -class SessionManager: - def __init__(self): - raise NotImplementedError("SessionManager不可以被实例化") - - @staticmethod - def create_session(user_id: uuid) -> str: - session_id = secrets.token_hex(16) - with RedisConnectionPool.get_redis_connection() as r: - try: - data = {"session_id": session_id} - r.hmset(str(user_id), data) - r.expire(str(user_id), config["SESSION_TTL"] * 60) - data = {"user_id": str(user_id)} - r.hmset(session_id, data) - r.expire(session_id, config["SESSION_TTL"] * 60) - except Exception as e: - logging.error(f"Session error: {e}") - return session_id - - @staticmethod - def delete_session(user_id: uuid) -> bool: - with RedisConnectionPool.get_redis_connection() as r: - try: - old_session_id=None - if r.hexists(str(user_id), "session_id"): - old_session_id=r.hget(str(user_id), "session_id") - r.hdel(str(user_id), "session_id") - if old_session_id and r.hexists(old_session_id, "user_id"): - r.hdel(old_session_id, "user_id") - if old_session_id and r.hexists(old_session_id, "nonce"): - r.hdel(old_session_id, "nonce") - except Exception as e: - logging.error(f"Delete session error: {e}") - return False - - @staticmethod - def verify_user(session_id: str) -> bool: - with RedisConnectionPool.get_redis_connection() as r: - try: - user_exist = r.hexists(session_id, "user_id") - r.expire(session_id, config["SESSION_TTL"] * 60) - return user_exist - except Exception as e: - logging.error(f"User not in session: {e}") - return False - - @staticmethod - def get_user_id(session_id: str) -> str | None: - - with RedisConnectionPool.get_redis_connection() as r: - try: - user_id = r.hget(session_id, "user_id") - r.expire(session_id, config["SESSION_TTL"] * 60) - except Exception as e: - logging.error(f"Get user from session error: {e}") - return None - - return uuid.UUID(user_id) - - @staticmethod - def create_csrf_token(session_id: str) -> str | None: - rand = secrets.token_hex(8) - - with RedisConnectionPool.get_redis_connection() as r: - try: - r.hset(session_id, "nonce", rand) - r.expire(session_id, config["SESSION_TTL"] * 60) - except Exception as e: - logging.error(f"Create csrf token from session error: {e}") - return None - - csrf_value = f"{session_id}{rand}" - csrf_b64 = base64.b64encode(bytes.fromhex(csrf_value)) - - hmac_processor = hmac.new(key=base64.b64decode(config["CSRF_KEY"]), msg=csrf_b64, digestmod=hashlib.sha256) - signature = base64.b64encode(hmac_processor.digest()) - - csrf_b64 = csrf_b64.decode("utf-8") - signature = signature.decode("utf-8") - return f"{csrf_b64}.{signature}" - - @staticmethod - def verify_csrf_token(session_id: str, token: str) -> bool: - if not token: - return False - - token_msg = token.split(".") - if len(token_msg) != 2: - return False - - first_part = base64.b64decode(token_msg[0]).hex() - current_session_id = first_part[:32] - logging.error(f"current_session_id: {current_session_id}, session_id: {session_id}") - if current_session_id != session_id: - return False - - current_nonce = first_part[32:] - with RedisConnectionPool.get_redis_connection() as r: - try: - nonce = r.hget(current_session_id, "nonce") - if nonce != current_nonce: - return False - r.expire(current_session_id, config["SESSION_TTL"] * 60) - except Exception as e: - logging.error(f"Get csrf token from session error: {e}") - - hmac_obj = hmac.new(key=base64.b64decode(config["CSRF_KEY"]), - msg=token_msg[0].encode("utf-8"), digestmod=hashlib.sha256) - signature = hmac_obj.digest() - current_signature = base64.b64decode(token_msg[1]) - - return hmac.compare_digest(signature, current_signature) diff --git a/data_chain/apps/base/task/document_task_handler.py b/data_chain/apps/base/task/document_task_handler.py deleted file mode 100644 index eedbdd6d70ed0653eb43bb6fdfd62d0e7a0b162b..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/task/document_task_handler.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import os -import secrets -import uuid -import shutil -from data_chain.logger.logger import logger as logging -from fastapi import File, UploadFile -import aiofiles -from typing import List -from data_chain.apps.base.task.task_handler import TaskRedisHandler -from data_chain.manager.document_manager import DocumentManager, TemporaryDocumentManager -from data_chain.manager.knowledge_manager import KnowledgeBaseManager -from data_chain.manager.task_manager import TaskManager, TaskStatusReportManager -from data_chain.models.constant import DocumentEmbeddingConstant, OssConstant, TaskConstant -from data_chain.parser.service.parser_service import ParserService -from data_chain.stores.minio.minio import MinIO -from data_chain.stores.postgres.postgres import DocumentEntity, TaskEntity, TaskStatusReportEntity -from data_chain.config.config import config - - -class DocumentTaskHandler(): - - @staticmethod - async def save_document_file_to_local(target_dir: str, file: UploadFile): - document_file_path = os.path.join(target_dir, file.filename) - async with aiofiles.open(document_file_path, "wb") as f: - content = await file.read() - await f.write(content) - return document_file_path - - @staticmethod - async def handle_parser_document_task(t_id: uuid.UUID): - target_dir = None - try: - task_entity = await TaskManager.select_by_id(t_id) - if task_entity.status != TaskConstant.TASK_STATUS_RUNNING: - TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], str(task_entity.id)) - return - # 文件解析主入口 - - # 下载文件 - document_entity = await DocumentManager.select_by_id(task_entity.op_id) - target_dir = os.path.join(OssConstant.PARSER_SAVE_FOLDER, str(document_entity.id), secrets.token_hex(16)) - if not os.path.exists(target_dir): - os.makedirs(target_dir) - file_extension = document_entity.extension.lower() - file_path = target_dir + '/' + str(document_entity.id) + file_extension - await MinIO.download_object(OssConstant.MINIO_BUCKET_DOCUMENT, str(document_entity.id), - file_path) - - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'document {document_entity.name} begin to parser', - current_stage=1, - stage_cnt=7 - )) - - parser = ParserService() - answer = await parser.parser(document_entity.id, file_path) - chunk_list, chunk_link_list, images = answer['chunk_list'], answer['chunk_link_list'], answer[ - 'image_chunks'] - new_chunk_list=[] - for chunk in chunk_list: - if len(chunk['text'].strip())!=0: - new_chunk_list.append(chunk) - chunk_list = new_chunk_list - chunk_id_set = set() - for chunk in chunk_list: - chunk_id_set.add(chunk['id']) - new_chunk_link_list = [] - for chunk_link in chunk_link_list: - if chunk_link['chunk_a'] in chunk_id_set: - new_chunk_link_list.append(chunk_link) - chunk_link_list = new_chunk_link_list - new_images = [] - for image in images: - if image['chunk_id'] in chunk_id_set: - new_images.append(image) - images = new_images - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Parse document {document_entity.name} completed, waiting for uploading', - current_stage=2, - stage_cnt=7 - )) - full_text = '' - for chunk in chunk_list: - full_text += chunk['text'] - await parser.upload_full_text_to_database(document_entity.id, full_text) - await parser.upload_chunks_to_database(chunk_list) - await parser.upload_chunk_links_to_database(chunk_link_list) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Upload document {document_entity.name} full text chunk and link completed', - current_stage=3, - stage_cnt=7 - )) - - await parser.upload_images_to_minio(images) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Upload document {document_entity.name} images to minio completed', - current_stage=4, - stage_cnt=7 - )) - await parser.upload_images_to_database(images) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Upload document {document_entity.name} images to pg completed', - current_stage=5, - stage_cnt=7 - )) - - vectors = await parser.embedding_chunks(chunk_list) - await parser.upload_vectors_to_database(vectors) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Upload document {document_entity.name} vectors completed', - current_stage=6, - stage_cnt=7 - )) - - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Parse document {task_entity.id} succcessfully', - current_stage=7, - stage_cnt=7 - )) - await DocumentManager.update(task_entity.op_id, {'status': DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_PENDING}) - await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_SUCCESS}) - TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) - except Exception as e: - import traceback - print(traceback.format_exc()) - TaskRedisHandler.put_task_by_tail(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_entity.id)) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Parse document {task_entity.id} failed due to {e}', - current_stage=7, - stage_cnt=7 - )) - finally: - if target_dir and os.path.exists(target_dir): - shutil.rmtree(target_dir) # 清理文件 - - @staticmethod - async def handle_parser_temporary_document_task(t_id: uuid.UUID): - target_dir = None - try: - task_entity = await TaskManager.select_by_id(t_id) - if task_entity.status != TaskConstant.TASK_STATUS_RUNNING: - TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], str(task_entity.id)) - return - # 文件解析主入口 - - document_entity = await TemporaryDocumentManager.select_by_id(task_entity.op_id) - target_dir = os.path.join(OssConstant.PARSER_SAVE_FOLDER, str(document_entity.id), secrets.token_hex(16)) - if not os.path.exists(target_dir): - os.makedirs(target_dir) - file_extension = document_entity.extension.lower() - file_path = target_dir + '/' + str(document_entity.id) + file_extension - await MinIO.download_object(document_entity.bucket_name, str(document_entity.id), - file_path) - - await TaskStatusReportManager.insert( - TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Temporary document {document_entity.name} begin to parse', - current_stage=1, - stage_cnt=6 - )) - - parser = ParserService() - parser_result = await parser.parser(document_entity.id, file_path, is_temporary_document=True) - chunk_list, chunk_link_list, images = parser_result['chunk_list'], parser_result['chunk_link_list'], parser_result[ - 'image_chunks'] - new_chunk_list=[] - for chunk in chunk_list: - if len(chunk['text'].strip())!=0: - new_chunk_list.append(chunk) - chunk_list = new_chunk_list - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Parse temporary document {document_entity.name} completed, waiting for uploading', - current_stage=2, - stage_cnt=6 - )) - full_text = '' - for chunk in chunk_list: - full_text += chunk['text'] - await parser.upload_full_text_to_database(document_entity.id, full_text, is_temporary_document=True) - await parser.upload_chunks_to_database(chunk_list, is_temporary_document=True) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Upload temporary document {document_entity.name} full text chunks completed', - current_stage=3, - stage_cnt=6 - )) - await parser.upload_images_to_minio(images, is_temporary_document=True) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Upload temporary document {document_entity.name} images completed', - current_stage=4, - stage_cnt=6 - )) - vectors = await parser.embedding_chunks(chunk_list, is_temporary_document=True) - await parser.upload_vectors_to_database(vectors, is_temporary_document=True) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Upload temporary document {document_entity.name} vectors completed', - current_stage=5, - stage_cnt=6 - )) - - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Parse temporary document {task_entity.id} succcessfully', - current_stage=6, - stage_cnt=6 - )) - await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_SUCCESS}) - TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) - except Exception as e: - TaskRedisHandler.put_task_by_tail(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_entity.id)) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Parse temporary document {task_entity.id} failed due to {e}', - current_stage=5, - stage_cnt=5 - )) - finally: - if target_dir and os.path.exists(target_dir): - shutil.rmtree(target_dir) # 清理文件 diff --git a/data_chain/apps/base/task/knowledge_base_task_handler.py b/data_chain/apps/base/task/knowledge_base_task_handler.py deleted file mode 100644 index 42bd77bf53640f6506fe15aece89e6a16ec6874b..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/task/knowledge_base_task_handler.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import os -import secrets -import shutil -import uuid -from data_chain.logger.logger import logger as logging - -from data_chain.models.constant import KnowledgeStatusEnum,ParseMethodEnum -from data_chain.apps.base.document.zip_handler import ZipHandler -from data_chain.apps.base.task.task_handler import TaskRedisHandler -from data_chain.manager.document_manager import DocumentManager -from data_chain.manager.document_type_manager import DocumentTypeManager -from data_chain.manager.knowledge_manager import KnowledgeBaseManager -from data_chain.manager.task_manager import TaskManager -from data_chain.manager.task_manager import TaskStatusReportManager -from data_chain.manager.user_manager import UserManager -from data_chain.models.constant import DocumentEmbeddingConstant, OssConstant, TaskConstant -from data_chain.stores.minio.minio import MinIO -from data_chain.stores.postgres.postgres import DocumentEntity, KnowledgeBaseEntity, TaskEntity, TaskStatusReportEntity -from data_chain.config.config import config -import yaml - - - - - -class KnowledgeBaseTaskHandler(): - @staticmethod - async def parse_knowledge_base_document_yaml_file( - knowledge_base_entity: KnowledgeBaseEntity, unzip_folder_path: str): - document_path = os.path.join(unzip_folder_path, "document") - document_yaml_path = os.path.join(unzip_folder_path, "document_yaml") - document_type_entity_list = await DocumentTypeManager.select_by_knowledge_base_id(knowledge_base_entity.id) - document_type_dict = {document_type_entity.type: document_type_entity.id - for document_type_entity in document_type_entity_list} - document_entity_list = [] - - file_path_list = [] - parse_methods=set(ParseMethodEnum.get_all_values()) - for root, _, files in os.walk(document_path): - for file in files: - file_path = os.path.join(root, file) - file_size = os.path.getsize(file_path) - yaml_file_path = os.path.join(document_yaml_path, file+'.yaml') - try: - with open(yaml_file_path, 'r')as yaml_file: - data = yaml.safe_load(yaml_file) - except Exception: - logging.error('文件缺少配置文件或者配置文件损坏') - continue - # 写入document表 - file_name = data.get('name', file) - if await DocumentManager.select_by_knowledge_base_id_and_file_name(knowledge_base_entity.id, file_name=file_name): - name = os.path.splitext(file_name)[0] - extension = os.path.splitext(file_name)[1] - if len(name)>=128: - name = name[:128] - file_name = name+'_'+secrets.token_hex(16)+extension - parser_method=data.get('parser_method', knowledge_base_entity.default_parser_method) - if parser_method not in parse_methods: - parser_method=ParseMethodEnum.GENERAL - document_entity = DocumentEntity( - kb_id=knowledge_base_entity.id,\ - user_id=knowledge_base_entity.user_id,\ - name=file_name,\ - extension=data.get('extension', ''),\ - size=file_size,\ - parser_method=parser_method,\ - type_id=document_type_dict.get( - data['type'], - uuid.UUID('00000000-0000-0000-0000-000000000000') - ),\ - chunk_size=data.get( - 'chunk_size', - knowledge_base_entity.default_chunk_size - ),\ - enabled=True,\ - status=DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING - ) - document_entity_list.append(document_entity) - file_path_list.append(file_path) - await DocumentManager.insert_bulk(document_entity_list) - for i in range(len(document_entity_list)): - await MinIO.put_object(OssConstant.MINIO_BUCKET_DOCUMENT, file_index=str(document_entity_list[i].id), - file_path=file_path_list[i]) - # 最后更新下knowledge_base的文档数量和文档大小 - total_cnt,total_sz=await DocumentManager.select_cnt_and_sz_by_kb_id(knowledge_base_entity.id) - update_dict = {'document_number': total_cnt, - 'document_size': total_sz} - await KnowledgeBaseManager.update(knowledge_base_entity.id, update_dict) - return document_entity_list - - @staticmethod - async def handle_import_knowledge_base_task(t_id: uuid.UUID): - unzip_folder_path=None - try: - task_entity=await TaskManager.select_by_id(t_id) - if task_entity.status != TaskConstant.TASK_STATUS_RUNNING: - TaskRedisHandler.put_task_by_tail( - config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], - str(task_entity.id) - ) - return - # 定义目标目录 - kb_id = task_entity.op_id - knowledge_base_entity = await KnowledgeBaseManager.select_by_id(kb_id) - user_entity = await UserManager.get_user_info_by_user_id(task_entity.user_id) - unzip_folder_path = os.path.join(OssConstant.IMPORT_FILE_SAVE_FOLDER, - str(user_entity.id), secrets.token_hex(16)) - # 创建目录 - if os.path.exists(unzip_folder_path): - shutil.rmtree(unzip_folder_path) - os.makedirs(unzip_folder_path) - zip_file_path = os.path.join(unzip_folder_path, str(kb_id)+'.zip') - # todo 下面两个子过程记录task report - if not await MinIO.download_object(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE, str(kb_id), zip_file_path): - await TaskStatusReportManager.insert( - TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Download knowledge base {kb_id} zip file failed', - current_stage=3, - stage_cnt=3 - ) - ) - await KnowledgeBaseManager.update(kb_id, {'status': KnowledgeStatusEnum.IDLE}) - await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_FAILED}) - TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) - return - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Download knowledge base {kb_id} zip file succcessfully', - current_stage=1, - stage_cnt=3 - )) - if not await ZipHandler.unzip_file(zip_file_path, unzip_folder_path): - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Unzip knowledge base {kb_id} zip file failed', - current_stage=3, - stage_cnt=3 - )) - await KnowledgeBaseManager.update(kb_id, {'status': KnowledgeStatusEnum.IDLE}) - await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_SUCCESS}) - TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) - return - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Unzip knowledge base {kb_id} zip file succcessfully', - current_stage=2, - stage_cnt=3 - )) - document_entity_list = await KnowledgeBaseTaskHandler.\ - parse_knowledge_base_document_yaml_file( - knowledge_base_entity,\ - unzip_folder_path - ) - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Save document and parse yaml from knowledge base {kb_id} zip file succcessfully', - current_stage=3, - stage_cnt=3 - )) - await KnowledgeBaseManager.update(kb_id, {'status': KnowledgeStatusEnum.IDLE}) - await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_SUCCESS}) - for document_entity in document_entity_list: - doc_task_entity = await TaskManager.insert(TaskEntity(user_id=user_entity.id, - op_id=document_entity.id, - type=TaskConstant.PARSE_DOCUMENT, - retry=0, - status=TaskConstant.TASK_STATUS_PENDING)) - # 提交redis任务队列 - TaskRedisHandler.put_task_by_tail(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(doc_task_entity.id)) - TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) - except Exception as e: - await TaskStatusReportManager.insert(TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Export knowledge base {kb_id} zip file failed due to {e}', - current_stage=0, - stage_cnt=4 - )) - TaskRedisHandler.put_task_by_tail(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_entity.id)) - logging.error("Import knowledge base error: {}".format(e)) - finally: - if unzip_folder_path and os.path.exists(unzip_folder_path): - shutil.rmtree(unzip_folder_path) - - @staticmethod - async def handle_export_knowledge_base_task(t_id: uuid.UUID): - knowledge_yaml_path=None - try: - task_entity=await TaskManager.select_by_id(t_id) - if task_entity.status != TaskConstant.TASK_STATUS_RUNNING: - TaskRedisHandler.put_task_by_tail(config['REDIS_SILENT_ERROR_TASK_QUEUE_NAME'], str(task_entity.id)) - return - knowledge_base_entity = await KnowledgeBaseManager.select_by_id(task_entity.op_id) - user_entity = await UserManager.get_user_info_by_user_id(knowledge_base_entity.user_id) - zip_file_path = os.path.join(OssConstant.ZIP_FILE_SAVE_FOLDER, str(user_entity.id)) - knowledge_yaml_path = os.path.join(zip_file_path, secrets.token_hex(16)) - document_path = os.path.join(knowledge_yaml_path, "document") - document_yaml_path = os.path.join(knowledge_yaml_path, "document_yaml") - if not os.path.exists(zip_file_path): - os.makedirs(zip_file_path) - if os.path.exists(knowledge_yaml_path): - shutil.rmtree(knowledge_yaml_path) - os.makedirs(knowledge_yaml_path) - os.makedirs(document_path) - os.makedirs(document_yaml_path) - document_type_entity_list = await DocumentTypeManager.select_by_knowledge_base_id(knowledge_base_entity.id) - # 写入knowledge_base.yaml文件 - with open(os.path.join(knowledge_yaml_path, "knowledge_base.yaml"), 'w') as knowledge_yaml_file: - - knowledge_dict = knowledge_dict = { - 'default_chuk_size': knowledge_base_entity.default_chunk_size, - 'default_parser_method': knowledge_base_entity.default_parser_method, - 'description': knowledge_base_entity.description, - 'document_number': knowledge_base_entity.document_number, - 'document_size': knowledge_base_entity.document_size, - 'embedding_model': knowledge_base_entity.embedding_model, - 'language': knowledge_base_entity.language, - 'name': knowledge_base_entity.name - } - knowledge_dict['document_type_list'] = list( - set([document_type_entity.type - for document_type_entity in document_type_entity_list] - ) - ) - yaml.dump(knowledge_dict, knowledge_yaml_file) - await TaskStatusReportManager.insert( - TaskStatusReportEntity( - task_id=task_entity.id,\ - message=f'Save knowledge base yaml from knowledge base {knowledge_base_entity.id} succcessfully',\ - current_stage=1,\ - stage_cnt=4\ - ) - ) - document_type_entity_map = { - str(document_type_entity.id): document_type_entity.type - for document_type_entity in document_type_entity_list} - document_entity_list = await DocumentManager.select_by_knowledge_base_id(knowledge_base_entity.id) - for document_entity in document_entity_list: - try: - if not await MinIO.download_object(OssConstant.MINIO_BUCKET_DOCUMENT, str( - document_entity.id), os.path.join(document_path, document_entity.name)): - continue - with open(os.path.join(document_yaml_path, document_entity.name+".yaml"), 'w') as yaml_file: - doc_type="default type" - if str(document_entity.type_id) in document_type_entity_map.keys(): - doc_type = document_type_entity_map[str(document_entity.type_id)] - yaml.dump( - { - 'name': document_entity.name,\ - 'extension': document_entity.extension,\ - 'parser_method': document_entity.parser_method,\ - 'chunk_size':document_entity.chunk_size,\ - 'type': doc_type}, - yaml_file) - except Exception as e: - logging.error(f"Download document {document_entity.id} failed due to: {e}") - continue - await TaskStatusReportManager.insert( - TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Save document and yaml from knowledge base {knowledge_base_entity.id} succcessfully', - current_stage=2, - stage_cnt=4 - )) - # 最后压缩export目录, 并放入minIO - save_knowledge_base_zip_file_name = str(task_entity.id)+".zip" - await ZipHandler.zip_dir(knowledge_yaml_path, os.path.join(zip_file_path, save_knowledge_base_zip_file_name)) - await TaskStatusReportManager.insert( - TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Zip knowledge base {knowledge_base_entity.id} succcessfully', - current_stage=3, - stage_cnt=4 - )) - # 上传到minIO exportzip桶 - res = await MinIO.put_object( - OssConstant.MINIO_BUCKET_KNOWLEDGEBASE,\ - str(task_entity.id),\ - os.path.join( - zip_file_path,\ - save_knowledge_base_zip_file_name\ - ) - ) - if res: - await TaskStatusReportManager.insert( - TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Update knowledge base {knowledge_base_entity.id} zip file {save_knowledge_base_zip_file_name} to Minio succcessfully', - current_stage=4, - stage_cnt=4 - ) - ) - else: - await TaskStatusReportManager.insert( - TaskStatusReportEntity( - task_id=task_entity.id, - message=f'Update knowledge base {knowledge_base_entity.id} zip file {save_knowledge_base_zip_file_name} to Minio failed', - current_stage=4, - stage_cnt=4 - ) - ) - raise Exception(f'Update knowledge base {knowledge_base_entity.id} zip file {save_knowledge_base_zip_file_name} to Minio failed') - await KnowledgeBaseManager.update(knowledge_base_entity.id, {'status': KnowledgeStatusEnum.IDLE}) - await TaskManager.update(task_entity.id, {'status': TaskConstant.TASK_STATUS_SUCCESS}) - TaskRedisHandler.put_task_by_tail(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_entity.id)) - except Exception as e: - await TaskStatusReportManager.insert( - TaskStatusReportEntity( - task_id=task_entity.id,\ - message=f'Import knowledge base {task_entity.op_id} failed due to {e}',\ - current_stage=0,\ - stage_cnt=4\ - ) - ) - TaskRedisHandler.put_task_by_tail( - config['REDIS_RESTART_TASK_QUEUE_NAME'],\ - str(task_entity.id) - ) - logging.error(f"Export knowledge base zip files errordue to {e}") - finally: - if knowledge_yaml_path and os.path.exists(knowledge_yaml_path): - shutil.rmtree(knowledge_yaml_path) diff --git a/data_chain/apps/base/task/message_queue.py b/data_chain/apps/base/task/message_queue.py new file mode 100644 index 0000000000000000000000000000000000000000..5e270618158cdbb9cbe3a37242ffeb3ee5478954 --- /dev/null +++ b/data_chain/apps/base/task/message_queue.py @@ -0,0 +1,33 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from data_chain.stores.mongodb import MongoDB + + +class TaskQueue: + """任务队列""" + @staticmethod + async def init_task_queue(): + pass + + @staticmethod + async def init_task(): + pass + + @staticmethod + async def handle_pending_tasks(): + pass + + @staticmethod + async def handle_running_tasks(): + pass + + @staticmethod + async def handle_successed_tasks(): + pass + + @staticmethod + async def handle_failed_tasks(): + pass + + @staticmethod + async def handle_tasks(): + pass diff --git a/data_chain/apps/base/task/process_handler.py b/data_chain/apps/base/task/process_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..0171a67ddf57a40f59cb810283de557b5ac5685b --- /dev/null +++ b/data_chain/apps/base/task/process_handler.py @@ -0,0 +1,67 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from data_chain.logger.logger import logger as logging +import os +import signal +import multiprocessing +import uuid +import asyncio +from data_chain.config.config import config + +multiprocessing = multiprocessing.get_context('spawn') + + +class ProcessHandler: + ''' 进程处理器类''' + tasks = {} # 存储进程的字典 + lock = multiprocessing.Lock() # 创建一个锁对象 + max_processes = min( + max((os.cpu_count() or 1) // 2, 1), + config['DOCUMENT_PARSE_USE_CPU_LIMIT']) # 获取CPU核心数作为最大进程数,默认为1 + + @staticmethod + def subprocess_target(target, *args, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(target(*args, **kwargs)) + finally: + loop.close() + + @staticmethod + def add_task(task_id: uuid.UUID, target, *args, **kwargs): + with ProcessHandler.lock: + if len(ProcessHandler.tasks) >= ProcessHandler.max_processes: + warning = f"任务数量已达上限({ProcessHandler.max_processes}),请稍后再试。" + logging.warning(f"[ProcessHandler] %s", warning) + return False + + if task_id not in ProcessHandler.tasks: + process = multiprocessing.Process(target=ProcessHandler.subprocess_target, + args=(target,) + args, kwargs=kwargs) + ProcessHandler.tasks[task_id] = process + process.start() + else: + info = f"任务ID {task_id} 已存在,无法添加。" + logging.info(f"[ProcessHandler] %s", info) + return True + + @staticmethod + def remove_task(task_id: uuid.UUID): + with ProcessHandler.lock: + if task_id in ProcessHandler.tasks.keys(): + process = ProcessHandler.tasks[task_id] + try: + if process.is_alive(): + pid = process.pid + os.kill(pid, signal.SIGKILL) + info = f"进程 {task_id} ({pid}) 被杀死。" + logging.info(f"[ProcessHandler] %s", info) + except Exception as e: + warning = f"杀死进程 {task_id} 失败: {e}" + logging.warning(f"[ProcessHandler] %s", warning) + del ProcessHandler.tasks[task_id] + info = f"任务ID {task_id} 被删除。" + logging.info(f"[ProcessHandler] %s", info) + else: + waring = f"任务ID {task_id} 不存在,无法删除。" + logging.warning(f"[ProcessHandler] %s", waring) diff --git a/data_chain/apps/base/task/task_handler.py b/data_chain/apps/base/task/task_handler.py deleted file mode 100644 index 4ea0c21596ef8ea7654b13f7119317bbc02b87cc..0000000000000000000000000000000000000000 --- a/data_chain/apps/base/task/task_handler.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from data_chain.logger.logger import logger as logging -from typing import List -import os -import signal -import multiprocessing -import uuid -import asyncio -import sys -from data_chain.config.config import config -from data_chain.stores.redis.redis import RedisConnectionPool -from data_chain.manager.task_manager import TaskManager -from data_chain.manager.document_manager import DocumentManager -from data_chain.manager.knowledge_manager import KnowledgeBaseManager -from data_chain.manager.chunk_manager import ChunkManager,TemporaryChunkManager -from data_chain.models.constant import TaskConstant, DocumentEmbeddingConstant, KnowledgeStatusEnum, OssConstant, TaskActionEnum -from data_chain.stores.minio.minio import MinIO - -multiprocessing = multiprocessing.get_context('spawn') -class TaskHandler: - tasks = {} # 存储进程的字典 - lock = multiprocessing.Lock() # 创建一个锁对象 - max_processes = min(max((os.cpu_count() or 1)//2, 1),config['DOCUMENT_PARSE_USE_CPU_LIMIT']) # 获取CPU核心数作为最大进程数,默认为1 - - @staticmethod - def subprocess_target(target, *args, **kwargs): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(target(*args, **kwargs)) - finally: - loop.close() - - @staticmethod - def add_task(task_id: uuid.UUID, target, *args, **kwargs): - with TaskHandler.lock: - if len(TaskHandler.tasks)>= TaskHandler.max_processes: - logging.info("Reached maximum number of active processes.") - return False - - if task_id not in TaskHandler.tasks: - process = multiprocessing.Process(target=TaskHandler.subprocess_target, - args=(target,) + args, kwargs=kwargs) - TaskHandler.tasks[task_id] = process - process.start() - else: - logging.info(f"Task ID {task_id} already exists.") - return True - - @staticmethod - def remove_task(task_id: uuid.UUID): - with TaskHandler.lock: - if task_id in TaskHandler.tasks.keys(): - process = TaskHandler.tasks[task_id] - try: - if process.is_alive(): - pid = process.pid - # TODO:优化杀死机制,考虑僵尸队列 - os.kill(pid, signal.SIGKILL) - logging.info(f"Process {task_id} ({pid}) killed.") - logging.info(f"Process {task_id} ({pid}) removed.") - except Exception as e: - logging.error(f"Process killed failed due to {e}") - del TaskHandler.tasks[task_id] - else: - logging.info(f"Task ID {task_id} does not exist.") - - @staticmethod - def get_task(task_id): - with TaskHandler.lock: - return TaskHandler.tasks.get(task_id, None) - - @staticmethod - def list_tasks(): - with TaskHandler.lock: - return list(TaskHandler.tasks.keys()) - - @staticmethod - def is_alive(task_id): - process = TaskHandler.get_task(task_id) - try: - alive = process.is_alive() - except Exception as e: - alive = False - logging.error(f"get process status failed due to {e}") - return alive - - @staticmethod - def check_and_adjust_active_count(): - with TaskHandler.lock: - TaskHandler.active_count = sum(process.is_alive() for process in TaskHandler.tasks.values()) - - @staticmethod - async def restart_or_clear_task(task_id: uuid.UUID, method=TaskActionEnum.RESTART): - TaskHandler.remove_task(task_id) - task_entity = await TaskManager.select_by_id(task_id) - if task_entity is None: - return - op_id = task_entity.op_id - task_type = task_entity.type - if task_entity.retry < 3 and method == TaskActionEnum.RESTART: - await TaskManager.update(task_id, {"retry": task_entity.retry+1}) - if task_type == TaskConstant.PARSE_DOCUMENT: - await DocumentManager.update(op_id, {'status': DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_RUNNING}) - await ChunkManager.delete_by_document_ids([op_id]) - elif task_type == TaskConstant.PARSE_TEMPORARY_DOCUMENT: - await TemporaryChunkManager.delete_by_temporary_document_ids([op_id]) - elif task_type == TaskConstant.IMPORT_KNOWLEDGE_BASE: - await KnowledgeBaseManager.delete(op_id) - elif task_type == TaskConstant.EXPORT_KNOWLEDGE_BASE: - await KnowledgeBaseManager.update(op_id, {'status': KnowledgeStatusEnum.EXPROTING}) - await TaskManager.update(task_id, {"status": TaskConstant.TASK_STATUS_PENDING}) - TaskRedisHandler.put_task_by_tail(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_id)) - elif method == TaskActionEnum.RESTART or method == TaskActionEnum.CANCEL or method == TaskActionEnum.DELETE: - TaskRedisHandler.remove_task_by_task_id(config['REDIS_PENDING_TASK_QUEUE_NAME'], str(task_id)) - TaskRedisHandler.remove_task_by_task_id(config['REDIS_SUCCESS_TASK_QUEUE_NAME'], str(task_id)) - TaskRedisHandler.remove_task_by_task_id(config['REDIS_RESTART_TASK_QUEUE_NAME'], str(task_id)) - if method == TaskActionEnum.CANCEL: - await TaskManager.update(task_id, {"status": TaskConstant.TASK_STATUS_CANCELED}) - elif method == TaskActionEnum.DELETE: - await TaskManager.update(task_id, {"status": TaskConstant.TASK_STATUS_DELETED}) - else: - await TaskManager.update(task_id, {"status": TaskConstant.TASK_STATUS_FAILED}) - if task_type == TaskConstant.PARSE_DOCUMENT: - await DocumentManager.update(op_id, {'status': DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_PENDING}) - await ChunkManager.delete_by_document_ids([op_id]) - elif task_type == TaskConstant.PARSE_TEMPORARY_DOCUMENT: - await TemporaryChunkManager.delete_by_temporary_document_ids([op_id]) - elif task_type == TaskConstant.IMPORT_KNOWLEDGE_BASE: - await KnowledgeBaseManager.delete(op_id) - await MinIO.delete_object(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE, str(task_entity.op_id)) - elif task_type == TaskConstant.EXPORT_KNOWLEDGE_BASE: - await KnowledgeBaseManager.update(op_id, {'status': KnowledgeStatusEnum.IDLE}) - await MinIO.delete_object(OssConstant.MINIO_BUCKET_KNOWLEDGEBASE, str(task_id)) - - -class TaskRedisHandler(): - - @staticmethod - def clear_all_task(queue_name: str) -> None: - with RedisConnectionPool.get_redis_connection() as r: - try: - r.delete(queue_name) - except Exception as e: - logging.error(f"Clear queue error: {e}") - - @staticmethod - def select_all_task(queue_name: str) -> List[str]: - with RedisConnectionPool.get_redis_connection() as r: - try: - return r.lrange(queue_name, 0, -1) - except Exception as e: - logging.error(f"Select task error: {e}") - return [] - - @staticmethod - def get_task_by_head(queue_name: str): - with RedisConnectionPool.get_redis_connection() as r: - try: - return r.lpop(queue_name) - except Exception as e: - logging.error(f"Get first task error: {e}") - return None - - @staticmethod - def put_task_by_tail(queue_name: str, task_id: str): - with RedisConnectionPool.get_redis_connection() as r: - try: - return r.rpush(queue_name, task_id) - except Exception as e: - logging.error(f"Remove task error: {e}") - - @staticmethod - def remove_task_by_task_id(queue_name: str, task_id: str): - with RedisConnectionPool.get_redis_connection() as r: - try: - return r.lrem(queue_name, 0, task_id) - except Exception as e: - logging.error(f"Remove task error: {e}") diff --git a/data_chain/apps/base/task/worker/base_worker.py b/data_chain/apps/base/task/worker/base_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..3eee410765b107708c542c9a6330d3aed8485094 --- /dev/null +++ b/data_chain/apps/base/task/worker/base_worker.py @@ -0,0 +1,106 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import uuid + +from data_chain.apps.base.task.process_handler import ProcessHandler +from data_chain.config.config import config +from data_chain.entities.enum import TaskStatus +from data_chain.stores.database.database import DataBase, TaskReportEntity +from data_chain.stores.mongodb.mongodb import MongoDB, Task +from data_chain.manager.task_manager import TaskManager +from data_chain.manager.task_report_manager import TaskReportManager +from data_chain.manager.task_queue_mamanger import TaskQueueManager +from data_chain.logger.logger import logger as logging + + +class BaseWorker: + """ + BaseWorker + """ + name = "BaseWorker" + + @staticmethod + def find_worker_class(worker_name): + subclasses = BaseWorker.__subclasses__() + for subclass in subclasses: + if subclass.name == worker_name: + return subclass + return None + + @staticmethod + async def get_worker_name(task_id: uuid.UUID) -> str: + '''获取worker_name''' + task = TaskManager.get_task_by_id(task_id) + if task is None: + err = f"[BaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + raise err + return task.type + + @staticmethod + async def init(worker_name: str, op_id: uuid.UUID) -> uuid.UUID: + '''初始化任务''' + task_id = await (BaseWorker.find_worker_class(worker_name).init(op_id)) + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.PENDING.value}) + + @staticmethod + async def reinit(task_id: uuid.UUID) -> bool: + '''重新初始化任务''' + worker_name = await BaseWorker.get_worker_name(task_id) + flag = await (BaseWorker.find_worker_class(worker_name).reinit(task_id)) + task_entity = await TaskManager.get_task_by_id(task_id) + if flag: + TaskManager.update_task_by_id(task_id, {"status": TaskStatus.PENDING.value, "retry": task_entity.retry + 1}) + return True + else: + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.FAILED.value}) + return False + + @staticmethod + async def deinit(task_id: uuid.UUID) -> uuid.UUID: + '''析构任务''' + worker_name = await BaseWorker.get_worker_name(task_id) + await (BaseWorker.find_worker_class(worker_name).deinit(task_id)) + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.SUCCESS.value}) + + @staticmethod + async def run(task_id: uuid.UUID) -> bool: + '''运行任务''' + worker_name = await BaseWorker.get_worker_name(task_id) + flag = ProcessHandler.add_task(BaseWorker.find_worker_class(worker_name).run, task_id) + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.RUNNING.value}) + return flag + + @staticmethod + async def stop(task_id: uuid.UUID) -> bool: + '''停止任务''' + worker_name = await BaseWorker.get_worker_name(task_id) + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity.status == TaskStatus.RUNNING.value: + await ProcessHandler.remove_task(task_id) + elif task_entity.status == TaskStatus.PENDING.value: + await TaskQueueManager.delete_task_by_id(task_id) + else: + return False + task_id = await (BaseWorker.find_worker_class(worker_name).stop(task_id)) + if task_entity.status == TaskStatus.PENDING.value or task_entity.status == TaskStatus.RUNNING.value: + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.CANCLED.value}) + return (task_id is not None) + + @staticmethod + async def delete(task_id: uuid.UUID) -> bool: + '''删除任务''' + worker_name = await BaseWorker.get_worker_name(task_id) + task_id = await (BaseWorker.find_worker_class(worker_name).delete(task_id)) + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.DELETED.value}) + return (task_id is not None) + + @staticmethod + async def report(task_id: uuid.UUID, report: str, current_stage: int, stage_cnt: int) -> bool: + '''报告任务''' + task_report_entity = TaskReportEntity( + task_id=task_id, + message=report, + current_stage=current_stage, + stage_cnt=stage_cnt + ) + await TaskReportManager.add_task_report(task_report_entity) diff --git a/data_chain/apps/base/task/worker/export_knowledge_base_worker.py b/data_chain/apps/base/task/worker/export_knowledge_base_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..62267f791209f98a85a0dc2d3d7e30158732ba39 --- /dev/null +++ b/data_chain/apps/base/task/worker/export_knowledge_base_worker.py @@ -0,0 +1,221 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import uuid +import os +import shutil +import yaml +from data_chain.apps.base.zip_handler import ZipHandler +from data_chain.config.config import config +from data_chain.logger.logger import logger as logging +from data_chain.apps.base.task.worker.base_worker import BaseWorker +from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus +from data_chain.entities.common import EXPORT_KB_PATH_IN_OS, DOC_PATH_IN_MINIO, EXPORT_KB_PATH_IN_MINIO +from data_chain.manager.task_manager import TaskManager +from data_chain.manager.knowledge_manager import KnowledgeBaseManager +from data_chain.manager.document_manager import DocumentManager +from data_chain.manager.task_queue_mamanger import TaskQueueManager +from data_chain.stores.database.database import TaskEntity, DocumentEntity +from data_chain.stores.minio.minio import MinIO +from data_chain.stores.mongodb.mongodb import Task + + +class ExportKnowledgeBaseWorker(BaseWorker): + """ + ExportKnowledgeBaseWorker + """ + name = TaskType.KB_EXPORT.value + + @staticmethod + async def init(kb_id: uuid.UUID) -> uuid.UUID: + '''初始化任务''' + knowledge_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + if knowledge_base_entity is None: + err = f"[ExportKnowledgeBaseWorker] 知识库不存在,知识库ID: {kb_id}" + logging.exception(err) + return None + if knowledge_base_entity.status != KnowledgeBaseStatus.IDLE.value: + warning = f"[ExportKnowledgeBaseWorker] 无法导出知识库,知识库ID: {kb_id},知识库状态: {knowledge_base_entity.status}" + logging.warning(warning) + return None + knowledge_base_entity = await KnowledgeBaseManager.update_knowledge_base_by_kb_id(kb_id, {"status": KnowledgeBaseStatus.PENDING.value}) + task_entity = TaskEntity( + team_id=knowledge_base_entity.team_id, + user_id=knowledge_base_entity.author_id, + op_id=knowledge_base_entity.id, + type=TaskType.KB_EXPORT.value, + retry=0, + status=TaskStatus.PENDING.value) + task_entity = await TaskManager.add_task(task_entity) + return task_entity.id + + @staticmethod + async def reinit(task_id: uuid.UUID) -> bool: + '''重新初始化任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + tmp_path = os.path.join(EXPORT_KB_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + if task_entity.retry < config['TASK_RETRY_TIME_LIMIT']: + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.PENDING.value}) + return True + else: + await MinIO.delete_object(EXPORT_KB_PATH_IN_MINIO, str(task_id)) + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.IDLE.value}) + return False + + @staticmethod + async def deinit(task_id: uuid.UUID) -> uuid.UUID: + '''析构任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.SUCCESS.value}) + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.IDLE.value}) + tmp_path = os.path.join(EXPORT_KB_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + return task_id + + @staticmethod + async def init_path(task_id: uuid.UUID) -> tuple: + '''初始化存放配置文件和文档的路径''' + tmp_path = os.path.join(EXPORT_KB_PATH_IN_OS, str(task_id)) + if not os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + os.mkdir(tmp_path) + source_path = os.path.join(tmp_path, "source") + target_path = os.path.join(tmp_path, f"{task_id}.zip") + os.mkdir(source_path) + os.mkdir(target_path) + doc_config_path = os.path.join(source_path, "doc_config") + doc_download_path = os.path.join(source_path, "doc_download") + os.mkdir(doc_config_path) + os.mkdir(doc_download_path) + return (source_path, target_path, doc_config_path, doc_download_path) + + @staticmethod + async def create_knowledge_base_yaml_config(source_path: str, kb_id: uuid.UUID) -> None: + '''创建知识库yaml文件''' + knowledge_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + kb_dict = { + "name": knowledge_base_entity.name, + "tokenizer": knowledge_base_entity.tokenizer, + "description": knowledge_base_entity.description, + "embedding_model": knowledge_base_entity.embedding_model, + "upload_count_limit": knowledge_base_entity.upload_count_limit, + "upload_size_limit": knowledge_base_entity.upload_size_limit, + "default_parse_method": knowledge_base_entity.default_parse_method, + "default_chunk_size": knowledge_base_entity.default_chunk_size, + "doc_types": [] + } + doc_type_entities = await KnowledgeBaseManager.list_doc_types_by_kb_id(kb_id) + for doc_type_entity in doc_type_entities: + kb_dict["doc_types"].append({"id": doc_type_entity.id, "name": doc_type_entity.name}) + yaml_path = os.path.join(source_path, "kb_config.yaml") + with open(yaml_path, "w", encoding="utf-8", errors='ignore') as f: + yaml.dump(kb_dict, f, allow_unicode=True) + + @staticmethod + async def create_document_yaml_config(doc_config_path: str, kb_id: uuid.UUID) -> None: + '''创建文档yaml文件''' + doc_entities = await DocumentManager.list_all_document_by_kb_id(kb_id) + for doc_entity in doc_entities: + doc_dict = { + "name": doc_entity.name, + "extension": doc_entity.extension, + "size": doc_entity.size, + "parse_method": doc_entity.parse_method, + "chunk_size": doc_entity.chunk_size, + "type_id": doc_entity.type_id, + "enabled": doc_entity.enabled, + } + yaml_path = os.path.join(doc_config_path, f"{doc_entity.id}.yaml") + with open(yaml_path, "w", encoding="utf-8", errors='ignore') as f: + yaml.dump(doc_dict, f, allow_unicode=True) + pass + + @staticmethod + async def download_document_from_minio(doc_config_path: str, kb_id: uuid.UUID) -> None: + '''从minio下载文档''' + doc_entities = await DocumentManager.list_all_document_by_kb_id(kb_id) + for doc_entity in doc_entities: + local_path = os.path.join(doc_config_path, f"{doc_entity.id}") + await MinIO.download_object(DOC_PATH_IN_MINIO, str(doc_entity.id), local_path) + + @staticmethod + async def zip_config_and_document(source_path: str, target_path: str) -> None: + '''压缩配置文件和文档''' + await ZipHandler.zip_dir(source_path, target_path) + + @staticmethod + async def upload_zip_to_minio(target_path: str, task_id: uuid.UUID) -> None: + '''上传压缩包到minio''' + await MinIO.put_object(EXPORT_KB_PATH_IN_MINIO, str(task_id), target_path) + + @staticmethod + async def run(task_id: uuid.UUID) -> None: + '''运行任务''' + try: + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.EXPORTING.value}) + current_stage = 0 + stage_cnt = 6 + source_path, target_path, doc_config_path, doc_download_path = await ExportKnowledgeBaseWorker.init_path(task_id) + current_stage += 1 + await ExportKnowledgeBaseWorker.report(task_id, "创建临时目录", current_stage, stage_cnt) + await ExportKnowledgeBaseWorker.create_knowledge_base_yaml_config(source_path, task_entity.op_id) + current_stage += 1 + await ExportKnowledgeBaseWorker.report(task_id, "创建知识库yaml配置文件", current_stage, stage_cnt) + await ExportKnowledgeBaseWorker.create_document_yaml_config(doc_config_path, task_entity.op_id) + current_stage += 1 + await ExportKnowledgeBaseWorker.report(task_id, "创建文档yaml配置文件", current_stage, stage_cnt) + await ExportKnowledgeBaseWorker.download_document_from_minio(doc_download_path, task_entity.op_id) + current_stage += 1 + await ExportKnowledgeBaseWorker.report(task_id, "下载文档", current_stage, stage_cnt) + await ExportKnowledgeBaseWorker.zip_config_and_document(source_path, target_path) + current_stage += 1 + await ExportKnowledgeBaseWorker.report(task_id, "压缩配置文件和文档", current_stage, stage_cnt) + await ExportKnowledgeBaseWorker.upload_zip_to_minio(target_path, task_id) + current_stage += 1 + await ExportKnowledgeBaseWorker.report(task_id, "上传压缩包到minio", current_stage, stage_cnt) + await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + except Exception as e: + err = f"[ExportKnowledgeBaseWorker] 运行任务失败,task_id: {task_id},错误信息: {e}" + logging.exception(err) + await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + await ExportKnowledgeBaseWorker.report(task_id, err, current_stage, stage_cnt) + + @staticmethod + async def stop(task_id: uuid.UUID) -> uuid.UUID: + '''停止任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.IDLE.value}) + tmp_path = os.path.join(EXPORT_KB_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + return task_id + + @staticmethod + async def delete(task_id) -> uuid.UUID: + '''删除任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + if task_entity.status == TaskStatus.CANCLED or TaskStatus.FAILED.value: + await MinIO.delete_object(EXPORT_KB_PATH_IN_MINIO, str(task_id)) + return task_id diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_chain/apps/base/task/worker/import_knowledge_base_worker.py b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3327acd0c80ec349c9cb582a22143864893ad8 --- /dev/null +++ b/data_chain/apps/base/task/worker/import_knowledge_base_worker.py @@ -0,0 +1,254 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import uuid +import os +import shutil +import yaml +from data_chain.apps.base.zip_handler import ZipHandler +from data_chain.config.config import config +from data_chain.logger.logger import logger as logging +from data_chain.apps.base.task.worker.base_worker import BaseWorker +from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus +from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, IMPORT_KB_PATH_IN_OS, DOC_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO +from data_chain.manager.task_manager import TaskManager +from data_chain.manager.knowledge_manager import KnowledgeBaseManager +from data_chain.manager.document_type_manager import DocumentTypeManager +from data_chain.manager.document_manager import DocumentManager +from data_chain.manager.task_queue_mamanger import TaskQueueManager +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity +from data_chain.stores.minio.minio import MinIO +from data_chain.stores.mongodb.mongodb import Task + + +class ImportKnowledgeBaseWorker(BaseWorker): + """ + ImportKnowledgeBaseWorker + """ + name = TaskType.KB_IMPORT.value + + @staticmethod + async def init(kb_id: uuid.UUID) -> uuid.UUID: + '''初始化任务''' + knowledge_base_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + if knowledge_base_entity is None: + err = f"[ImportKnowledgeBaseWorker] 知识库不存在,知识库ID: {kb_id}" + logging.exception(err) + return None + knowledge_base_entity = await KnowledgeBaseManager.update_knowledge_base_by_kb_id(kb_id, {"status": KnowledgeBaseStatus.PENDING.value}) + task_entity = TaskEntity( + team_id=knowledge_base_entity.team_id, + user_id=knowledge_base_entity.author_id, + op_id=knowledge_base_entity.id, + type=TaskType.KB_IMPORT.value, + retry=0, + status=TaskStatus.PENDING.value) + task_entity = await TaskManager.add_task(task_entity) + return task_entity.id + + @staticmethod + async def reinit(task_id: uuid.UUID) -> bool: + '''重新初始化任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ImportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + tmp_path = os.path.join(IMPORT_KB_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + if task_entity.retry < config['TASK_RETRY_TIME_LIMIT']: + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.PENDING.value}) + return True + else: + await MinIO.delete_object(IMPORT_KB_PATH_IN_OS, str(task_id)) + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.DELTED.value}) + return False + + @staticmethod + async def deinit(task_id: uuid.UUID) -> uuid.UUID: + '''析构任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ImportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await TaskManager.update_task_by_id(task_id, {"status": TaskStatus.SUCCESS.value}) + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.IDLE.value}) + tmp_path = os.path.join(IMPORT_KB_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + return task_id + + @staticmethod + async def init_path(task_id: uuid.UUID) -> tuple: + '''初始化存放配置文件和文档的路径''' + tmp_path = os.path.join(IMPORT_KB_PATH_IN_OS, str(task_id)) + if not os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + os.mkdir(tmp_path) + source_path = os.path.join(tmp_path, f"{task_id}.zip") + target_path = os.path.join(tmp_path, "source") + os.mkdir(source_path) + os.mkdir(target_path) + doc_config_path = os.path.join(target_path, "doc_config") + doc_download_path = os.path.join(target_path, "doc_download") + os.mkdir(doc_config_path) + os.mkdir(doc_download_path) + return (source_path, target_path, doc_config_path, doc_download_path) + + @staticmethod + async def download_zip_from_minio(source_path: str, kb_id: uuid.UUID) -> None: + '''从minio下载zip文件''' + await MinIO.download_object(IMPORT_KB_PATH_IN_MINIO, str(kb_id), source_path) + + @staticmethod + async def unzip_config_and_document(source_path: str, target_path: str) -> None: + '''解压zip文件''' + await ZipHandler.unzip_file(source_path, target_path) + + @staticmethod + async def add_doc_types_to_kb(kb_id: uuid.UUID, source_path: str) -> dict[uuid.UUID, uuid.UUID]: + '''添加文档类型到知识库''' + yaml_path = os.path.join(source_path, "kb_config.yaml") + with open(yaml_path, "r", encoding="utf-8") as f: + kb_config = yaml.load(f, Loader=yaml.SafeLoader) + doc_types_old_id_map_to_new_id = {} + doc_type_dicts = kb_config.get("doc_types", []) + for doc_type_dict in doc_type_dicts: + doc_type_entity = DocumentTypeEntity( + kb_id=kb_id, + name=doc_type_dict.get("name") + ) + doc_type_entity = await DocumentTypeManager.add_document_type(doc_type_entity) + if doc_type_entity: + doc_types_old_id_map_to_new_id[doc_type_dict['id']] = doc_type_entity.id + + @staticmethod + async def add_docs_to_kb(kb_id: uuid.UUID, doc_config_path: str, doc_download_path: str, + doc_types_old_id_map_to_new_id: dict[uuid.UUID, uuid.UUID]) -> dict[uuid.UUID, uuid.UUID]: + '''添加文档到知识库''' + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + doc_old_id_map_to_new_id = {} + doc_config_names = os.listdir(doc_config_path) + for doc_config_name in doc_config_names: + try: + doc_config_path = os.path.join(doc_config_path, doc_config_name) + doc_path = os.path.join(doc_download_path, doc_config_name) + if not os.path.exists(doc_path): + continue + with open(doc_config_path, "r", encoding="utf-8") as f: + doc_config = yaml.load(f, Loader=yaml.SafeLoader) + doc_type_id = doc_types_old_id_map_to_new_id.get(doc_config.get("type_id"), DEFAULt_DOC_TYPE_ID) + document_entity = DocumentEntity( + team_id=kb_entity.team_id, + kb_id=kb_entity.id, + author_id=kb_entity.author_id, + author_name=kb_entity.author_name, + name=doc_config.get("name", ''), + extension=doc_config.get("extension", ''), + size=doc_config.get("size", ''), + parse_method=doc_config.get("parse_method", kb_entity.default_parse_method), + chunk_size=doc_config.get("chunk_size", kb_entity.default_chunk_size), + type_id=doc_type_id, + enabled=doc_config.get("enabled", True), + status=DocumentStatus.IDLE.value, + ) + document_entity = await DocumentManager.add_document(document_entity) + if document_entity: + doc_old_id_map_to_new_id[doc_config.get('id', '')] = document_entity.id + except Exception as e: + err = f"[ImportKnowledgeBaseWorker] 添加文档失败,文档配置文件: {doc_config_path},错误信息: {e}" + logging.exception(err) + continue + await KnowledgeBaseManager.update_doc_cnt_and_doc_size(kb_id) + + @staticmethod + async def upload_document_to_minio( + doc_download_path: str, doc_old_id_map_to_new_id: dict[uuid.UUID, uuid.UUID]) -> None: + '''上传文档到minio''' + doc_names = os.listdir(doc_download_path) + for doc_name in doc_names: + try: + doc_path = os.path.join(doc_download_path, doc_name) + if not os.path.exists(doc_path): + continue + if doc_name in doc_old_id_map_to_new_id.keys(): + await MinIO.upload_object(DOC_PATH_IN_MINIO, doc_old_id_map_to_new_id.get(doc_name), doc_path) + except Exception as e: + err = f"[ImportKnowledgeBaseWorker] 上传文档失败,文档路径: {doc_path},错误信息: {e}" + logging.exception(err) + continue + + @staticmethod + async def init_doc_parse_tasks(kb_id: uuid.UUID) -> None: + '''初始化文档解析任务''' + document_entities = await DocumentManager.list_all_document_by_kb_id(kb_id) + for document_entity in document_entities: + await BaseWorker.init(TaskType.DOC_PARSE.value, document_entity.id) + + @staticmethod + async def run(task_id: uuid.UUID) -> None: + '''运行任务''' + try: + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ImportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.IMPORTING.value}) + current_stage = 0 + stage_cnt = 7 + source_path, target_path, doc_config_path, doc_download_path = await ImportKnowledgeBaseWorker.init_path(task_id) + current_stage += 1 + await ImportKnowledgeBaseWorker.report(task_id, "初始化路径", current_stage, stage_cnt) + kb_id = task_entity.op_id + await ImportKnowledgeBaseWorker.download_zip_from_minio(source_path, kb_id) + current_stage += 1 + await ImportKnowledgeBaseWorker.report(task_id, "下载zip文件", current_stage, stage_cnt) + await ImportKnowledgeBaseWorker.unzip_config_and_document(source_path, target_path) + current_stage += 1 + await ImportKnowledgeBaseWorker.report(task_id, "解压zip文件", current_stage, stage_cnt) + doc_types_old_id_map_to_new_id = await ImportKnowledgeBaseWorker.add_doc_types_to_kb(kb_id, doc_config_path) + current_stage += 1 + await ImportKnowledgeBaseWorker.report(task_id, "添加文档类型到知识库", current_stage, stage_cnt) + doc_old_id_map_to_new_id = await ImportKnowledgeBaseWorker.add_docs_to_kb(kb_id, doc_config_path, doc_download_path, doc_types_old_id_map_to_new_id) + current_stage += 1 + await ImportKnowledgeBaseWorker.report(task_id, "添加文档到知识库", current_stage, stage_cnt) + await ImportKnowledgeBaseWorker.upload_document_to_minio(doc_download_path, doc_old_id_map_to_new_id) + current_stage += 1 + await ImportKnowledgeBaseWorker.report(task_id, "上传文档到minio", current_stage, stage_cnt) + await ImportKnowledgeBaseWorker.init_doc_parse_tasks(kb_id) + current_stage += 1 + await ImportKnowledgeBaseWorker.report(task_id, "初始化文档解析任务", current_stage, stage_cnt) + await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.SUCCESS.value)) + except Exception as e: + err = f"[ImportKnowledgeBaseWorker] 任务失败,task_id: {task_id},错误信息: {e}" + logging.exception(err) + await TaskQueueManager.add_task(Task(_id=task_id, status=TaskStatus.FAILED.value)) + return None + + @staticmethod + async def stop(task_id: uuid.UUID) -> uuid.UUID: + '''停止任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.IDLE.value}) + tmp_path = os.path.join(IMPORT_KB_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + return task_id + + @staticmethod + async def delete(task_id) -> uuid.UUID: + '''删除任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + if task_entity.status == TaskStatus.CANCLED or TaskStatus.FAILED.value: + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.DELTED.value}) + await MinIO.delete_object(IMPORT_KB_PATH_IN_OS, str(task_entity.op_id)) + return task_id diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..98f41601a7df9f4b9294615eb3819ac45d3ffe9b --- /dev/null +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -0,0 +1,120 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +import uuid +import os +import shutil +import yaml +from data_chain.apps.base.zip_handler import ZipHandler +from data_chain.config.config import config +from data_chain.logger.logger import logger as logging +from data_chain.apps.base.task.worker.base_worker import BaseWorker +from data_chain.entities.enum import TaskType, TaskStatus, KnowledgeBaseStatus, DocumentStatus +from data_chain.entities.common import DEFAULt_DOC_TYPE_ID, IMPORT_KB_PATH_IN_OS, DOC_PATH_IN_MINIO, IMPORT_KB_PATH_IN_MINIO +from data_chain.manager.task_manager import TaskManager +from data_chain.manager.knowledge_manager import KnowledgeBaseManager +from data_chain.manager.document_type_manager import DocumentTypeManager +from data_chain.manager.document_manager import DocumentManager +from data_chain.manager.task_queue_mamanger import TaskQueueManager +from data_chain.stores.database.database import TaskEntity, DocumentEntity, DocumentTypeEntity +from data_chain.stores.minio.minio import MinIO +from data_chain.stores.mongodb.mongodb import Task + + +class ParseDocumentWorker(BaseWorker): + name = TaskType.DOC_PARSE + + @staticmethod + async def init(doc_id: uuid.UUID) -> uuid.UUID: + '''初始化任务''' + pass + + @staticmethod + async def reinit(task_id: uuid.UUID) -> bool: + '''重新初始化任务''' + pass + + @staticmethod + async def deinit(task_id: uuid.UUID) -> uuid.UUID: + '''析构任务''' + pass + + @staticmethod + async def init_path(task_id: uuid.UUID) -> tuple: + '''初始化存放配置文件和文档的路径''' + pass + + @staticmethod + async def download_doc_from_minio(doc_id: uuid.UUID) -> str: + '''下载文档''' + pass + + @staticmethod + async def parse_doc(download_path: str): + '''解析文档''' + pass + + @staticmethod + async def get_doc_abstrace(): + '''获取文档摘要''' + pass + + @staticmethod + async def ocr_from_parse_image(): + '''从解析图片中获取ocr''' + pass + + @staticmethod + async def upload_parse_image_to_minio(): + '''上传解析图片到minio''' + pass + + @staticmethod + async def push_down_words_feature(): + '''下推words特征''' + pass + + @staticmethod + async def handle_doc_parse_result(): + '''处理解析结果''' + pass + + @staticmethod + async def embedding_chunk(): + '''嵌入chunk''' + pass + + @staticmethod + async def add_parse_result_to_db(): + '''添加解析结果到数据库''' + pass + + @staticmethod + async def run(task_id: uuid.UUID) -> None: + '''运行任务''' + pass + + @staticmethod + async def stop(task_id: uuid.UUID) -> uuid.UUID: + '''停止任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.IDLE.value}) + tmp_path = os.path.join(IMPORT_KB_PATH_IN_OS, str(task_id)) + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + return task_id + + @staticmethod + async def delete(task_id) -> uuid.UUID: + '''删除任务''' + task_entity = await TaskManager.get_task_by_id(task_id) + if task_entity is None: + err = f"[ExportKnowledgeBaseWorker] 任务不存在,task_id: {task_id}" + logging.exception(err) + return None + if task_entity.status == TaskStatus.CANCLED or TaskStatus.FAILED.value: + await KnowledgeBaseManager.update_knowledge_base_by_kb_id(task_entity.op_id, {"status": KnowledgeBaseStatus.DELTED.value}) + await MinIO.delete_object(IMPORT_KB_PATH_IN_OS, str(task_entity.op_id)) + return task_id diff --git a/data_chain/apps/base/task/worker/testing_acc_worker.py b/data_chain/apps/base/task/worker/testing_acc_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_chain/apps/router/acc_testing.py b/data_chain/apps/router/acc_testing.py index f7a6c7cf59950fd8e8352297afd29bbdccae981a..a8dade50ffa0b1c50428f476d077d23836f270cd 100644 --- a/data_chain/apps/router/acc_testing.py +++ b/data_chain/apps/router/acc_testing.py @@ -16,22 +16,24 @@ from data_chain.entities.response_data import ( DeleteTestingResponse ) from data_chain.apps.service.session_service import get_user_sub, verify_user - +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/testing', tags=['Testing']) -@router.get('', response_model=ListTestingResponse, dependencies=[Depends(verify_user)]) +@router.post('/list', response_model=ListTestingResponse, dependencies=[Depends(verify_user)]) async def list_testing_by_dataset_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], req: Annotated[ListTestingRequest, Body()], ): return ListTestingResponse() -@router.get('/testcase', response_model=ListTestCaseResponse, - dependencies=[Depends(verify_user)]) +@router.post('/testcase', response_model=ListTestCaseResponse, + dependencies=[Depends(verify_user)]) async def list_test_case_by_testing_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], testing_id: Annotated[UUID, Query(alias="testingId")]): return ListTestCaseResponse() @@ -39,6 +41,7 @@ async def list_test_case_by_testing_id( @router.get('/download', dependencies=[Depends(verify_user)]) async def download_testing_report_by_testing_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], testing_id: Annotated[UUID, Query(alias="testingId")]): # try: # await _validate_doucument_belong_to_user(user_id, id) @@ -69,6 +72,7 @@ async def download_testing_report_by_testing_id( '', response_model=CreateTestingResponsing, dependencies=[Depends(verify_user)]) async def create_testing( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], req: Annotated[CreateTestingRequest, Body()]): return CreateTestingResponsing() @@ -77,6 +81,7 @@ async def create_testing( dependencies=[Depends(verify_user)]) async def run_testing_by_testing_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], testing_id: Annotated[UUID, Query(alias="testingId")], run: Annotated[bool, Query()]): return RunTestingResponse() @@ -86,6 +91,7 @@ async def run_testing_by_testing_id( dependencies=[Depends(verify_user)]) async def update_testing_by_testing_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], testing_id: Annotated[UUID, Query(alias="testingId")], req: Annotated[UpdateTestingRequest, Body(...)]): return UpdateTestingResponse() @@ -95,5 +101,6 @@ async def update_testing_by_testing_id( dependencies=[Depends(verify_user)]) async def delete_testing_by_testing_ids( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], testing_ids: Annotated[list[UUID], Query(alias="testingId")]): return DeleteTestingResponse() diff --git a/data_chain/apps/router/chunk.py b/data_chain/apps/router/chunk.py index 7cf85b47b7d96753863fdb002cd89b54171d6ef4..feb98c27043da526a0ef6ee38ab5b676e7a3c2ad 100644 --- a/data_chain/apps/router/chunk.py +++ b/data_chain/apps/router/chunk.py @@ -12,11 +12,11 @@ from data_chain.entities.response_data import ( ListChunkResponse, UpdateChunkResponse ) - +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/chunk', tags=['Chunk']) -@router.get('', response_model=ListChunkResponse, dependencies=[Depends(verify_user)]) +@router.post('/list', response_model=ListChunkResponse, dependencies=[Depends(verify_user)]) async def list_chunks_by_document_id( user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[ListChunkRequest, Body()], diff --git a/data_chain/apps/router/dataset.py b/data_chain/apps/router/dataset.py index 6ee24618422508a1a6df7f7c3d452ecc2286894b..19f1d88cf12b8d34755c230f068520a04b69b7a4 100644 --- a/data_chain/apps/router/dataset.py +++ b/data_chain/apps/router/dataset.py @@ -20,11 +20,11 @@ from data_chain.entities.response_data import ( DeleteDatasetResponse, ) from data_chain.apps.service.session_service import get_user_sub, verify_user - +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/dataset', tags=['Dataset']) -@router.get('', response_model=ListDatasetResponse, dependencies=[Depends(verify_user)]) +@router.post('/list', response_model=ListDatasetResponse, dependencies=[Depends(verify_user)]) async def list_datasets_by_kb_id( user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[ListDatasetRequest, Body()], @@ -32,7 +32,7 @@ async def list_datasets_by_kb_id( return ListDatasetResponse() -@router.get('/data', response_model=ListDataInDatasetResponse, dependencies=[Depends(verify_user)]) +@router.post('/data', response_model=ListDataInDatasetResponse, dependencies=[Depends(verify_user)]) async def list_data_in_dataset_by_dataset_id( user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[ListDataInDatasetRequest, Body()]): diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py index 25f3a5e49c75264cd02fed480066fd3d7cfbe350..619b28da55e2f210527969cf35fbbc8ad980c821 100644 --- a/data_chain/apps/router/document.py +++ b/data_chain/apps/router/document.py @@ -16,11 +16,11 @@ from data_chain.entities.response_data import ( DeleteDocumentResponse ) from data_chain.apps.service.session_service import get_user_sub, verify_user - +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/doc', tags=['Document']) -@router.get('', response_model=ListDocumentResponse, dependencies=[Depends(verify_user)]) +@router.post('/list', response_model=ListDocumentResponse, dependencies=[Depends(verify_user)]) async def list_doc( user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[ListDocumentRequest, Body()] diff --git a/data_chain/apps/router/health_check.py b/data_chain/apps/router/health_check.py index e020fcf1074856eef9980908df30c03e5a4d5f6f..f126594a16b7913f9bc52d18677987690e316cc4 100644 --- a/data_chain/apps/router/health_check.py +++ b/data_chain/apps/router/health_check.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Response, status router = APIRouter( prefix="/health_check", - tags=["health_check"] + tags=["Health check"] ) diff --git a/data_chain/apps/router/knowledge_base.py b/data_chain/apps/router/knowledge_base.py index 80f2d9e674f6520040edb6cf181b6a1cda3b61dd..701bc3ee11cce906929bec4781c5af99e46f08ea 100644 --- a/data_chain/apps/router/knowledge_base.py +++ b/data_chain/apps/router/knowledge_base.py @@ -1,8 +1,11 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from fastapi import APIRouter, Depends, Query, Body, File, UploadFile +from fastapi.responses import StreamingResponse, HTMLResponse, Response from typing import Annotated +import urllib from uuid import UUID +from httpx import AsyncClient from data_chain.entities.request_data import ( ListKnowledgeBaseRequest, CreateKnowledgeBaseRequest, @@ -10,6 +13,7 @@ from data_chain.entities.request_data import ( ) from data_chain.entities.response_data import ( + ListAllKnowledgeBaseMsg, ListAllKnowledgeBaseResponse, ListKnowledgeBaseResponse, ListDocumentTypesResponse, @@ -19,93 +23,134 @@ from data_chain.entities.response_data import ( UpdateKnowledgeBaseResponse, DeleteKnowledgeBaseResponse, ) +from data_chain.apps.service.team_service import TeamService +from data_chain.apps.service.knwoledge_base_service import KnowledgeBaseService +from data_chain.apps.service.task_service import TaskService from data_chain.apps.service.session_service import get_user_sub, verify_user - +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/kb', tags=['Knowledge Base']) -@router.get('/', response_model=ListAllKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +@router.get('', response_model=ListAllKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) async def list_kb_by_user_sub( - user_sub: Annotated[str, Depends(get_user_sub)] + user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)] ): - return ListAllKnowledgeBaseResponse() + list_all_kb_msg = await KnowledgeBaseService.list_kb_by_user_sub(user_sub) + return ListAllKnowledgeBaseResponse(result=list_all_kb_msg) -@router.get('/team', response_model=ListKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +@router.post('/team', response_model=ListKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) async def list_kb_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], req: Annotated[ListKnowledgeBaseRequest, Body()] ): - return ListKnowledgeBaseResponse() + if not await TeamService.validate_user_action_in_team(user_sub, req.team_id, action): + raise Exception("用户没有权限访问该团队的知识库") + list_kb_msg = await KnowledgeBaseService.list_kb_by_team_id(req) + return ListKnowledgeBaseResponse(result=list_kb_msg) @router.get('/doc_type', response_model=ListDocumentTypesResponse, dependencies=[Depends(verify_user)]) -async def list_doc_types_by_team_id( +async def list_doc_types_by_kb_id( user_sub: Annotated[str, Depends(get_user_sub)], - team_id: Annotated[UUID, Query(alias="teamId")], + action: Annotated[str, Depends(get_route_info)], + kb_id: Annotated[UUID, Query(alias="kbId")], ): - return ListDocumentTypesResponse() + if not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + raise Exception("用户没有权限访问该知识库的文档类型") + list_doc_types_msg = await KnowledgeBaseService.list_doc_types_by_kb_id(kb_id) + return ListDocumentTypesResponse(result=list_doc_types_msg) @router.get('/download', dependencies=[Depends(verify_user)]) async def download_kb_by_task_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], task_id: Annotated[UUID, Query(alias="taskId")]): - # try: - # await _validate_doucument_belong_to_user(user_id, id) - # document_link_url = await generate_document_download_link(id) - # document_name, extension = await get_file_name_and_extension(id) - # async with AsyncClient() as async_client: - # response = await async_client.get(document_link_url) - # if response.status_code == 200: - # content_disposition = f"attachment; filename={urllib.parse.quote(document_name.encode('utf-8'))}" - - # async def stream_generator(): - # async for chunk in response.aiter_bytes(chunk_size=8192): - # yield chunk - - # return StreamingResponse(stream_generator(), headers={ - # "Content-Disposition": content_disposition, - # "Content-Length": str(response.headers.get('content-length')) - # }, media_type="application/" + extension) - # else: - # return BaseResponse( - # retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, retmsg="Failed to retrieve the file.", data=None) - # except Exception as e: - # return BaseResponse(retcode=ErrorCode.DOWNLOAD_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) - pass + if not await TaskService.validate_user_action_to_task(user_sub, task_id, action): + raise Exception("用户没有权限访问该知识库的任务") + zip_download_url = await KnowledgeBaseService.generate_knowledge_base_download_link(task_id) + if not zip_download_url: + raise Exception("知识库下载连接生成失败") + async with AsyncClient() as async_client: + response = await async_client.get(zip_download_url) + if response.status_code == 200: + zip_name = f"{task_id}.zip" + content_disposition = f"attachment; filename={urllib.parse.quote(zip_name.encode('utf-8'))}" + content_length = response.headers.get('content-length') + + # 定义一个协程函数来生成数据流 + async def stream_generator(): + async for chunk in response.aiter_bytes(chunk_size=8192): + yield chunk + + return StreamingResponse( + stream_generator(), + headers={ + "Content-Disposition": content_disposition, + "Content-Length": str(content_length) if content_length else None + }, + media_type="application/zip" + ) + else: + raise Exception(f"下载知识库 zip 失败,状态码: {response.status_code}") @router.post('', response_model=CreateKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) async def create_kb(user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], + team_id: Annotated[UUID, Query(alias="teamId")], req: Annotated[CreateKnowledgeBaseRequest, Body()]): - return CreateKnowledgeBaseResponse() + if not await TeamService.validate_user_action_in_team(user_sub, team_id, action): + raise Exception("用户没有权限在该团队创建知识库") + kb_id = await KnowledgeBaseService.create_kb(user_sub, team_id, req) + return CreateKnowledgeBaseResponse(result=kb_id) @router.post('/import', response_model=ImportKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) -async def import_kb(user_sub: Annotated[str, Depends(get_user_sub)], - team_id: Annotated[UUID, Query(alias="teamId")], - kb_packages: list[UploadFile] = File(...)): - return ImportKnowledgeBaseResponse() +async def import_kbs(user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], + team_id: Annotated[UUID, Query(alias="teamId")], + kb_packages: list[UploadFile] = File(...)): + if not await TeamService.validate_user_action_in_team(user_sub, team_id, action): + raise Exception("用户没有权限在该团队导入知识库") + kb_import_task_ids = await KnowledgeBaseService.import_kb(user_sub, team_id, kb_packages) + return ImportKnowledgeBaseResponse(result=kb_import_task_ids) @router.post('/export', response_model=ExportKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) async def export_kb_by_kb_ids( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], kb_ids: Annotated[list[UUID], Query(alias="kbIds")]): - return ExportKnowledgeBaseResponse() + for kb_id in kb_ids: + if not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + raise Exception("用户没有权限在该知识库导出知识库") + kb_export_task_ids = KnowledgeBaseService.export_kb_by_kb_ids(kb_ids) + return ExportKnowledgeBaseResponse(result=kb_export_task_ids) @router.put('', response_model=UpdateKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) async def update_kb_by_kb_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], kb_id: Annotated[UUID, Query(alias="kbId")], req: Annotated[UpdateKnowledgeBaseRequest, Body()]): - return UpdateKnowledgeBaseResponse() + if not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + raise Exception("用户没有权限在该知识库更新知识库") + kb_id = await KnowledgeBaseService.update_kb_by_kb_id(kb_id, req) + return UpdateKnowledgeBaseResponse(result=kb_id) @router.delete('', response_model=DeleteKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) async def delete_kb_by_kb_ids( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], kb_ids: Annotated[list[UUID], Query(alias="kbIds")]): - return DeleteKnowledgeBaseResponse() + for kb_id in kb_ids: + if not await KnowledgeBaseService.validate_user_action_to_knowledge_base(user_sub, kb_id, action): + raise Exception("用户没有权限在该知识库删除知识库") + kb_ids_deleted = await KnowledgeBaseService.delete_kb_by_kb_ids(kb_ids) + return DeleteKnowledgeBaseResponse(result=kb_ids_deleted) diff --git a/data_chain/apps/router/model.py b/data_chain/apps/router/model.py deleted file mode 100644 index acec0ab05d1c04344039c887cbabdf58ccf1ddf6..0000000000000000000000000000000000000000 --- a/data_chain/apps/router/model.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from fastapi import Depends -from fastapi import APIRouter -from typing import List -from data_chain.models.service import ModelDTO -from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, verify_user -from data_chain.exceptions.err_code import ErrorCode -from data_chain.exceptions.exception import DocumentException -from data_chain.models.api import BaseResponse -from data_chain.models.api import UpdateModelRequest -from data_chain.apps.service.model_service import get_model_by_user_id, list_offline_model, update_model - - -router = APIRouter(prefix='/model', tags=['Model']) - - -@router.post('/update', response_model=BaseResponse[ModelDTO], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def update(req: UpdateModelRequest, user_id=Depends(get_user_id)): - try: - update_dict = dict(req) - update_dict['user_id']=user_id - model_dto = await update_model(user_id, update_dict) - model_dto.openai_api_key=None - return BaseResponse(data=model_dto) - except Exception as e: - return BaseResponse(retcode=ErrorCode.UPDATE_MODEL_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.get('/get', response_model=BaseResponse[ModelDTO], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def get(user_id=Depends(get_user_id)): - try: - model_dto = await get_model_by_user_id(user_id) - model_dto.openai_api_key = None - return BaseResponse(data=model_dto) - except Exception as e: - return BaseResponse(retcode=ErrorCode.UPDATE_MODEL_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.get('/list', response_model=BaseResponse[List[ModelDTO]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def list(): - try: - model_dto_list = await list_offline_model() - return BaseResponse(data=model_dto_list) - except Exception as e: - return BaseResponse(retcode=ErrorCode.UPDATE_MODEL_ERROR, retmsg=str(e.args[0]), data=None) diff --git a/data_chain/apps/router/other.py b/data_chain/apps/router/other.py index cb28731251a775b3559e0ff90630f00dbc7d0dd7..f7e2ffef93f8992684d998b7dd047d80fb474836 100644 --- a/data_chain/apps/router/other.py +++ b/data_chain/apps/router/other.py @@ -8,10 +8,10 @@ from data_chain.entities.response_data import ( ListLLMResponse, ListEmbeddingResponse, ListTokenizerResponse, - ListParserMethodResponse + ListParseMethodResponse ) from data_chain.apps.service.session_service import get_user_sub, verify_user - +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/other', tags=['Other']) @@ -32,6 +32,6 @@ async def list_tokenizers(): return ListTokenizerResponse() -@router.get('parser_method', response_model=ListParserMethodResponse, dependencies=[Depends(verify_user)]) -async def list_parser_method(): - return ListParserMethodResponse() +@router.get('/parse_method', response_model=ListParseMethodResponse, dependencies=[Depends(verify_user)]) +async def list_parse_method(): + return ListParseMethodResponse() diff --git a/data_chain/apps/router/role.py b/data_chain/apps/router/role.py index 1d60593281d4b781d2c98aa5c9ac9d806130ad20..5198b5f577ce2b4d394895331dd76cbd7c08d089 100644 --- a/data_chain/apps/router/role.py +++ b/data_chain/apps/router/role.py @@ -10,17 +10,25 @@ from data_chain.entities.request_data import ( ) from data_chain.entities.response_data import ( + ListActionResponse, ListRoleResponse, CreateRoleResponse, UpdateRoleResponse, DeleteRoleResponse ) from data_chain.apps.service.session_service import get_user_sub, verify_user - +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/role', tags=['Role']) -@router.get('', response_model=ListRoleResponse, dependencies=[Depends(verify_user)]) +@router.get('/action', response_model=ListActionResponse, dependencies=[Depends(verify_user)]) +async def list_actions( + user_sub: Annotated[str, Depends(get_user_sub)], +): + return ListActionResponse() + + +@router.post('/list', response_model=ListRoleResponse, dependencies=[Depends(verify_user)]) async def list_role_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[ListRoleRequest, Body()], diff --git a/data_chain/apps/router/task.py b/data_chain/apps/router/task.py index ae3260c5fb14d976e4f583d8055afbfa45997737..ea07552a75ed9924dd2a28690e9b8b7ce0183a0b 100644 --- a/data_chain/apps/router/task.py +++ b/data_chain/apps/router/task.py @@ -13,10 +13,11 @@ from data_chain.entities.response_data import ( DeleteTaskResponse ) from data_chain.entities.enum import TaskType +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/task', tags=['Task']) -@router.get('', response_model=ListTaskResponse, dependencies=[Depends(verify_user)]) +@router.post('', response_model=ListTaskResponse, dependencies=[Depends(verify_user)]) async def list_task( user_sub: Annotated[str, Depends(get_user_sub)], req: Annotated[ListTaskRequest, Body()] diff --git a/data_chain/apps/router/team.py b/data_chain/apps/router/team.py index 4f4a62fcee33237260b2ddaaff786ea30b8a8f72..a24de47cf3cd92468ea37472e3677c6f60955daa 100644 --- a/data_chain/apps/router/team.py +++ b/data_chain/apps/router/team.py @@ -12,6 +12,7 @@ from data_chain.entities.request_data import ( ) from data_chain.entities.response_data import ( + ListTeamMsg, ListTeamResponse, ListTeamMsgResponse, ListTeamUserResponse, @@ -25,41 +26,49 @@ from data_chain.entities.response_data import ( InviteTeamUserResponse ) from data_chain.apps.service.session_service import get_user_sub, verify_user - +from data_chain.apps.service.team_service import TeamService +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/team', tags=['Team']) -@router.get('', response_model=ListTeamResponse, dependencies=[Depends(verify_user)]) +@router.post('/list', response_model=ListTeamResponse, dependencies=[Depends(verify_user)]) async def list_teams( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], req: Annotated[ListTeamRequest, Body()] ): - return ListTeamResponse() + list_team_msg = await TeamService.list_teams(user_sub, req) + return ListTeamResponse(message='团队列表获取成功', result=list_team_msg) -@router.get('/usr', response_model=ListTeamUserResponse, dependencies=[Depends(verify_user)]) +@router.post('/usr', response_model=ListTeamUserResponse, dependencies=[Depends(verify_user)]) async def list_team_user_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], req: Annotated[ListTeamUserRequest, Body()]): return ListTeamUserResponse() -@router.get('/msg', response_model=ListTeamMsgResponse, dependencies=[Depends(verify_user)]) +@router.post('/msg', response_model=ListTeamMsgResponse, dependencies=[Depends(verify_user)]) async def list_team_msg_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], req: Annotated[ListTeamMsgRequest, Body()]): return ListTeamMsgResponse() @router.post('', response_model=CreateTeamResponse, dependencies=[Depends(verify_user)]) async def create_team(user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], req: Annotated[CreateTeamRequest, Body()]): - return CreateTeamResponse() + team_id = await TeamService.create_team(user_sub, req) + return CreateTeamResponse(message='团队创建成功', result=team_id) @router.post('/invitation', response_model=InviteTeamUserResponse, dependencies=[Depends(verify_user)]) async def invite_team_user_by_user_sub( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], team_id: Annotated[UUID, Query(alias="teamId")], user_sub_invite: Annotated[str, Query(alias="userSubInvite")]): return InviteTeamUserResponse() @@ -68,6 +77,7 @@ async def invite_team_user_by_user_sub( @router.post('/application', response_model=JoinTeamResponse, dependencies=[Depends(verify_user)]) async def join_team( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], team_id: Annotated[UUID, Query(alias="teamId")]): return JoinTeamResponse() @@ -75,14 +85,19 @@ async def join_team( @router.put('', response_model=UpdateTeamResponse, dependencies=[Depends(verify_user)]) async def update_team_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], team_id: Annotated[UUID, Query(alias="teamId")], req: Annotated[UpdateTeamRequest, Body()]): - return UpdateTeamResponse() + if not TeamService.validate_user_action_in_team(user_sub, team_id, action): + raise Exception('用户没有权限修改该团队') + team_id = await TeamService.update_team_by_team_id(user_sub, team_id, req) + return UpdateTeamResponse(message='团队更新成功', result=team_id) @router.put('/usr', response_model=UpdateTeamUserRoleResponse, dependencies=[Depends(verify_user)]) async def update_team_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], team_id: Annotated[UUID, Query(alias="teamId")], role_id: Annotated[UUID, Query(alias="roleId")]): return UpdateTeamUserRoleResponse() @@ -91,6 +106,7 @@ async def update_team_by_team_id( @router.put('/author', response_model=UpdateTeamAuthorResponse, dependencies=[Depends(verify_user)]) async def update_team_author_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], recriver_sub: Annotated[str, Query(alias="recriverSub")], team_id: Annotated[UUID, Query(alias="teamId")]): return UpdateTeamAuthorResponse() @@ -99,13 +115,18 @@ async def update_team_author_by_team_id( @router.delete('', response_model=DeleteTeamResponse, dependencies=[Depends(verify_user)]) async def delete_team_by_team_id( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], team_id: Annotated[UUID, Query(alias="teamId")]): - return DeleteTeamResponse() + if not TeamService.validate_user_action_in_team(user_sub, team_id, action): + raise Exception('用户没有权限删除该团队') + team_id = await TeamService.soft_delete_team_by_team_id(team_id) + return DeleteTeamResponse(message='团队删除成功', result=team_id) @router.delete('/usr', response_model=DeleteTeamUserResponse, dependencies=[Depends(verify_user)]) async def delete_team_user_by_team_id_and_user_subs( user_sub: Annotated[str, Depends(get_user_sub)], + action: Annotated[str, Depends(get_route_info)], team_id: Annotated[UUID, Query(alias="teamId")], user_subs: Annotated[list[str], Query(alias="userSub")]): return DeleteTeamUserResponse() diff --git a/data_chain/apps/router/user.py b/data_chain/apps/router/user.py index 74217d7166bbfda3512c806088e9a4fb906f0ab7..96cb6023b928fed87b9b503b4988544880dfcaf7 100644 --- a/data_chain/apps/router/user.py +++ b/data_chain/apps/router/user.py @@ -1,16 +1,16 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from __future__ import annotations -import uuid -from data_chain.logger.logger import logger as logging - -from fastapi import APIRouter, Depends, Request, Response, status -from data_chain.config.config import config -from data_chain.apps.service.user_service import verify_csrf_token, verify_passwd, get_user_id, verify_user -from data_chain.apps.base.session.session import SessionManager -from data_chain.manager.knowledge_manager import KnowledgeBaseManager -from data_chain.manager.user_manager import UserManager -from data_chain.models.api import BaseResponse, AddUserRequest, UpdateUserRequest +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from fastapi import APIRouter, Depends, Query, Body +from typing import Annotated +from uuid import UUID +from data_chain.entities.request_data import ( + ListUserRequest +) +from data_chain.entities.response_data import ( + ListUserResponse +) +from data_chain.apps.service.session_service import get_user_sub, verify_user +from data_chain.apps.service.router_service import get_route_info router = APIRouter( prefix="/user", @@ -18,151 +18,9 @@ router = APIRouter( ) -@router.post("/add", response_model=BaseResponse) -async def add_user(request: AddUserRequest): - name = request.name - email = request.email - account = request.account - passwd = request.passwd - user_entity = await UserManager.get_user_info_by_account(account) - if user_entity is not None: - return BaseResponse( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - retmsg="Sign failed due to duplicate account", - data={} - ) - if email is not None: - user_entity = await UserManager.get_user_info_by_email(email) - if user_entity is not None: - return BaseResponse( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - retmsg="Sign failed due to duplicate email", - data={} - ) - user_entity = await UserManager.add_user(name,email,account, passwd) - if user_entity is None: - return BaseResponse( - code=status.HTTP_422_UNPROCESSABLE_ENTITY, - retmsg="Sign failed due to add user failed", - data={} - ) - return BaseResponse( - code=status.HTTP_200_OK, - retmsg="Sign successful", - data={} - ) - - -@router.post("/del", response_model=BaseResponse, dependencies=[Depends(verify_user), Depends(verify_csrf_token)]) -async def del_user(request: Request, response: Response, user_id=Depends(get_user_id)): - session_id = request.cookies['WD_ECSESSION'] - if not SessionManager.verify_user(session_id): - logging.info("User already logged out.") - return BaseResponse(code=200, retmsg="ok", data={}) - - SessionManager.delete_session(user_id) - response.delete_cookie("WD_ECSESSION") - response.delete_cookie("wd_csrf_tk") - await UserManager.del_user_by_user_id(user_id) - response_data = BaseResponse( - code=status.HTTP_200_OK, - retmsg="Cancel successful", - data={} - ) - return response_data - - -@router.get("/login", response_model=BaseResponse, dependencies=[Depends(verify_passwd)]) -async def login(request: Request, response: Response, account: str): - user_info = await UserManager.get_user_info_by_account(account) - if user_info is None: - return BaseResponse( - code=status.HTTP_401_UNAUTHORIZED, - retmsg="Login failed", - data={} - ) - - user_id = user_info.id - try: - SessionManager.delete_session(user_id) - current_session = SessionManager.create_session(user_id) - except Exception as e: - logging.error(f"Change session failed: {e}") - return BaseResponse( - code=status.HTTP_401_UNAUTHORIZED, - retmsg="Login failed", - data={} - ) - - new_csrf_token = SessionManager.create_csrf_token(current_session) - if config['COOKIE_MODE'] == 'DEBUG': - response.set_cookie( - "wd_csrf_tk", - new_csrf_token - ) - response.set_cookie( - "WD_ECSESSION", - current_session - ) - else: - response.set_cookie( - "wd_csrf_tk", - new_csrf_token, - max_age=config["SESSION_TTL"] * 60, - secure=config['SSL_ENABLE'], - domain=config["DOMAIN"], - samesite="strict" - ) - response.set_cookie( - "WD_ECSESSION", - current_session, - max_age=config["SESSION_TTL"] * 60, - secure=config['SSL_ENABLE'], - domain=config["DOMAIN"], - httponly=True, - samesite="strict" - ) - response_data = BaseResponse( - code=status.HTTP_200_OK, - retmsg="Login successful", - data={ - 'name': user_info.name, - 'language': user_info.language - } - ) - return response_data - - -@router.get("/logout", response_model=BaseResponse, dependencies=[Depends(verify_csrf_token)]) -async def logout(request: Request, response: Response, user_id=Depends(get_user_id)): - session_id = request.cookies['WD_ECSESSION'] - if not SessionManager.verify_user(session_id): - logging.info("User already logged out.") - return BaseResponse(code=200, retmsg="ok", data={}) - - SessionManager.delete_session(user_id) - response.delete_cookie("WD_ECSESSION") - response.delete_cookie("wd_csrf_tk") - return { - "code": status.HTTP_200_OK, - "rtmsg": "Logout success", - "data": {} - } - - -@router.post("/update", response_model=BaseResponse, dependencies=[Depends(verify_user), Depends(verify_csrf_token)]) -async def switch(req: UpdateUserRequest, user_id=Depends(get_user_id)): - user_info = UserManager.get_user_info_by_user_id(user_id) - if user_info is None: - return BaseResponse( - code=status.HTTP_401_UNAUTHORIZED, - retmsg="User is not exist", - data={} - ) - tmp_dict = dict(req) - UserManager.update_user_by_user_id(user_id, tmp_dict) - return { - "code": status.HTTP_200_OK, - "rtmsg": "Update success", - "data": {} - } +@router.post("/list", response_model=ListUserResponse, dependencies=[Depends(verify_user)]) +async def list_users( + user_sub: Annotated[str, Query(default=None, alias="userSub")], + req: Annotated[ListUserRequest, Body()] +): + return ListUserResponse() diff --git a/data_chain/apps/router/usr_message.py b/data_chain/apps/router/usr_message.py index 9c76316de50d696f131cab5cf9475e96e6df906f..eecc594f5b36c0dcbee4a3d5700d161665b30bfb 100644 --- a/data_chain/apps/router/usr_message.py +++ b/data_chain/apps/router/usr_message.py @@ -10,11 +10,11 @@ from data_chain.entities.response_data import ( DeleteUserMessageResponse ) from data_chain.apps.service.session_service import get_user_sub, verify_user - +from data_chain.apps.service.router_service import get_route_info router = APIRouter(prefix='/usr_msg', tags=['User Message']) -@router.get('', response_model=ListUserMessageResponse, dependencies=[Depends(verify_user)]) +@router.post('/list', response_model=ListUserMessageResponse, dependencies=[Depends(verify_user)]) async def list_user_msgs_by_user_sub( user_sub: Annotated[str, Depends(get_user_sub)], msg_type: Annotated[UserMessageType, Query(alias="msgType")],