diff --git a/data_chain/parser/tools/ocr.py b/data_chain/parser/tools/ocr.py index e49c465637cfdbd2f7bb974a2890c8099abfdc64..1d529edad38bd2fbf7e0f0a5ac6b32b6437029aa 100644 --- a/data_chain/parser/tools/ocr.py +++ b/data_chain/parser/tools/ocr.py @@ -100,9 +100,12 @@ class BaseOCR: return '' pre_part_description = "" ocr_result_parts = await self.cut_ocr_result_in_part(ocr_result, self.max_tokens // 5*2) + image_related_text = split_tools.get_k_tokens_words_from_content(image_related_text, self.max_tokens // 5*2) user_call = '请详细输出图片的摘要,不要输出其他内容' for part in ocr_result_parts: pre_part_description_cp = pre_part_description + pre_part_description = split_tools.get_k_tokens_words_from_content( + pre_part_description, self.max_tokens // 5) try: prompt = prompt_template.format( image_related_text=image_related_text, diff --git a/data_chain/parser/tools/split.py b/data_chain/parser/tools/split.py index 2de56ec09204bcb16975d467ff4c74eebdcd6d1e..d4c5d9a03d68aa79432c34c29894bc4ec0976c0b 100644 --- a/data_chain/parser/tools/split.py +++ b/data_chain/parser/tools/split.py @@ -13,6 +13,24 @@ class SplitTools: logging.error(f"Get tokens failed due to: {e}") return 0 + def get_k_tokens_words_from_content(self, content: str, k: int) -> list: + try: + if (self.get_tokens(content) <= k): + return content + l = 0 + r = len(content) + while l+1 < r: + mid = (l+r)//2 + if (self.get_tokens(content[:mid]) <= k): + l = mid + else: + r = mid + return content[:l] + except Exception as e: + err = f"[TokenTool] 获取k个token的词失败 {e}" + logging.exception("[TokenTool] %s", err) + return "" + def split_words(self, text): return list(jieba.cut(str(text)))