From f032c8e994dcd56103994b74894b3b0c4ffef694 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 6 May 2025 10:21:01 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E8=B5=84=E4=BA=A7=E5=BA=93?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E7=9A=84manager?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/manager/document_manager.py | 340 +++++--------------- data_chain/manager/document_type_manager.py | 133 +++----- data_chain/manager/image_manager.py | 117 ++----- data_chain/manager/knowledge_manager.py | 211 ++++++------ data_chain/manager/role_manager.py | 91 ++++++ data_chain/manager/session_manager.py | 37 +++ data_chain/manager/task_manager.py | 319 ++++-------------- data_chain/manager/task_queue_mamanger.py | 61 ++++ data_chain/manager/task_report_manager.py | 23 ++ data_chain/manager/team_manager.py | 152 +++++++++ data_chain/manager/user_manager.py | 106 +----- 11 files changed, 667 insertions(+), 923 deletions(-) create mode 100644 data_chain/manager/role_manager.py create mode 100644 data_chain/manager/session_manager.py create mode 100644 data_chain/manager/task_queue_mamanger.py create mode 100644 data_chain/manager/task_report_manager.py create mode 100644 data_chain/manager/team_manager.py diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 31bb0b2..b99e83f 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -1,297 +1,127 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. from sqlalchemy import select, delete, update, func, between, asc, desc, and_ from datetime import datetime, timezone import uuid -from data_chain.logger.logger import logger as logging from typing import Dict, List, Tuple -from data_chain.stores.postgres.postgres import PostgresDB, DocumentEntity, TaskEntity,TemporaryDocumentEntity -from data_chain.models.constant import DocumentEmbeddingConstant,TemporaryDocumentStatusEnum - - +from data_chain.stores.database.database import DataBase, KnowledgeBaseEntity, DocumentTypeEntity, DocumentEntity +from data_chain.entities.enum import KnowledgeBaseStatus, DocumentStatus +from data_chain.entities.request_data import ListDocumentRequest +from data_chain.logger.logger import logger as logging class DocumentManager(): + """文档管理类""" @staticmethod - async def insert(entity: DocumentEntity) -> DocumentEntity: - try: - async with await PostgresDB.get_session() as session: - session.add(entity) - await session.commit() - await session.refresh(entity) - return entity - except Exception as e: - logging.error(f"Failed to insert entity: {e}") - return None - - @staticmethod - async def insert_bulk(entity_list: List[DocumentEntity]) -> List[DocumentEntity]: + async def add_document(document_entity: DocumentEntity) -> DocumentEntity: + """添加文档""" try: - async with await PostgresDB.get_session() as session: - session.add_all(entity_list) + async with await DataBase.get_session() as session: + session.add(document_entity) await session.commit() - # 可以选择刷新所有实体,但这可能不是必要的,取决于实际需求 - for entity in entity_list: - await session.refresh(entity) - return entity_list - except Exception as e: - logging.error(f"Failed to insert bulk entities: {e}") - return [] - - @staticmethod - async def select_by_id(id: uuid.UUID) -> DocumentEntity: - try: - async with await PostgresDB.get_session() as session: - stmt = select(DocumentEntity).where(DocumentEntity.id == id) - result = await session.execute(stmt) - return result.scalars().first() - except Exception as e: - logging.error(f"Failed to select entity by ID: {e}") - return None - - @staticmethod - async def select_by_ids(ids: List[uuid.UUID]) -> List[DocumentEntity]: - try: - async with await PostgresDB.get_session() as session: - stmt = select(DocumentEntity).where(DocumentEntity.id.in_(ids)) - results = await session.execute(stmt) - return results.scalars().all() + return document_entity except Exception as e: - logging.error(f"Failed to select entity by ID: {e}") - return None + err = "添加文档失败" + logging.exception("[DocumentManager] %s", err) @staticmethod - async def select_by_knowledge_base_id(kb_id: uuid.UUID) -> List[DocumentEntity]: - try: - async with await PostgresDB.get_session() as session: # 确保 get_async_session 返回一个异步的session - stmt = select(DocumentEntity).where(DocumentEntity.kb_id == kb_id) - result = await session.execute(stmt) - return result.scalars().all() - except Exception as e: - logging.error(f"Failed to select by knowledge base id: {e}") - return [] + async def add_documents(document_entities: List[DocumentEntity]) -> List[DocumentEntity]: + """批量添加文档""" + pass @staticmethod - async def select_by_knowledge_base_id_and_file_name(kb_id: str, file_name: str): + async def list_all_document_by_kb_id(kb_id: uuid.UUID) -> List[DocumentEntity]: + """根据知识库ID获取文档列表""" try: - async with await PostgresDB.get_session() as session: + async with await DataBase.get_session() as session: stmt = select(DocumentEntity).where( and_(DocumentEntity.kb_id == kb_id, - DocumentEntity.name == file_name)) + DocumentEntity.status != DocumentStatus.DELETED.value)) result = await session.execute(stmt) - return result.scalars().first() + document_entities = result.scalars().all() + return document_entities except Exception as e: - logging.error(f"Failed to select by knowledge base id and file name: {e}") - return None - - @staticmethod - async def select_by_page(params: Dict, page_number: int, page_size: int) -> Tuple[int, List['DocumentEntity']]: - try: - async with await PostgresDB.get_session() as session: - # 子查询:找到每个文档最近的任务 - subq = ( - select( - TaskEntity.op_id, - TaskEntity.status, - func.row_number().over(partition_by=TaskEntity.op_id, order_by=desc(TaskEntity.created_time)).label('rn') - ).subquery() - ) - - # 主查询 - stmt = ( - select(DocumentEntity) - .outerjoin(subq, and_(DocumentEntity.id == subq.c.op_id, subq.c.rn == 1)) - ) - if 'kb_id' in params: - stmt = stmt.where(DocumentEntity.kb_id == params['kb_id']) - if 'id' in params: - stmt = stmt.where(DocumentEntity.id == params['id']) - if 'name' in params: - stmt = stmt.where(DocumentEntity.name.ilike(f"%{params['name']}%")) - if 'parser_method' in params: - stmt = stmt.where(DocumentEntity.parser_method.in_(params['parser_method'])) - if 'document_type_list' in params: - document_type_ids = params['document_type_list'] - stmt = stmt.where(DocumentEntity.type_id.in_(document_type_ids)) - if 'created_time_order' in params: - if params['created_time_order'] == 'desc': - stmt = stmt.order_by(desc(DocumentEntity.created_time)) - elif params['created_time_order'] == 'asc': - stmt = stmt.order_by(asc(DocumentEntity.created_time)) - stmt = stmt.order_by(desc(DocumentEntity.id)) - if 'created_time_start' in params and 'created_time_end' in params: - stmt = stmt.where(between(DocumentEntity.created_time, + err = "获取所有文档列表失败" + logging.exception("[DocumentManager] %s", err) + raise e + + @staticmethod + async def list_document(req: ListDocumentRequest) -> List[DocumentEntity]: + """ + 根据req的过滤条件获取文档列表 + kb_id: Optional[uuid.UUID] = Field(default=None, description="资产id", alias="kbId") + doc_id: Optional[uuid.UUID] = Field(default=None, description="文档id", min_length=1, max_length=30, alias="docId") + doc_name: Optional[str] = Field(default=None, description="文档名称", alias="docName") + doc_type_id: Optional[uuid.UUID] = Field(default=None, description="文档类型id", alias="docTypeId") + parse_status: Optional[TaskStatus] = Field(default=None, description="文档解析状态", alias="parseStatus") --要和task表做连结进行查询 + parse_method: Optional[ParseMethod] = Field(default=None, description="文档解析方法", alias="parseMethod") + author_name: Optional[str] = Field(default=None, description="文档创建者", alias="authorName") + created_time_start: Optional[str] = Field(default=None, description="文档创建时间开始", alias="createdTimeStart") + created_time_end: Optional[str] = Field(default=None, description="文档创建时间结束", alias="createdTimeEnd") + created_time_order: OrderType = Field(default=OrderType.DESC, description="文档创建时间排序", alias="createdTimeOrder") + page: int = Field(default=1, description="页码") + page_size: int = Field(default=40, description="每页数量", alias="pageSize") + + 时间范围查询可参考下面内容: + stmt.where(between(DocumentEntity.created_time, datetime.strptime(params['created_time_start'], '%Y-%m-%d %H:%M').replace(tzinfo=timezone.utc), datetime.strptime(params['created_time_end'], '%Y-%m-%d %H:%M').replace(tzinfo=timezone.utc))) - if 'status' in params: - stmt = stmt.where(subq.c.status.in_(params['status'])) - if 'enabled' in params: - stmt = stmt.where(DocumentEntity.enabled == params['enabled']) - # Execute the count part of the query separately - count_stmt = select(func.count()).select_from(stmt.subquery()) - total = (await session.execute(count_stmt)).scalar() - # Apply pagination - stmt = stmt.offset((page_number-1)*page_size).limit(page_size) - # Execute the main query - result = await session.execute(stmt) - document_list = result.scalars().all() + 默认不展示状态为deleted的文档 + """ + pass - return (total, document_list) - except Exception as e: - logging.error(f"Failed to select documents by page: {e}") - return (0, []) @staticmethod - async def select_cnt_and_sz_by_kb_id(kb_id: uuid.UUID) -> Tuple[int, int]: - try: - async with await PostgresDB.get_session() as session: - # 构造查询语句 - stmt = ( - select( - func.count(DocumentEntity.id).label('total_cnt'), - func.sum(DocumentEntity.size).label('total_sz') - ) - .where(DocumentEntity.kb_id == kb_id) - ) - - # 执行查询 - result = await session.execute(stmt) - - # 获取结果 - first_row = result.first() - - # 如果没有结果,返回 (0, 0) - if first_row is None: - return 0, 0 - - total_cnt, total_sz = first_row - return int(total_cnt) if total_cnt is not None else 0, int(total_sz) if total_sz is not None else 0 - except Exception as e: - logging.error(f"Failed to select count and size by knowledge base id: {e}") - return 0, 0 + async def get_document_by_doc_id(doc_id: uuid.UUID) -> DocumentEntity: + """根据文档ID获取文档""" + pass @staticmethod - async def update(id: str, update_dict: dict): + async def update_doc_type_by_kb_id( + kb_id: uuid.UUID, old_doc_type_ids: list[uuid.UUID], + new_doc_type_id: uuid.UUID) -> None: + """根据知识库ID更新文档类型""" try: - async with await PostgresDB.get_session() as session: - # 使用异步查询 - result = await session.execute( - select(DocumentEntity).where(DocumentEntity.id == id).with_for_update() - ) - document_entity = result.scalars().first() - if 'status' in update_dict.keys() and update_dict['status'] != DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_PENDING: - if document_entity.status != DocumentEmbeddingConstant.DOCUMENT_EMBEDDING_STATUS_PENDING: - return None - # 更新记录 - await session.execute( - update(DocumentEntity).where(DocumentEntity.id == id).values(**update_dict) - ) - + async with await DataBase.get_session() as session: + stmt = update(DocumentEntity).where( + and_(DocumentEntity.kb_id == kb_id, + DocumentEntity.status != DocumentStatus.DELETED.value, + DocumentEntity.type_id.in_(old_doc_type_ids)) + ).values(type_id=new_doc_type_id) + await session.execute(stmt) await session.commit() - return document_entity except Exception as e: - logging.error(f"Failed to update document: {e}") - return False + err = "更新文档类型失败" + logging.exception("[DocumentManager] %s", err) + raise e @staticmethod - async def delete_by_id(id: uuid.UUID) -> int: - async with await PostgresDB.get_session() as session: - stmt = delete(DocumentEntity).where(DocumentEntity.id == id) - result = await session.execute(stmt) - await session.commit() - return result.rowcount + async def update_document_by_doc_id(doc_id: uuid.UUID, doc_dict: Dict[str, str]) -> DocumentEntity: + """根据文档ID更新文档""" + pass @staticmethod - async def delete_by_ids(ids: List[uuid.UUID]) -> int: - async with await PostgresDB.get_session() as session: - stmt = delete(DocumentEntity).where(DocumentEntity.id.in_(ids)) - result = await session.execute(stmt) - await session.commit() - return result.rowcount + async def update_document_by_doc_ids(doc_ids: list[uuid.UUID], doc_dict: Dict[str, str]) -> None: + """根据文档ID批量更新文档""" + pass @staticmethod - async def delete_by_knowledge_base_id(kb_id: uuid.UUID) -> int: - async with await PostgresDB.get_session() as session: - stmt = delete(DocumentEntity).where(DocumentEntity.kb_id == kb_id) - result = await session.execute(stmt) - await session.commit() - return result.rowcount + async def delte_document_by_doc_id(doc_id: uuid.UUID) -> None: + """根据文档ID删除文档""" + pass -class TemporaryDocumentManager(): - @staticmethod - async def insert(entity: TemporaryDocumentEntity) -> TemporaryDocumentEntity: - try: - async with await PostgresDB.get_session() as session: - session.add(entity) - await session.commit() - await session.refresh(entity) - return entity - except Exception as e: - logging.error(f"Failed to insert temporary document: {e}") - return None - @staticmethod - async def delete_by_ids(ids: List[uuid.UUID]) -> int: - try: - async with await PostgresDB.get_session() as session: - stmt = delete(TemporaryDocumentEntity).where(TemporaryDocumentEntity.id.in_(ids)) - result = await session.execute(stmt) - await session.commit() - return result.rowcount - except Exception as e: - logging.error(f"Failed to delete temporary documents: {e}") - return 0 @staticmethod - async def select_by_ids(ids: List[uuid.UUID]) -> List[TemporaryDocumentEntity]: - try: - async with await PostgresDB.get_session() as session: - stmt = select(TemporaryDocumentEntity).where( - and_( - TemporaryDocumentEntity.id.in_(ids), - TemporaryDocumentEntity.status !=TemporaryDocumentStatusEnum.DELETED - ) - ) - result = await session.execute(stmt) - tmp_list = result.scalars().all() - return tmp_list - except Exception as e: - logging.error(f"Failed to select temporary documents by page: {e}") - return [] - @staticmethod - async def select_by_id(id: uuid.UUID) -> TemporaryDocumentEntity: - try: - async with await PostgresDB.get_session() as session: - stmt = select(TemporaryDocumentEntity).where(TemporaryDocumentEntity.id==id) - result = await session.execute(stmt) - return result.scalars().first() - except Exception as e: - logging.error(f"Failed to select temporary document by id: {e}") - return None + async def delete_document_by_kb_id(kb_id: uuid.UUID) -> None: + """根据知识库ID删除文档""" + pass + @staticmethod - async def update(id: uuid.UUID, update_dict: dict): - try: - async with await PostgresDB.get_session() as session: - await session.execute( - update(TemporaryDocumentEntity).where(TemporaryDocumentEntity.id == id).values(**update_dict) - ) - await session.commit() - entity=await session.execute(select(TemporaryDocumentEntity).where( - TemporaryDocumentEntity.id==id, - ) - ) - return entity - except Exception as e: - logging.error(f"Failed to update temporary document by id: {e}") - return None + async def delete_document_by_doc_id(doc_id: uuid.UUID) -> None: + """根据文档ID删除文档""" + pass + @staticmethod - async def update_all(ids: List[uuid.UUID], update_dict: dict): - try: - async with await PostgresDB.get_session() as session: - await session.execute( - update(TemporaryDocumentEntity).where(TemporaryDocumentEntity.id.in_(ids)).values(**update_dict) - ) - await session.commit() - return True - except Exception as e: - logging.error(f"Failed to update temporary document by ids: {e}") - return False \ No newline at end of file + async def delete_document_by_doc_ids(doc_ids: list[uuid.UUID]) -> None: + """根据文档ID批量删除文档""" + pass diff --git a/data_chain/manager/document_type_manager.py b/data_chain/manager/document_type_manager.py index 03e988a..475875a 100644 --- a/data_chain/manager/document_type_manager.py +++ b/data_chain/manager/document_type_manager.py @@ -1,114 +1,53 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from sqlalchemy import select, delete, update -from typing import List +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from sqlalchemy import select, delete, update, func, between, asc, desc, and_ +from datetime import datetime, timezone import uuid from data_chain.logger.logger import logger as logging +from typing import Dict, List, Tuple -from data_chain.models.service import DocumentTypeDTO -from data_chain.stores.postgres.postgres import PostgresDB, DocumentEntity, DocumentTypeEntity - - +from data_chain.stores.database.database import DataBase, KnowledgeBaseEntity, DocumentTypeEntity, DocumentEntity +from data_chain.entities.enum import KnowledgeBaseStatus, DocumentStatus class DocumentTypeManager(): @staticmethod - async def select_by_id(id: str) -> DocumentTypeEntity: - async with await PostgresDB.get_session()as session: - stmt = select(DocumentTypeEntity).where(DocumentTypeEntity.id == id) - result = await session.execute(stmt) - return result.scalars().first() - @staticmethod - async def select_by_ids(ids: List[str]) -> List[DocumentTypeEntity]: - async with await PostgresDB.get_session()as session: - stmt = select(DocumentTypeEntity).where(DocumentTypeEntity.id.in_(ids)) - result = await session.execute(stmt) - return result.scalars().all() - @staticmethod - async def select_by_knowledge_base_id(kb_id: str) -> List[DocumentTypeEntity]: - async with await PostgresDB.get_session()as session: - stmt = select(DocumentTypeEntity).where(DocumentTypeEntity.kb_id == kb_id) - result = await session.execute(stmt) - return result.scalars().all() - - @staticmethod - async def insert_bulk(kb_id: uuid.UUID, types: List[str]) -> List[DocumentTypeEntity]: - if types is None or len(types) == 0: - return [] - async with await PostgresDB.get_session()as session: - document_type_entity_list = [ - DocumentTypeEntity(kb_id=kb_id, type=type) for type in types - ] - session.add_all(document_type_entity_list) - await session.commit() - # Refresh the entities after committing so that they have their primary keys filled in. - for entity in document_type_entity_list: - await session.refresh(entity) - return document_type_entity_list + async def add_document_type(document_type_entity: DocumentTypeEntity) -> DocumentTypeEntity: + """添加文档类型""" + try: + async with await DataBase.get_session() as session: + session.add(document_type_entity) + await session.commit() + return document_type_entity + except Exception as e: + err = "添加文档类型失败" + logging.exception("[DocumentTypeManager] %s", err) @staticmethod - async def update_knowledge_base_document_type(kb_id: str, types: List['DocumentTypeDTO'] = None): + async def add_document_types( + document_type_entities: List[DocumentTypeEntity]) -> List[DocumentTypeEntity]: + """批量添加文档类型""" try: - async with await PostgresDB.get_session()as session: - if types is not None: - new_document_type_map = {str(_type.id): _type.type for _type in types} - new_document_type_ids = {_type.id for _type in types} - old_document_type_ids = set((await session.execute( - select(DocumentTypeEntity.id).filter(DocumentTypeEntity.kb_id == kb_id))).scalars().all()) - delete_document_type_ids = old_document_type_ids - new_document_type_ids - add_document_type_ids = new_document_type_ids - old_document_type_ids - update_document_type_ids = old_document_type_ids & new_document_type_ids - - # 删掉document_type, 然后document对应的type_id修改为默认值 - if len(delete_document_type_ids) > 0: - default_document_type_id = uuid.UUID('00000000-0000-0000-0000-000000000000') - await session.execute( - delete(DocumentTypeEntity).where(DocumentTypeEntity.id.in_(delete_document_type_ids))) - await session.execute( - update(DocumentEntity).where(DocumentEntity.type_id.in_(delete_document_type_ids)).values(type_id=default_document_type_id)) - await session.commit() - - # 插入document_type - if len(add_document_type_ids) > 0: - add_document_type_entity_list = [ - DocumentTypeEntity( - id=add_document_type_id, kb_id=kb_id, type=new_document_type_map[str(add_document_type_id)]) - for add_document_type_id in add_document_type_ids - ] - session.add_all(add_document_type_entity_list) - await session.commit() - - # 修改document_type - if len(update_document_type_ids) > 0: - old_document_type_entity_list=( - await session.execute( - select(DocumentTypeEntity).filter(DocumentTypeEntity.id.in_(update_document_type_ids))) - ).scalars().all() - for old_document_type_entity in old_document_type_entity_list: - new_type = new_document_type_map.get(str(old_document_type_entity.id),None) - if old_document_type_entity.type != new_type: - await session.execute( - update(DocumentTypeEntity).where(DocumentTypeEntity.id == old_document_type_entity.id).values(type=new_type)) - await session.commit() - - results = await session.execute(select(DocumentTypeEntity).filter(DocumentTypeEntity.kb_id == kb_id)) - return results.scalars().all() + async with await DataBase.get_session() as session: + session.add_all(document_type_entities) + await session.commit() + return document_type_entities except Exception as e: - logging.error(f"Update document type faile by knowledge base id failed due to: {e}") - return [] + err = "批量添加文档类型失败" + logging.exception("[DocumentTypeManager] %s", err) @staticmethod - async def delete_by_knowledge_base_id(kb_id: str) -> int: + async def update_doc_type_by_doc_type_id( + doc_type_id: uuid.UUID, doc_type_name: str) -> None: + """根据文档类型ID更新文档类型名称""" try: - async with await PostgresDB.get_session() as session: - # 构建删除语句 - stmt = delete(DocumentTypeEntity).where(DocumentTypeEntity.kb_id == kb_id) - # 执行删除操作 - result = await session.execute(stmt) - # 提交事务 + async with await DataBase.get_session() as session: + stmt = update(DocumentTypeEntity).where( + DocumentTypeEntity.id == doc_type_id + ).values(name=doc_type_name) + await session.execute(stmt) await session.commit() - # 返回删除的数量 - return result.rowcount except Exception as e: - logging.error(f"Failed to delete by knowledge base id: {e}") - return 0 + err = "更新文档类型名称失败" + logging.exception("[DocumentTypeManager] %s", err) + raise e diff --git a/data_chain/manager/image_manager.py b/data_chain/manager/image_manager.py index ed091aa..023fdc9 100644 --- a/data_chain/manager/image_manager.py +++ b/data_chain/manager/image_manager.py @@ -1,116 +1,41 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from sqlalchemy import select -from typing import List +from sqlalchemy import select, update +from typing import List, Dict import uuid from data_chain.logger.logger import logger as logging -from data_chain.stores.postgres.postgres import PostgresDB, ImageEntity - - +from data_chain.stores.database.database import DataBase, ImageEntity class ImageManager: - - @staticmethod - async def add_image(image_entity: ImageEntity): - try: - async with await PostgresDB.get_session() as session: - session.add(image_entity) - await session.commit() - await session.refresh(image_entity) - except Exception as e: - logging.error(f"Add image failed due to error: {e}") - return None - return image_entity + """图片管理类""" @staticmethod - async def add_images(image_entity_list: List[ImageEntity]): + async def add_images(image_entity_list: List[ImageEntity]) -> List[ImageEntity]: try: - async with await PostgresDB.get_session() as session: + async with await DataBase.get_session() as session: session.add_all(image_entity_list) await session.commit() for image_entity in image_entity_list: await session.refresh(image_entity) except Exception as e: - logging.error(f"Add image failed due to error: {e}") - return None - return image_entity_list - @staticmethod - async def del_image_by_image_id(image_id: uuid.UUID): - try: - async with await PostgresDB.get_session() as session: - # 使用执行SQL语句的方式获取并删除用户 - stmt = select(ImageEntity).where(ImageEntity.id == image_id) - result = await session.execute(stmt) - image_to_delete = result.scalars().first() - - if image_to_delete is not None: - await session.delete(image_to_delete) - await session.commit() - except Exception as e: - logging.error(f"Delete image failed due to error: {e}") - return False - return True - - @staticmethod - async def query_image_by_image_id(image_id: uuid.UUID): - try: - async with await PostgresDB.get_session() as session: - # 使用执行SQL语句的方式获取并删除用户 - stmt = select(ImageEntity).where(ImageEntity.id == image_id) - result = await session.execute(stmt) - image_entity = result.scalars().first() - except Exception as e: - logging.error(f"Query image by image id failed due to error: {e}") - return None - return image_entity - - @staticmethod - async def query_image_by_doc_id(doc_id: uuid.UUID): - try: - async with await PostgresDB.get_session() as session: - # 使用执行SQL语句的方式获取并删除用户 - stmt = select(ImageEntity).where(ImageEntity.document_id == doc_id) - result = await session.execute(stmt) - image_entity_list = result.scalars().all() - except Exception as e: - logging.error(f"Query image by doc id failed due to error: {e}") - return [] - return image_entity_list - - @staticmethod - async def query_image_by_user_id(user_id: uuid.UUID): - try: - async with await PostgresDB.get_session() as session: - # 使用执行SQL语句的方式获取并删除用户 - stmt = select(ImageEntity).where(ImageEntity.user_id == user_id) - result = await session.execute(stmt) - image_entity_list = result.scalars().all() - except Exception as e: - logging.error(f"Query image by user id failed due to error: {e}") - return [] + err = "添加图片失败" + logging.exception("[ImageManager] %s", err) + raise e return image_entity_list @staticmethod - async def update_image_by_image_id(image_id, tmp_dict: dict): + async def update_images_by_doc_id(doc_id: uuid.UUID, image_dict: Dict[str, str]) -> None: + """根据文档ID更新图片""" try: - async with await PostgresDB.get_session() as session: - # 使用执行SQL语句的方式获取用户 - stmt = select(ImageEntity).where(ImageEntity.id == image_id) - result = await session.execute(stmt) - image_to_update = result.scalars().first() - - if image_to_update is None: - raise ValueError(f"No image found with id {image_id}") - - # 更新用户属性 - for key, value in tmp_dict.items(): - if hasattr(image_to_update, key): - setattr(image_to_update, key, value) - else: - logging.error(f"Attribute {key} does not exist on User model") - + async with await DataBase.get_session() as session: + stmt = ( + update(ImageEntity) + .where(ImageEntity.doc_id == doc_id) + .values(**image_dict) + ) + await session.execute(stmt) await session.commit() - return True except Exception as e: - logging.error(f"Failed to update user: {e}") - return False + err = "更新图片失败" + logging.exception("[ImageManager] %s", err) + raise e diff --git a/data_chain/manager/knowledge_manager.py b/data_chain/manager/knowledge_manager.py index 54d172b..540d64f 100644 --- a/data_chain/manager/knowledge_manager.py +++ b/data_chain/manager/knowledge_manager.py @@ -1,151 +1,130 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. from typing import Dict, List, Tuple, Optional import uuid from data_chain.logger.logger import logger as logging -from sqlalchemy import and_, select, delete, func,between -from datetime import datetime,timezone -from data_chain.stores.postgres.postgres import PostgresDB, KnowledgeBaseEntity -from data_chain.models.constant import KnowledgeStatusEnum -from data_chain.models.constant import embedding_model_out_dimensions - +from sqlalchemy import and_, select, delete, update, func, between +from datetime import datetime, timezone +from data_chain.entities.request_data import ListKnowledgeBaseRequest +from data_chain.stores.database.database import DataBase, KnowledgeBaseEntity, DocumentTypeEntity, DocumentEntity +from data_chain.entities.enum import KnowledgeBaseStatus, DocumentStatus class KnowledgeBaseManager(): - + """知识库管理类""" @staticmethod - async def insert(entity: KnowledgeBaseEntity) -> Optional[KnowledgeBaseEntity]: + async def add_knowledge_base(knowledge_base_entity: KnowledgeBaseEntity) -> KnowledgeBaseEntity: + """添加知识库""" try: - async with await PostgresDB.get_session() as session: - vector_items_table = await PostgresDB.get_dynamic_vector_items_table( - str(entity.vector_items_id), - embedding_model_out_dimensions[entity.embedding_model] - ) - await PostgresDB.create_table(vector_items_table) - session.add(entity) + async with await DataBase.get_session() as session: + session.add(knowledge_base_entity) await session.commit() - await session.refresh(entity) # Refresh the entity to get any auto-generated values. - return entity + return knowledge_base_entity except Exception as e: - import traceback - logging.error(f"Failed to insert entity: {traceback.format_exc()}") - logging.error(f"Failed to insert entity: {e}") - return None + err = "添加知识库失败" + logging.exception("[KnowledgeBaseManager] %s", err) @staticmethod - async def update(id: uuid.UUID, update_dict: Dict) -> Optional[KnowledgeBaseEntity]: + async def get_knowledge_base_by_kb_id(kb_id: uuid.UUID) -> KnowledgeBaseEntity: + """根据知识库ID获取知识库""" try: - async with await PostgresDB.get_session() as session: - stmt = select(KnowledgeBaseEntity).where(KnowledgeBaseEntity.id == id).with_for_update() + async with await DataBase.get_session() as session: + stmt = select(KnowledgeBaseEntity).where(and_(KnowledgeBaseEntity.id == kb_id, + KnowledgeBaseEntity.status != KnowledgeBaseStatus.DELTED)) result = await session.execute(stmt) - entity = result.scalars().first() - - if 'status' in update_dict.keys() and update_dict['status'] != KnowledgeStatusEnum.IDLE: - if entity.status != KnowledgeStatusEnum.IDLE: - return None - if entity is not None: - for key, value in update_dict.items(): - setattr(entity, key, value) - await session.commit() - await session.refresh(entity) # Refresh the entity to ensure it's up to date. - return entity + knowledge_base_entity = result.scalars().first() + return knowledge_base_entity except Exception as e: - logging.error(f"Failed to update entity: {e}") - return None + err = "获取知识库失败" + logging.exception("[KnowledgeBaseManager] %s", err) + raise e @staticmethod - async def select_by_id(id: uuid.UUID) -> KnowledgeBaseEntity: - async with await PostgresDB.get_session() as session: # 假设get_async_session返回一个异步会话 - stmt = select(KnowledgeBaseEntity).where(KnowledgeBaseEntity.id == id) - result = await session.execute(stmt) - entity = result.scalars().first() - return entity - return None + async def list_knowledge_base(req: ListKnowledgeBaseRequest) -> Tuple[int, List[KnowledgeBaseEntity]]: + """列出知识库""" + try: + async with await DataBase.get_session() as session: + stmt = select(KnowledgeBaseEntity).where(KnowledgeBaseEntity.status != KnowledgeBaseStatus.DELTED) + if req.team_id: + stmt = stmt.where(KnowledgeBaseEntity.team_id == req.team_id) + if req.kb_id: + stmt = stmt.where(KnowledgeBaseEntity.id == req.kb_id) + if req.kb_name: + stmt = stmt.where(KnowledgeBaseEntity.name.like(f"%{req.kb_name}%")) + if req.author_name: + stmt = stmt.where(KnowledgeBaseEntity.author_name.like(f"%{req.author_name}%")) + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = (await session.execute(count_stmt)).scalar() + stmt = stmt.limit(req.page_size).offset((req.page - 1) * req.page_size) + stmt = stmt.order_by(KnowledgeBaseEntity.created_time.desc()) + result = await session.execute(stmt) + knowledge_base_entities = result.scalars().all() + return (total, knowledge_base_entities) + except Exception as e: + err = "列出知识库失败" + logging.exception("[KnowledgeBaseManager] %s", err) + raise e @staticmethod - async def select_by_user_id_and_kb_name(user_id: uuid.UUID, kb_name: str) -> KnowledgeBaseEntity: + async def list_knowledge_base_by_team_ids(team_ids: List[uuid.UUID]) -> List[KnowledgeBaseEntity]: + """根据团队ID获取知识库""" try: - async with await PostgresDB.get_session() as session: + async with await DataBase.get_session() as session: stmt = select(KnowledgeBaseEntity).where( - and_(KnowledgeBaseEntity.user_id == user_id, KnowledgeBaseEntity.name == kb_name)) + and_(KnowledgeBaseEntity.team_id.in_(team_ids), + KnowledgeBaseEntity.status != KnowledgeBaseStatus.DELTED.value)) result = await session.execute(stmt) - entity = result.scalars().first() - return entity + knowledge_base_entities = result.scalars().all() + return knowledge_base_entities except Exception as e: - logging.error(f"Failed to select by user id and kb name: {e}") - return None + err = "获取知识库失败" + logging.exception("[KnowledgeBaseManager] %s", err) + raise e @staticmethod - async def select_by_page(params: Dict, page_number: int, page_size: int) -> Tuple[int, List[KnowledgeBaseEntity]]: + async def list_doc_types_by_kb_id(kb_id: uuid.UUID) -> List[DocumentTypeEntity]: + """列出知识库文档类型""" try: - async with await PostgresDB.get_session() as session: - base_query = select(KnowledgeBaseEntity).where( - KnowledgeBaseEntity.status != KnowledgeStatusEnum.IMPORTING) - # 构建过滤条件 - filters = [] - if 'id' in params.keys(): - filters.append(KnowledgeBaseEntity.id == params['id']) - if 'user_id' in params.keys(): - filters.append(KnowledgeBaseEntity.user_id == params['user_id']) - if 'name' in params.keys(): - filters.append(KnowledgeBaseEntity.name.ilike(f"%{params['name']}%")) - if 'created_time_start' in params and 'created_time_end' in params: - filters.append(between(KnowledgeBaseEntity.created_time, - datetime.strptime(params['created_time_start'], '%Y-%m-%d %H:%M').replace(tzinfo=timezone.utc), - datetime.strptime(params['created_time_end'], '%Y-%m-%d %H:%M').replace(tzinfo=timezone.utc))) - # 应用过滤条件 - query = base_query.where(*filters) - - # 排序 - if 'created_time_order' in params.keys(): - if params['created_time_order'] == 'desc': - query = query.order_by(KnowledgeBaseEntity.created_time.desc()) - elif params['created_time_order'] == 'asc': - query = query.order_by(KnowledgeBaseEntity.created_time.asc()) - if 'document_count_order' in params.keys(): - if params['document_count_order'] == 'desc': - query = query.order_by(KnowledgeBaseEntity.document_number.desc()) - elif params['document_count_order'] == 'asc': - query = query.order_by(KnowledgeBaseEntity.document_number.asc()) - - # 获取总数 - count_query = select(func.count()).select_from(query.subquery()) - total = await session.scalar(count_query) - # 分页查询 - paginated_query = query.offset((page_number - 1) * page_size).limit(page_size) - results = await session.scalars(paginated_query) - knowledge_base_entity_list = results.all() - - return (total, knowledge_base_entity_list) + async with await DataBase.get_session() as session: + stmt = select(DocumentTypeEntity).where(DocumentTypeEntity.kb_id == kb_id) + result = await session.execute(stmt) + document_type_entities = result.scalars().all() + return document_type_entities except Exception as e: - logging.error(f"Failed to select by page: {e}") - return (0, []) + err = "列出知识库文档类型失败" + logging.exception("[KnowledgeBaseManager] %s", err) + raise e @staticmethod - async def delete(kb_id: str) -> int: + async def update_knowledge_base_by_kb_id(kb_id: uuid.UUID, kb_dict: Dict[str, str]) -> KnowledgeBaseEntity: + """根据知识库ID更新知识库""" try: - async with await PostgresDB.get_session() as session: + async with await DataBase.get_session() as session: + stmt = update(KnowledgeBaseEntity).where(KnowledgeBaseEntity.id == kb_id).values(**kb_dict) + await session.execute(stmt) + await session.commit() stmt = select(KnowledgeBaseEntity).where(KnowledgeBaseEntity.id == kb_id) result = await session.execute(stmt) knowledge_base_entity = result.scalars().first() + return knowledge_base_entity + except Exception as e: + err = "更新知识库失败" + logging.exception("[KnowledgeBaseManager] %s", err) + raise e - if knowledge_base_entity: - try: - vector_items_id = str(knowledge_base_entity.vector_items_id) - vector_dim = embedding_model_out_dimensions[knowledge_base_entity.embedding_model] - vector_items_table = await PostgresDB.get_dynamic_vector_items_table( - vector_items_id, - vector_dim - ) - await PostgresDB.drop_table(vector_items_table) - except Exception as e: - logging.error(f"Failed to delete vector items table: {e}") - delete_stmt = delete(KnowledgeBaseEntity).where(KnowledgeBaseEntity.id == kb_id) - result = await session.execute(delete_stmt) - cnt = result.rowcount - await session.commit() - return cnt - else: - return 0 + @staticmethod + async def update_doc_cnt_and_doc_size(kb_id: uuid.UUID) -> None: + """根据知识库ID更新知识库文档数量和文档大小,获取document表内状态不是deleted的文档数量和大小""" + try: + async with await DataBase.get_session() as session: + stmt = select(func.count(DocumentEntity.id), func.sum(DocumentEntity.size)).where( + and_(DocumentEntity.kb_id == kb_id, DocumentEntity.status != DocumentStatus.DELETED)) + result = await session.execute(stmt) + doc_cnt, doc_size = result.first() + stmt = update(KnowledgeBaseEntity).where(KnowledgeBaseEntity.id == kb_id).values( + document_count=doc_cnt, document_size=doc_size) + await session.execute(stmt) + await session.commit() except Exception as e: - logging.error(f"Failed to delete knowledge base entity: {e}") - return 0 + err = "更新知识库文档数量和大小失败" + logging.exception("[KnowledgeBaseManager] %s", err) + raise e diff --git a/data_chain/manager/role_manager.py b/data_chain/manager/role_manager.py new file mode 100644 index 0000000..4b97da7 --- /dev/null +++ b/data_chain/manager/role_manager.py @@ -0,0 +1,91 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from sqlalchemy import select, delete, and_ +from typing import Dict +import uuid + +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ListTeamRequest +from data_chain.entities.enum import TeamStatus +from data_chain.stores.database.database import DataBase, RoleEntity, ActionEntity, RoleActionEntity, UserRoleEntity + + +class RoleManager: + @staticmethod + async def add_role(role_entity: RoleEntity) -> RoleEntity: + """添加角色""" + try: + async with await DataBase.get_session() as session: + session.add(role_entity) + await session.commit() + await session.refresh(role_entity) + except Exception as e: + err = "添加角色失败" + logging.exception("[RoleManager] %s", err) + raise e + return role_entity + + @staticmethod + async def add_user_role(user_role_entity: UserRoleEntity) -> UserRoleEntity: + """添加用户角色""" + try: + async with await DataBase.get_session() as session: + session.add(user_role_entity) + await session.commit() + await session.refresh(user_role_entity) + except Exception as e: + err = "添加用户角色失败" + logging.exception("[RoleManager] %s", err) + raise e + return user_role_entity + + @staticmethod + async def add_action(action_entity: ActionEntity) -> ActionEntity: + """添加操作""" + try: + async with await DataBase.get_session() as session: + session.add(action_entity) + await session.commit() + await session.refresh(action_entity) + return True + except Exception as e: + err = "添加操作失败" + logging.warning("[RoleManager] %s", err) + return False + + @staticmethod + async def add_role_actions(role_action_entities: list[RoleActionEntity]) -> bool: + """添加角色操作""" + try: + async with await DataBase.get_session() as session: + for role_action_entity in role_action_entities: + session.add(role_action_entity) + await session.commit() + for role_action_entity in role_action_entities: + await session.refresh(role_action_entity) + except Exception as e: + err = "添加角色操作失败" + logging.exception("[RoleManager] %s", err) + raise e + + @staticmethod + async def get_action_by_team_id_user_sub_and_action( + user_sub: str, team_id: uuid.UUID, action: str) -> ActionEntity: + """根据团队ID、用户ID和操作获取操作""" + try: + async with await DataBase.get_session() as session: + stmt = select(ActionEntity).join( + RoleActionEntity, ActionEntity.action == RoleActionEntity.action).join( + UserRoleEntity, RoleActionEntity.role_id == UserRoleEntity.role_id).where( + and_( + UserRoleEntity.user_id == user_sub, + UserRoleEntity.team_id == team_id, + ActionEntity.action == action, + ) + ) + result = await session.execute(stmt) + action_entity = result.scalars().first() + return action_entity + except Exception as e: + err = "根据团队ID、用户ID和操作获取操作失败" + logging.exception("[RoleManager] %s", err) + raise e diff --git a/data_chain/manager/session_manager.py b/data_chain/manager/session_manager.py new file mode 100644 index 0000000..e8c8153 --- /dev/null +++ b/data_chain/manager/session_manager.py @@ -0,0 +1,37 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from data_chain.logger.logger import logger as logging +from data_chain.stores.mongodb.mongodb import Session, MongoDB + + +class SessionManager: + """浏览器Session管理""" + + @staticmethod + async def verify_user(session_id: str) -> bool: + """验证用户是否在Session中""" + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: + return False + return Session(**data).user_sub is not None + except Exception as e: + err = "用户不在Session中" + logging.error("[SessionManager] %s", err) + raise e + + @staticmethod + async def get_user_sub(session_id: str) -> str | None: + """从Session中获取用户""" + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: + return None + user_sub = Session(**data).user_sub + except Exception as e: + err = "从Session中获取用户失败" + logging.error("[SessionManager] %s", err) + raise e + + return user_sub diff --git a/data_chain/manager/task_manager.py b/data_chain/manager/task_manager.py index 80060a5..a8608f5 100644 --- a/data_chain/manager/task_manager.py +++ b/data_chain/manager/task_manager.py @@ -4,290 +4,81 @@ from sqlalchemy.orm import aliased import uuid from typing import Dict, List, Optional, Tuple from data_chain.logger.logger import logger as logging -from data_chain.stores.postgres.postgres import PostgresDB, TaskEntity, TaskStatusReportEntity, DocumentEntity, KnowledgeBaseEntity -from data_chain.models.constant import TaskConstant - +from data_chain.stores.database.database import DataBase, TaskEntity +from data_chain.entities.enum import TaskStatus class TaskManager(): - @staticmethod - async def insert(entity: TaskEntity) -> TaskEntity: + async def add_task(task_entity: TaskEntity) -> TaskEntity: + """添加任务""" try: - async with await PostgresDB.get_session() as session: - session.add(entity) + async with await DataBase.get_session() as session: + session.add(task_entity) await session.commit() - await session.refresh(entity) # 刷新实体以获取可能由数据库生成的数据(例如自增ID) - return entity + await session.refresh(task_entity) except Exception as e: - logging.error(f"Failed to insert entity: {e}") - return None + err = "添加任务失败" + logging.exception("[TaskManager] %s", err) + raise e + return task_entity @staticmethod - async def update(task_id: uuid.UUID, update_dict: Dict): + async def get_task_by_id(task_id: uuid.UUID) -> TaskEntity: + """根据任务ID获取任务""" try: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where(TaskEntity.id == task_id) + async with await DataBase.get_session() as session: + stmt = select(TaskEntity).where(and_(TaskEntity.id == task_id, + TaskEntity.status != TaskStatus.DELETED.value)) result = await session.execute(stmt) - entity = result.scalars().first() - if entity is not None: - for key, value in update_dict.items(): - setattr(entity, key, value) - await session.commit() - await session.refresh(entity) # Refresh the entity to ensure it's up to date. - return entity + task_entity = result.scalars().first() + return task_entity except Exception as e: - logging.error(f"Failed to update entity: {e}") - return None + err = "获取任务失败" + logging.exception("[TaskManager] %s", err) + raise e @staticmethod - async def update_task_by_op_id(op_id: uuid.UUID, update_dict: Dict): + async def get_current_task_by_op_id(op_id: uuid.UUID) -> Optional[TaskEntity]: + """根据op_id获取当前最近的任务""" try: - async with await PostgresDB.get_session() as session: - stmt = update(TaskEntity).where(TaskEntity.op_id == op_id).values(**update_dict) - await session.execute(stmt) - await session.commit() - return True + async with await DataBase.get_session() as session: + stmt = select(TaskEntity).where(TaskEntity.op_id == op_id).order_by(desc(TaskEntity.created_time)) + result = await session.execute(stmt) + task_entity = result.scalars().first() + return task_entity except Exception as e: - logging.error(f"Failed to update entity: {e}") - return False + err = "获取任务失败" + logging.exception("[TaskManager] %s", err) + raise e @staticmethod - async def select_by_page(page_number: int, page_size: int, params: Dict) -> Tuple[int, List[TaskEntity]]: + async def get_task_by_task_status(task_status: str) -> List[TaskEntity]: + """根据任务状态获取任务""" try: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where(TaskEntity.status != TaskConstant.TASK_STATUS_DELETED) - stmt = stmt.where( - exists(select(1).where(or_( - DocumentEntity.id == TaskEntity.op_id, - KnowledgeBaseEntity.id == TaskEntity.op_id - ))) - ) - if 'user_id' in params: - stmt = stmt.where(TaskEntity.user_id == params['user_id']) - if 'id' in params: - stmt = stmt.where(TaskEntity.id == params['id']) - if 'op_id' in params: - stmt = stmt.where(TaskEntity.op_id == params['op_id']) - if 'types' in params: - stmt = stmt.where(TaskEntity.type.in_(params['types'])) - if 'status' in params: - stmt = stmt.where(TaskEntity.status == params['status']) - if 'created_time_order' in params: - if params['created_time_order'] == 'desc': - stmt = stmt.order_by(desc(TaskEntity.created_time)) - elif params['created_time_order'] == 'asc': - stmt = stmt.order_by(asc(TaskEntity.created_time)) - - # Execute the count part of the query separately - count_stmt = select(func.count()).select_from(stmt.subquery()) - total = (await session.execute(count_stmt)).scalar() - - # Apply pagination - stmt = stmt.offset((page_number-1)*page_size).limit(page_size) - - # Execute the main query + async with await DataBase.get_session() as session: + stmt = select(TaskEntity).where(TaskEntity.status == task_status) result = await session.execute(stmt) - task_list = result.scalars().all() - - return (total, task_list) + task_entity = result.scalars().all() + return task_entity except Exception as e: - logging.error(f"Failed to select tasks by page: {e}") - return (0, []) - - @staticmethod - async def select_by_id(id: uuid.UUID) -> TaskEntity: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where(TaskEntity.id == id) - result = await session.execute(stmt) - return result.scalar() - - @staticmethod - async def select_by_ids(ids: List[uuid.UUID]) -> List[TaskEntity]: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where(TaskEntity.id.in_(ids)) - result = await session.execute(stmt) - result = result.scalars().all() - return result - - @staticmethod - async def select_by_user_id(user_id: uuid.UUID) -> TaskEntity: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where(and_(TaskEntity.user_id == user_id, - TaskEntity.status != TaskConstant.TASK_STATUS_DELETED)) - result = await session.execute(stmt) - return result.scalars().all() - - @staticmethod - async def select_by_op_id(op_id: uuid.UUID, method='one') -> TaskEntity: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where( - TaskEntity.op_id == op_id).order_by( - desc(TaskEntity.created_time)) - if method=='one': - stmt=stmt.limit(1) - result = await session.execute(stmt) - if method == 'one': - result = result.scalars().first() - else: - result = result.scalars().all() - return result - @staticmethod - async def select_latest_task_by_op_ids(op_ids: List[uuid.UUID]) -> List[TaskEntity]: - async with await PostgresDB.get_session() as session: - # 创建一个别名用于子查询 - task_alias = aliased(TaskEntity) - - # 构建子查询,为每个op_id分配一个行号 - subquery = ( - select( - task_alias, - func.row_number().over( - partition_by=task_alias.op_id, - order_by=desc(task_alias.created_time) - ).label('row_num') - ) - .where(task_alias.op_id.in_(op_ids)) - .subquery() - ) - - # 主查询选择row_num为1的记录,即每个op_id的最新任务 - stmt = ( - select(TaskEntity) - .join(subquery, TaskEntity.id == subquery.c.id) - .where(subquery.c.row_num == 1) - ) - - result = await session.execute(stmt) - return result.scalars().all() - - @staticmethod - async def delete_by_op_id(op_id: uuid.UUID) -> TaskEntity: - async with await PostgresDB.get_session() as session: - stmt = delete(TaskEntity).where( - TaskEntity.op_id == op_id) - await session.execute(stmt) - - @staticmethod - async def select_by_op_ids(op_ids: List[uuid.UUID]) -> List[TaskEntity]: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where(TaskEntity.op_id.in_(op_ids)) - result = await session.execute(stmt) - return result.scalars().all() + err = "获取任务失败" + logging.exception("[TaskManager] %s", err) + raise e @staticmethod - async def select_by_user_id_and_task_type( - user_id: uuid.UUID, task_type: Optional[str] = None) -> List[TaskEntity]: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where(TaskEntity.user_id == user_id) - if task_type: - stmt = stmt.where(TaskEntity.type == task_type) - stmt = stmt.order_by(TaskEntity.created_time.desc()) - result = await session.execute(stmt) - return result.scalars().all() - - @staticmethod - async def select_by_user_id_and_task_type_list( - user_id: uuid.UUID, task_type_list: Optional[List[str]] = []) -> List[TaskEntity]: - async with await PostgresDB.get_session() as session: - stmt = select(TaskEntity).where(TaskEntity.user_id == user_id) - stmt = stmt.where(TaskEntity.type.in_(task_type_list)) - stmt = stmt.where(TaskEntity.status != TaskConstant.TASK_STATUS_DELETED) - stmt = stmt.order_by(TaskEntity.created_time.desc()) - result = await session.execute(stmt) - return result.scalars().all() - - @staticmethod - async def delete_by_id(id: uuid.UUID): - async with await PostgresDB.get_session() as session: - # 构建删除语句 - stmt = ( - select(TaskEntity) - .where(TaskEntity.id == id) - ) - # 执行删除操作 - result = await session.execute(stmt) - instances = result.scalars().all() - if instances: - for instance in instances: - await session.delete(instance) - await session.commit() - - @staticmethod - async def delete_by_ids(ids: List[uuid.UUID]): - async with await PostgresDB.get_session() as session: - # 构建删除语句 - stmt = ( - select(TaskEntity) - .where(TaskEntity.id.in_(ids)) - ) - # 执行删除操作 - result = await session.execute(stmt) - instances = result.scalars().all() - if instances: - for instance in instances: - await session.delete(instance) - await session.commit() - - -class TaskStatusReportManager(): - - @staticmethod - async def insert(entity: TaskStatusReportEntity): - async with await PostgresDB.get_session() as session: - session.add(entity) - await session.commit() - return entity - - @staticmethod - async def del_by_task_id(task_id: uuid.UUID): - async with await PostgresDB.get_session() as session: - stmt = ( - select(TaskStatusReportEntity) - .where(TaskStatusReportEntity.task_id == task_id) - ) - result = await session.execute(stmt) - entities = result.scalars().all() - for entity in entities: - await session.delete(entity) - await session.commit() - - @staticmethod - async def select_by_task_id(task_id: uuid.UUID, limited: int = 10): - async with await PostgresDB.get_session() as session: - stmt = ( - select(TaskStatusReportEntity) - .where(TaskStatusReportEntity.task_id == task_id) - .order_by(desc(TaskStatusReportEntity.created_time)) - .limit(limited) - ) - result = await session.execute(stmt) - return result.scalars().all() - - @staticmethod - async def select_latest_report_by_task_ids(task_ids: List[uuid.UUID]) -> List[TaskStatusReportEntity]: - async with await PostgresDB.get_session() as session: - # 创建一个别名用于子查询 - report_alias = aliased(TaskStatusReportEntity) - - # 构建子查询,为每个task_id分配一个行号 - subquery = ( - select( - report_alias, - func.row_number().over( - partition_by=report_alias.task_id, - order_by=desc(report_alias.created_time) - ).label('row_num') - ) - .where(report_alias.task_id.in_(task_ids)) - .subquery() - ) - # 主查询选择row_num为1的记录,即每个op_id的最新任务 - stmt = ( - select(TaskStatusReportEntity) - .join(subquery, TaskStatusReportEntity.id == subquery.c.id) - .where(subquery.c.row_num == 1) - ) - - result = await session.execute(stmt) - return result.scalars().all() + async def update_task_by_id(task_id: uuid.UUID, task_dict: Dict) -> TaskEntity: + """根据任务ID更新任务""" + try: + async with await DataBase.get_session() as session: + stmt = update(TaskEntity).where(TaskEntity.id == task_id).values(**task_dict) + await session.execute(stmt) + await session.commit() + stmt = select(TaskEntity).where(TaskEntity.id == task_id) + result = await session.execute(stmt) + task_entity = result.scalars().first() + return task_entity + except Exception as e: + err = "更新任务失败" + logging.exception("[TaskManager] %s", err) + raise e diff --git a/data_chain/manager/task_queue_mamanger.py b/data_chain/manager/task_queue_mamanger.py new file mode 100644 index 0000000..4523e19 --- /dev/null +++ b/data_chain/manager/task_queue_mamanger.py @@ -0,0 +1,61 @@ + +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from sqlalchemy import select, delete, update, desc, asc, func, exists, or_, and_ +from sqlalchemy.orm import aliased +import uuid +from typing import Dict, List, Optional, Tuple +from data_chain.logger.logger import logger as logging +from data_chain.stores.database.database import DataBase, TaskEntity +from data_chain.stores.mongodb.mongodb import MongoDB, Task +from data_chain.entities.enum import TaskStatus + + +class TaskQueueManager(): + """任务队列管理类""" + + @staticmethod + async def add_task(task: Task): + try: + async with MongoDB.get_session() as session, await session.start_transaction(): + task_colletion = MongoDB.get_collection('witchiand_task') + await task_colletion.insert_one(task.model_dump(by_alias=True), session=session) + except Exception as e: + err = "添加任务到队列失败" + logging.exception("[TaskQueueManager] %s", err) + + @staticmethod + async def delete_task_by_id(task_id: uuid.UUID): + """根据任务ID删除任务""" + try: + async with MongoDB.get_session() as session, await session.start_transaction(): + task_colletion = MongoDB.get_collection('witchiand_task') + await task_colletion.delete_one({"_id": task_id}, session=session) + except Exception as e: + err = "删除任务失败" + logging.exception("[TaskQueueManager] %s", err) + raise e + + @staticmethod + async def get_oldest_tasks_by_status(status: TaskStatus) -> Task: + """根据任务状态获取最早的任务""" + try: + async with MongoDB.get_session() as session: + task_colletion = MongoDB.get_collection('witchiand_task') + task = await task_colletion.find_one({"status": status.value}, sort=[("created_time", 1)], session=session) + return Task(**task) if task else None + except Exception as e: + err = "获取最早的任务失败" + logging.exception("[TaskQueueManager] %s", err) + raise e + + @staticmethod + async def update_task_by_id(task_id: uuid.UUID, task: Task): + """根据任务ID更新任务""" + try: + async with MongoDB.get_session() as session, await session.start_transaction(): + task_colletion = MongoDB.get_collection('witchiand_task') + await task_colletion.update_one({"_id": task_id}, {"$set": task.model_dump(by_alias=True)}, session=session) + except Exception as e: + err = "更新任务失败" + logging.exception("[TaskQueueManager] %s", err) + raise e diff --git a/data_chain/manager/task_report_manager.py b/data_chain/manager/task_report_manager.py new file mode 100644 index 0000000..402748b --- /dev/null +++ b/data_chain/manager/task_report_manager.py @@ -0,0 +1,23 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +from sqlalchemy import select, delete, update, desc, asc, func, exists, or_, and_ +from sqlalchemy.orm import aliased +import uuid +from typing import Dict, List, Optional, Tuple +from data_chain.logger.logger import logger as logging +from data_chain.stores.database.database import DataBase, TaskReportEntity +from data_chain.entities.enum import TaskStatus + + +class TaskReportManager(): + @staticmethod + async def add_task_report(task_report_entity: TaskReportEntity) -> TaskReportEntity: + """添加任务报告""" + try: + async with await DataBase.get_session() as session: + session.add(task_report_entity) + await session.commit() + await session.refresh(task_report_entity) + return task_report_entity + except Exception as e: + err = "添加任务报告失败" + logging.exception("[TaskReportManager] %s", err) diff --git a/data_chain/manager/team_manager.py b/data_chain/manager/team_manager.py new file mode 100644 index 0000000..9efe572 --- /dev/null +++ b/data_chain/manager/team_manager.py @@ -0,0 +1,152 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from sqlalchemy import select, update, delete, and_, func +from typing import Dict +import uuid + +from data_chain.logger.logger import logger as logging +from data_chain.entities.request_data import ListTeamRequest +from data_chain.entities.enum import TeamStatus +from data_chain.stores.database.database import DataBase, TeamEntity, TeamUserEntity + + +class TeamManager: + + @staticmethod + async def add_team(team_entity: TeamEntity) -> TeamEntity: + """添加团队""" + try: + async with await DataBase.get_session() as session: + session.add(team_entity) + await session.commit() + await session.refresh(team_entity) + except Exception as e: + err = "添加团队失败" + logging.exception("[TeamManger] %s", err) + raise e + return team_entity + + @staticmethod + async def add_team_user(team_user_entity: TeamUserEntity) -> TeamUserEntity: + """添加团队成员""" + try: + async with await DataBase.get_session() as session: + session.add(team_user_entity) + await session.commit() + await session.refresh(team_user_entity) + except Exception as e: + err = "添加团队成员失败" + logging.exception("[TeamManger] %s", err) + return team_user_entity + + @staticmethod + async def list_team_myjoined_by_user_sub(user_sub: str, req: ListTeamRequest) -> list[TeamEntity]: + """列出我加入的团队,以及总数""" + try: + async with await DataBase.get_session() as session: + stmt = select(TeamEntity).join(TeamUserEntity, TeamEntity.id == TeamUserEntity.team_id).where( + and_(TeamUserEntity.user_id == user_sub, TeamEntity.status != TeamStatus.DELETED.value)) + if req.team_id: + stmt = stmt.where(TeamEntity.id == req.team_id) + if req.team_name: + stmt = stmt.where(TeamEntity.name.ilike(f"%{req.team_name}%")) + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = (await session.execute(count_stmt)).scalar() + stmt = stmt.limit(req.page_size).offset((req.page - 1) * req.page_size) + stmt = stmt.order_by(TeamEntity.created_time.desc()) + result = await session.execute(stmt) + team_entities = result.scalars().all() + return (total, team_entities) + except Exception as e: + err = "列出我加入的团队失败" + logging.exception("[TeamManager] %s", err) + raise e + + @staticmethod + async def list_team_mycreated_user_sub(user_sub: str, req: ListTeamRequest) -> list[TeamEntity]: + """列出我创建的团队""" + try: + async with await DataBase.get_session() as session: + stmt = select(TeamEntity).where(and_( + TeamEntity.author_id == user_sub, TeamEntity.status != TeamStatus.DELETED.value)) + if req.team_id: + stmt = stmt.where(TeamEntity.id == req.team_id) + if req.team_name: + stmt = stmt.where(TeamEntity.name.ilike(f"%{req.team_name}%")) + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = (await session.execute(count_stmt)).scalar() + stmt = stmt.limit(req.page_size).offset((req.page - 1) * req.page_size) + stmt = stmt.order_by(TeamEntity.created_time.desc()) + result = await session.execute(stmt) + team_entities = result.scalars().all() + return (total, team_entities) + except Exception as e: + err = "列出我创建的团队失败" + logging.exception("[TeamManager] %s", err) + raise e + + @staticmethod + async def list_pulic_team(req: ListTeamRequest) -> list[TeamEntity]: + """列出公开的团队""" + try: + async with await DataBase.get_session() as session: + stmt = select(TeamEntity).where(and_( + TeamEntity.status != TeamStatus.DELETED.value, TeamEntity.is_public == True)) + if req.team_id: + stmt = stmt.where(TeamEntity.id == req.team_id) + if req.team_name: + stmt = stmt.where(TeamEntity.name.ilike(f"%{req.team_name}%")) + count_stmt = select(func.count()).select_from(stmt.subquery()) + total = (await session.execute(count_stmt)).scalar() + stmt = stmt.limit(req.page_size).offset((req.page - 1) * req.page_size) + stmt = stmt.order_by(TeamEntity.created_time.desc()) + result = await session.execute(stmt) + team_entities = result.scalars().all() + return (total, team_entities) + except Exception as e: + err = "列出公开的团队失败" + logging.exception("[TeamManager] %s", err) + raise e + + @staticmethod + async def delete_team_by_id(team_id: uuid.UUID) -> uuid.UUID: + """删除团队""" + try: + async with await DataBase.get_session() as session: + stmt = delete(TeamEntity).where(TeamEntity.id == team_id) + await session.execute(stmt) + await session.commit() + except Exception as e: + err = "删除团队失败" + logging.exception("[TeamManager] %s", err) + raise e + return team_id + + @staticmethod + async def delete_teams_deleted() -> None: + """删除团队""" + try: + async with await DataBase.get_session() as session: + stmt = delete(TeamEntity).where(TeamEntity.status == TeamStatus.DELETED.value) + await session.execute(stmt) + await session.commit() + except Exception as e: + err = "删除团队失败" + logging.exception("[TeamManager] %s", err) + raise e + + @staticmethod + async def update_team_by_id(team_id: uuid.UUID, team_dict: Dict[str, str]) -> TeamEntity: + """更新团队""" + try: + async with await DataBase.get_session() as session: + stmt = update(TeamEntity).where(TeamEntity.id == team_id).values(**team_dict) + await session.execute(stmt) + await session.commit() + stmt = select(TeamEntity).where(TeamEntity.id == team_id) + result = await session.execute(stmt) + team_entity = result.scalars().first() + return team_entity + except Exception as e: + err = "更新团队失败" + logging.exception("[TeamManager] %s", err) + raise e diff --git a/data_chain/manager/user_manager.py b/data_chain/manager/user_manager.py index 90c9eb7..2277872 100644 --- a/data_chain/manager/user_manager.py +++ b/data_chain/manager/user_manager.py @@ -1,106 +1,22 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from sqlalchemy import select,delete +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from sqlalchemy import select, delete from data_chain.logger.logger import logger as logging -from data_chain.stores.postgres.postgres import PostgresDB, User - - +from data_chain.entities.enum import UserStatus +from data_chain.stores.database.database import DataBase, UserEntity class UserManager: @staticmethod - async def add_user(name,email, account, passwd): - user_slice = User( - name=name, - email=email, - account=account, - passwd=passwd - ) + async def add_user(user_entity: UserEntity) -> bool: try: - async with await PostgresDB.get_session() as session: - session.add(user_slice) - await session.commit() - await session.refresh(user_slice) - except Exception as e: - logging.error(f"Add user failed due to error: {e}") - return None - return user_slice - - @staticmethod - async def del_user_by_user_id(user_id): - try: - async with await PostgresDB.get_session() as session: - # 使用执行SQL语句的方式获取并删除用户 - stmt = select(User).where(User.id == user_id) - result = await session.execute(stmt) - user_to_delete = result.scalars().first() - - if user_to_delete is not None: - delete_stmt = delete(User).where(User.id==user_id) - result = await session.execute(delete_stmt) - await session.commit() - except Exception as e: - logging.error(f"Delete user failed due to error: {e}") - return False - return True - - @staticmethod - async def update_user_by_user_id(user_id, tmp_dict: dict): - try: - async with await PostgresDB.get_session() as session: - # 使用执行SQL语句的方式获取用户 - stmt = select(User).where(User.id == user_id) - result = await session.execute(stmt) - user_to_update = result.scalars().first() - - if user_to_update is None: - raise ValueError(f"No user found with id {user_id}") - - # 更新用户属性 - for key, value in tmp_dict.items(): - if hasattr(user_to_update, key): - setattr(user_to_update, key, value) - else: - logging.error(f"Attribute {key} does not exist on User model") - + async with await DataBase.get_session() as session: + session.add(user_entity) await session.commit() + await session.refresh(user_entity) return True except Exception as e: - logging.error(f"Failed to update user: {e}") - return False - - @staticmethod - async def get_user_info_by_account(account): - try: - async with await PostgresDB.get_session() as session: - stmt = select(User).where(User.account == account) - result = await session.execute(stmt) - user = result.scalars().first() - return user - except Exception as e: - logging.error(f"Failed to get user info by account: {e}") - return None - @staticmethod - async def get_user_info_by_email(email): - try: - async with await PostgresDB.get_session() as session: - stmt = select(User).where(User.email == email) - result = await session.execute(stmt) - user = result.scalars().first() - return user - except Exception as e: - logging.error(f"Failed to get user info by account: {e}") - return None - @staticmethod - async def get_user_info_by_user_id(user_id): - result = None - try: - async with await PostgresDB.get_session() as session: - stmt = select(User).where(User.id == user_id) - result = await session.execute(stmt) - result = result.scalars().first() - except Exception as e: - logging.error(f"Get user failed due to error: {e}") - return result - + err = "用户添加失败" + logging.exception("[UserManger] %s", err) + return False -- Gitee