diff --git a/data_chain/apps/app.py b/data_chain/apps/app.py index ed4821a0ea14eef89e3fb50914b8fce38c886d41..c39d4e9333961136732bb6c181a2803d75b4b667 100644 --- a/data_chain/apps/app.py +++ b/data_chain/apps/app.py @@ -53,7 +53,9 @@ from data_chain.rag import ( keyword_and_vector_searcher, doc2chunk_searcher, doc2chunk_bfs_searcher, - enhanced_by_llm_searcher + enhanced_by_llm_searcher, + query_extend_searcher, + dynamic_weighted_keyword_searcher ) from data_chain.stores.database.database import ( DataBase, diff --git a/data_chain/apps/base/task/worker/generate_dataset_worker.py b/data_chain/apps/base/task/worker/generate_dataset_worker.py index 530cbafcc2d378ab1551a81fd897a1c1440f2ea5..94b2c3a7ea03acc4882c5b2e0c328ed8e02c0425 100644 --- a/data_chain/apps/base/task/worker/generate_dataset_worker.py +++ b/data_chain/apps/base/task/worker/generate_dataset_worker.py @@ -132,7 +132,9 @@ class GenerateDataSetWorker(BaseWorker): chunk_cnt = len(chunk_index_list) division = data_cnt // chunk_cnt remainder = data_cnt % chunk_cnt + logging.error(f"数据集总条目 {dataset_entity.data_cnt}, 分块数量: {chunk_cnt}, 每块数据量: {division}, 余数: {remainder}") index = 0 + d_index = 0 random.shuffle(doc_chunks) with open(config['PROMPT_PATH'], 'r', encoding='utf-8') as f: prompt_dict = yaml.load(f, Loader=yaml.SafeLoader) @@ -140,6 +142,7 @@ class GenerateDataSetWorker(BaseWorker): answer_generate_prompt_template = prompt_dict.get('GENERATE_ANSWER_FROM_QUESTION_AND_CONTENT_PROMPT', '') cal_qa_score_prompt_template = prompt_dict.get('CAL_QA_SCORE_PROMPT', '') dataset_score = 0 + logging.error(f"{chunk_index_list}") for i in range(len(doc_chunks)): doc_chunk = doc_chunks[i] for j in range(len(doc_chunk.chunks)): @@ -147,33 +150,34 @@ class GenerateDataSetWorker(BaseWorker): index += 1 continue index += 1 + d_index += 1 chunk = doc_chunk.chunks[j].text if dataset_entity.is_chunk_related: l = j-1 r = j+1 tokens_sub = 0 while TokenTool.get_tokens(chunk) < llm.max_tokens: - if l < 0 and r >= len(doc_chunks): + if l < 0 and r >= len(doc_chunk.chunks): break if tokens_sub > 0: if l >= 0: - tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[l].text) - chunk = doc_chunks[i].chunks[l].text+chunk + tokens_sub -= TokenTool.get_tokens(doc_chunk.chunks[l].text) + chunk = doc_chunk.chunks[l].text+chunk l -= 1 else: - tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[r].text) - chunk += doc_chunks[i].chunks[r].text + tokens_sub += TokenTool.get_tokens(doc_chunk.chunks[r].text) + chunk += doc_chunk.chunks[r].text r += 1 else: - if r < len(doc_chunks): - tokens_sub += TokenTool.get_tokens(doc_chunks[i].chunks[r].text) - chunk += doc_chunks[i].chunks[r].text + if r < len(doc_chunk.chunks): + tokens_sub += TokenTool.get_tokens(doc_chunk.chunks[r].text) + chunk += doc_chunk.chunks[r].text r += 1 else: - tokens_sub -= TokenTool.get_tokens(doc_chunks[i].chunks[l].text) - chunk = doc_chunks[i].chunks[l].text+chunk + tokens_sub -= TokenTool.get_tokens(doc_chunk.chunks[l].text) + chunk = doc_chunk.chunks[l].text+chunk l -= 1 - qa_cnt = division+(index < remainder) + qa_cnt = division+(d_index <= remainder) qs = [] answers = [] rd = 5 @@ -181,7 +185,7 @@ class GenerateDataSetWorker(BaseWorker): rd -= 1 try: sys_call = q_generate_prompt_template.format( - k=qa_cnt, + k=qa_cnt-len(qs), content=TokenTool.get_k_tokens_words_from_content(chunk, llm.max_tokens) ) usr_call = '请输出问题的列表' @@ -191,6 +195,7 @@ class GenerateDataSetWorker(BaseWorker): err = f"[GenerateDataSetWorker] 生成问题失败,错误信息: {e}" logging.exception(err) continue + sub_qs = sub_qs[:qa_cnt-len(qs)] sub_answers = [] try: for q in sub_qs: diff --git a/data_chain/apps/base/task/worker/parse_document_worker.py b/data_chain/apps/base/task/worker/parse_document_worker.py index 731a067f4ea72b31b5cf291721f42b1bec3a3e58..b6373c0f7a8d8589d6f07217f3e2f4e72c161c46 100644 --- a/data_chain/apps/base/task/worker/parse_document_worker.py +++ b/data_chain/apps/base/task/worker/parse_document_worker.py @@ -273,7 +273,7 @@ class ParseDocumentWorker(BaseWorker): node.content = await OcrTool.image_to_text(img_np, image_related_text, llm) node.text_feature = node.content except Exception as e: - err = f"[ParseDocumentWorker] OCR失败,doc_id: {node.doc_id}, error: {e}" + err = f"[ParseDocumentWorker] OCR失败 error: {e}" logging.exception(err) continue @@ -351,6 +351,7 @@ class ParseDocumentWorker(BaseWorker): link_nodes=[] ) nodes.append(tmp_node) + parse_result.nodes = nodes @staticmethod async def push_up_words_feature(parse_result: ParseResult, llm: LLM = None) -> None: diff --git a/data_chain/apps/router/other.py b/data_chain/apps/router/other.py index 3085e9a009d3a50fc8468a501903ae4cb4ad7093..24f9d001a069d4bc8fe4059392c7a5f78e21d1bd 100644 --- a/data_chain/apps/router/other.py +++ b/data_chain/apps/router/other.py @@ -52,7 +52,7 @@ async def list_llms_by_user_sub( @router.get('/embedding', response_model=ListEmbeddingResponse, dependencies=[Depends(verify_user)]) async def list_embeddings(): - embeddings = [embedding.value for embedding in Embedding] + embeddings = [config['EMBEDDING_MODEL_NAME']] return ListEmbeddingResponse(result=embeddings) diff --git a/data_chain/apps/service/chunk_service.py b/data_chain/apps/service/chunk_service.py index e93d6ccc1d41c59f94c7eb133ec229e34d12d282..1e7b26d7972e19dc7574ebac402dbab2254ac1a4 100644 --- a/data_chain/apps/service/chunk_service.py +++ b/data_chain/apps/service/chunk_service.py @@ -80,17 +80,6 @@ class ChunkService: logging.error("[ChunkService] 搜索分片,查询条件: %s", req) chunk_entities = [] for kb_id in req.kb_ids: - if kb_id != DEFAULT_KNOWLEDGE_BASE_ID: - try: - if not (await KnowledgeBaseService.validate_user_action_to_knowledge_base( - user_sub, kb_id, action)): - err = f"用户没有权限访问知识库中的块,知识库ID: {kb_id}" - logging.error("[ChunkService] %s", err) - continue - except Exception as e: - err = f"验证用户对知识库的操作权限失败,error: {e}" - logging.exception(err) - continue try: chunk_entities += await BaseSearcher.search(req.search_method.value, kb_id, req.query, 2*req.top_k, req.doc_ids, req.banned_ids) except Exception as e: diff --git a/data_chain/common/prompt.yaml b/data_chain/common/prompt.yaml index 9ce98d6098b8b0178a29f78c7d6e7f1cad6f44a5..611949edaafa4c7f0043e7a11d7d4fe0491dfe8f 100644 --- a/data_chain/common/prompt.yaml +++ b/data_chain/common/prompt.yaml @@ -351,4 +351,22 @@ CHUNK_QUERY_MATCH_PROMPT: '你是一个文本分析专家,你的任务是根 下面是给出的片段和问题: 片段:{chunk} 问题:{question} + ' +QUERY_EXTEND_PROMPT: '你是一个问题扩写专家,你的任务是根据给出的问题扩写{k}个问题 + 注意: + #01 扩写的问题的内容必须来源于原问题中的内容 + #02 扩写的问题长度不超过50个字 + #03 可以通过近义词替换 问题内词序交换 修改英文大小写等方式来改写问题 + #04 请仅输出扩写的问题列表,不要输出其他内容 + 例子: + 输入:openEuler是什么操作系统? + 输出: + [ + \"openEuler是一个什么样的操作系统?\", + \"openEuler操作系统的特点是什么?\", + \"openEuler操作系统有哪些功能?\", + \"openEuler操作系统的优势是什么?\" + ] + 下面是给出的问题: + {question} ' \ No newline at end of file diff --git a/data_chain/entities/enum.py b/data_chain/entities/enum.py index d59ce12dfda8db73c2d5d9a704e95e65c2e75e8f..227ea1d0578a80b4d2976e7435eeb8f39a19ab79 100644 --- a/data_chain/entities/enum.py +++ b/data_chain/entities/enum.py @@ -160,6 +160,8 @@ class SearchMethod(str, Enum): DOC2CHUNK = "doc2chunk" DOC2CHUNK_BFS = "doc2chunk_bfs" ENHANCED_BY_LLM = "enhanced_by_llm" + QUERY_EXTEND = "query_extend" + DYNAMIC_WEIGHTED_KEYWORD = "dynamic_weighted_keyword" class TaskType(str, Enum): diff --git a/data_chain/manager/chunk_manager.py b/data_chain/manager/chunk_manager.py index 88e0dc2779efae2e658afb7d606dc06549bdf654..53e210c30772719dba4c5f7d773b169bbcef2419 100644 --- a/data_chain/manager/chunk_manager.py +++ b/data_chain/manager/chunk_manager.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from sqlalchemy import select, update, func, text, or_, and_ +from sqlalchemy import select, update, func, text, or_, and_, Float, literal_column from typing import List, Tuple, Dict, Optional import uuid from data_chain.entities.enum import DocumentStatus, ChunkStatus, Tokenizer @@ -172,38 +172,50 @@ class ChunkManager(): """根据知识库ID和向量查询文档解析结果""" try: async with await DataBase.get_session() as session: + # 计算相似度分数 + similarity_score = ChunkEntity.text_vector.cosine_distance(vector).label("similarity_score") + + # 构建基础查询条件 stmt = ( - select(ChunkEntity) + select(ChunkEntity, similarity_score) .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id ) .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) + .where(ChunkEntity.enabled == True) .where(ChunkEntity.status != ChunkStatus.DELETED.value) .where(ChunkEntity.id.notin_(banned_ids)) - .order_by(ChunkEntity.text_vector.cosine_distance(vector).desc()) - .limit(top_k) ) + + # 添加可选条件 if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + + # 应用排序条件 + stmt = stmt.order_by(similarity_score.desc()) + stmt = stmt.limit(top_k) + + # 执行最终查询 result = await session.execute(stmt) chunk_entities = result.scalars().all() + return chunk_entities except Exception as e: err = "根据知识库ID和向量查询文档解析结果失败" logging.exception("[ChunkManager] %s", err) - raise e + return [] - @staticmethod async def get_top_k_chunk_by_kb_id_keyword( - kb_id: uuid.UUID, query: str, top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], + kb_id: uuid.UUID, query: str, + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: - """根据知识库ID和关键词查询文档解析结果""" + """根据知识库ID和向量查询文档解析结果""" try: async with await DataBase.get_session() as session: kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) @@ -219,37 +231,138 @@ class ChunkManager(): tokenizer = 'chparser' else: tokenizer = 'zhparser' + + # 计算相似度分数并选择它 + similarity_score = func.ts_rank_cd( + func.to_tsvector(tokenizer, ChunkEntity.text), + func.plainto_tsquery(tokenizer, query) + ).label("similarity_score") + stmt = ( - select(ChunkEntity) + select(ChunkEntity, similarity_score) .join(DocumentEntity, DocumentEntity.id == ChunkEntity.doc_id ) + .where(similarity_score > 0) .where(DocumentEntity.enabled == True) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(ChunkEntity.kb_id == kb_id) + .where(ChunkEntity.enabled == True) .where(ChunkEntity.status != ChunkStatus.DELETED.value) .where(ChunkEntity.id.notin_(banned_ids)) - .order_by( - func.ts_rank_cd( - func.to_tsvector(tokenizer, ChunkEntity.text), - func.plainto_tsquery(tokenizer, query) - ).desc() - ) - .limit(top_k) ) + if doc_ids is not None: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) if chunk_to_type is not None: stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) if pre_ids is not None: stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + + # 按相似度分数排序 + stmt = stmt.order_by(similarity_score.desc()) + stmt = stmt.limit(top_k) + + # 执行最终查询 result = await session.execute(stmt) chunk_entities = result.scalars().all() + return chunk_entities except Exception as e: - err = "根据知识库ID和关键词查询文档解析结果失败" + err = f"根据知识库ID和向量查询文档解析结果失败: {str(e)}" logging.exception("[ChunkManager] %s", err) - raise e + return [] + + async def get_top_k_chunk_by_kb_id_dynamic_weighted_keyword( + kb_id: uuid.UUID, keywords: List[str], weights: List[float], + top_k: int, doc_ids: list[uuid.UUID] = None, banned_ids: list[uuid.UUID] = [], + chunk_to_type: str = None, pre_ids: list[uuid.UUID] = None) -> List[ChunkEntity]: + """根据知识库ID和关键词和关键词权重查询文档解析结果""" + try: + async with await DataBase.get_session() as session: + kb_entity = await KnowledgeBaseManager.get_knowledge_base_by_kb_id(kb_id) + if kb_entity.tokenizer == Tokenizer.ZH.value: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + elif kb_entity.tokenizer == Tokenizer.EN.value: + tokenizer = 'english' + else: + if config['DATABASE_TYPE'].lower() == 'opengauss': + tokenizer = 'chparser' + else: + tokenizer = 'zhparser' + + # 构建VALUES子句的参数 + params = {} + values_clause = [] + + for i, (term, weight) in enumerate(zip(keywords, weights)): + # 使用单独的参数名,避免与类型转换冲突 + params[f"term_{i}"] = term + params[f"weight_{i}"] = weight + # 在VALUES子句中使用类型转换函数 + values_clause.append(f"(CAST(:term_{i} AS TEXT), CAST(:weight_{i} AS FLOAT8))") + + # 构建VALUES子句 + values_text = f"(VALUES {', '.join(values_clause)}) AS t(term, weight)" + + # 创建weighted_terms CTE + weighted_terms = ( + select( + literal_column("t.term").label("term"), + literal_column("t.weight").cast(Float).label("weight") + ) + .select_from(text(values_text)) + .cte("weighted_terms") + ) + + # 计算相似度得分 + similarity_score = func.sum( + func.ts_rank_cd( + func.to_tsvector(tokenizer, ChunkEntity.text), + func.to_tsquery(tokenizer, weighted_terms.c.term) + ) * weighted_terms.c.weight + ).label("similarity_score") + + stmt = ( + select(ChunkEntity, similarity_score) + .join(DocumentEntity, + DocumentEntity.id == ChunkEntity.doc_id + ) + .where(DocumentEntity.enabled == True) + .where(DocumentEntity.status != DocumentStatus.DELETED.value) + .where(ChunkEntity.kb_id == kb_id) + .where(ChunkEntity.enabled == True) + .where(ChunkEntity.status != ChunkStatus.DELETED.value) + .where(ChunkEntity.id.notin_(banned_ids)) + ) + # 添加 GROUP BY 子句,按 ChunkEntity.id 分组 + stmt = stmt.group_by(ChunkEntity.id) + + stmt.having(similarity_score > 0) + + if doc_ids is not None: + stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) + if chunk_to_type is not None: + stmt = stmt.where(ChunkEntity.parse_topology_type == chunk_to_type) + if pre_ids is not None: + stmt = stmt.where(ChunkEntity.pre_id_in_parse_topology.in_(pre_ids)) + + # 按相似度分数排序 + stmt = stmt.order_by(similarity_score.desc()) + stmt = stmt.limit(top_k) + + # 执行最终查询 + result = await session.execute(stmt, params=params) + chunk_entities = result.scalars().all() + + return chunk_entities + except Exception as e: + err = f"根据知识库ID和向量查询文档解析结果失败: {str(e)}" + logging.exception("[ChunkManager] %s", err) + return [] @staticmethod async def fetch_surrounding_chunk_by_doc_id_and_global_offset( diff --git a/data_chain/manager/document_manager.py b/data_chain/manager/document_manager.py index 17ac0a70c9223c3f0f55dab13d4260b91ae611f3..29b6cab863600646ad5bb83b8f6c3ef7c6184fe5 100644 --- a/data_chain/manager/document_manager.py +++ b/data_chain/manager/document_manager.py @@ -50,19 +50,24 @@ class DocumentManager(): """根据知识库ID和向量获取前K个文档""" try: async with await DataBase.get_session() as session: + similarity_score = DocumentEntity.abstract.cosine_distance(vector).label("similarity_score") stmt = ( - select(DocumentEntity) + select(DocumentEntity, similarity_score) + .where(similarity_score > 0) .where(DocumentEntity.kb_id == kb_id) .where(DocumentEntity.id.notin_(banned_ids)) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.enabled == True) .where(DocumentEntity.abstract_vector.cosine_distance(vector).desc()) - .order_by(DocumentEntity.abstract_vector.cosine_distance(vector).desc()) - .limit(top_k) ) if doc_ids: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) + stmt = stmt.order_by( + similarity_score.desc() + ) + result = await session.execute(stmt) + document_entities = result.scalars().all() return document_entities except Exception as e: @@ -85,20 +90,24 @@ class DocumentManager(): tokenizer = 'zhparser' elif kb_entity.tokenizer == Tokenizer.EN.value: tokenizer = 'english' + similarity_score = func.ts_rank_cd( + func.to_tsvector(tokenizer, DocumentEntity.abstract), + func.plainto_tsquery(tokenizer, query) + ).label("similarity_score") stmt = ( - select(DocumentEntity) + select(DocumentEntity, similarity_score) .where(DocumentEntity.kb_id == kb_id) .where(DocumentEntity.id.notin_(banned_ids)) .where(DocumentEntity.status != DocumentStatus.DELETED.value) .where(DocumentEntity.enabled == True) - .where(func.ts_rank_cd( - func.to_tsvector(tokenizer, DocumentEntity.abstract), - func.plainto_tsquery(tokenizer, query) - ).desc()) - .limit(top_k) ) if doc_ids: stmt = stmt.where(DocumentEntity.id.in_(doc_ids)) + + stmt = stmt.order_by( + similarity_score.desc() + ) + stmt = stmt.limit(top_k) result = await session.execute(stmt) document_entities = result.scalars().all() return document_entities diff --git a/data_chain/manager/knowledge_manager.py b/data_chain/manager/knowledge_manager.py index 0c109c49f0a01927be2b047913c2e8aaf9b865ac..ae51c6f6e77b9db76e1d5c79ac8d8febc33c5b52 100644 --- a/data_chain/manager/knowledge_manager.py +++ b/data_chain/manager/knowledge_manager.py @@ -21,7 +21,7 @@ class KnowledgeBaseManager(): return knowledge_base_entity except Exception as e: err = "添加知识库失败" - logging.exception("[KnowledgeBaseManager] %s", err) + logging.error("[KnowledgeBaseManager] %s", err) @staticmethod async def get_knowledge_base_by_kb_id(kb_id: uuid.UUID) -> KnowledgeBaseEntity: diff --git a/data_chain/parser/handler/html_parser.py b/data_chain/parser/handler/html_parser.py index 6c4e06031493268ce1fd8e7581741b9b4f28d3df..9a98c196a6dffcd5d2ce83f0f4cfb825e1f6f3c0 100644 --- a/data_chain/parser/handler/html_parser.py +++ b/data_chain/parser/handler/html_parser.py @@ -45,26 +45,29 @@ class HTMLParser(BaseParser): # 获取body元素作为起点(如果是完整HTML文档) root = soup.body if soup.body else soup - + valid_headers = ["h1", "h2", "h3", "h4", "h5", "h6"] # 获取当前层级的直接子元素 current_level_elements = list(root.children) subtree = [] while current_level_elements: element = current_level_elements.pop(0) if not isinstance(element, Tag): - node = ParseNode( - id=uuid.uuid4(), - lv=current_level, - parse_topology_type=ChunkParseTopology.TREELEAF, - content=element.get_text(strip=True), - type=ChunkType.TEXT, - link_nodes=[] - ) - subtree.append(node) + try: + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=ChunkParseTopology.TREELEAF, + content=element.get_text(strip=True), + type=ChunkType.TEXT, + link_nodes=[] + ) + subtree.append(node) + except Exception as e: + logging.warning(f"[HTMLParser] 处理非标签元素失败: {e}") continue if element.name == 'div' or element.name == 'head' or element.name == 'header' or \ element.name == 'body' or element.name == 'section' or element.name == 'article' or \ - element.name == 'nav' or element.name == 'main' or element.name == 'ol': + element.name == 'nav' or element.name == 'main' or element.name == 'p' or element.name == 'ol': # 处理div内部元素 inner_html = ''.join(str(child) for child in element.children) child_subtree = await HTMLParser.build_subtree(inner_html, current_level+1) @@ -91,7 +94,7 @@ class HTMLParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name.startswith('h'): + elif element.name in valid_headers: try: level = int(element.name[1:]) except Exception: @@ -101,7 +104,7 @@ class HTMLParser(BaseParser): content_elements = [] while current_level_elements: sibling = current_level_elements[0] - if sibling.name and sibling.name.startswith('h'): + if sibling.name and sibling.name in valid_headers: next_level = int(sibling.name[1:]) else: next_level = level + 1 @@ -138,7 +141,7 @@ class HTMLParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name == 'p' or element.name == 'title' or element.name == 'span' or element.name == 'pre'\ + elif element.name == 'title' or element.name == 'span' or element.name == 'pre'\ or element.name == 'li': para_text = element.get_text().strip() if para_text: diff --git a/data_chain/parser/handler/md_parser.py b/data_chain/parser/handler/md_parser.py index e226836ae722cc610ddd87b7d7bda5811ed8e35b..97ff373160e3f7f5a6ae58813075476e3e7060c3 100644 --- a/data_chain/parser/handler/md_parser.py +++ b/data_chain/parser/handler/md_parser.py @@ -61,20 +61,25 @@ class MdParser(BaseParser): current_level_elements = list(root.children) # 过滤掉非标签节点(如文本节点) subtree = [] + valid_headers = ["h1", "h2", "h3", "h4", "h5", "h6"] while current_level_elements: element = current_level_elements.pop(0) if not isinstance(element, Tag): - node = ParseNode( - id=uuid.uuid4(), - lv=current_level, - parse_topology_type=ChunkParseTopology.TREELEAF, - content=element.get_text(strip=True), - type=ChunkType.TEXT, - link_nodes=[] - ) - subtree.append(node) + try: + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=ChunkParseTopology.TREELEAF, + content=element.get_text(strip=True), + type=ChunkType.TEXT, + link_nodes=[] + ) + subtree.append(node) + except Exception as e: + logging.error(f"[MdParser] 处理非标签节点失败: {e}") + continue - if element.name == 'ol' or element.name == 'hr': + if element.name == 'p' or element.name == 'ol' or element.name == 'hr': inner_html = ''.join(str(child) for child in element.children) child_subtree = await MdParser.build_subtree(inner_html, current_level+1) parse_topology_type = ChunkParseTopology.TREENORMAL if len( @@ -100,7 +105,7 @@ class MdParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name.startswith('h'): + elif element.name in valid_headers: try: level = int(element.name[1:]) except Exception: @@ -110,7 +115,7 @@ class MdParser(BaseParser): content_elements = [] while current_level_elements: sibling = current_level_elements[0] - if sibling.name and sibling.name.startswith('h'): + if sibling.name and sibling.name in valid_headers: next_level = int(sibling.name[1:]) else: next_level = level + 1 @@ -137,7 +142,7 @@ class MdParser(BaseParser): ) subtree.append(node) elif element.name == 'code': - code_text = element.find('code').get_text() + code_text = element.get_text().strip() node = ParseNode( id=uuid.uuid4(), @@ -148,7 +153,7 @@ class MdParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name == 'p' or element.name == 'li': + elif element.name == 'li': para_text = element.get_text().strip() if para_text: node = ParseNode( @@ -224,4 +229,3 @@ class MdParser(BaseParser): nodes=nodes ) return parse_result - diff --git a/data_chain/parser/handler/md_zip_parser.py b/data_chain/parser/handler/md_zip_parser.py index bccdbb91f4f9be56beedfa1244deaf2732507781..3804f9fe81615ac31a134065a222264f76d16f93 100644 --- a/data_chain/parser/handler/md_zip_parser.py +++ b/data_chain/parser/handler/md_zip_parser.py @@ -2,6 +2,7 @@ import asyncio from bs4 import BeautifulSoup, Tag import markdown import os +import re import requests import uuid from data_chain.entities.enum import DocParseRelutTopology, ChunkParseTopology, ChunkType @@ -48,6 +49,14 @@ class MdZipParser(BaseParser): logging.warning(warining) return None else: + # 1. 处理反斜杠问题 - 替换所有反斜杠为正斜杠 + img_src = img_src.replace('\\', '/') + + # 2. 移除可能的相对路径前缀(如 ./ 或 ../) + img_src = re.sub(r'^\./+', '', img_src) + + # 3. 规范化路径(处理多余的斜杠等) + img_src = os.path.normpath(img_src) img_path = os.path.join(base_dir, img_src) if os.path.exists(img_path): try: @@ -71,24 +80,29 @@ class MdZipParser(BaseParser): # 获取当前层级的直接子元素 current_level_elements = list(root.children) + valid_headers = ["h1", "h2", "h3", "h4", "h5", "h6"] # 过滤掉非标签节点(如文本节点) subtree = [] while current_level_elements: element = current_level_elements.pop(0) + if not isinstance(element, Tag): - node = ParseNode( - id=uuid.uuid4(), - lv=current_level, - parse_topology_type=ChunkParseTopology.TREELEAF, - content=element.get_text(strip=True), - type=ChunkType.TEXT, - link_nodes=[] - ) - subtree.append(node) + try: + node = ParseNode( + id=uuid.uuid4(), + lv=current_level, + parse_topology_type=ChunkParseTopology.TREELEAF, + content=element.get_text(strip=True), + type=ChunkType.TEXT, + link_nodes=[] + ) + subtree.append(node) + except Exception as e: + logging.error(f"[MdZipParser] 处理非标签节点失败: {e}") continue - if element.name == 'ol': + if element.name == 'p' or element.name == 'ol' or element.name == 'hr': inner_html = ''.join(str(child) for child in element.children) - child_subtree = await MdZipParser.build_subtree(inner_html, current_level+1) + child_subtree = await MdZipParser.build_subtree(file_path, inner_html, current_level+1) parse_topology_type = ChunkParseTopology.TREENORMAL if len( child_subtree) else ChunkParseTopology.TREELEAF if child_subtree: @@ -112,7 +126,7 @@ class MdZipParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name.startswith('h'): + elif element.name in valid_headers: try: level = int(element.name[1:]) except Exception: @@ -122,7 +136,7 @@ class MdZipParser(BaseParser): content_elements = [] while current_level_elements: sibling = current_level_elements[0] - if sibling.name and sibling.name.startswith('h'): + if sibling.name and sibling.name in valid_headers: next_level = int(sibling.name[1:]) else: next_level = level + 1 @@ -132,7 +146,7 @@ class MdZipParser(BaseParser): # 如果有内容,处理这些内容 if content_elements: content_html = ''.join(str(el) for el in content_elements) - child_subtree = await MdZipParser.build_subtree(content_html, level) + child_subtree = await MdZipParser.build_subtree(file_path, content_html, level) parse_topology_type = ChunkParseTopology.TREENORMAL else: child_subtree = [] @@ -149,7 +163,7 @@ class MdZipParser(BaseParser): ) subtree.append(node) elif element.name == 'code': - code_text = element.find('code').get_text() + code_text = element.get_text().strip() node = ParseNode( id=uuid.uuid4(), @@ -160,7 +174,7 @@ class MdZipParser(BaseParser): link_nodes=[] ) subtree.append(node) - elif element.name == 'p' or element.name == 'li': + elif element.name == 'li': para_text = element.get_text().strip() if para_text: node = ParseNode( @@ -212,6 +226,7 @@ class MdZipParser(BaseParser): @staticmethod async def markdown_to_tree(file_path: str, markdown_text: str) -> ParseNode: html = markdown.markdown(markdown_text, extensions=['tables']) + logging.error(html) root = ParseNode( id=uuid.uuid4(), title="", @@ -230,15 +245,23 @@ class MdZipParser(BaseParser): async def parser(file_path: str) -> ParseResult: target_file_path = os.path.join(os.path.dirname(file_path), 'temp') await ZipHandler.unzip_file(file_path, target_file_path) - markdown_file = [f for f in os.listdir(target_file_path) if f.endswith('.md')] - if not markdown_file: + # 递归查找markdown文件 + markdown_file_path_list = [] + for root, dirs, files in os.walk(target_file_path): + for file in files: + if file.endswith('.md'): + target_file_path = os.path.join(root, file) + markdown_file_path_list.append(target_file_path) + else: + continue + if not markdown_file_path_list: err = f"[MdZipParser] markdown文件不存在" logging.error(err) raise FileNotFoundError(err) - markdown_file_path = os.path.join(target_file_path, markdown_file[0]) if markdown_file else None + markdown_file_path = markdown_file_path_list[0] if markdown_file_path_list else None with open(markdown_file_path, 'r', encoding='utf-8', errors='ignore') as f: markdown_text = f.read() - nodes = await MdZipParser.markdown_to_tree(target_file_path, markdown_text) + nodes = await MdZipParser.markdown_to_tree(markdown_file_path, markdown_text) return ParseResult( parse_topology_type=DocParseRelutTopology.TREE, nodes=nodes diff --git a/data_chain/parser/tools/token_tool.py b/data_chain/parser/tools/token_tool.py index fb53d270188229d0abe9d4e8121371327069655f..9f585d48a7c0ad267a4c7b3ba72bd58c1d155f6b 100644 --- a/data_chain/parser/tools/token_tool.py +++ b/data_chain/parser/tools/token_tool.py @@ -24,6 +24,20 @@ class TokenTool: with open(stop_words_path, 'r', encoding='utf-8') as f: stopwords = set(line.strip() for line in f) + @staticmethod + def filter_stopwords(content: str) -> str: + """ + 过滤停用词 + """ + try: + words = TokenTool.split_words(content) + filtered_words = [word for word in words if word not in TokenTool.stopwords] + return ' '.join(filtered_words) + except Exception as e: + err = f"[TokenTool] 过滤停用词失败 {e}" + logging.exception("[TokenTool] %s", err) + return content + @staticmethod def get_leave_tokens_from_content_len(content: str) -> int: """ @@ -146,6 +160,19 @@ class TokenTool: logging.exception("[TokenTool] %s", err) return [] + @staticmethod + def get_top_k_keywords_and_weights(content: str, k=10) -> list: + try: + # 使用jieba提取关键词 + keyword_weight_list = extract_tags(content, topK=k, withWeight=True) + keywords = [keyword for keyword, weight in keyword_weight_list] + weights = [weight for keyword, weight in keyword_weight_list] + return keywords, weights + except Exception as e: + err = f"[TokenTool] 获取关键词失败 {e}" + logging.exception("[TokenTool] %s", err) + return [] + @staticmethod def compress_tokens(content: str, k: int = None) -> str: try: diff --git a/data_chain/rag/dynamic_weighted_keyword_searcher.py b/data_chain/rag/dynamic_weighted_keyword_searcher.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8ac64780eddcde3c0c0f81d9a6242a76532a91 --- /dev/null +++ b/data_chain/rag/dynamic_weighted_keyword_searcher.py @@ -0,0 +1,38 @@ +import asyncio +import uuid +from pydantic import BaseModel, Field +import random +from data_chain.logger.logger import logger as logging +from data_chain.stores.database.database import ChunkEntity +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.manager.chunk_manager import ChunkManager +from data_chain.rag.base_searcher import BaseSearcher +from data_chain.entities.enum import SearchMethod + + +class DynamicWeightKeyWordSearcher(BaseSearcher): + """ + 动态加权关键词检索 + """ + name = SearchMethod.DYNAMIC_WEIGHTED_KEYWORD.value + + @staticmethod + async def search( + query: str, kb_id: uuid.UUID, top_k: int = 5, doc_ids: list[uuid.UUID] = None, + banned_ids: list[uuid.UUID] = [] + ) -> list[ChunkEntity]: + """ + 向量检索 + :param query: 查询 + :param top_k: 返回的结果数量 + :return: 检索结果 + """ + try: + query_filtered = TokenTool.filter_stopwords(query) + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query_filtered) + chunk_entities = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k, doc_ids, banned_ids) + except Exception as e: + err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" + logging.exception(err) + return [] + return chunk_entities diff --git a/data_chain/rag/keyword_and_vector_searcher.py b/data_chain/rag/keyword_and_vector_searcher.py index d3753029c57c1ae2ae0a9f53632c0f543e445856..0237eaa61e30fb2168a19df12874e4ee090d0d77 100644 --- a/data_chain/rag/keyword_and_vector_searcher.py +++ b/data_chain/rag/keyword_and_vector_searcher.py @@ -30,7 +30,10 @@ class KeywordVectorSearcher(BaseSearcher): """ vector = await Embedding.vectorize_embedding(query) try: - chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, top_k//2, doc_ids, banned_ids) + query_filtered = TokenTool.filter_stopwords(query) + keywords, weights = TokenTool.get_top_k_keywords_and_weights(query_filtered) + logging.error(f"[KeywordVectorSearcher] keywords: {keywords}, weights: {weights}") + chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_dynamic_weighted_keyword(kb_id, keywords, weights, top_k//2, doc_ids, banned_ids) chunk_ids = [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] chunk_entities_get_by_vector = [] for _ in range(3): @@ -42,6 +45,9 @@ class KeywordVectorSearcher(BaseSearcher): logging.error(err) continue chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_vector + for chunk_entity in chunk_entities: + logging.error( + f"[KeywordVectorSearcher] chunk_entity: {chunk_entity.id}, text: {chunk_entity.text[:100]}...") except Exception as e: err = f"[KeywordVectorSearcher] 关键词向量检索失败,error: {e}" logging.exception(err) diff --git a/data_chain/rag/query_extend_searcher.py b/data_chain/rag/query_extend_searcher.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd8bf559a8feef96759f988c37e396c6af15658 --- /dev/null +++ b/data_chain/rag/query_extend_searcher.py @@ -0,0 +1,73 @@ +import asyncio +import uuid +import yaml +from pydantic import BaseModel, Field +import random +import json +from data_chain.logger.logger import logger as logging +from data_chain.stores.database.database import ChunkEntity +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.manager.chunk_manager import ChunkManager +from data_chain.rag.base_searcher import BaseSearcher +from data_chain.embedding.embedding import Embedding +from data_chain.entities.enum import SearchMethod +from data_chain.parser.tools.token_tool import TokenTool +from data_chain.llm.llm import LLM +from data_chain.config.config import config + + +class QueryExtendSearcher(BaseSearcher): + """ + 基于查询扩展的搜索 + """ + name = SearchMethod.QUERY_EXTEND.value + + @staticmethod + async def search( + query: str, kb_id: uuid.UUID, top_k: int = 5, doc_ids: list[uuid.UUID] = None, + banned_ids: list[uuid.UUID] = [] + ) -> list[ChunkEntity]: + """ + 向量检索 + :param query: 查询 + :param top_k: 返回的结果数量 + :return: 检索结果 + """ + with open('./data_chain/common/prompt.yaml', 'r', encoding='utf-8') as f: + prompt_dict = yaml.safe_load(f) + prompt_template = prompt_dict['QUERY_EXTEND_PROMPT'] + chunk_entities = [] + llm = LLM( + openai_api_key=config['OPENAI_API_KEY'], + openai_api_base=config['OPENAI_API_BASE'], + model_name=config['MODEL_NAME'], + max_tokens=config['MAX_TOKENS'], + ) + sys_call = prompt_template.format(k=2*top_k, question=query) + user_call = "请输出扩写的问题列表" + queries = await llm.nostream([], sys_call, user_call) + try: + queries = json.loads(queries) + queries += [query] + except Exception as e: + logging.error(f"[QueryExtendSearcher] JSON解析失败,error: {e}") + queries = [query] + queries = list(set(queries)) + chunk_entities = [] + for query in queries: + vector = await Embedding.vectorize_embedding(query) + chunk_entities_get_by_keyword = await ChunkManager.get_top_k_chunk_by_kb_id_keyword(kb_id, query, 2, doc_ids, banned_ids) + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_keyword] + chunk_entities_get_by_vector = [] + for _ in range(3): + try: + chunk_entities_get_by_vector = await asyncio.wait_for(ChunkManager.get_top_k_chunk_by_kb_id_vector(kb_id, vector, top_k-len(chunk_entities_get_by_keyword), doc_ids, banned_ids), timeout=3) + break + except Exception as e: + err = f"[KeywordVectorSearcher] 向量检索失败,error: {e}" + logging.error(err) + continue + banned_ids += [chunk_entity.id for chunk_entity in chunk_entities_get_by_vector] + sub_chunk_entities = chunk_entities_get_by_keyword + chunk_entities_get_by_vector + chunk_entities += sub_chunk_entities + return chunk_entities