From b63ae1c8dd5c7a32113d733abd7d42a250111853 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Wed, 5 Mar 2025 11:30:07 +0800 Subject: [PATCH 1/5] =?UTF-8?q?refactor(parser):=20=E9=87=8D=E6=9E=84=20PD?= =?UTF-8?q?F=20=E8=A7=A3=E6=9E=90=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 优化了 PDF 文件的打开和解析流程 - 改进了文本、表格和图像的提取方法 - 完善了附近文本的查找算法 - 重构了列表合并逻辑,提高合并效率 - 添加了类型注解和函数文档字符串,提升代码可读性 --- data_chain/parser/handler/base_parser.py | 17 +- data_chain/parser/handler/pdf_parser.py | 310 ++++++++++++++--------- 2 files changed, 193 insertions(+), 134 deletions(-) diff --git a/data_chain/parser/handler/base_parser.py b/data_chain/parser/handler/base_parser.py index 6de4b39..6147ddf 100644 --- a/data_chain/parser/handler/base_parser.py +++ b/data_chain/parser/handler/base_parser.py @@ -1,22 +1,17 @@ +import json import os import uuid -import json -from data_chain.logger.logger import logger as logging -from pandas import DataFrame + from docx.table import Table as DocxTable +from pandas import DataFrame from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity -import secrets -import shutil -from data_chain.manager.document_manager import DocumentManager -from data_chain.manager.model_manager import ModelManager -from data_chain.parser.tools.split import split_tools -from data_chain.models.constant import OssConstant + from data_chain.apps.base.model.llm import LLM from data_chain.apps.base.security.security import Security -from data_chain.stores.minio.minio import MinIO +from data_chain.logger.logger import logger as logging from data_chain.models.constant import OssConstant - +from data_chain.parser.tools.split import split_tools # TODO chunk和chunk_link可以封装成类 diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index 659442a..f151802 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -1,11 +1,12 @@ import io + import fitz -import uuid import numpy as np -from data_chain.logger.logger import logger as logging from PIL import Image -from data_chain.parser.tools.ocr import BaseOCR + +from data_chain.logger.logger import logger as logging from data_chain.parser.handler.base_parser import BaseService +from data_chain.parser.tools.ocr import BaseOCR class PdfService(BaseService): @@ -13,123 +14,169 @@ class PdfService(BaseService): def __init__(self): super().__init__() self.image_model = None - self.page_numbers = None - self.pdf = None + self.total_pages = None + self.pdf_document = None - def open_pdf(self, file_path): + def open_pdf(self, file_path: str) -> None: + """打开PDF文件并初始化文档对象 + + :param file_path: PDF文件的路径 + :type file_path: str + """ try: - self.pdf = fitz.open(file_path) - self.page_numbers = len(self.pdf) + self.pdf_document = fitz.open(file_path) + self.total_pages = len(self.pdf_document) except Exception as e: logging.error(f"Error opening file {file_path} :{e}") raise e - def extract_text(self, page_number): - page = self.pdf.load_page(page_number) - lines = [] + def extract_text(self, page_number: int) -> list[dict]: + """从PDF页面中提取文本块及其位置信息 - text_blocks = page.get_text('blocks') + :param page_number: PDF页面的页码 + :type page_number: int + :return: 包含文本块及其位置信息的列表 + :rtype: list[dict] + """ + if self.pdf_document is None: + return [] + page = self.pdf_document.load_page(page_number) + text_lines = [] + + text_blocks = page.get_text("blocks") for block in text_blocks: if block[6] == 0: # 确保是文本块 text = block[4].strip() - rect = block[:4] # (x0, y0, x1, y1) + bounding_box = block[:4] # (x0, y0, x1, y1) if text: - lines.append({'bbox': rect, - 'text': text, - 'type': 'para', - }) - return sorted_lines + text_lines.append({"bbox": bounding_box, + "text": text, + "type": "paragraph", + }) + sorted_text_lines = sorted(text_lines, key=lambda x: (x["bbox"][1], x["bbox"][0])) + return sorted_text_lines - def extract_table(self, page_number): - """ - 读取pdf中的列表 - :param page_number:pdf页码 - :返回 pdf中的列表的内容(pandas格式)和坐标(x0,y0,x1,y1) + def extract_table(self, page_number: int) -> list[dict]: + """从PDF页面中提取表格 + + :param page_number: PDF页面的页码 + :type page_number: int + :return: 包含表格内容(pandas格式)和边界框(x0, y0, x1, y1)的列表 + :rtype: list[dict] """ - page = self.pdf.load_page(page_number) - tabs = page.find_tables() - dfs = [] - for tab in tabs: - tab_bbox = fitz.Rect(tab.bbox) - page.add_redact_annot(tab.bbox) - df = tab.to_pandas() - lines = self.split_table(df) - for line in lines: - dfs.append({ - 'text': line, - 'bbox': tab_bbox, - 'type': 'table' - }) + if self.pdf_document is None: + return [] + page = self.pdf_document.load_page(page_number) + tables = page.find_tables() + table_data = [] + for table in tables: + table_bbox = fitz.Rect(table.bbox) + page.add_redact_annot(table.bbox) + table_df = table.to_pandas() + table_lines = self.split_table(table_df) + for line in table_lines: + table_data.extend([{ + "text": line, + "bbox": table_bbox, + "type": "table", + } for line in table_lines]) + page.apply_redactions() - return dfs + return table_data + + async def extract_image(self, page_number: int, text: list[dict]) -> tuple[list[dict], list[dict]]: + """提取图片并返回图片的识别结果和图片的id - async def extract_image(self, page_number, text): - page = self.pdf.load_page(page_number) + :param page_number: PDF页面的页码 + :type page_number: int + :param text: 从PDF中提取的文本块 + :type text: list[dict] + :return: 包含图片识别结果和图片块的列表 + :rtype: tuple[list[dict], list[dict]] + """ + if self.pdf_document is None: + return [], [] + page = self.pdf_document.load_page(page_number) + if page is None: + return [], [] image_list = page.get_images(full=True) - results = [] + image_results = [] image_chunks = [] - for img in image_list: - # 获取图像的xref - xref = img[0] - # 获取图像的base图像(如果存在) - base_image = self.pdf.extract_image(xref) - pos = page.get_image_rects(xref)[0] - # 获取图像的二进制数据 + for image_info in image_list: + # 获取图片的xref + xref = image_info[0] + # 提取基础图片(如果存在) + base_image = self.pdf_document.extract_image(xref) + position = page.get_image_rects(xref)[0] + # 获取图片的二进制数据 image_bytes = base_image["image"] - # 获取图像的扩展名 + # 获取图片的扩展名 image_ext = base_image["ext"] - # 获取图像的位置信息 - - bbox = (pos.x0, pos.y0, pos.x1, pos.y1) - near = self.find_near_words(bbox, text) + # 获取图片的边界框 + bounding_box = (position.x0, position.y0, position.x1, position.y1) + nearby_text = self.find_near_words(bounding_box, text) image = Image.open(io.BytesIO(image_bytes)) image_id = self.get_uuid() await self.insert_image_to_tmp_folder(image_bytes, image_id, image_ext) try: - img_np = np.array(image) + image_np = np.array(image) except Exception as e: logging.error(f"Error converting image to numpy array: {e}") continue - ocr_results = await self.image_model.run(img_np, text=near) + ocr_results = await self.image_model.run(image_np, text=nearby_text) - # 获取OCR + # 获取OCR结果 chunk_id = self.get_uuid() - results.append({ - 'type': 'image', - 'text': ocr_results, - 'bbox': bbox, - 'xref': xref, - 'id': chunk_id, + image_results.append({ + "type": "image", + "text": ocr_results, + "bbox": bounding_box, + "xref": xref, + "id": chunk_id, }) image_chunks.append({ - 'id': image_id, - 'chunk_id': chunk_id, - 'extension': image_ext, + "id": image_id, + "chunk_id": chunk_id, + "extension": image_ext, }) - return results, image_chunks + return image_results, image_chunks + + def extract_text_with_position(self, page_number: int) -> list[dict]: + """提取带有位置的文本块 - def extract_text_with_position(self, page_number): - """获取带坐标的文本块""" - page = self.pdf.load_page(page_number) + :param page_number: PDF页面的页码 + :type page_number: int + :return: 包含文本块及其位置信息的列表 + :rtype: list[dict] + """ + page = self.pdf_document.load_page(page_number) text_blocks = [] for block in page.get_text("dict")["blocks"]: if "lines" in block: # 确保是文本块 for line in block["lines"]: for span in line["spans"]: text_blocks.append({ - 'text': span['text'], - 'bbox': span['bbox'], # 文本的矩形区域 (x0, y0, x1, y1) - 'type': 'para' + "text": span["text"], + "bbox": span["bbox"], # 文本边界框 (x0, y0, x1, y1) + "type": "paragraph", }) return text_blocks - def find_near_words(self, bbox, texts): - """寻找相邻文本""" - image_x0, image_y0, image_x1, image_y1 = bbox + def find_near_words(self, bounding_box: tuple[float, float, float, float], texts: list[dict]) -> str: + """查找附近的文本 + + :param bounding_box: 图片的边界框 (x0, y0, x1, y1) + :type bounding_box: tuple[float, float, float, float] + :param texts: 文本块列表 + :type texts: list[dict] + :return: 附近的文本内容 + :rtype: str + """ + image_x0, image_y0, image_x1, image_y1 = bounding_box threshold = 100 image_x0 -= threshold image_y0 -= threshold @@ -137,53 +184,63 @@ class PdfService(BaseService): image_y1 += threshold line = "" for text in texts: - text_x0, text_y0, text_x1, text_y1 = text['bbox'] - text_content = text['text'] - # 左右相邻:水平距离小于等于阈值,且垂直方向有重叠 + text_x0, text_y0, text_x1, text_y1 = text["bbox"] + text_content = text["text"] + # 检查文本是否水平相邻 horizontally_adjacent = (text_x1 >= image_x0 - threshold and text_x0 <= image_x1 + threshold) - # 上下相邻:垂直距离小于等于阈值,且水平方向有重叠 + # 检查文本是否垂直相邻 vertically_adjacent = (text_y1 >= image_y0 - threshold and text_y0 <= image_y1 + threshold) - # 判断相交或相邻 + # 检查文本是否相交或相邻 if horizontally_adjacent and vertically_adjacent: line = line + text_content return line @staticmethod - def merge_list(list_a, list_b): - """ - 按照x0,y0,x1,y1合并list_a,list_b - :param - list_a:文字list - list_b:图像或者列表的list + def merge_list(text_list: list[dict], image_or_table_list: list[dict]) -> list[dict]: + """根据边界框合并文本列表和图片/表格列表 + + :param text_list: 文本块列表 + :type text_list: list[dict] + :param image_or_table_list: 图片或表格列表 + :type image_or_table_list: list[dict] + :return: 合并后的列表 + :rtype: list[dict] """ - if list_a is None: - return list_b - if list_b is None: - return list_a - len_b = len(list_b) - now_b = 0 + if text_list is None: + return image_or_table_list + if image_or_table_list is None: + return text_list + image_or_table_list_length = len(image_or_table_list) + current_index = 0 max_x = 0 - result_list = [] - - for part_a in list_a: - max_x = max(max_x, part_a['bbox'][2]) - if now_b < len_b: - part_b = list_b[now_b] - while now_b < len_b and part_b['bbox'][0] < max_x and part_b['bbox'][1] < part_a['bbox'][1]: - result_list.append(part_b) - now_b += 1 - if now_b < len_b: - part_b = list_b[now_b] - result_list.append(part_a) - while now_b < len_b: - part_b = list_b[now_b] - result_list.append(part_b) - now_b += 1 - - return result_list - - async def parser(self, file_path): + merged_list = [] + + for text_block in text_list: + max_x = max(max_x, text_block["bbox"][2]) + if current_index < image_or_table_list_length: + image_or_table_block = image_or_table_list[current_index] + while current_index < image_or_table_list_length and image_or_table_block["bbox"][0] < max_x and image_or_table_block["bbox"][1] < text_block["bbox"][1]: + merged_list.append(image_or_table_block) + current_index += 1 + if current_index < image_or_table_list_length: + image_or_table_block = image_or_table_list[current_index] + merged_list.append(text_block) + while current_index < image_or_table_list_length: + image_or_table_block = image_or_table_list[current_index] + merged_list.append(image_or_table_block) + current_index += 1 + + return merged_list + + async def parser(self, file_path: str) -> tuple[list[dict], list[dict], list[dict]]: + """解析PDF文件并返回文本块、链接和图片块 + + :param file_path: PDF文件的路径 + :type file_path: str + :return: 包含文本块、链接和图片块的元组 + :rtype: tuple[list[dict], list[dict], list[dict]] + """ self.open_pdf(file_path) method = self.parser_method sentences = [] @@ -191,25 +248,32 @@ class PdfService(BaseService): if method != "general": self.image_model = BaseOCR(llm=self.llm, llm_max_tokens=self.llm_max_tokens, method=self.parser_method) - for page_num in range(self.page_numbers): + for page_num in range(self.total_pages): tables = self.extract_table(page_num) text = self.extract_text(page_num) - temp_list = self.merge_list(text, tables) + merged_list = self.merge_list(text, tables) if method != "general": images, image_chunks = await self.extract_image(page_num, text) - merge_list = self.merge_list(temp_list, images) + merged_list = self.merge_list(merged_list, images) all_image_chunks.extend(image_chunks) - else: - merge_list = temp_list - sentences.extend(merge_list) - sentences = sorted(sentences, key=lambda x: (x['bbox'][1], x['bbox'][0])) + sentences.extend(merged_list) + chunks = self.build_chunks_by_lines(sentences) chunk_links = self.build_chunk_links_by_line(chunks) return chunks, chunk_links, all_image_chunks def __del__(self): - if self.pdf: - self.pdf.close() - self.page_numbers = None - self.pdf = None + """析构函数,关闭PDF文档并释放资源""" + if self.pdf_document: + self.pdf_document.close() + self.total_pages = None + self.pdf_document = None self.image_model = None + +import asyncio +if __name__ == "__main__": + parser = PdfService() + asyncio.run(parser.init_service(llm_entity=None, tokens=1024, parser_method="ocr")) + chunk, chunk_links, image_chunks = asyncio.run(parser.parser("./data_chain/test/test.pdf")) + for chunk_item in chunk: + print(chunk_item["text"]) \ No newline at end of file -- Gitee From a51c13f48142e4d233961888d728a1669f94dc97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Wed, 5 Mar 2025 11:58:32 +0800 Subject: [PATCH 2/5] =?UTF-8?q?refactor(data=5Fchain):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=20PDF=20=E6=96=87=E6=9C=AC=E5=9D=97=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 引入临时列表 temp_blocks 以简化嵌套循环结构 - 将文本块数据先存储到 temp_blocks 中,再统一添加到 text_blocks - 保留原有的逻辑和功能,仅对代码结构进行优化 --- data_chain/parser/handler/pdf_parser.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index 4f2b9d2..771674a 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -155,15 +155,19 @@ class PdfService(BaseService): """ page = self.pdf_document.load_page(page_number) text_blocks = [] + temp_blocks = [] + for block in page.get_text("dict")["blocks"]: if "lines" in block: # 确保是文本块 for line in block["lines"]: for span in line["spans"]: - text_blocks.append({ + temp_blocks.extend([{ "text": span["text"], "bbox": span["bbox"], # 文本边界框 (x0, y0, x1, y1) "type": "paragraph", - }) + }]) + + text_blocks.extend(temp_blocks) return text_blocks def find_near_words(self, bounding_box: tuple[float, float, float, float], texts: list[dict]) -> str: -- Gitee From 7c2923f2336cabe5e0aedd40b8a8bdb820bc81c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Wed, 5 Mar 2025 12:02:37 +0800 Subject: [PATCH 3/5] =?UTF-8?q?refactor(data=5Fchain):=20=E9=87=8D?= =?UTF-8?q?=E6=9E=84=20PDF=20=E8=A7=A3=E6=9E=90=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修正图像处理逻辑,将 image_np 重命名为 img_np - 更新 image_to_text 方法调用,使用 image_related_text 参数 - 移除未使用的导入和测试代码 --- data_chain/parser/handler/pdf_parser.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index 771674a..e00e77f 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -122,11 +122,11 @@ class PdfService(BaseService): await self.insert_image_to_tmp_folder(image_bytes, image_id, image_ext) try: - image_np = np.array(image) + img_np = np.array(image) except Exception as e: logging.error(f"Error converting image to numpy array: {e}") continue - ocr_results = await self.image_model.image_to_text(img_np, text=near) + ocr_results = await self.image_model.image_to_text(img_np, image_related_text=nearby_text) # 获取OCR结果 chunk_id = self.get_uuid() @@ -273,11 +273,3 @@ class PdfService(BaseService): self.total_pages = None self.pdf_document = None self.image_model = None - -import asyncio -if __name__ == "__main__": - parser = PdfService() - asyncio.run(parser.init_service(llm_entity=None, tokens=1024, parser_method="ocr")) - chunk, chunk_links, image_chunks = asyncio.run(parser.parser("./data_chain/test/test.pdf")) - for chunk_item in chunk: - print(chunk_item["text"]) \ No newline at end of file -- Gitee From 911ddcd52c780efdcc0b907c46543a9e8498735a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Wed, 5 Mar 2025 15:04:39 +0800 Subject: [PATCH 4/5] =?UTF-8?q?refactor(parser):=20=E4=BC=98=E5=8C=96=20PD?= =?UTF-8?q?F=20=E8=A7=A3=E6=9E=90=E5=99=A8=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除了 PdfService 类中冗余的 sorted 赋值语句 - 添加了 __init__ 和 __del__ 方法的类型注解 - 优化了代码结构,提高了代码的可读性和维护性 --- data_chain/parser/handler/pdf_parser.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/data_chain/parser/handler/pdf_parser.py b/data_chain/parser/handler/pdf_parser.py index e00e77f..4a23322 100644 --- a/data_chain/parser/handler/pdf_parser.py +++ b/data_chain/parser/handler/pdf_parser.py @@ -11,7 +11,7 @@ from data_chain.parser.tools.ocr import BaseOCR class PdfService(BaseService): - def __init__(self): + def __init__(self) -> None: super().__init__() self.image_model = None self.total_pages = None @@ -53,8 +53,7 @@ class PdfService(BaseService): "text": text, "type": "paragraph", }) - sorted_text_lines = sorted(text_lines, key=lambda x: (x["bbox"][1], x["bbox"][0])) - return sorted_text_lines + return sorted(text_lines, key=lambda x: (x["bbox"][1], x["bbox"][0])) def extract_table(self, page_number: int) -> list[dict]: """从PDF页面中提取表格 @@ -266,7 +265,7 @@ class PdfService(BaseService): chunk_links = self.build_chunk_links_by_line(chunks) return chunks, chunk_links, all_image_chunks - def __del__(self): + def __del__(self) -> None: """析构函数,关闭PDF文档并释放资源""" if self.pdf_document: self.pdf_document.close() -- Gitee From 32b833ac84f519ec734dd346b8d4085a7a9ff88f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Wed, 5 Mar 2025 17:49:13 +0800 Subject: [PATCH 5/5] refactor(parser): fix bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 更新导入模块以支持新的功能需求 - 修正和统一数据结构中的字段名称 - 优化 OCR 相关函数,支持异步操作 - 改进文本切割工具,增强对不同内容类型的处理 --- data_chain/parser/handler/base_parser.py | 6 ++++-- data_chain/parser/handler/pptx_parser.py | 7 +++---- data_chain/parser/service/parser_service.py | 4 ++-- data_chain/parser/tools/ocr.py | 20 ++++++++++---------- data_chain/parser/tools/split.py | 2 +- 5 files changed, 20 insertions(+), 19 deletions(-) diff --git a/data_chain/parser/handler/base_parser.py b/data_chain/parser/handler/base_parser.py index c302372..b4d2350 100644 --- a/data_chain/parser/handler/base_parser.py +++ b/data_chain/parser/handler/base_parser.py @@ -3,6 +3,8 @@ import os import uuid import json +import numpy as np + import pptx.table from data_chain.logger.logger import logger as logging from pandas import DataFrame @@ -192,7 +194,7 @@ class BaseService: pass for line in lines: if line['type'] == 'image': - line['text'] = await self.ocr_tool.image_to_text(line['image'], text=line['related_text']) + line['text'] = await self.ocr_tool.image_to_text(line['image'], image_related_text=line['related_text']) return lines async def change_lines(self, lines): @@ -241,7 +243,7 @@ class BaseService: 'type': line['type']}) last_para_id = new_lines[-1]['id'] - elif line[1] == 'table': + elif line["type"] == 'table': # 处理表格 new_lines.append({'id': self.get_uuid(), 'text': line['text'], diff --git a/data_chain/parser/handler/pptx_parser.py b/data_chain/parser/handler/pptx_parser.py index 24164f7..f8879fc 100644 --- a/data_chain/parser/handler/pptx_parser.py +++ b/data_chain/parser/handler/pptx_parser.py @@ -30,7 +30,7 @@ class PptxService(BaseService): if text.strip(): lines.append({ "text": text, - "type": 'text' + "type": 'para' }) # 提取表格 elif shape.has_table: @@ -39,7 +39,7 @@ class PptxService(BaseService): for row in rows: lines.append({ "text": text, - "type": table + "type": "table" }) # 提取图片 elif shape.shape_type == 13: # 13 表示图片类型 @@ -51,7 +51,7 @@ class PptxService(BaseService): continue lines.append({ "image": Image.open(BytesIO(image.blob)), - "type": table, + "type": "image", "extension": image_ext }) @@ -76,7 +76,6 @@ class PptxService(BaseService): if self.parser_method != "general": self.ocr_tool = BaseOCR(llm=self.llm, method=self.parser_method) lines = await self.extract_ppt_content(pptx) - lines, images = await self.change_lines(lines) lines = await self.ocr_from_images_in_lines(lines) chunks = self.build_chunks_by_lines(lines) diff --git a/data_chain/parser/service/parser_service.py b/data_chain/parser/service/parser_service.py index bd79fe2..8bdb01c 100644 --- a/data_chain/parser/service/parser_service.py +++ b/data_chain/parser/service/parser_service.py @@ -50,11 +50,11 @@ class ParserService: if not is_temporary_document: llm_entity = await ModelManager.select_by_user_id(self.doc.user_id) await model.init_service(llm_entity=llm_entity, - tokens=self.doc.chunk_size, + chunk_tokens=self.doc.chunk_size, parser_method=self.doc.parser_method) else: await model.init_service(llm_entity=None, - tokens=self.doc.chunk_size, + chunk_tokens=self.doc.chunk_size, parser_method=self.doc.parser_method) chunk_list, chunk_link_list, image_chunks = await model.parser(file_path) if not is_temporary_document: diff --git a/data_chain/parser/tools/ocr.py b/data_chain/parser/tools/ocr.py index 5f3ceed..28c68df 100644 --- a/data_chain/parser/tools/ocr.py +++ b/data_chain/parser/tools/ocr.py @@ -29,7 +29,7 @@ class BaseOCR: self.max_tokens = 1024 self.method = method - def ocr_from_image(self, image): + async def ocr_from_image(self, image): """ 图片ocr接口 参数: @@ -44,7 +44,7 @@ class BaseOCR: logging.error(f"Ocr from image failed due to: {e}") return None - def merge_text_from_ocr_result(self, ocr_result): + async def merge_text_from_ocr_result(self, ocr_result): """ ocr结果文字内容合并接口 参数: @@ -52,14 +52,14 @@ class BaseOCR: """ text = '' try: - for _ in ocr_result[0][0]: - text += _[1][0] + for _ in ocr_result[0]: + text += str(_[1][0]) return text except Exception as e: logging.error(f'Get text from ocr result failed due to: {e}') return '' - def cut_ocr_result_in_part(self, ocr_result, max_tokens=1024): + async def cut_ocr_result_in_part(self, ocr_result, max_tokens=1024): """ ocr结果切割接口 参数: @@ -99,7 +99,7 @@ class BaseOCR: logging.error(f'Get prompt template failed due to :{e}') return '' pre_part_description = "" - ocr_result_parts = self.cut_ocr_result_in_part(ocr_result, self.max_tokens // 5*2) + ocr_result_parts = await self.cut_ocr_result_in_part(ocr_result, self.max_tokens // 5*2) user_call = '请详细输出图片的摘要,不要输出其他内容' for part in ocr_result_parts: pre_part_description_cp = pre_part_description @@ -125,16 +125,16 @@ class BaseOCR: image_related_text: 图片相关文字 """ if self.method == 'ocr': - text = self.merge_text_from_ocr_result(ocr_result) + text = await self.merge_text_from_ocr_result(ocr_result) return text elif self.method == 'enhanced': try: text = await self.enhance_ocr_result(ocr_result, image_related_text) if len(text) == 0: - text = self.merge_text_from_ocr_result(ocr_result) + text = await self.merge_text_from_ocr_result(ocr_result) except Exception as e: logging.error(f"LLM ERROR with: {e}") - text = self.merge_text_from_ocr_result(ocr_result) + text = await self.merge_text_from_ocr_result(ocr_result) return text else: return "" @@ -146,7 +146,7 @@ class BaseOCR: image:图像文件 image_related_text:图像相关的文本 """ - ocr_result = self.ocr_from_image(image) + ocr_result = await self.ocr_from_image(image) if ocr_result is None: return "" text = await self.get_text_from_image(ocr_result, image_related_text) diff --git a/data_chain/parser/tools/split.py b/data_chain/parser/tools/split.py index 0a8bbb5..2de56ec 100644 --- a/data_chain/parser/tools/split.py +++ b/data_chain/parser/tools/split.py @@ -8,7 +8,7 @@ class SplitTools: def get_tokens(self, content): try: enc = tiktoken.encoding_for_model("gpt-4") - return len(enc.encode(content)) + return len(enc.encode(str(content))) except Exception as e: logging.error(f"Get tokens failed due to: {e}") return 0 -- Gitee