diff --git a/data_chain/apps/router/chunk.py b/data_chain/apps/router/chunk.py index 794afb0acf8e3676d690011bc94bde30b82cf54b..7cf85b47b7d96753863fdb002cd89b54171d6ef4 100644 --- a/data_chain/apps/router/chunk.py +++ b/data_chain/apps/router/chunk.py @@ -1,79 +1,29 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from typing import List -import tiktoken -from fastapi import APIRouter, Depends, status - - -from data_chain.models.service import ChunkDTO -from data_chain.models.api import Page, BaseResponse, ListChunkRequest, SwitchChunkRequest,GetChunkRequest -from data_chain.exceptions.err_code import ErrorCode -from data_chain.exceptions.exception import DocumentException -from data_chain.apps.service.chunk_service import _validate_chunk_belong_to_user, list_chunk, switch_chunk,get_similar_chunks,get_similar_full_text, get_keywords_from_chunk -from data_chain.apps.service.document_service import _validate_doucument_belong_to_user -from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, verify_user - -router = APIRouter(prefix='/chunk', tags=['Corpus']) - - -@router.post('/list', response_model=BaseResponse[Page[ChunkDTO]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def list(req: ListChunkRequest, user_id=Depends(get_user_id)): - try: - await _validate_doucument_belong_to_user(user_id, req.document_id) - params = dict(req) - chunk_list, total = await list_chunk(params, req.page_number, req.page_size) - chunk_page = Page(page_number=req.page_number, page_size=req.page_size, - total=total, - data_list=chunk_list) - return BaseResponse(data=chunk_page) - except Exception as e: - return BaseResponse(retcode=ErrorCode.CREATE_CHUNK_ERROR, data=str(e.args[0])) - - -@router.post('/switch', response_model=BaseResponse[str], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def switch(req: SwitchChunkRequest, user_id=Depends(get_user_id)): - try: - for id in req.ids: - await _validate_chunk_belong_to_user(user_id, id) - for id in req.ids: - await switch_chunk(id, req.enabled) - return BaseResponse(data='success') - except Exception as e: - return BaseResponse(retcode=ErrorCode.SWITCH_CHUNK_ERROR, data=str(e.args[0])) - -@router.post('/get', response_model=BaseResponse[List[str]]) -async def get(req:GetChunkRequest): - try: - content = req.content - kb_sn=req.kb_sn - topk=req.topk - retrieval_mode=req.retrieval_mode - enc = tiktoken.encoding_for_model("gpt-4") - str_len_keywords_len_ratio_pair_list=[(30,1),(60,0.75),(120,0.55),(240,0.35),(1000,0.1)] - content_len=len(enc.encode(content)) - ratio=0 - for str_len,keywords_len_ratio in str_len_keywords_len_ratio_pair_list: - if content_len<=str_len: - ratio=keywords_len_ratio - break - if ratio==0: - keywords_cnt=100 - else: - keywords_cnt=int(content_len*ratio) - if len(enc.encode(content)) > 100: - keywords=await get_keywords_from_chunk(content,keywords_cnt) - content='' - for keyword in keywords: - content+=keyword+' ' - if retrieval_mode=='chunk': - chunk_list=await get_similar_chunks(content=content,kb_id=kb_sn,topk=topk,devided_by_document_id=False) - elif retrieval_mode=='full_text': - chunk_list=await get_similar_full_text(content=content,kb_id=kb_sn,topk=topk) - else: - chunk_list=[] - return BaseResponse(data=chunk_list) - except Exception as e: - return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, data=str(e.args[0])) \ No newline at end of file +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +from data_chain.apps.service.session_service import get_user_sub, verify_user +from fastapi import APIRouter, Depends, Query, Body +from typing import Annotated +from uuid import UUID +from data_chain.entities.request_data import ( + ListChunkRequest, + UpdateChunkRequest +) + +from data_chain.entities.response_data import ( + ListChunkResponse, + UpdateChunkResponse +) + +router = APIRouter(prefix='/chunk', tags=['Chunk']) + + +@router.get('', response_model=ListChunkResponse, dependencies=[Depends(verify_user)]) +async def list_chunks_by_document_id( + user_sub: Annotated[str, Depends(get_user_sub)], + req: Annotated[ListChunkRequest, Body()], +): + return ListChunkResponse() + + +@router.put('', response_model=UpdateChunkResponse, dependencies=[Depends(verify_user)]) +async def update_chunk_by_id(user_sub: Annotated[str, Depends(get_user_sub)], req: UpdateChunkRequest): + return UpdateChunkResponse() diff --git a/data_chain/apps/router/document.py b/data_chain/apps/router/document.py index 5d0e86f66b9a31807f33da77298edab4b1c9be85..25f3a5e49c75264cd02fed480066fd3d7cfbe350 100644 --- a/data_chain/apps/router/document.py +++ b/data_chain/apps/router/document.py @@ -1,209 +1,88 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import urllib -from typing import Dict, List -import uuid -from fastapi import HTTPException, status -from data_chain.models.service import DocumentDTO, TemporaryDocumentDTO -from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, verify_user -from data_chain.exceptions.err_code import ErrorCode -from data_chain.exceptions.exception import DocumentException -from data_chain.models.api import BaseResponse, Page -from data_chain.models.api import DeleteDocumentRequest, ListDocumentRequest, UpdateDocumentRequest, \ - RunDocumentRequest, SwitchDocumentRequest, ParserTemporaryDocumenRequest, GetTemporaryDocumentStatusRequest, \ - DeleteTemporaryDocumentRequest, RelatedTemporaryDocumenRequest -from data_chain.apps.service.knwoledge_base_service import _validate_knowledge_base_belong_to_user -from data_chain.apps.service.document_service import _validate_doucument_belong_to_user, delete_document, \ - generate_document_download_link, \ - list_documents_by_knowledgebase_id, run_document, submit_upload_document_task, switch_document, update_document, \ - get_file_name_and_extension, init_temporary_document_parse_task, delete_temporary_document, get_temporary_document_parse_status, \ - get_related_document - -from httpx import AsyncClient -from fastapi import Depends -from fastapi import APIRouter, File, UploadFile -from fastapi.responses import StreamingResponse + +from fastapi import APIRouter, Depends, Query, Body, File, UploadFile +from typing import Annotated +from uuid import UUID +from data_chain.entities.request_data import ( + ListDocumentRequest, + UpdateDocumentRequest +) + +from data_chain.entities.response_data import ( + ListDocumentResponse, + UploadDocumentResponse, + ParseDocumentResponse, + UpdateDocumentResponse, + DeleteDocumentResponse +) +from data_chain.apps.service.session_service import get_user_sub, verify_user router = APIRouter(prefix='/doc', tags=['Document']) -@router.post('/list', response_model=BaseResponse[Page[DocumentDTO]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def list(req: ListDocumentRequest, user_id=Depends(get_user_id)): - try: - await _validate_knowledge_base_belong_to_user(user_id, req.kb_id) - params = dict(req) - page_number = req.page_number - page_size = req.page_size - document_list_tuple = await list_documents_by_knowledgebase_id(params, page_number, page_size) - document_page = Page(page_number=req.page_number, page_size=req.page_size, - total=document_list_tuple[1], - data_list=document_list_tuple[0]) - return BaseResponse(data=document_page) - except Exception as e: - return BaseResponse(retcode=ErrorCode.LIST_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/update', response_model=BaseResponse[DocumentDTO], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def update(req: UpdateDocumentRequest, user_id=Depends(get_user_id)): - try: - await _validate_doucument_belong_to_user(user_id, req.id) - tmp_dict = dict(req) - document = await update_document(tmp_dict) - return BaseResponse(data=document) - except Exception as e: - return BaseResponse(retcode=ErrorCode.RENAME_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/run', response_model=BaseResponse[List[DocumentDTO]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def run(reqs: RunDocumentRequest, user_id=Depends(get_user_id)): - try: - run = reqs.run - ids = reqs.ids - document_dto_list = [] - for req_id in ids: - await _validate_doucument_belong_to_user(user_id, req_id) - document = await run_document(dict(id=req_id, run=run)) - document_dto_list.append(document) - return BaseResponse(data=document_dto_list) - except Exception as e: - return BaseResponse(retcode=ErrorCode.RUN_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/switch', response_model=BaseResponse[DocumentDTO], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def switch(req: SwitchDocumentRequest, user_id=Depends(get_user_id)): - try: - await _validate_doucument_belong_to_user(user_id, req.id) - document = await switch_document(req.id, req.enabled) - return BaseResponse(data=document) - except Exception as e: - return BaseResponse(retcode=ErrorCode.SWITCH_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/rm', response_model=BaseResponse[int], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def rm(req: DeleteDocumentRequest, user_id=Depends(get_user_id)): - try: - for id in req.ids: - await _validate_doucument_belong_to_user(user_id, id) - deleted_cnt = await delete_document(req.ids) - return BaseResponse(data=deleted_cnt) - except Exception as e: - return BaseResponse(retcode=ErrorCode.DELETE_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/upload', response_model=BaseResponse[List[str]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def upload(kb_id: str, files: List[UploadFile] = File(...), user_id=Depends(get_user_id)): - MAX_FILES = 128 - MAX_SIZE = 50 * 1024 * 1024 - MAX_TOTAL_SIZE = 500 * 1024 * 1024 - if len(files) > MAX_FILES: - raise HTTPException(status_code=400, detail="Too many files. Maximum allowed is 50.") - - total_size = 0 - for file in files: - if file.size > MAX_SIZE: - raise HTTPException(status_code=400, detail="File size exceeds the limit (25MB).") - total_size += file.size - - if total_size > MAX_TOTAL_SIZE: - raise HTTPException(status_code=400, detail="Total size of all files exceeds the limit (500MB).") - try: - await _validate_knowledge_base_belong_to_user(user_id, kb_id) - res = await submit_upload_document_task(user_id, kb_id, files) - return BaseResponse(data=res) - except Exception as e: - return BaseResponse(retcode=ErrorCode.UPLOAD_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.get('/download', response_model=BaseResponse[Dict], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def download(id: uuid.UUID, user_id=Depends(get_user_id)): - try: - await _validate_doucument_belong_to_user(user_id, id) - document_link_url = await generate_document_download_link(id) - document_name, extension = await get_file_name_and_extension(id) - async with AsyncClient() as async_client: - response = await async_client.get(document_link_url) - if response.status_code == 200: - content_disposition = f"attachment; filename={urllib.parse.quote(document_name.encode('utf-8'))}" - - async def stream_generator(): - async for chunk in response.aiter_bytes(chunk_size=8192): - yield chunk - - return StreamingResponse(stream_generator(), headers={ - "Content-Disposition": content_disposition, - "Content-Length": str(response.headers.get('content-length')) - }, media_type="application/" + extension) - else: - return BaseResponse( - retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, retmsg="Failed to retrieve the file.", data=None) - except Exception as e: - return BaseResponse(retcode=ErrorCode.DOWNLOAD_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/temporary/related', response_model=BaseResponse[List[uuid.UUID]]) -async def related_temporary_doc(req: RelatedTemporaryDocumenRequest): - try: - results = await get_related_document(req.content, req.top_k, req.document_ids, req.kb_sn) - return BaseResponse(data=results) - except Exception as e: - return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, retmsg=str(e), data=None) - - -@router.post('/temporary/parser', response_model=BaseResponse[List[uuid.UUID]]) -async def parser_temporary_doc(req: ParserTemporaryDocumenRequest): - try: - temporary_document_list = [] - for i in range(len(req.document_list)): - tmp_dict = dict(req.document_list[i]) - if tmp_dict['type'] == 'application/pdf': - tmp_dict['type'] = '.pdf' - elif tmp_dict['type'] == 'text/html': - tmp_dict['type'] = '.html' - elif tmp_dict['type'] == 'text/plain': - tmp_dict['type'] = '.txt' - elif tmp_dict['type'] == 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': - tmp_dict['type'] = '.xlsx' - elif tmp_dict['type'] == 'text/x-markdown': - tmp_dict['type'] = '.md' - elif tmp_dict['type'] == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': - tmp_dict['type'] = '.docx' - elif tmp_dict['type'] == 'application/msword': - tmp_dict['type'] = '.doc' - elif tmp_dict['type'] == 'application/vnd.openxmlformats-officedocument.presentationml.presentation': - tmp_dict['type'] = '.pptx' - temporary_document_list.append(tmp_dict) - result = await init_temporary_document_parse_task(temporary_document_list) - return BaseResponse(data=result) - except Exception as e: - return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, retmsg=str(e), data=None) - - -@router.post('/temporary/status', response_model=BaseResponse[List[TemporaryDocumentDTO]]) -async def get_temporary_doc_parse_status(req: GetTemporaryDocumentStatusRequest): - try: - result = await get_temporary_document_parse_status(req.ids) - return BaseResponse(data=result) - except Exception as e: - return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, retmsg=str(e), data=None) - - -@router.post('/temporary/delete', response_model=BaseResponse[List[uuid.UUID]]) -async def delete_temporary_doc(req: DeleteTemporaryDocumentRequest): - try: - result = await delete_temporary_document(req.ids) - return BaseResponse(data=result) - except Exception as e: - return BaseResponse(retcode=status.HTTP_500_INTERNAL_SERVER_ERROR, retmsg=str(e), data=None) +@router.get('', response_model=ListDocumentResponse, dependencies=[Depends(verify_user)]) +async def list_doc( + user_sub: Annotated[str, Depends(get_user_sub)], + req: Annotated[ListDocumentRequest, Body()] +): + return ListDocumentResponse() + + +@router.get('/download', dependencies=[Depends(verify_user)]) +async def download_doc_by_id( + user_sub: Annotated[str, Depends(get_user_sub)], + doc_id: Annotated[UUID, Query(alias="docId")]): + # try: + # await _validate_doucument_belong_to_user(user_id, id) + # document_link_url = await generate_document_download_link(id) + # document_name, extension = await get_file_name_and_extension(id) + # async with AsyncClient() as async_client: + # response = await async_client.get(document_link_url) + # if response.status_code == 200: + # content_disposition = f"attachment; filename={urllib.parse.quote(document_name.encode('utf-8'))}" + + # async def stream_generator(): + # async for chunk in response.aiter_bytes(chunk_size=8192): + # yield chunk + + # return StreamingResponse(stream_generator(), headers={ + # "Content-Disposition": content_disposition, + # "Content-Length": str(response.headers.get('content-length')) + # }, media_type="application/" + extension) + # else: + # return BaseResponse( + # retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, retmsg="Failed to retrieve the file.", data=None) + # except Exception as e: + # return BaseResponse(retcode=ErrorCode.DOWNLOAD_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) + pass + + +@router.post('', response_model=UploadDocumentResponse, dependencies=[Depends(verify_user)]) +async def upload_docs( + user_sub: Annotated[str, Depends(get_user_sub)], + kb_id: Annotated[UUID, Query(alias="kbId")], + docs: list[UploadFile] = File(...)): + return UploadDocumentResponse() + + +@router.post('/parse', response_model=ParseDocumentResponse, dependencies=[Depends(verify_user)]) +async def parse_docuement_by_doc_ids( + user_sub: Annotated[str, Depends(get_user_sub)], + doc_ids: Annotated[list[UUID], Query(alias="docIds")], + parse: Annotated[bool, Query()]): + return ParseDocumentResponse() + + +@router.put('', response_model=UpdateDocumentResponse, dependencies=[Depends(verify_user)]) +async def update_doc_by_doc_id( + user_sub: Annotated[str, Depends(get_user_sub)], + doc_id: Annotated[UUID, Query(alias="docId")], + req: Annotated[UpdateDocumentRequest, Body()]): + return UpdateDocumentResponse() + + +@router.delete('', response_model=DeleteDocumentResponse, dependencies=[Depends(verify_user)]) +async def delete_docs_by_ids( + user_sub: Annotated[str, Depends(get_user_sub)], + doc_ids: Annotated[list[UUID], Query(alias="docIds")]): + return DeleteDocumentResponse() diff --git a/data_chain/apps/router/knowledge_base.py b/data_chain/apps/router/knowledge_base.py index b04f1613d45ba2546b42ec0f2c21257c26aecd05..80f2d9e674f6520040edb6cf181b6a1cda3b61dd 100644 --- a/data_chain/apps/router/knowledge_base.py +++ b/data_chain/apps/router/knowledge_base.py @@ -1,325 +1,111 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -import urllib -import uuid -import time -from typing import List -import uuid -from httpx import AsyncClient -from fastapi import APIRouter, File, UploadFile, status, HTTPException -from fastapi import Depends -from fastapi.responses import StreamingResponse, HTMLResponse, Response -from data_chain.logger.logger import logger as logging -from data_chain.apps.service.user_service import verify_csrf_token, get_user_id, verify_user -from data_chain.exceptions.err_code import ErrorCode -from data_chain.exceptions.exception import KnowledgeBaseException -from data_chain.models.api import Page, BaseResponse, ExportKnowledgeBaseRequest, \ - CreateKnowledgeBaseRequest, DeleteKnowledgeBaseRequest, ListKnowledgeBaseRequest, StopTaskRequest, \ - UpdateKnowledgeBaseRequest, RmoveTaskRequest, ListTaskRequest, QueryRequest -from data_chain.apps.service.knwoledge_base_service import _validate_knowledge_base_belong_to_user, \ - create_knowledge_base, list_knowledge_base, rm_knowledge_base, generate_knowledge_base_download_link, submit_import_knowledge_base_task, \ - update_knowledge_base, list_knowledge_base_task, stop_knowledge_base_task, submit_export_knowledge_base_task, rm_knowledge_base_task, rm_all_knowledge_base_task -from data_chain.apps.service.model_service import get_model_by_kb_id -from data_chain.models.constant import KnowledgeLanguageEnum, TaskConstant -from data_chain.models.service import KnowledgeBaseDTO -from data_chain.apps.service.task_service import _validate_task_belong_to_user -from data_chain.apps.service.llm_service import question_rewrite, get_llm_answer, filter_stopwords -from data_chain.config.config import config -from data_chain.apps.service.chunk_service import get_similar_chunks, split_chunk -router = APIRouter(prefix='/kb', tags=['Knowledge Base']) - - -@router.post('/create', response_model=BaseResponse[KnowledgeBaseDTO], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def create(req: CreateKnowledgeBaseRequest, user_id=Depends(get_user_id)): - try: - tmp_dict = dict(req) - tmp_dict['user_id'] = user_id - knowledge_base = await create_knowledge_base(tmp_dict) - return BaseResponse(data=knowledge_base) - except Exception as e: - logging.error(f"Create knowledge base failed due to: {e}") - return BaseResponse(retcode=ErrorCode.CREATE_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/update', response_model=BaseResponse[KnowledgeBaseDTO], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def update(req: UpdateKnowledgeBaseRequest, user_id=Depends(get_user_id)): - try: - update_dict = dict(req) - update_dict['user_id'] = user_id - knowledge_base = await update_knowledge_base(update_dict) - return BaseResponse(data=knowledge_base) - except Exception as e: - logging.error(f"Update knowledge base failed due to: {e}") - return BaseResponse(retcode=ErrorCode.UPDATE_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/list', response_model=BaseResponse[Page[KnowledgeBaseDTO]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def list(req: ListKnowledgeBaseRequest, user_id=Depends(get_user_id)): - try: - params = dict(req) - params['user_id'] = user_id - page_number = req.page_number - page_size = req.page_size - knowledge_base_list_tuple = await list_knowledge_base(params, page_number, page_size) - knowledge_base_page = Page(page_number=req.page_number, page_size=req.page_size, - total=knowledge_base_list_tuple[1], - data_list=knowledge_base_list_tuple[0]) - return BaseResponse(data=knowledge_base_page) - except Exception as e: - logging.error(f"List knowledge base failed due to: {e}") - return BaseResponse(retcode=ErrorCode.LIST_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/rm', response_model=BaseResponse[bool], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def rm(req: DeleteKnowledgeBaseRequest, user_id=Depends(get_user_id)): - try: - await _validate_knowledge_base_belong_to_user(user_id, req.id) - res = await rm_knowledge_base(req.id) - return BaseResponse(data=res) - except Exception as e: - logging.error(f"Rmove knowledge base failed due to: {e}") - return BaseResponse(retcode=ErrorCode.DELETE_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) - - -@router.post('/import', response_model=BaseResponse[List[str]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def import_(files: List[UploadFile] = File(...), user_id=Depends(get_user_id)): - try: - res = await submit_import_knowledge_base_task(user_id, files) - return BaseResponse(data=res) - except Exception as e: - logging.error(f"Import knowledge base failed due to: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) - - -@router.post('/export', response_model=BaseResponse[str], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def export(req: ExportKnowledgeBaseRequest, user_id=Depends(get_user_id)): - try: - await _validate_knowledge_base_belong_to_user(user_id, req.id) - res = await submit_export_knowledge_base_task(user_id, req.id) - return BaseResponse(data=res) - except Exception as e: - logging.error(f"Export knowledge base failed due to: {e}") - return BaseResponse(retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, retmsg=str(e.args[0]), data=None) +from fastapi import APIRouter, Depends, Query, Body, File, UploadFile +from typing import Annotated +from uuid import UUID +from data_chain.entities.request_data import ( + ListKnowledgeBaseRequest, + CreateKnowledgeBaseRequest, + UpdateKnowledgeBaseRequest, +) + +from data_chain.entities.response_data import ( + ListAllKnowledgeBaseResponse, + ListKnowledgeBaseResponse, + ListDocumentTypesResponse, + CreateKnowledgeBaseResponse, + ImportKnowledgeBaseResponse, + ExportKnowledgeBaseResponse, + UpdateKnowledgeBaseResponse, + DeleteKnowledgeBaseResponse, +) +from data_chain.apps.service.session_service import get_user_sub, verify_user - -@router.get('/download', dependencies=[Depends(verify_user), Depends(verify_csrf_token)]) -async def download(task_id: uuid.UUID, user_id=Depends(get_user_id)): - try: - await _validate_task_belong_to_user(user_id, task_id) - zip_download_url = await generate_knowledge_base_download_link(task_id) - if not zip_download_url: - return BaseResponse( - retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, - retmsg="zip download url is empty", - data=None - ) - async with AsyncClient() as async_client: - response = await async_client.get(zip_download_url) - if response.status_code == 200: - # 保持 response 对象打开直到所有数据都被发送 - zip_name = f"{task_id}.zip" - content_disposition = f"attachment; filename={urllib.parse.quote(zip_name.encode('utf-8'))}" - content_length = response.headers.get('content-length') - - # 定义一个协程函数来生成数据流 - async def stream_generator(): - async for chunk in response.aiter_bytes(chunk_size=8192): - yield chunk - - return StreamingResponse( - stream_generator(), - headers={ - "Content-Disposition": content_disposition, - "Content-Length": str(content_length) if content_length else None - }, - media_type="application/zip" - ) - else: - return BaseResponse( - retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, - retmsg="Failed to retrieve the file.", - data=None - ) - except Exception as e: - logging.error(f"Download knowledge base zip failed due to: {e}") - return BaseResponse( - retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, - retmsg=str(e.args[0]), - data=None - ) - - -@router.get('/language', response_model=BaseResponse[List[str]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -def language(): - return BaseResponse(data=KnowledgeLanguageEnum.get_all_values()) - - -@router.post('/task/list', response_model=BaseResponse[Page[KnowledgeBaseDTO]], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def list_kb_task(req: ListTaskRequest, user_id=Depends(get_user_id)): - try: - params = dict(req) - params['user_id'] = user_id - if 'types' not in params.keys(): - params['types'] = [TaskConstant.EXPORT_KNOWLEDGE_BASE, TaskConstant.IMPORT_KNOWLEDGE_BASE] - total, knowledge_dto_list = await list_knowledge_base_task(req.page_number, req.page_size, params) - knowledge_base_page = Page(page_number=req.page_number, page_size=req.page_size, - total=total, - data_list=knowledge_dto_list) - return BaseResponse(data=knowledge_base_page) - except Exception as e: - logging.error(f"List knowledge base task error due to: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) - - -@router.post('/task/rm', response_model=BaseResponse[bool], - dependencies=[Depends(verify_user), - Depends(verify_csrf_token)]) -async def rm_kb_task(req: RmoveTaskRequest, user_id=Depends(get_user_id)): - try: - if req.task_id is not None: - await _validate_task_belong_to_user(user_id, req.task_id) - res = await rm_knowledge_base_task(req.task_id) - else: - if req.types is None: - types = [TaskConstant.EXPORT_KNOWLEDGE_BASE, TaskConstant.IMPORT_KNOWLEDGE_BASE] - else: - types = req.types - res = await rm_all_knowledge_base_task(user_id, types) - return BaseResponse(data=res) - except Exception as e: - logging.error(f"Remove knowledge base task failed due to: {e}") - return BaseResponse(retcode=ErrorCode.STOP_KNOWLEDGE_BASE_TASK_ERROR, retmsg=e.args[0], data=None) - - -@router.post('/get_stream_answer', response_class=HTMLResponse) -async def get_stream_answer(req: QueryRequest, response: Response): - model_dto = None - if req.kb_sn is not None: - model_dto = await get_model_by_kb_id(req.kb_sn) - if model_dto is None: - if len(config['MODELS']) > 0: - tokens_upper = config['MODELS'][0]['MAX_TOKENS'] - else: - logging.error("Can not find model config locally") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Can not find model config locally") - else: - tokens_upper = model_dto.max_tokens - try: - question = await question_rewrite(req.history, req.question, model_dto) - max_tokens = tokens_upper//3*2 - bac_info = '' - document_chunk_list = await get_similar_chunks(content=question, kb_id=req.kb_sn, temporary_document_ids=req.document_ids, max_tokens=2*tokens_upper, topk=req.top_k) - for i in range(len(document_chunk_list)): - document_name = document_chunk_list[i]['document_name'] - chunk_list = document_chunk_list[i]['chunk_list'] - bac_info += '文档名称:'+document_name+':\n\n' - for j in range(len(chunk_list)): - bac_info += '段落'+str(j)+':\n\n' - bac_info += chunk_list[j]+'\n\n' - bac_info = split_chunk(bac_info) - if len(bac_info) > max_tokens: - bac_info = '' - for i in range(len(document_chunk_list)): - document_name = document_chunk_list[i]['document_name'] - chunk_list = document_chunk_list[i]['chunk_list'] - bac_info += '文档名称:'+document_name+':\n\n' - for j in range(len(chunk_list)): - bac_info += '段落'+str(j)+':\n\n' - bac_info += ''.join(filter_stopwords(chunk_list[j]))+'\n\n' - bac_info = split_chunk(bac_info) - bac_info = bac_info[:max_tokens] - bac_info = ''.join(bac_info) - except Exception as e: - bac_info = '' - logging.error(f"get bac info failed due to: {e}") - try: - response.headers["Content-Type"] = "text/event-stream" - res = await get_llm_answer(req.history, bac_info, req.question, is_stream=True, model_dto=model_dto) - return StreamingResponse( - res, - status_code=status.HTTP_200_OK, - headers=response.headers - ) - except Exception as e: - logging.error(f"Get stream answer failed due to: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) +router = APIRouter(prefix='/kb', tags=['Knowledge Base']) -@router.post('/get_answer', response_model=BaseResponse[dict]) -async def get_answer(req: QueryRequest): - model_dto = None - if req.kb_sn is not None: - model_dto = await get_model_by_kb_id(req.kb_sn) - if model_dto is None: - if len(config['MODELS']) > 0: - tokens_upper = config['MODELS'][0]['MAX_TOKENS'] - else: - logging.error("Can not find model config locally") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Can not find model config locally") - else: - tokens_upper = model_dto.max_tokens - try: - question = await question_rewrite(req.history, req.question, model_dto) - max_tokens = tokens_upper//3*2 - bac_info = '' - t_cost_dict, document_chunk_list = await get_similar_chunks(content=question, kb_id=req.kb_sn, temporary_document_ids=req.document_ids, max_tokens=2*tokens_upper, topk=req.top_k, return_t_cost=True) - for i in range(len(document_chunk_list)): - document_name = document_chunk_list[i]['document_name'] - chunk_list = document_chunk_list[i]['chunk_list'] - bac_info += '文档名称:'+document_name+':\n\n' - for j in range(len(chunk_list)): - bac_info += '段落'+str(j)+':\n\n' - bac_info += chunk_list[j]+'\n\n' - bac_info = split_chunk(bac_info) - if len(bac_info) > max_tokens: - bac_info = '' - for i in range(len(document_chunk_list)): - document_name = document_chunk_list[i]['document_name'] - chunk_list = document_chunk_list[i]['chunk_list'] - bac_info += '文档名称:'+document_name+':\n\n' - for j in range(len(chunk_list)): - bac_info += '段落'+str(j)+':\n\n' - bac_info += ''.join(filter_stopwords(chunk_list[j]))+'\n\n' - bac_info = split_chunk(bac_info) - bac_info = bac_info[:max_tokens] - bac_info = ''.join(bac_info) - except Exception as e: - bac_info = '' - logging.error(f"get bac info failed due to: {e}") - try: - st = time.time() - answer = await get_llm_answer(req.history, bac_info, req.question, is_stream=False, model_dto=model_dto) - t_cost_dict['llm_answer'] = time.time()-st - tmp_dict = { - 'answer': answer, - 'time_cost': t_cost_dict - } - if req.fetch_source: - tmp_dict['source'] = [] - for i in range(len(document_chunk_list)): - document_name = document_chunk_list[i]['document_name'] - chunk_list = document_chunk_list[i]['chunk_list'] - for j in range(len(chunk_list)): - tmp_dict['source'].append({'document_name': document_name, 'chunk': chunk_list[j]}) - return BaseResponse(data=tmp_dict) - except Exception as e: - logging.error(f"Get stream answer failed due to: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) +@router.get('/', response_model=ListAllKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +async def list_kb_by_user_sub( + user_sub: Annotated[str, Depends(get_user_sub)] +): + return ListAllKnowledgeBaseResponse() + + +@router.get('/team', response_model=ListKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +async def list_kb_by_team_id( + user_sub: Annotated[str, Depends(get_user_sub)], + req: Annotated[ListKnowledgeBaseRequest, Body()] +): + return ListKnowledgeBaseResponse() + + +@router.get('/doc_type', response_model=ListDocumentTypesResponse, dependencies=[Depends(verify_user)]) +async def list_doc_types_by_team_id( + user_sub: Annotated[str, Depends(get_user_sub)], + team_id: Annotated[UUID, Query(alias="teamId")], +): + return ListDocumentTypesResponse() + + +@router.get('/download', dependencies=[Depends(verify_user)]) +async def download_kb_by_task_id( + user_sub: Annotated[str, Depends(get_user_sub)], + task_id: Annotated[UUID, Query(alias="taskId")]): + # try: + # await _validate_doucument_belong_to_user(user_id, id) + # document_link_url = await generate_document_download_link(id) + # document_name, extension = await get_file_name_and_extension(id) + # async with AsyncClient() as async_client: + # response = await async_client.get(document_link_url) + # if response.status_code == 200: + # content_disposition = f"attachment; filename={urllib.parse.quote(document_name.encode('utf-8'))}" + + # async def stream_generator(): + # async for chunk in response.aiter_bytes(chunk_size=8192): + # yield chunk + + # return StreamingResponse(stream_generator(), headers={ + # "Content-Disposition": content_disposition, + # "Content-Length": str(response.headers.get('content-length')) + # }, media_type="application/" + extension) + # else: + # return BaseResponse( + # retcode=ErrorCode.EXPORT_KNOWLEDGE_BASE_ERROR, retmsg="Failed to retrieve the file.", data=None) + # except Exception as e: + # return BaseResponse(retcode=ErrorCode.DOWNLOAD_DOCUMENT_ERROR, retmsg=str(e.args[0]), data=None) + pass + + +@router.post('', response_model=CreateKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +async def create_kb(user_sub: Annotated[str, Depends(get_user_sub)], + req: Annotated[CreateKnowledgeBaseRequest, Body()]): + return CreateKnowledgeBaseResponse() + + +@router.post('/import', response_model=ImportKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +async def import_kb(user_sub: Annotated[str, Depends(get_user_sub)], + team_id: Annotated[UUID, Query(alias="teamId")], + kb_packages: list[UploadFile] = File(...)): + return ImportKnowledgeBaseResponse() + + +@router.post('/export', response_model=ExportKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +async def export_kb_by_kb_ids( + user_sub: Annotated[str, Depends(get_user_sub)], + kb_ids: Annotated[list[UUID], Query(alias="kbIds")]): + return ExportKnowledgeBaseResponse() + + +@router.put('', response_model=UpdateKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +async def update_kb_by_kb_id( + user_sub: Annotated[str, Depends(get_user_sub)], + kb_id: Annotated[UUID, Query(alias="kbId")], + req: Annotated[UpdateKnowledgeBaseRequest, Body()]): + return UpdateKnowledgeBaseResponse() + + +@router.delete('', response_model=DeleteKnowledgeBaseResponse, dependencies=[Depends(verify_user)]) +async def delete_kb_by_kb_ids( + user_sub: Annotated[str, Depends(get_user_sub)], + kb_ids: Annotated[list[UUID], Query(alias="kbIds")]): + return DeleteKnowledgeBaseResponse()