diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9abc69c1045deb0024c3fcdc57394ce349b64b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/export_onnx.py @@ -0,0 +1,122 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import torch +import torch.onnx +import torch.nn as nn +from transformers import AutoConfig +from transformers.models.auto.modeling_auto import AutoModel + +logging.basicConfig(level=logging.INFO) + + +class CLIPWrapper(nn.Module): + def __init__(self, clip): + super(CLIPWrapper, self).__init__() + self.model = clip + self.logit_scale = clip.logit_scale.exp() + self.logit_scale.to(self.model.device) + + def forward(self, input_ids, pixel_values, attention_mask): + image_embeds = self.model.get_image_features(pixel_values) + text_embeds = self.model.get_text_features(input_ids, attention_mask) + + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + logits_per_image = image_embeds @ text_embeds.transpose(1, 0).contiguous() * self.logit_scale + logits_per_text = logits_per_image.transpose(1, 0).contiguous() + + return image_embeds, text_embeds, logits_per_text, logits_per_image + + +def export_onnx(args): + + # 加载Pytorch模型 + with torch.no_grad(): + torch_model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + torch_model = CLIPWrapper(torch_model) + config = AutoConfig.from_pretrained(args.hf_model_path) + + # 构造模型输入 + image_width = config.vision_config.image_size + img_input_shape = (1, 3, image_width, image_width) + text_input_shape = (3, args.max_token_len) + input_img = torch.ones(img_input_shape, dtype=torch.float32) + input_ids = torch.randint(high=1, size=text_input_shape, dtype=torch.int32) + attention_mask = torch.ones_like(input_ids, dtype=torch.int32) + torch_model(input_ids, input_img, attention_mask) + + # 导出onnx模型 + file_name = f"CLIP-{args.model_version}.onnx" + model_save_dir = args.save_dir + file_name + logging.info("Starting to export dynamic onnx ...") + text_batch_size = "text_batch_size" + image_batch_size = "image_batch_size" + seq_len = "seq_len" + torch.onnx.export( + torch_model, + (input_ids, input_img, attention_mask), + model_save_dir, + input_names=['input_ids', "pixel_values", "attention_mask"], + output_names=['image_embeds', "text_embeds", "logits_per_text", "logits_per_image"], + export_params=True, + opset_version=13, + verbose=True, + dynamic_axes={ + "input_ids":{0: text_batch_size, 1: seq_len}, + "pixel_values":{0: image_batch_size}, + "attention_mask":{0: text_batch_size, 1: seq_len}, + "image_embeds":{0: image_batch_size}, + "text_embeds":{0: text_batch_size}, + "logits_per_text":{0: text_batch_size, 1: image_batch_size}, + "logits_per_image":{0: image_batch_size, 1:text_batch_size}, + } + ) + logging.info("Successfully exported dynamic onnx!") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile Clip model") + parser.add_argument("--text-max-batch", type=int, default=80) + parser.add_argument("--img-max-batch", type=int, default=8) + parser.add_argument( + "--max-token-len", + type=int, + default=52, + help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + parser.add_argument( + "--model-version", + default="ViT-B-16", + choices=["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"], + help="Specify the architecture of CLIP model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--save-dir", type=str, default="./", help="Path to save the exported model") + + return parser.parse_args() + + +if __name__ == "__main__": + compile_args = parse_args() + export_onnx(compile_args) \ No newline at end of file