From 349bec683b10617cf95480a752d782546edab91b Mon Sep 17 00:00:00 2001 From: commc Date: Wed, 4 Sep 2024 15:21:51 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E5=AE=A2=E6=88=B7=E9=9C=80=E6=B1=82-onnx?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=AF=BC=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/multimodal/export_onnx.py | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py diff --git a/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py b/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py new file mode 100644 index 0000000000..0b5897cb28 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py @@ -0,0 +1,126 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import torch +import torch.onnx +import torch.nn as nn +from transformers.models.auto.modeling_auto import AutoModel + +logging.basicConfig(level=logging.INFO) + + +class CLIPWrapper(nn.Module): + def __init__(self, clip): + super(CLIPWrapper, self).__init__() + self.model = clip + self.logit_scale = clip.logit_scale.exp() + self.logit_scale.to(self.model.device) + + def forward(self, input_ids, pixel_values, attention_mask): + image_embeds = self.model.get_image_features(pixel_values) + text_embeds = self.model.get_text_features(input_ids, attention_mask) + + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + logits_per_image = image_embeds @ text_embeds.transpose(1, 0).contiguous() * self.logit_scale + logits_per_text = logits_per_image.transpose(1, 0).contiguous() + + return image_embeds, text_embeds, logits_per_text, logits_per_image + + +def export_onnx(args): + + # 加载Pytorch模型 + with torch.no_grad(): + torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + torch_model = CLIPWrapper(torch_model) + + hf_config_path = os.path.join(args.hf_model_path, "config.json") + if not os.path.exists(hf_config_path): + raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") + with open(hf_config_path, "r") as f: + config_dict = json.load(f) + + # 构造模型输入 + image_width = config_dict["vision_config"]["image_size"] + img_input_shape = (1, 3, image_width, image_width) + text_input_shape = (3, args.max_token_len) + input_img = torch.ones(img_input_shape, dtype=torch.float32) + input_ids = torch.randint(high=1, size=text_input_shape, dtype=torch.int32) + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + torch_model(input_ids, input_img, attention_mask) + + # 导出onnx模型 + file_name = f"CLIP-{args.model_version}.onnx" + model_save_dir = args.save_dir + file_name + logging.info("Starting to export dynamic onnx ...") + text_batch_size = "text_batch_size" + image_batch_size = "image_batch_size" + seq_len = "seq_len" + torch.onnx.export( + torch_model, + (input_ids, input_img, attention_mask), + model_save_dir, + input_names=['input_ids', "pixel_values", "attention_mask"], + output_names=['image_embeds', "text_embeds", "logits_per_text", "logits_per_image"], + export_params=True, + opset_version=13, + verbose=True, + dynamic_axes={ + "input_ids":{0: text_batch_size, 1: seq_len}, + "pixel_values":{0: image_batch_size}, + "attention_mask":{0: text_batch_size, 1: seq_len}, + "image_embeds":{0: image_batch_size}, + "text_embeds":{0: text_batch_size}, + "logits_per_text":{0: text_batch_size, 1: image_batch_size}, + "logits_per_image":{0: image_batch_size, 1:text_batch_size}, + } + ) + logging.info("Successfully exported dynamic onnx!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile Clip model") + parser.add_argument("--text-max-batch", type=int, default=80) + parser.add_argument("--img-max-batch", type=int, default=8) + parser.add_argument( + "--max-token-len", + type=int, + default=52, + help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + parser.add_argument( + "--model-version", + default="ViT-B-16", + choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], + help="Specify the architecture of CLIP model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--save-dir", type=str, default="./", help="Path to save the exported model") + + return parser.parse_args() + + +if __name__ == "__main__": + compile_args = parse_args() + export_onnx(compile_args) \ No newline at end of file -- Gitee From e4ba1e14597885c472814a94f3aa0c027e149d05 Mon Sep 17 00:00:00 2001 From: commc Date: Wed, 4 Sep 2024 18:32:35 +0800 Subject: [PATCH 2/3] =?UTF-8?q?config=E6=96=87=E4=BB=B6=E6=89=93=E5=BC=80?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py b/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py index 0b5897cb28..dc9abc69c1 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py @@ -18,6 +18,7 @@ import argparse import torch import torch.onnx import torch.nn as nn +from transformers import AutoConfig from transformers.models.auto.modeling_auto import AutoModel logging.basicConfig(level=logging.INFO) @@ -49,15 +50,10 @@ def export_onnx(args): with torch.no_grad(): torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() torch_model = CLIPWrapper(torch_model) - - hf_config_path = os.path.join(args.hf_model_path, "config.json") - if not os.path.exists(hf_config_path): - raise FileNotFoundError(f"config.json not found at {args.hf_model_path}: {hf_config_path}") - with open(hf_config_path, "r") as f: - config_dict = json.load(f) + config = AutoConfig.from_pretrained(args.hf_model_path) # 构造模型输入 - image_width = config_dict["vision_config"]["image_size"] + image_width = config.vision_config.image_size img_input_shape = (1, 3, image_width, image_width) text_input_shape = (3, args.max_token_len) input_img = torch.ones(img_input_shape, dtype=torch.float32) -- Gitee From 1f89585c9e112a5910beb259f24aaeed150599f3 Mon Sep 17 00:00:00 2001 From: commc Date: Thu, 5 Sep 2024 17:13:13 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=AD=A3=E7=A1=AE=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=A4=B9=E5=B1=82=E7=BA=A7=E7=BB=93=E6=9E=84=E6=95=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MindIE-Torch/built-in/multimodal/{ => CLIP}/export_onnx.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename MindIE/MindIE-Torch/built-in/multimodal/{ => CLIP}/export_onnx.py (100%) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py similarity index 100% rename from MindIE/MindIE-Torch/built-in/multimodal/export_onnx.py rename to MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py -- Gitee