From b90f5070224e38d75bf1f4b911ff89efca72e77a Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 3 Jun 2025 18:56:20 +0800 Subject: [PATCH 01/15] =?UTF-8?q?=E5=A2=9E=E5=8A=A0query=E6=94=B9=E5=86=99?= =?UTF-8?q?=E7=9A=84=E6=A3=80=E7=B4=A2=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/app.py | 3 +- data_chain/common/prompt.yaml | 18 ++++++ data_chain/entities/enum.py | 1 + data_chain/rag/query_extend_searcher.py | 73 +++++++++++++++++++++++++ 4 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 data_chain/rag/query_extend_searcher.py diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index ed4821a..b8071dd 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -53,7 +53,8 @@ from data_chain.rag import ( keyword_and_vector_searcher, doc2chunk_searcher, doc2chunk_bfs_searcher, - enhanced_by_llm_searcher + enhanced_by_llm_searcher, + query_extend_searcher ) from data_chain.stores.database.database import ( DataBase, diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 9ce98d6..29e9c85 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -351,4 +351,22 @@ CHUNK_QUERY_MATCH_PROMPT: '你是一个文本分析专家,你的任务是根 下面是给出的片段和问题: 片段:{chunk} 问题:{question} + ' + QUERY_EXTEND_PROMPT: '你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题 + 注意: + #01 扩写的问题的内容必须来源于原问题中的内容 + #02 扩写的问题长度不超过50个字 + #03 可以通过近义词替换 问题内词序交换 修改英文大小写等方式来改写问题 + #04 请仅输出扩写的问题列表,不要输出其他内容 + 例子: + 输入:openEuler是什么操作系统? + 输出: + [ + "openEuler是一个什么样的操作系统?", + "openEuler操作系统的特点是什么?", + "openEuler操作系统有哪些功能?", + "openEuler操作系统的优势是什么?" + ] + 下面是给出的问题: + {question} ' \ No newline at end of file diff --git a/data_chain/entities/enum.py b/data_chain/entities/enum.py index d59ce12..70e488d 100644 --- a/data_chain/entities/enum.py +++ b/data_chain/entities/enum.py @@ -160,6 +160,7 @@ class SearchMethod(str, Enum): DOC2CHUNK = "doc2chunk" DOC2CHUNK_BFS = "doc2chunk_bfs" ENHANCED_BY_LLM = "enhanced_by_llm" + QUERY_EXTEND = "query_extend" class TaskType(str, Enum): diff --git a/data_chain/rag/query_extend_searcher.py b/data_chain/rag/query_extend_searcher.py new file mode 100644 index 0000000..0d8f0ff --- /dev/null +++ b/data_chain/rag/query_extend_searcher.py @@ -0,0 +1,73 @@ +import asyncio +import uuid +import yaml +from pydantic import BaseModel, Field +import random +import json +from data_chain.logger.logger import logger as logging +from data_chain.stores.database.database import ChunkEntity +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.manager.chunk_manager import ChunkManager +from data_chain.rag.base_searcher import BaseSearcher +from data_chain.embedding.embedding import Embedding +from data_chain.entities.enum import SearchMethod +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.llm.llm import LLM +from data_chain.config.config import config + + +class QueryExtendSearcher(BaseSearcher): + """ + 基于查询扩展的搜索 + """ + name = SearchMethod.QUERY_EXTEND.value + + @staticmethod + async def search( + query: str, kb_id: uuid.UUID, top_k: int = 5, doc_ids: list[uuid.UUID] = None, + banned_ids: list[uuid.UUID] = [] + ) -> list[ChunkEntity]: + """ + 向量检索 + :param query: 查询 + :param top_k: 返回的结果数量 + :return: 检索结果 + """ + with open('./data_chain/common/prompt.yaml', 'r', encoding='utf-8') as f: + prompt_dict = yaml.safe_load(f) + prompt_template = prompt_dict['QUERY_EXTEND_PROMPT'] + chunk_entities = [] + llm = LLM( + openai_api_key=config['OPENAI_API_KEY'], + openai_api_base=config['OPENAI_API_BASE'], + model_name=config['MODEL_NAME'], + max_tokens=config['MAX_TOKENS'], + ) + sys_call = prompt_template.format(k=2*top_k, query=query) + user_call = "请输出扩写的问题列表" + queries = await llm.nostream([], sys_call, user_call) + try: + queries = json.loads(queries) + queries += [query] + except Exception as e: + logging.error(f"[QueryExtendSearcher] JSON解析失败,error: {e}") + queries = [query] + queries = list(set(queries)) + chunk_entities = [] + for query in queries: + vector = await Embedding.vectorize_embedding(query) + chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, 2, doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] + chunk_entities_get_by_vector = [] + for _ in range(3): + try: + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3) + break + except Exception as e: + err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" + logging.error(err) + continue + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_vector] + sub_chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_vector + chunk_entities += sub_chunk_entities + return chunk_entities -- Gitee From 0f3709cf24d5d5290f9ecd16a09b4913ef561ba4 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 3 Jun 2025 19:40:27 +0800 Subject: [PATCH 02/15] =?UTF-8?q?=E6=89=93=E5=8D=B0=E5=8F=A5=E5=AD=90?= =?UTF-8?q?=E5=88=87=E5=89=B2=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apps/base/task/worker/parse_document_worker.py | 2 ++ data_chain/common/prompt.yaml | 10 +++++----- data_chain/rag/query_extend_searcher.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 731a067..6d9d348 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -310,6 +310,8 @@ class ParseDocumentWorker(BaseWorker): tmp = tmp[len(sub_sentence):] else: new_sentences.append(sentence) + for sentence in new_sentences: + logging.error(f"[ParseDocumentWorker] {sentence} {TokenTool.get_tokens(sentence)}") sentences = new_sentences for sentence in sentences: if TokenTool.get_tokens(tmp) + TokenTool.get_tokens(sentence) > doc_entity.chunk_size: diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 29e9c85..611949e 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -352,7 +352,7 @@ CHUNK_QUERY_MATCH_PROMPT: '你是一个文本分析专家,你的任务是根 片段:{chunk} 问题:{question} ' - QUERY_EXTEND_PROMPT: '你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题 +QUERY_EXTEND_PROMPT: '你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题 注意: #01 扩写的问题的内容必须来源于原问题中的内容 #02 扩写的问题长度不超过50个字 @@ -362,10 +362,10 @@ CHUNK_QUERY_MATCH_PROMPT: '你是一个文本分析专家,你的任务是根 输入:openEuler是什么操作系统? 输出: [ - "openEuler是一个什么样的操作系统?", - "openEuler操作系统的特点是什么?", - "openEuler操作系统有哪些功能?", - "openEuler操作系统的优势是什么?" + \"openEuler是一个什么样的操作系统?\", + \"openEuler操作系统的特点是什么?\", + \"openEuler操作系统有哪些功能?\", + \"openEuler操作系统的优势是什么?\" ] 下面是给出的问题: {question} diff --git a/data_chain/rag/query_extend_searcher.py b/data_chain/rag/query_extend_searcher.py index 0d8f0ff..ebd8bf5 100644 --- a/data_chain/rag/query_extend_searcher.py +++ b/data_chain/rag/query_extend_searcher.py @@ -43,7 +43,7 @@ class QueryExtendSearcher(BaseSearcher): model_name=config['MODEL_NAME'], max_tokens=config['MAX_TOKENS'], ) - sys_call = prompt_template.format(k=2*top_k, query=query) + sys_call = prompt_template.format(k=2*top_k, question=query) user_call = "请输出扩写的问题列表" queries = await llm.nostream([], sys_call, user_call) try: -- Gitee From 3d1922b29594183931d10a240f6c21f32ec51024 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 3 Jun 2025 19:57:13 +0800 Subject: [PATCH 03/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E8=A7=A3=E6=9E=90=E6=97=B6=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/base/task/worker/parse_document_worker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 6d9d348..b09274e 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -310,8 +310,6 @@ class ParseDocumentWorker(BaseWorker): tmp = tmp[len(sub_sentence):] else: new_sentences.append(sentence) - for sentence in new_sentences: - logging.error(f"[ParseDocumentWorker] {sentence} {TokenTool.get_tokens(sentence)}") sentences = new_sentences for sentence in sentences: if TokenTool.get_tokens(tmp) + TokenTool.get_tokens(sentence) > doc_entity.chunk_size: @@ -353,6 +351,7 @@ class ParseDocumentWorker(BaseWorker): link_nodes=[] ) nodes.append(tmp_node) + parse_result.nodes = nodes @staticmethod async def push_up_words_feature(parse_result: ParseResult, llm: LLM = None) -> None: -- Gitee From 7d988998379dc504b99ec60b4355f405e85fc990 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 3 Jun 2025 20:32:29 +0800 Subject: [PATCH 04/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dheader=E6=A0=87?= =?UTF-8?q?=E7=AD=BE=E8=A7=A3=E6=9E=90=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/base/task/worker/generate_dataset_worker.py | 1 + data_chain/parser/handler/html_parser.py | 6 +++--- data_chain/parser/handler/md_parser.py | 6 +++--- data_chain/parser/handler/md_zip_parser.py | 5 +++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index 530cbaf..8997947 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -140,6 +140,7 @@ class GenerateDataSetWorker(BaseWorker): answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', '') cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', '') dataset_score = 0 + logging.error(f"{chunk_index_list}") for i in range(len(doc_chunks)): doc_chunk = doc_chunks[i] for j in range(len(doc_chunk.chunks)): diff --git a/data_chain/parser/handler/html_parser.py b/data_chain/parser/handler/html_parser.py index 6c4e060..6ad61ca 100644 --- a/data_chain/parser/handler/html_parser.py +++ b/data_chain/parser/handler/html_parser.py @@ -45,7 +45,7 @@ class HTMLParser(BaseParser): # 获取body元素作为起点(如果是完整HTML文档) root = soup.body if soup.body else soup - + valid_headers = ["h1", "h2", "h3", "h4", "h5", "h6"] # 获取当前层级的直接子元素 current_level_elements = list(root.children) subtree = [] @@ -91,7 +91,7 @@ class HTMLParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name.startswith('h'): + elif element.name in valid_headers: try: level = int(element.name[1:]) except Exception: @@ -101,7 +101,7 @@ class HTMLParser(BaseParser): content_elements = [] while current_level_elements: sibling = current_level_elements[0] - if sibling.name and sibling.name.startswith('h'): + if sibling.name and sibling.name in valid_headers: next_level = int(sibling.name[1:]) else: next_level = level + 1 diff --git a/data_chain/parser/handler/md_parser.py b/data_chain/parser/handler/md_parser.py index e226836..d549584 100644 --- a/data_chain/parser/handler/md_parser.py +++ b/data_chain/parser/handler/md_parser.py @@ -61,6 +61,7 @@ class MdParser(BaseParser): current_level_elements = list(root.children) # 过滤掉非标签节点(如文本节点) subtree = [] + valid_headers = ["h1", "h2", "h3", "h4", "h5", "h6"] while current_level_elements: element = current_level_elements.pop(0) if not isinstance(element, Tag): @@ -100,7 +101,7 @@ class MdParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name.startswith('h'): + elif element.name in valid_headers: try: level = int(element.name[1:]) except Exception: @@ -110,7 +111,7 @@ class MdParser(BaseParser): content_elements = [] while current_level_elements: sibling = current_level_elements[0] - if sibling.name and sibling.name.startswith('h'): + if sibling.name and sibling.name in valid_headers: next_level = int(sibling.name[1:]) else: next_level = level + 1 @@ -224,4 +225,3 @@ class MdParser(BaseParser): nodes=nodes ) return parse_result - diff --git a/data_chain/parser/handler/md_zip_parser.py b/data_chain/parser/handler/md_zip_parser.py index bccdbb9..24bbb4d 100644 --- a/data_chain/parser/handler/md_zip_parser.py +++ b/data_chain/parser/handler/md_zip_parser.py @@ -71,6 +71,7 @@ class MdZipParser(BaseParser): # 获取当前层级的直接子元素 current_level_elements = list(root.children) + valid_headers = ["h1", "h2", "h3", "h4", "h5", "h6"] # 过滤掉非标签节点(如文本节点) subtree = [] while current_level_elements: @@ -112,7 +113,7 @@ class MdZipParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name.startswith('h'): + elif element.name in valid_headers: try: level = int(element.name[1:]) except Exception: @@ -122,7 +123,7 @@ class MdZipParser(BaseParser): content_elements = [] while current_level_elements: sibling = current_level_elements[0] - if sibling.name and sibling.name.startswith('h'): + if sibling.name and sibling.name in valid_headers: next_level = int(sibling.name[1:]) else: next_level = level + 1 -- Gitee From 45c88f6287d1aff235fef12d8401fd1f897524e0 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 3 Jun 2025 20:52:53 +0800 Subject: [PATCH 05/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E7=94=9F=E6=88=90=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apps/base/task/worker/generate_dataset_worker.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index 8997947..7fe1359 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -132,7 +132,9 @@ class GenerateDataSetWorker(BaseWorker): chunk_cnt = len(chunk_index_list) division = data_cnt // chunk_cnt remainder = data_cnt % chunk_cnt + logging.error(f"数据集总条目 {dataset_entity.data_cnt}, 分块数量: {chunk_cnt}, 每块数据量: {division}, 余数: {remainder}") index = 0 + d_index = 0 random.shuffle(doc_chunks) with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) @@ -148,6 +150,7 @@ class GenerateDataSetWorker(BaseWorker): index += 1 continue index += 1 + d_index += 1 chunk = doc_chunk.chunks[j].text if dataset_entity.is_chunk_related: l = j-1 @@ -158,11 +161,11 @@ class GenerateDataSetWorker(BaseWorker): break if tokens_sub > 0: if l >= 0: - tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[l].text) + tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[l].text) chunk = doc_chunks[i].chunks[l].text+chunk l -= 1 else: - tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[r].text) + tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[r].text) chunk += doc_chunks[i].chunks[r].text r += 1 else: @@ -174,7 +177,7 @@ class GenerateDataSetWorker(BaseWorker): tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[l].text) chunk = doc_chunks[i].chunks[l].text+chunk l -= 1 - qa_cnt = division+(index < remainder) + qa_cnt = division+(d_index <= remainder) qs = [] answers = [] rd = 5 @@ -182,7 +185,7 @@ class GenerateDataSetWorker(BaseWorker): rd -= 1 try: sys_call = q_generate_prompt_template.format( - k=qa_cnt, + k=qa_cnt-len(qs), content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens) ) usr_call = '请输出问题的列表' @@ -192,6 +195,7 @@ class GenerateDataSetWorker(BaseWorker): err = f"[GenerateDataSetWorker] 生成问题失败,错误信息: {e}" logging.exception(err) continue + sub_qs = sub_qs[:qa_cnt-len(qs)] sub_answers = [] try: for q in sub_qs: -- Gitee From 911d6d1eef7f37d0961da298639af5367407f972 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 4 Jun 2025 12:13:24 +0800 Subject: [PATCH 06/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=90=91=E9=87=8F?= =?UTF-8?q?=E5=8C=96=E5=8C=B9=E9=85=8D=E5=85=B3=E9=94=AE=E5=AD=97=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/manager/chunk_manager.py | 16 +++++++--------- data_chain/manager/document_manager.py | 12 +++++------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 88e0dc2..dcc88a4 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -182,8 +182,6 @@ class ChunkManager(): .where(ChunkEntity.kb_id == kb_id) .where(ChunkEntity.status != ChunkStatus.DELETED.value) .where(ChunkEntity.id.notin_(banned_ids)) - .order_by(ChunkEntity.text_vector.cosine_distance(vector).desc()) - .limit(top_k) ) if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) @@ -192,6 +190,7 @@ class ChunkManager(): if pre_ids is not None: stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) result = await session.execute(stmt) + stmt = stmt.order_by(ChunkEntity.text_vector.cosine_distance(vector).desc()).limit(top_k) chunk_entities = result.scalars().all() return chunk_entities except Exception as e: @@ -229,13 +228,6 @@ class ChunkManager(): .where(ChunkEntity.kb_id == kb_id) .where(ChunkEntity.status != ChunkStatus.DELETED.value) .where(ChunkEntity.id.notin_(banned_ids)) - .order_by( - func.ts_rank_cd( - func.to_tsvector(tokenizer, ChunkEntity.text), - func.plainto_tsquery(tokenizer, query) - ).desc() - ) - .limit(top_k) ) if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) @@ -243,6 +235,12 @@ class ChunkManager(): stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + stmt = stmt.order_by( + func.ts_rank_cd( + func.to_tsvector(tokenizer, ChunkEntity.text), + func.plainto_tsquery(tokenizer, query) + ).desc() + ).limit(top_k) result = await session.execute(stmt) chunk_entities = result.scalars().all() return chunk_entities diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 17ac0a7..060d889 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -57,11 +57,10 @@ class DocumentManager(): .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.enabled == True) .where(DocumentEntity.abstract_vector.cosine_distance(vector).desc()) - .order_by(DocumentEntity.abstract_vector.cosine_distance(vector).desc()) - .limit(top_k) ) if doc_ids: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) + stmt = stmt.order_by(DocumentEntity.abstract_vector.cosine_distance(vector).desc()).limit(top_k) result = await session.execute(stmt) document_entities = result.scalars().all() return document_entities @@ -91,14 +90,13 @@ class DocumentManager(): .where(DocumentEntity.id.notin_(banned_ids)) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.enabled == True) - .where(func.ts_rank_cd( - func.to_tsvector(tokenizer, DocumentEntity.abstract), - func.plainto_tsquery(tokenizer, query) - ).desc()) - .limit(top_k) ) if doc_ids: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) + stmt = stmt.order_by(func.ts_rank_cd( + func.to_tsvector(tokenizer, DocumentEntity.abstract), + func.plainto_tsquery(tokenizer, query) + ).desc()).limit(top_k) result = await session.execute(stmt) document_entities = result.scalars().all() return document_entities -- Gitee From ca61e73f9f8b952c8ac03f114d025b5635dbc048 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 4 Jun 2025 14:51:05 +0800 Subject: [PATCH 07/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8Docr=E6=8A=A5=E9=94=99?= =?UTF-8?q?=E6=97=B6=E6=97=A5=E5=BF=97=E6=89=93=E5=8D=B0=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/base/task/worker/parse_document_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index b09274e..b6373c0 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -273,7 +273,7 @@ class ParseDocumentWorker(BaseWorker): node.content = await OcrTool.image_to_text(img_np, image_related_text, llm) node.text_feature = node.content except Exception as e: - err = f"[ParseDocumentWorker] OCR失败,doc_id: {node.doc_id}, error: {e}" + err = f"[ParseDocumentWorker] OCR失败 error: {e}" logging.exception(err) continue -- Gitee From 2dfd12e27507989d59e2b07c5a475e341c7a563b Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 4 Jun 2025 15:17:02 +0800 Subject: [PATCH 08/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E5=88=B0=E8=A2=AB=E7=A6=81=E6=AD=A2=E7=9A=84chunk=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/manager/chunk_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index dcc88a4..834858f 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -180,6 +180,7 @@ class ChunkManager(): .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) + .where(ChunkEntity.enabled == True) .where(ChunkEntity.status != ChunkStatus.DELETED.value) .where(ChunkEntity.id.notin_(banned_ids)) ) @@ -226,6 +227,7 @@ class ChunkManager(): .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) + .where(ChunkEntity.enabled == True) .where(ChunkEntity.status != ChunkStatus.DELETED.value) .where(ChunkEntity.id.notin_(banned_ids)) ) @@ -242,6 +244,7 @@ class ChunkManager(): ).desc() ).limit(top_k) result = await session.execute(stmt) + chunk_entities = result.scalars().all() return chunk_entities except Exception as e: -- Gitee From 5481e005e041cef4e3c8318be0dbdd47027e085d Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 4 Jun 2025 17:26:15 +0800 Subject: [PATCH 09/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=85=B3=E9=94=AE?= =?UTF-8?q?=E5=AD=97=E6=A3=80=E7=B4=A2=E5=92=8C=E5=90=91=E9=87=8F=E5=8C=96?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/manager/chunk_manager.py | 49 ++++++++++++++++++-------- data_chain/manager/document_manager.py | 25 +++++++++---- 2 files changed, 53 insertions(+), 21 deletions(-) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 834858f..58ace2c 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -172,8 +172,12 @@ class ChunkManager(): """根据知识库ID和向量查询文档解析结果""" try: async with await DataBase.get_session() as session: + # 计算相似度分数 + similarity_score = ChunkEntity.text_vector.cosine_distance(vector).label("similarity_score") + + # 构建基础查询条件 stmt = ( - select(ChunkEntity) + select(ChunkEntity, similarity_score) .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id ) @@ -184,26 +188,34 @@ class ChunkManager(): .where(ChunkEntity.status != ChunkStatus.DELETED.value) .where(ChunkEntity.id.notin_(banned_ids)) ) + + # 添加可选条件 if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + + # 应用排序条件 + stmt = stmt.order_by(similarity_score.desc()) + stmt = stmt.limit(top_k) + + # 执行最终查询 result = await session.execute(stmt) - stmt = stmt.order_by(ChunkEntity.text_vector.cosine_distance(vector).desc()).limit(top_k) chunk_entities = result.scalars().all() + return chunk_entities except Exception as e: err = "根据知识库ID和向量查询文档解析结果失败" logging.exception("[ChunkManager] %s", err) raise e - @staticmethod async def get_top_k_chunk_by_kb_id_keyword( - kb_id: uuid.UUID, query: str, top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], + kb_id: uuid.UUID, query: str, + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: - """根据知识库ID和关键词查询文档解析结果""" + """根据知识库ID和向量查询文档解析结果""" try: async with await DataBase.get_session() as session: kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) @@ -219,11 +231,19 @@ class ChunkManager(): tokenizer = 'chparser' else: tokenizer = 'zhparser' + + # 计算相似度分数并选择它 + similarity_score = func.ts_rank_cd( + func.to_tsvector(tokenizer, ChunkEntity.text), + func.plainto_tsquery(tokenizer, query) + ).label("similarity_score") + stmt = ( - select(ChunkEntity) + select(ChunkEntity, similarity_score) .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id ) + .where(similarity_score > 0) .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) @@ -231,24 +251,25 @@ class ChunkManager(): .where(ChunkEntity.status != ChunkStatus.DELETED.value) .where(ChunkEntity.id.notin_(banned_ids)) ) + if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) - stmt = stmt.order_by( - func.ts_rank_cd( - func.to_tsvector(tokenizer, ChunkEntity.text), - func.plainto_tsquery(tokenizer, query) - ).desc() - ).limit(top_k) - result = await session.execute(stmt) + # 按相似度分数排序 + stmt = stmt.order_by(similarity_score.desc()) + stmt = stmt.limit(top_k) + + # 执行最终查询 + result = await session.execute(stmt) chunk_entities = result.scalars().all() + return chunk_entities except Exception as e: - err = "根据知识库ID和关键词查询文档解析结果失败" + err = f"根据知识库ID和向量查询文档解析结果失败: {str(e)}" logging.exception("[ChunkManager] %s", err) raise e diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 060d889..29b6cab 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -50,8 +50,10 @@ class DocumentManager(): """根据知识库ID和向量获取前K个文档""" try: async with await DataBase.get_session() as session: + similarity_score = DocumentEntity.abstract.cosine_distance(vector).label("similarity_score") stmt = ( - select(DocumentEntity) + select(DocumentEntity, similarity_score) + .where(similarity_score > 0) .where(DocumentEntity.kb_id == kb_id) .where(DocumentEntity.id.notin_(banned_ids)) .where(DocumentEntity.status != DocumentStatus.DELETED.value) @@ -60,8 +62,12 @@ class DocumentManager(): ) if doc_ids: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) - stmt = stmt.order_by(DocumentEntity.abstract_vector.cosine_distance(vector).desc()).limit(top_k) + stmt = stmt.order_by( + similarity_score.desc() + ) + result = await session.execute(stmt) + document_entities = result.scalars().all() return document_entities except Exception as e: @@ -84,8 +90,12 @@ class DocumentManager(): tokenizer = 'zhparser' elif kb_entity.tokenizer == Tokenizer.EN.value: tokenizer = 'english' + similarity_score = func.ts_rank_cd( + func.to_tsvector(tokenizer, DocumentEntity.abstract), + func.plainto_tsquery(tokenizer, query) + ).label("similarity_score") stmt = ( - select(DocumentEntity) + select(DocumentEntity, similarity_score) .where(DocumentEntity.kb_id == kb_id) .where(DocumentEntity.id.notin_(banned_ids)) .where(DocumentEntity.status != DocumentStatus.DELETED.value) @@ -93,10 +103,11 @@ class DocumentManager(): ) if doc_ids: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) - stmt = stmt.order_by(func.ts_rank_cd( - func.to_tsvector(tokenizer, DocumentEntity.abstract), - func.plainto_tsquery(tokenizer, query) - ).desc()).limit(top_k) + + stmt = stmt.order_by( + similarity_score.desc() + ) + stmt = stmt.limit(top_k) result = await session.execute(stmt) document_entities = result.scalars().all() return document_entities -- Gitee From 6c8747795fa4d99e310d5e0b54a842f3c83ebe62 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 5 Jun 2025 15:00:25 +0800 Subject: [PATCH 10/15] =?UTF-8?q?=E9=9D=99=E9=BB=98=E5=A4=84=E7=90=86rag?= =?UTF-8?q?=20chunk=E6=A3=80=E7=B4=A2=E5=A4=B1=E8=B4=A5=E7=9A=84=E6=83=85?= =?UTF-8?q?=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/manager/chunk_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 58ace2c..648a113 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -209,7 +209,7 @@ class ChunkManager(): except Exception as e: err = "根据知识库ID和向量查询文档解析结果失败" logging.exception("[ChunkManager] %s", err) - raise e + return [] async def get_top_k_chunk_by_kb_id_keyword( kb_id: uuid.UUID, query: str, @@ -271,7 +271,7 @@ class ChunkManager(): except Exception as e: err = f"根据知识库ID和向量查询文档解析结果失败: {str(e)}" logging.exception("[ChunkManager] %s", err) - raise e + return [] @staticmethod async def fetch_surrounding_chunk_by_doc_id_and_global_offset( -- Gitee From 2b2e100444231cf5eac53e3a1a7c939fc333e698 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 5 Jun 2025 15:55:27 +0800 Subject: [PATCH 11/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E7=94=9F=E6=88=90=E7=9A=84=E5=92=8Cmarkdown.zip?= =?UTF-8?q?=E8=A7=A3=E6=9E=90=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../task/worker/generate_dataset_worker.py | 4 ++-- data_chain/apps/router/other.py | 2 +- data_chain/manager/knowledge_manager.py | 2 +- data_chain/parser/handler/md_zip_parser.py | 18 +++++++++++++----- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index 7fe1359..0cd7c3a 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -157,7 +157,7 @@ class GenerateDataSetWorker(BaseWorker): r = j+1 tokens_sub = 0 while TokenTool.get_tokens(chunk) < llm.max_tokens: - if l < 0 and r >= len(doc_chunks): + if l < 0 and r >= len(doc_chunks[i]): break if tokens_sub > 0: if l >= 0: @@ -169,7 +169,7 @@ class GenerateDataSetWorker(BaseWorker): chunk += doc_chunks[i].chunks[r].text r += 1 else: - if r < len(doc_chunks): + if r < len(doc_chunks[i]): tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[r].text) chunk += doc_chunks[i].chunks[r].text r += 1 diff --git a/data_chain/apps/router/other.py b/data_chain/apps/router/other.py index 3085e9a..24f9d00 100644 --- a/data_chain/apps/router/other.py +++ b/data_chain/apps/router/other.py @@ -52,7 +52,7 @@ async def list_llms_by_user_sub( @router.get('/embedding', response_model=ListEmbeddingResponse, dependencies=[Depends(verify_user)]) async def list_embeddings(): - embeddings = [embedding.value for embedding in Embedding] + embeddings = [config['EMBEDDING_MODEL_NAME']] return ListEmbeddingResponse(result=embeddings) diff --git a/data_chain/manager/knowledge_manager.py b/data_chain/manager/knowledge_manager.py index 0c109c4..ae51c6f 100644 --- a/data_chain/manager/knowledge_manager.py +++ b/data_chain/manager/knowledge_manager.py @@ -21,7 +21,7 @@ class KnowledgeBaseManager(): return knowledge_base_entity except Exception as e: err = "添加知识库失败" - logging.exception("[KnowledgeBaseManager] %s", err) + logging.error("[KnowledgeBaseManager] %s", err) @staticmethod async def get_knowledge_base_by_kb_id(kb_id: uuid.UUID) -> KnowledgeBaseEntity: diff --git a/data_chain/parser/handler/md_zip_parser.py b/data_chain/parser/handler/md_zip_parser.py index 24bbb4d..69e94ab 100644 --- a/data_chain/parser/handler/md_zip_parser.py +++ b/data_chain/parser/handler/md_zip_parser.py @@ -89,7 +89,7 @@ class MdZipParser(BaseParser): continue if element.name == 'ol': inner_html = ''.join(str(child) for child in element.children) - child_subtree = await MdZipParser.build_subtree(inner_html, current_level+1) + child_subtree = await MdZipParser.build_subtree(file_path, inner_html, current_level+1) parse_topology_type = ChunkParseTopology.TREENORMAL if len( child_subtree) else ChunkParseTopology.TREELEAF if child_subtree: @@ -133,7 +133,7 @@ class MdZipParser(BaseParser): # 如果有内容,处理这些内容 if content_elements: content_html = ''.join(str(el) for el in content_elements) - child_subtree = await MdZipParser.build_subtree(content_html, level) + child_subtree = await MdZipParser.build_subtree(file_path, content_html, level) parse_topology_type = ChunkParseTopology.TREENORMAL else: child_subtree = [] @@ -231,12 +231,20 @@ class MdZipParser(BaseParser): async def parser(file_path: str) -> ParseResult: target_file_path = os.path.join(os.path.dirname(file_path), 'temp') await ZipHandler.unzip_file(file_path, target_file_path) - markdown_file = [f for f in os.listdir(target_file_path) if f.endswith('.md')] - if not markdown_file: + # 递归查找markdown文件 + markdown_file_path_list = [] + for root, dirs, files in os.walk(target_file_path): + for file in files: + if file.endswith('.md'): + target_file_path = os.path.join(root, file) + markdown_file_path_list.append(target_file_path) + else: + continue + if not markdown_file_path_list: err = f"[MdZipParser] markdown文件不存在" logging.error(err) raise FileNotFoundError(err) - markdown_file_path = os.path.join(target_file_path, markdown_file[0]) if markdown_file else None + markdown_file_path = markdown_file_path_list[0] if markdown_file_path_list else None with open(markdown_file_path, 'r', encoding='utf-8', errors='ignore') as f: markdown_text = f.read() nodes = await MdZipParser.markdown_to_tree(target_file_path, markdown_text) -- Gitee From 2f327f456c76dde21c955e9e52ac87753e556b24 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 5 Jun 2025 15:56:51 +0800 Subject: [PATCH 12/15] =?UTF-8?q?=E6=9A=82=E6=97=B6=E5=8E=BB=E9=99=A4?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E5=86=85=E6=96=87=E6=A1=A3=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=E7=BB=93=E6=9E=9C=E8=AE=BF=E9=97=AE=E7=9A=84=E9=99=90?= =?UTF-8?q?=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/service/chunk_service.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index e93d6cc..1e7b26d 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -80,17 +80,6 @@ class ChunkService: logging.error("[ChunkService] 搜索分片,查询条件: %s", req) chunk_entities = [] for kb_id in req.kb_ids: - if kb_id != DEFAULT_KNOWLEDGE_BASE_ID: - try: - if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base( - user_sub, kb_id, action)): - err = f"用户没有权限访问知识库中的块,知识库ID: {kb_id}" - logging.error("[ChunkService] %s", err) - continue - except Exception as e: - err = f"验证用户对知识库的操作权限失败,error: {e}" - logging.exception(err) - continue try: chunk_entities += await BaseSearcher.search(req.search_method.value, kb_id, req.query, 2*req.top_k, req.doc_ids, req.banned_ids) except Exception as e: -- Gitee From 51229b6bdb3b871b5c0912b82bbad3cac0b33ea3 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 5 Jun 2025 17:11:40 +0800 Subject: [PATCH 13/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dmd.zip=E8=A7=A3?= =?UTF-8?q?=E6=9E=90&=E5=9B=BE=E7=89=87=E6=8F=90=E5=8F=96=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/parser/handler/html_parser.py | 25 ++++++++------ data_chain/parser/handler/md_parser.py | 28 ++++++++------- data_chain/parser/handler/md_zip_parser.py | 40 +++++++++++++++------- 3 files changed, 57 insertions(+), 36 deletions(-) diff --git a/data_chain/parser/handler/html_parser.py b/data_chain/parser/handler/html_parser.py index 6ad61ca..9a98c19 100644 --- a/data_chain/parser/handler/html_parser.py +++ b/data_chain/parser/handler/html_parser.py @@ -52,19 +52,22 @@ class HTMLParser(BaseParser): while current_level_elements: element = current_level_elements.pop(0) if not isinstance(element, Tag): - node = ParseNode( - id=uuid.uuid4(), - lv=current_level, - parse_topology_type=ChunkParseTopology.TREELEAF, - content=element.get_text(strip=True), - type=ChunkType.TEXT, - link_nodes=[] - ) - subtree.append(node) + try: + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=ChunkParseTopology.TREELEAF, + content=element.get_text(strip=True), + type=ChunkType.TEXT, + link_nodes=[] + ) + subtree.append(node) + except Exception as e: + logging.warning(f"[HTMLParser] 处理非标签元素失败: {e}") continue if element.name == 'div' or element.name == 'head' or element.name == 'header' or \ element.name == 'body' or element.name == 'section' or element.name == 'article' or \ - element.name == 'nav' or element.name == 'main' or element.name == 'ol': + element.name == 'nav' or element.name == 'main' or element.name == 'p' or element.name == 'ol': # 处理div内部元素 inner_html = ''.join(str(child) for child in element.children) child_subtree = await HTMLParser.build_subtree(inner_html, current_level+1) @@ -138,7 +141,7 @@ class HTMLParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name == 'p' or element.name == 'title' or element.name == 'span' or element.name == 'pre'\ + elif element.name == 'title' or element.name == 'span' or element.name == 'pre'\ or element.name == 'li': para_text = element.get_text().strip() if para_text: diff --git a/data_chain/parser/handler/md_parser.py b/data_chain/parser/handler/md_parser.py index d549584..97ff373 100644 --- a/data_chain/parser/handler/md_parser.py +++ b/data_chain/parser/handler/md_parser.py @@ -65,17 +65,21 @@ class MdParser(BaseParser): while current_level_elements: element = current_level_elements.pop(0) if not isinstance(element, Tag): - node = ParseNode( - id=uuid.uuid4(), - lv=current_level, - parse_topology_type=ChunkParseTopology.TREELEAF, - content=element.get_text(strip=True), - type=ChunkType.TEXT, - link_nodes=[] - ) - subtree.append(node) + try: + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=ChunkParseTopology.TREELEAF, + content=element.get_text(strip=True), + type=ChunkType.TEXT, + link_nodes=[] + ) + subtree.append(node) + except Exception as e: + logging.error(f"[MdParser] 处理非标签节点失败: {e}") + continue - if element.name == 'ol' or element.name == 'hr': + if element.name == 'p' or element.name == 'ol' or element.name == 'hr': inner_html = ''.join(str(child) for child in element.children) child_subtree = await MdParser.build_subtree(inner_html, current_level+1) parse_topology_type = ChunkParseTopology.TREENORMAL if len( @@ -138,7 +142,7 @@ class MdParser(BaseParser): ) subtree.append(node) elif element.name == 'code': - code_text = element.find('code').get_text() + code_text = element.get_text().strip() node = ParseNode( id=uuid.uuid4(), @@ -149,7 +153,7 @@ class MdParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name == 'p' or element.name == 'li': + elif element.name == 'li': para_text = element.get_text().strip() if para_text: node = ParseNode( diff --git a/data_chain/parser/handler/md_zip_parser.py b/data_chain/parser/handler/md_zip_parser.py index 69e94ab..3804f9f 100644 --- a/data_chain/parser/handler/md_zip_parser.py +++ b/data_chain/parser/handler/md_zip_parser.py @@ -2,6 +2,7 @@ import asyncio from bs4 import BeautifulSoup, Tag import markdown import os +import re import requests import uuid from data_chain.entities.enum import DocParseRelutTopology, ChunkParseTopology, ChunkType @@ -48,6 +49,14 @@ class MdZipParser(BaseParser): logging.warning(warining) return None else: + # 1. 处理反斜杠问题 - 替换所有反斜杠为正斜杠 + img_src = img_src.replace('\\', '/') + + # 2. 移除可能的相对路径前缀(如 ./ 或 ../) + img_src = re.sub(r'^\./+', '', img_src) + + # 3. 规范化路径(处理多余的斜杠等) + img_src = os.path.normpath(img_src) img_path = os.path.join(base_dir, img_src) if os.path.exists(img_path): try: @@ -76,18 +85,22 @@ class MdZipParser(BaseParser): subtree = [] while current_level_elements: element = current_level_elements.pop(0) + if not isinstance(element, Tag): - node = ParseNode( - id=uuid.uuid4(), - lv=current_level, - parse_topology_type=ChunkParseTopology.TREELEAF, - content=element.get_text(strip=True), - type=ChunkType.TEXT, - link_nodes=[] - ) - subtree.append(node) + try: + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=ChunkParseTopology.TREELEAF, + content=element.get_text(strip=True), + type=ChunkType.TEXT, + link_nodes=[] + ) + subtree.append(node) + except Exception as e: + logging.error(f"[MdZipParser] 处理非标签节点失败: {e}") continue - if element.name == 'ol': + if element.name == 'p' or element.name == 'ol' or element.name == 'hr': inner_html = ''.join(str(child) for child in element.children) child_subtree = await MdZipParser.build_subtree(file_path, inner_html, current_level+1) parse_topology_type = ChunkParseTopology.TREENORMAL if len( @@ -150,7 +163,7 @@ class MdZipParser(BaseParser): ) subtree.append(node) elif element.name == 'code': - code_text = element.find('code').get_text() + code_text = element.get_text().strip() node = ParseNode( id=uuid.uuid4(), @@ -161,7 +174,7 @@ class MdZipParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name == 'p' or element.name == 'li': + elif element.name == 'li': para_text = element.get_text().strip() if para_text: node = ParseNode( @@ -213,6 +226,7 @@ class MdZipParser(BaseParser): @staticmethod async def markdown_to_tree(file_path: str, markdown_text: str) -> ParseNode: html = markdown.markdown(markdown_text, extensions=['tables']) + logging.error(html) root = ParseNode( id=uuid.uuid4(), title="", @@ -247,7 +261,7 @@ class MdZipParser(BaseParser): markdown_file_path = markdown_file_path_list[0] if markdown_file_path_list else None with open(markdown_file_path, 'r', encoding='utf-8', errors='ignore') as f: markdown_text = f.read() - nodes = await MdZipParser.markdown_to_tree(target_file_path, markdown_text) + nodes = await MdZipParser.markdown_to_tree(markdown_file_path, markdown_text) return ParseResult( parse_topology_type=DocParseRelutTopology.TREE, nodes=nodes -- Gitee From 80502dbb7b2b4001963fd3421f7fcd2e7f809e36 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 5 Jun 2025 22:12:09 +0800 Subject: [PATCH 14/15] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=8A=A8=E6=80=81?= =?UTF-8?q?=E5=85=B3=E9=94=AE=E5=AD=97=E6=A3=80=E7=B4=A2=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_chain/apps/app.py | 3 +- data_chain/entities/enum.py | 1 + data_chain/manager/chunk_manager.py | 93 ++++++++++++++++++- data_chain/parser/tools/token_tool.py | 27 ++++++ .../rag/dynamic_weighted_keyword_searcher.py | 38 ++++++++ data_chain/rag/keyword_and_vector_searcher.py | 8 +- 6 files changed, 167 insertions(+), 3 deletions(-) create mode 100644 data_chain/rag/dynamic_weighted_keyword_searcher.py diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index b8071dd..c39d4e9 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -54,7 +54,8 @@ from data_chain.rag import ( doc2chunk_searcher, doc2chunk_bfs_searcher, enhanced_by_llm_searcher, - query_extend_searcher + query_extend_searcher, + dynamic_weighted_keyword_searcher ) from data_chain.stores.database.database import ( DataBase, diff --git a/data_chain/entities/enum.py b/data_chain/entities/enum.py index 70e488d..227ea1d 100644 --- a/data_chain/entities/enum.py +++ b/data_chain/entities/enum.py @@ -161,6 +161,7 @@ class SearchMethod(str, Enum): DOC2CHUNK_BFS = "doc2chunk_bfs" ENHANCED_BY_LLM = "enhanced_by_llm" QUERY_EXTEND = "query_extend" + DYNAMIC_WEIGHTED_KEYWORD = "dynamic_weighted_keyword" class TaskType(str, Enum): diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 648a113..53e210c 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from sqlalchemy import select, update, func, text, or_, and_ +from sqlalchemy import select, update, func, text, or_, and_, Float, literal_column from typing import List, Tuple, Dict, Optional import uuid from data_chain.entities.enum import DocumentStatus, ChunkStatus, Tokenizer @@ -273,6 +273,97 @@ class ChunkManager(): logging.exception("[ChunkManager] %s", err) return [] + async def get_top_k_chunk_by_kb_id_dynamic_weighted_keyword( + kb_id: uuid.UUID, keywords: List[str], weights: List[float], + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], + chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: + """根据知识库ID和关键词和关键词权重查询文档解析结果""" + try: + async with await DataBase.get_session() as session: + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + if kb_entity.tokenizer == Tokenizer.ZH.value: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + elif kb_entity.tokenizer == Tokenizer.EN.value: + tokenizer = 'english' + else: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + + # 构建VALUES子句的参数 + params = {} + values_clause = [] + + for i, (term, weight) in enumerate(zip(keywords, weights)): + # 使用单独的参数名,避免与类型转换冲突 + params[f"term_{i}"] = term + params[f"weight_{i}"] = weight + # 在VALUES子句中使用类型转换函数 + values_clause.append(f"(CAST(:term_{i} AS TEXT), CAST(:weight_{i} AS FLOAT8))") + + # 构建VALUES子句 + values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" + + # 创建weighted_terms CTE + weighted_terms = ( + select( + literal_column("t.term").label("term"), + literal_column("t.weight").cast(Float).label("weight") + ) + .select_from(text(values_text)) + .cte("weighted_terms") + ) + + # 计算相似度得分 + similarity_score = func.sum( + func.ts_rank_cd( + func.to_tsvector(tokenizer, ChunkEntity.text), + func.to_tsquery(tokenizer, weighted_terms.c.term) + ) * weighted_terms.c.weight + ).label("similarity_score") + + stmt = ( + select(ChunkEntity, similarity_score) + .join(DocumentEntity, + DocumentEntity.id == ChunkEntity.doc_id + ) + .where(DocumentEntity.enabled == True) + .where(DocumentEntity.status != DocumentStatus.DELETED.value) + .where(ChunkEntity.kb_id == kb_id) + .where(ChunkEntity.enabled == True) + .where(ChunkEntity.status != ChunkStatus.DELETED.value) + .where(ChunkEntity.id.notin_(banned_ids)) + ) + # 添加 GROUP BY 子句,按 ChunkEntity.id 分组 + stmt = stmt.group_by(ChunkEntity.id) + + stmt.having(similarity_score > 0) + + if doc_ids is not None: + stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) + if chunk_to_type is not None: + stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) + if pre_ids is not None: + stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + + # 按相似度分数排序 + stmt = stmt.order_by(similarity_score.desc()) + stmt = stmt.limit(top_k) + + # 执行最终查询 + result = await session.execute(stmt, params=params) + chunk_entities = result.scalars().all() + + return chunk_entities + except Exception as e: + err = f"根据知识库ID和向量查询文档解析结果失败: {str(e)}" + logging.exception("[ChunkManager] %s", err) + return [] + @staticmethod async def fetch_surrounding_chunk_by_doc_id_and_global_offset( doc_id: uuid.UUID, global_offset: int, diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py index fb53d27..9f585d4 100644 --- a/data_chain/parser/tools/token_tool.py +++ b/data_chain/parser/tools/token_tool.py @@ -24,6 +24,20 @@ class TokenTool: with open(stop_words_path, 'r', encoding='utf-8') as f: stopwords = set(line.strip() for line in f) + @staticmethod + def filter_stopwords(content: str) -> str: + """ + 过滤停用词 + """ + try: + words = TokenTool.split_words(content) + filtered_words = [word for word in words if word not in TokenTool.stopwords] + return ' '.join(filtered_words) + except Exception as e: + err = f"[TokenTool] 过滤停用词失败 {e}" + logging.exception("[TokenTool] %s", err) + return content + @staticmethod def get_leave_tokens_from_content_len(content: str) -> int: """ @@ -146,6 +160,19 @@ class TokenTool: logging.exception("[TokenTool] %s", err) return [] + @staticmethod + def get_top_k_keywords_and_weights(content: str, k=10) -> list: + try: + # 使用jieba提取关键词 + keyword_weight_list = extract_tags(content, topK=k, withWeight=True) + keywords = [keyword for keyword, weight in keyword_weight_list] + weights = [weight for keyword, weight in keyword_weight_list] + return keywords, weights + except Exception as e: + err = f"[TokenTool] 获取关键词失败 {e}" + logging.exception("[TokenTool] %s", err) + return [] + @staticmethod def compress_tokens(content: str, k: int = None) -> str: try: diff --git a/data_chain/rag/dynamic_weighted_keyword_searcher.py b/data_chain/rag/dynamic_weighted_keyword_searcher.py new file mode 100644 index 0000000..6a8ac64 --- /dev/null +++ b/data_chain/rag/dynamic_weighted_keyword_searcher.py @@ -0,0 +1,38 @@ +import asyncio +import uuid +from pydantic import BaseModel, Field +import random +from data_chain.logger.logger import logger as logging +from data_chain.stores.database.database import ChunkEntity +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.manager.chunk_manager import ChunkManager +from data_chain.rag.base_searcher import BaseSearcher +from data_chain.entities.enum import SearchMethod + + +class DynamicWeightKeyWordSearcher(BaseSearcher): + """ + 动态加权关键词检索 + """ + name = SearchMethod.DYNAMIC_WEIGHTED_KEYWORD.value + + @staticmethod + async def search( + query: str, kb_id: uuid.UUID, top_k: int = 5, doc_ids: list[uuid.UUID] = None, + banned_ids: list[uuid.UUID] = [] + ) -> list[ChunkEntity]: + """ + 向量检索 + :param query: 查询 + :param top_k: 返回的结果数量 + :return: 检索结果 + """ + try: + query_filtered = TokenTool.filter_stopwords(query) + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query_filtered) + chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k, doc_ids, banned_ids) + except Exception as e: + err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" + logging.exception(err) + return [] + return chunk_entities diff --git a/data_chain/rag/keyword_and_vector_searcher.py b/data_chain/rag/keyword_and_vector_searcher.py index d375302..0237eaa 100644 --- a/data_chain/rag/keyword_and_vector_searcher.py +++ b/data_chain/rag/keyword_and_vector_searcher.py @@ -30,7 +30,10 @@ class KeywordVectorSearcher(BaseSearcher): """ vector = await Embedding.vectorize_embedding(query) try: - chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k//2, doc_ids, banned_ids) + query_filtered = TokenTool.filter_stopwords(query) + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query_filtered) + logging.error(f"[KeywordVectorSearcher] keywords: {keywords}, weights: {weights}") + chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] chunk_entities_get_by_vector = [] for _ in range(3): @@ -42,6 +45,9 @@ class KeywordVectorSearcher(BaseSearcher): logging.error(err) continue chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_vector + for chunk_entity in chunk_entities: + logging.error( + f"[KeywordVectorSearcher] chunk_entity: {chunk_entity.id}, text: {chunk_entity.text[:100]}...") except Exception as e: err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" logging.exception(err) -- Gitee From 9c276aa262d08aef53283cf7a69b92c10c5e3b3b Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 6 Jun 2025 09:33:01 +0800 Subject: [PATCH 15/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E7=94=9F=E6=88=90=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../task/worker/generate_dataset_worker.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index 0cd7c3a..94b2c3a 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -157,25 +157,25 @@ class GenerateDataSetWorker(BaseWorker): r = j+1 tokens_sub = 0 while TokenTool.get_tokens(chunk) < llm.max_tokens: - if l < 0 and r >= len(doc_chunks[i]): + if l < 0 and r >= len(doc_chunk.chunks): break if tokens_sub > 0: if l >= 0: - tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[l].text) - chunk = doc_chunks[i].chunks[l].text+chunk + tokens_sub -= TokenTool.get_tokens(doc_chunk.chunks[l].text) + chunk = doc_chunk.chunks[l].text+chunk l -= 1 else: - tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[r].text) - chunk += doc_chunks[i].chunks[r].text + tokens_sub += TokenTool.get_tokens(doc_chunk.chunks[r].text) + chunk += doc_chunk.chunks[r].text r += 1 else: - if r < len(doc_chunks[i]): - tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[r].text) - chunk += doc_chunks[i].chunks[r].text + if r < len(doc_chunk.chunks): + tokens_sub += TokenTool.get_tokens(doc_chunk.chunks[r].text) + chunk += doc_chunk.chunks[r].text r += 1 else: - tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[l].text) - chunk = doc_chunks[i].chunks[l].text+chunk + tokens_sub -= TokenTool.get_tokens(doc_chunk.chunks[l].text) + chunk = doc_chunk.chunks[l].text+chunk l -= 1 qa_cnt = division+(d_index <= remainder) qs = [] -- Gitee