From 73cefec4ea2504f40a6a0f8180d11a10569b2f1f Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 16 May 2025 15:38:31 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=8A=E4=B8=8B=E6=96=87?= =?UTF-8?q?=E8=A1=A5=E5=85=A8=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/service/chunk_service.py | 2 +- data_chain/manager/chunk_manager.py | 4 ++-- data_chain/rag/base_searcher.py | 11 ++++------- requirements.txt | 1 + 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index 6098afe..49cf06d 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -88,7 +88,6 @@ class ChunkService: if req.is_rerank: chunk_entities = await BaseSearcher.rerank(chunk_entities, req.query) chunk_entities = chunk_entities[:req.top_k] - chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities] if req.is_related_surrounding: # 关联上下文 @@ -119,6 +118,7 @@ class ChunkService: for chunk in doc_chunk.chunks: if req.is_compress: chunk.text = TokenTool.compress_tokens(chunk.text) + search_chunk_msg.doc_chunks.append(doc_chunk) else: for chunk_entity in chunk_entities: chunk = await Convertor.convert_chunk_entity_to_chunk(chunk_entity) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 13ef02b..b0b3654 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -185,7 +185,7 @@ class ChunkManager(): .order_by(ChunkEntity.text_vector.cosine_distance(vector).desc()) .limit(top_k) ) - if doc_ids: + if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type: stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) @@ -233,7 +233,7 @@ class ChunkManager(): ) .limit(top_k) ) - if doc_ids: + if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type: stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) diff --git a/data_chain/rag/base_searcher.py b/data_chain/rag/base_searcher.py index 7a33f3d..72e65a9 100644 --- a/data_chain/rag/base_searcher.py +++ b/data_chain/rag/base_searcher.py @@ -67,6 +67,8 @@ class BaseSearcher: """ chunk_entities = await ChunkManager.fetch_surrounding_chunk_by_doc_id_and_global_offset(chunk_entity.doc_id, chunk_entity.global_offset, 50, banned_ids) chunk_entity_dict = {} + lower = chunk_entity.global_offset-1 + upper = chunk_entity.global_offset+1 min_offset = chunk_entity.global_offset max_offset = chunk_entity.global_offset for chunk_entity in chunk_entities: @@ -75,8 +77,6 @@ class BaseSearcher: if chunk_entity.global_offset > max_offset: max_offset = chunk_entity.global_offset chunk_entity_dict[chunk_entity.global_offset] = chunk_entity - lower = chunk_entity.global_offset-1 - upper = chunk_entity.global_offset+1 related_chunk_entities = [] tokens_sub = 0 tokens_sum = 0 @@ -116,7 +116,7 @@ class BaseSearcher: lower -= 1 else: if chunk_entity_dict.get(upper) is not None: - tokens_sub += chunk_entity_dict[upper].tokens + tokens_sub -= chunk_entity_dict[upper].tokens related_chunk_entities.append(chunk_entity_dict[upper]) tokens_sum += chunk_entity_dict[upper].tokens upper += 1 @@ -149,10 +149,7 @@ class BaseSearcher: for chunk_entity in chunk_entities: if chunk_entity.doc_id not in doc_chunk_dict: doc_chunk_dict[chunk_entity.doc_id] = DocChunk( - doc_id=chunk_entity.doc_id, doc_name=chunk_entity.doc_name, chunks=[]) + docId=chunk_entity.doc_id, docName=chunk_entity.doc_name, chunks=[]) chunk = await Convertor.convert_chunk_entity_to_chunk(chunk_entity) - if chunk_entity.doc_id not in doc_chunk_dict: - doc_chunk_dict[chunk_entity.doc_id] = DocChunk( - doc_id=chunk_entity.doc_id, doc_name=chunk_entity.doc_name, chunks=[]) doc_chunk_dict[chunk_entity.doc_id].chunks.append(chunk) return list(doc_chunk_dict.values()) diff --git a/requirements.txt b/requirements.txt index e01b146..4f0bb33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,6 +33,7 @@ python-multipart==0.0.9 python-pptx==0.6.23 pytz==2020.5 pyyaml==6.0.1 +pymongo==4.12.1 redis==5.0.3 requests==2.32.2 scikit-learn==1.5.0 -- Gitee