From 8914e549d6f15bf8239e149f1bad8b8ab03249be Mon Sep 17 00:00:00 2001 From: lizhongyang10 Date: Tue, 3 Sep 2024 12:48:43 +0800 Subject: [PATCH 1/7] add siglip model --- .../multimodal/Siglip-SO400M/README.md | 111 ++++++++++++++++++ .../Siglip-SO400M/precision_test.py | 107 +++++++++++++++++ .../multimodal/Siglip-SO400M/pref_test_gpu.py | 70 +++++++++++ .../multimodal/Siglip-SO400M/pref_test_npu.py | 93 +++++++++++++++ .../multimodal/Siglip-SO400M/siglip_export.py | 102 ++++++++++++++++ 5 files changed, 483 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py create mode 100644 MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md new file mode 100644 index 0000000000..d97162bee5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md @@ -0,0 +1,111 @@ +# Siglip模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + +- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + +CLIP 使用图像-文本对对对比损失进行网络预训练。CLIP 面临两个技术挑战:(1)它需要较大的批大小。例如,CLIP 使用了 32K 的批量大小,这需要大量 GPU;(2)它需要这些 GPU 之间进行大量通信。因此,后续提出了SigLIP,以减少 CLIP 的批量大小要求。SigLip 是Sigmoid Language Image Pre -training 的缩写。SigLIP 的核心思想是使用 Sigmoid运算而不是Softmax运算。 + + +# 推理环境准备\[所有版本\] + +- 该模型需要以下依赖 + + **表 1** 版本配套表 + + | 配套 | 版本 | + |---------|---------| + | CANN | 8.0RC3 | + | MindIE | 1.0RC3 | + | Python | 3.10.13 | + | torch | 2.1.0 | + | torch_npu | 2.1.0post6 | + | transformers | 4.44.2 | + | NPU芯片类型 | Ascend910B4 | + | 处理器架构 | arm64 | + +# 快速上手 + + + +## 模型转换 +1. 下载模型 + +从以下HuggingFace链接下载所需文件,所需文件为整个文件夹 +https://huggingface.co/google/siglip-so400m-patch14-384/tree/main + +新建一个用于保存文件和模型的目录,用变量名model_dir表示 +将siglip-so400m-patch14-384文件夹放在model_dir中 + + + +1. 导出ts模型 + + 1. 修改代码 + 在安装包路径site-packages/mindietorch/_dynamo/lowering/passes/mindie_lowering.py中, + 注释掉 from mindietorch._dynamo.lowering.passes.lower_linear import lower_linear 这一行 + 注释掉 lower_linear, 这一行 + + 在安装包路径site-packages/transformers/configuration_utils.py中, + 将 self.return_dict = kwargs.pop("return_dict", True) 这一行的 True 改为 False + + 2. 转为MindIETorch模型 + 其中test_img为一张测试图片的名称,保存在model_dir下 + + 执行siglip_export.py: + + ```sh + python siglip_export.py \ + --pytorch_model model_dir/siglip-so400m-patch14-384 \ + --exported_program_path model_dir/exported_siglip.pt2 \ + --saved_compile_model_path model_dir/compile_siglip.ts \ + --images_path model_dir/{test_img} + ``` + + 会在model_dir目录下生成compile_siglip.ts文件。 + +## 精度验证 + +模型精度验证,屏幕显示Cosine Similarity的Image和Text分数都大于0.9999 为精度正常。 + +执行precision_test.py: + + ```sh + python precision_test.py \ + --pytorch_model model_dir/siglip-so400m-patch14-384 \ + --exported_program_path model_dir/exported_siglip.pt2 \ + --saved_compile_model_path model_dir/compile_siglip.ts + ``` + +## 性能验证 + +npu性能测试,执行pref_test_npu.py: + + ```sh + python pref_test_npu.py \ + --config_path model_dir/siglip-so400m-patch14-384 \ + --encoder_aie_path model_dir/compile_siglip.ts + ``` + +gpu性能测试,执行pref_test_gpu.py: + + ```sh + python pref_test_gpu.py \ + --model_path model_dir/siglip-so400m-patch14-384 \ + ``` + +屏幕上会打印性能数据,以FPS记 + + +## 性能数据 (时延/吞吐率) +|Model| MindIE Torch | T4| +|------| ----------------- |------| +|encoder| 114.98ms/8.7FPS | 20.53ms/48.70FPS | diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py new file mode 100644 index 0000000000..20d4282349 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py @@ -0,0 +1,107 @@ +import argparse +import os +import time +import sys +import torch +import mindietorch +import torch.nn as nn +import numpy as np +from PIL import Image +from mindietorch import _enums +from torch._export import export, dynamic_dim +from transformers import AutoProcessor +from transformers.models.auto.modeling_auto import AutoModel + + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + return res + + +def inference(args): + + # Load compiled model + mindietorch.set_device(args.device) + stream = mindietorch.npu.Stream() + mindietorch.npu.set_stream(stream) + compiled_model = torch.load(args.saved_compile_model_path) + + + # prepare PyTorch implemented model + pytorch_model = AutoModel.from_pretrained(args.pytorch_model) + processor = AutoProcessor.from_pretrained(args.pytorch_model) + + + img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8) + img_c = img_c.astype(np.float32) + img_c = (img_c / 255. - 0.5) / 0.5 # torch style norm + img_c = torch.from_numpy(img_c) + + text_c = np.random.randint(0, 25566, size=(3, 64), dtype=np.int64) + text_c = torch.from_numpy(text_c) + + text = text_c + img = img_c + + text_npu = text.to(torch.int64).to(f"npu:{args.device}") + img_npu = img.to(torch.float32).to(f"npu:{args.device}") + + pytorch_results = pytorch_model(text, img) + pytorch_img_result, pytorch_text_results = pytorch_results[0], pytorch_results[1] + pytorch_text_results = pytorch_text_results / pytorch_text_results.norm(p=2, dim=-1, keepdim=True) + pytorch_img_result = pytorch_img_result / pytorch_img_result.norm(p=2, dim=-1, keepdim=True) + + compiled_results = compiled_model(text_npu, img_npu) + compiled_img_results, compiled_text_results = compiled_results[0].to('cpu'), compiled_results[1].to('cpu') + compiled_text_results = compiled_text_results / compiled_text_results.norm(p=2, dim=-1, keepdim=True) + compiled_img_results = compiled_img_results / compiled_img_results.norm(p=2, dim=-1, keepdim=True) + + img_similarity = cosine_similarity(pytorch_img_result, compiled_img_results) + text_similarity = cosine_similarity(pytorch_text_results, compiled_text_results) + print("Cosine Similarity: ", f"\n Image: {img_similarity}, Text: {text_similarity}") + + # pytorch_logits = (100.0 * pytorch_img_result @ pytorch_text_results.T).softmax(dim=-1) + # compile_logits = (100.0 * compiled_img_results @ compiled_text_results.T).softmax(dim=-1) + # print("Prediction score: ", f"\nPytorch Model Score: {pytorch_logits}, ", f"\nCompile Model Score: {compile_logits}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="mindietorch clip model compilation") + parser.add_argument("--soc_version", default="Ascend910A") + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--text_max_batch", type=int, default=64) + parser.add_argument("--img_max_batch", type=int, default=16) + parser.add_argument( + "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + + parser.add_argument("--pytorch_model", default="/home/lizhongyang/siglip-so400m-patch14-384", type=str) + + parser.add_argument("--exported_program_path", + default="/home/lizhongyang/code/siglip/exported_siglip.pt2" + ) + parser.add_argument("--saved_compile_model_path", + default="/home/lizhongyang/code/siglip/compile_siglip.ts") + parser.add_argument( + "--precision", + default="fp16", + choices=["fp16", "fp32"], + help="Specify the architecture (model scale) of Chinese-CLIP model to be converted." + ) + parser.add_argument("--text", default=["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"]) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + inference(args) + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py new file mode 100644 index 0000000000..61b0d7e949 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py @@ -0,0 +1,70 @@ +import sys +import torch +import numpy as np +from torch.export import export, dynamic_dim +from PIL import Image +import requests +import argparse +import time +from tqdm import tqdm +from transformers import AutoProcessor, AutoModel + + +def test_encoder(aie_path, device_id=0): + batch_size = 1 + device = f'cuda:{device_id}' + + print("Start loading ts module...") + ts = AutoModel.from_pretrained(aie_path) + + ts = ts.to(device) + print("Ts module loaded.") + + processor = AutoProcessor.from_pretrained(aie_path) + + img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8) + img_c = img_c.astype(np.float32) + img_c = torch.from_numpy(img_c) + + texts = ["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"] + inputs = processor(text=texts, images=img_c, padding="max_length", return_tensors="pt") + + text = inputs.input_ids + img = inputs.pixel_values + text = text.to(device) + img = img.to(device) + + print("Start infering...") + # warmup + for _ in range(100): + ts(text, img) + + # performance test + num_infer = 200 + + start = time.time() + for _ in tqdm(range(num_infer)): + ts(text, img) + end = time.time() + + print(f"Encoder latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"Encoder throughput: {num_infer * batch_size / (end - start):.2f} fps") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", type=str, default="/root/lizhongyang/data/siglip-so400m-patch14-384") + parser.add_argument("--device_id", type=int, help="NPU device id", default=0) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + test_encoder(args.model_path, args.device_id) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py new file mode 100644 index 0000000000..f14e573824 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py @@ -0,0 +1,93 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +import argparse +import json +from PIL import Image +import numpy as np +import torch +import mindietorch +from tqdm import tqdm + + +def test_encoder(aie_path, config_path, device_id=0): + batch_size = 1 + device = f'npu:{device_id}' + stream = mindietorch.npu.Stream(device) + print("Start loading ts module...") + # ts = torch.jit.load(aie_path) + compiled_model = torch.load(aie_path) + print("Ts module loaded.") + compiled_model.eval() + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(config_path) + text = ["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"] + + # image = Image.open('/home/lizhongyang/code/3.png') + # if image.mode != 'RGB': + # image = image.convert('RGB') + img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8) + img_c = img_c.astype(np.float32) + # img_c = (img_c / 255. - 0.5) / 0.5 # torch style norm + img_c = torch.from_numpy(img_c) + image = img_c + + inputs = processor(text=text, images=image, return_tensors="pt", padding="max_length") + + text = inputs.input_ids + img = inputs.pixel_values + text_npu = text.to(torch.int64).to(device) + img_npu = img.to(torch.float32).to(device) + + print("Start infering...") + # warmup + for _ in range(100): + with mindietorch.npu.stream(stream): + compiled_model(text_npu, img_npu) + stream.synchronize() + # performance test + num_infer = 200 + start = time.time() + for _ in tqdm(range(num_infer)): + with mindietorch.npu.stream(stream): + compiled_model(text_npu, img_npu) + stream.synchronize() + end = time.time() + + print(f"Encoder latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"Encoder throughput: {num_infer * batch_size / (end - start):.2f} FPS") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--encoder_aie_path", type=str, default='/data1/lizhongyang/imgs/compile_siglip.ts') + parser.add_argument("--config_path", type=str, default='/data1/lizhongyang/imgs/siglip-so400m-patch14-384') + parser.add_argument("--device_id", type=int, help="NPU device id", default=0) + + args = parser.parse_args() + return args + + +def main(): + mindietorch.set_device(0) + args = parse_args() + + test_encoder(args.encoder_aie_path, args.config_path, args.device_id) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py new file mode 100644 index 0000000000..e595e487fe --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py @@ -0,0 +1,102 @@ +import argparse +import os +import time +import sys +import torch +import mindietorch +import torch.nn as nn +import numpy as np +from PIL import Image +from mindietorch import _enums +from torch._export import export, dynamic_dim +from transformers import AutoProcessor +from transformers.models.auto.modeling_auto import AutoModel + + +def export_compile(args): + model = AutoModel.from_pretrained(args.pytorch_model) + processor = AutoProcessor.from_pretrained(args.pytorch_model) + + images = [] + image = Image.open(args.images_path) + if image.mode != 'RGB': + image = image.convert('RGB') + images.append(image) + + inputs = processor(text=args.text, images=images, padding="max_length", return_tensors="pt") + + input_ids = inputs["input_ids"] + pixel_values = inputs["pixel_values"] + + constraints = [ + dynamic_dim(input_ids, 0) >= 1, dynamic_dim(input_ids, 0) <= args.text_max_batch, + dynamic_dim(pixel_values, 0) >= 1, dynamic_dim(pixel_values, 0) <= args.img_max_batch, + ] + + if not os.path.exists(args.exported_program_path): + encoder_ts_model = export(model, (input_ids, pixel_values), constraints=constraints) + torch.export.save(encoder_ts_model, args.exported_program_path) + + img_min_shape = (1, 3, 384, 384) + img_max_shape = (16, 3, 384, 384) + text_min_shape = (1, 64) + text_max_shape = (64, 64) + compile_inputs = [ + mindietorch.Input(min_shape=text_min_shape, max_shape=text_max_shape, dtype=torch.int64,), + mindietorch.Input(min_shape=img_min_shape, max_shape=img_max_shape, dtype=torch.float32,) + ] + + mindietorch.set_device(0) + + if not os.path.exists(args.saved_compile_model_path): + compiled_encoder_model = mindietorch.compile( + encoder_ts_model, + inputs=compile_inputs, + # precision_policy=mindietorch.PrecisionPolicy.FP16, + precision_policy=_enums.PrecisionPolicy.FP16, + # truncate_long_and_double=True, + soc_version="Ascend910A" + ) + + torch.save(compiled_encoder_model, args.saved_compile_model_path, pickle_protocol=4) + return compiled_encoder_model + else: + compiled_model = torch.load(args.saved_compile_model_path) + return compiled_model + + + +def parse_args(): + parser = argparse.ArgumentParser(description="mindietorch clip model compilation") + parser.add_argument("--soc_version", default="Ascend910A") + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--text_max_batch", type=int, default=64) + parser.add_argument("--img_max_batch", type=int, default=16) + parser.add_argument( + "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + + parser.add_argument("--pytorch_model", default="/home/lizhongyang/siglip-so400m-patch14-384", type=str) + + parser.add_argument("--exported_program_path", + default="/home/lizhongyang/code/siglip/exported_siglip.pt2" + ) + parser.add_argument("--saved_compile_model_path", + default="/home/lizhongyang/code/siglip/compile_siglip.ts") + parser.add_argument( + "--precision", + default="fp16", + choices=["fp16", "fp32"], + help="Specify the architecture (model scale) of Chinese-CLIP model to be converted." + ) + parser.add_argument("--images_path", default="/home/lizhongyang/code/3.png") + parser.add_argument("--text", default=["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"]) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + export_compile(args) + +if __name__ == "__main__": + main() -- Gitee From 5c80c36c42f3a3120aa9897a94566d9ed250673c Mon Sep 17 00:00:00 2001 From: lizhongyang10 Date: Tue, 3 Sep 2024 20:53:22 +0800 Subject: [PATCH 2/7] check code --- .../multimodal/Siglip-SO400M/README.md | 32 ++++++++++------ .../Siglip-SO400M/precision_test.py | 38 +++++++++++-------- .../multimodal/Siglip-SO400M/pref_test_gpu.py | 17 ++++++++- .../multimodal/Siglip-SO400M/pref_test_npu.py | 10 ++--- .../multimodal/Siglip-SO400M/siglip_export.py | 31 +++++++++++---- 5 files changed, 86 insertions(+), 42 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md index d97162bee5..a43c25d7d4 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md @@ -29,8 +29,7 @@ CLIP 使用图像-文本对对对比损失进行网络预训练。CLIP 面临两 | torch | 2.1.0 | | torch_npu | 2.1.0post6 | | transformers | 4.44.2 | - | NPU芯片类型 | Ascend910B4 | - | 处理器架构 | arm64 | + | 处理器架构 | aarch64 | # 快速上手 @@ -44,12 +43,18 @@ https://huggingface.co/google/siglip-so400m-patch14-384/tree/main 新建一个用于保存文件和模型的目录,用变量名model_dir表示 将siglip-so400m-patch14-384文件夹放在model_dir中 +## 路径变量解释 +| 变量名 | 含义 | +| ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| working_dir | 新建的用于测试代码的工作目录 | +| model_dir | 用于保存文件和模型的目录,路径为 `${working_dir}/siglip_model/` | 1. 导出ts模型 1. 修改代码 + 在安装包路径site-packages/mindietorch/_dynamo/lowering/passes/mindie_lowering.py中, 注释掉 from mindietorch._dynamo.lowering.passes.lower_linear import lower_linear 这一行 注释掉 lower_linear, 这一行 @@ -57,6 +62,9 @@ https://huggingface.co/google/siglip-so400m-patch14-384/tree/main 在安装包路径site-packages/transformers/configuration_utils.py中, 将 self.return_dict = kwargs.pop("return_dict", True) 这一行的 True 改为 False + 修改这写代码主要是为了避免mindie-torch编译是受分支语句的影响,由于python的site-packages的位置每台设备并不固定,因此采用手动修改源码的方式。 + + 2. 转为MindIETorch模型 其中test_img为一张测试图片的名称,保存在model_dir下 @@ -64,10 +72,10 @@ https://huggingface.co/google/siglip-so400m-patch14-384/tree/main ```sh python siglip_export.py \ - --pytorch_model model_dir/siglip-so400m-patch14-384 \ - --exported_program_path model_dir/exported_siglip.pt2 \ - --saved_compile_model_path model_dir/compile_siglip.ts \ - --images_path model_dir/{test_img} + --pytorch_model ${model_dir}/siglip-so400m-patch14-384 \ + --exported_program_path ${model_dir}/exported_siglip.pt2 \ + --saved_compile_model_path ${model_dir}/compile_siglip.ts \ + --images_path ${model_dir}/{test_img} ``` 会在model_dir目录下生成compile_siglip.ts文件。 @@ -80,9 +88,9 @@ https://huggingface.co/google/siglip-so400m-patch14-384/tree/main ```sh python precision_test.py \ - --pytorch_model model_dir/siglip-so400m-patch14-384 \ - --exported_program_path model_dir/exported_siglip.pt2 \ - --saved_compile_model_path model_dir/compile_siglip.ts + --pytorch_model ${model_dir}/siglip-so400m-patch14-384 \ + --exported_program_path ${model_dir}/exported_siglip.pt2 \ + --saved_compile_model_path ${model_dir}/compile_siglip.ts ``` ## 性能验证 @@ -91,15 +99,15 @@ npu性能测试,执行pref_test_npu.py: ```sh python pref_test_npu.py \ - --config_path model_dir/siglip-so400m-patch14-384 \ - --encoder_aie_path model_dir/compile_siglip.ts + --config_path ${model_dir}/siglip-so400m-patch14-384 \ + --encoder_aie_path ${model_dir}/compile_siglip.ts ``` gpu性能测试,执行pref_test_gpu.py: ```sh python pref_test_gpu.py \ - --model_path model_dir/siglip-so400m-patch14-384 \ + --model_path ${model_dir}/siglip-so400m-patch14-384 \ ``` 屏幕上会打印性能数据,以FPS记 diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py index 20d4282349..4c8c74e9b4 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py @@ -1,3 +1,18 @@ +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import argparse import os import time @@ -28,7 +43,7 @@ def cosine_similarity(gt_tensor, pred_tensor): def inference(args): # Load compiled model - mindietorch.set_device(args.device) + mindietorch.set_device(args.device_id) stream = mindietorch.npu.Stream() mindietorch.npu.set_stream(stream) compiled_model = torch.load(args.saved_compile_model_path) @@ -47,13 +62,10 @@ def inference(args): text_c = np.random.randint(0, 25566, size=(3, 64), dtype=np.int64) text_c = torch.from_numpy(text_c) - text = text_c - img = img_c + text_npu = text_c.to(torch.int64).to(f"npu:{args.device_id}") + img_npu = img_c.to(torch.float32).to(f"npu:{args.device_id}") - text_npu = text.to(torch.int64).to(f"npu:{args.device}") - img_npu = img.to(torch.float32).to(f"npu:{args.device}") - - pytorch_results = pytorch_model(text, img) + pytorch_results = pytorch_model(text_c, img_c) pytorch_img_result, pytorch_text_results = pytorch_results[0], pytorch_results[1] pytorch_text_results = pytorch_text_results / pytorch_text_results.norm(p=2, dim=-1, keepdim=True) pytorch_img_result = pytorch_img_result / pytorch_img_result.norm(p=2, dim=-1, keepdim=True) @@ -67,28 +79,24 @@ def inference(args): text_similarity = cosine_similarity(pytorch_text_results, compiled_text_results) print("Cosine Similarity: ", f"\n Image: {img_similarity}, Text: {text_similarity}") - # pytorch_logits = (100.0 * pytorch_img_result @ pytorch_text_results.T).softmax(dim=-1) - # compile_logits = (100.0 * compiled_img_results @ compiled_text_results.T).softmax(dim=-1) - # print("Prediction score: ", f"\nPytorch Model Score: {pytorch_logits}, ", f"\nCompile Model Score: {compile_logits}") - def parse_args(): parser = argparse.ArgumentParser(description="mindietorch clip model compilation") parser.add_argument("--soc_version", default="Ascend910A") - parser.add_argument("--device", type=int, default=0) + parser.add_argument("--device_id", type=int, default=0) parser.add_argument("--text_max_batch", type=int, default=64) parser.add_argument("--img_max_batch", type=int, default=16) parser.add_argument( "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)." ) - parser.add_argument("--pytorch_model", default="/home/lizhongyang/siglip-so400m-patch14-384", type=str) + parser.add_argument("--pytorch_model", default="", type=str) parser.add_argument("--exported_program_path", - default="/home/lizhongyang/code/siglip/exported_siglip.pt2" + default="" ) parser.add_argument("--saved_compile_model_path", - default="/home/lizhongyang/code/siglip/compile_siglip.ts") + default="") parser.add_argument( "--precision", default="fp16", diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py index 61b0d7e949..bb6872b63a 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py @@ -1,3 +1,18 @@ +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import sys import torch import numpy as np @@ -36,7 +51,7 @@ def test_encoder(aie_path, device_id=0): print("Start infering...") # warmup - for _ in range(100): + for _ in range(20): ts(text, img) # performance test diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py index f14e573824..9b0d8eb08e 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py @@ -36,9 +36,6 @@ def test_encoder(aie_path, config_path, device_id=0): processor = AutoProcessor.from_pretrained(config_path) text = ["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"] - # image = Image.open('/home/lizhongyang/code/3.png') - # if image.mode != 'RGB': - # image = image.convert('RGB') img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8) img_c = img_c.astype(np.float32) # img_c = (img_c / 255. - 0.5) / 0.5 # torch style norm @@ -74,8 +71,8 @@ def test_encoder(aie_path, config_path, device_id=0): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--encoder_aie_path", type=str, default='/data1/lizhongyang/imgs/compile_siglip.ts') - parser.add_argument("--config_path", type=str, default='/data1/lizhongyang/imgs/siglip-so400m-patch14-384') + parser.add_argument("--encoder_aie_path", type=str, default='') + parser.add_argument("--config_path", type=str, default='') parser.add_argument("--device_id", type=int, help="NPU device id", default=0) args = parser.parse_args() @@ -83,8 +80,9 @@ def parse_args(): def main(): - mindietorch.set_device(0) + args = parse_args() + mindietorch.set_device(args.device_id) test_encoder(args.encoder_aie_path, args.config_path, args.device_id) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py index e595e487fe..17d16971c7 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py @@ -1,3 +1,18 @@ +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import argparse import os import time @@ -46,7 +61,7 @@ def export_compile(args): mindietorch.Input(min_shape=img_min_shape, max_shape=img_max_shape, dtype=torch.float32,) ] - mindietorch.set_device(0) + mindietorch.set_device(args.device_id) if not os.path.exists(args.saved_compile_model_path): compiled_encoder_model = mindietorch.compile( @@ -55,7 +70,7 @@ def export_compile(args): # precision_policy=mindietorch.PrecisionPolicy.FP16, precision_policy=_enums.PrecisionPolicy.FP16, # truncate_long_and_double=True, - soc_version="Ascend910A" + soc_version="Ascend910B4" ) torch.save(compiled_encoder_model, args.saved_compile_model_path, pickle_protocol=4) @@ -68,28 +83,28 @@ def export_compile(args): def parse_args(): parser = argparse.ArgumentParser(description="mindietorch clip model compilation") - parser.add_argument("--soc_version", default="Ascend910A") - parser.add_argument("--device", type=int, default=0) + parser.add_argument("--soc_version", default="Ascend910B4") + parser.add_argument("--device_id", type=int, default=0) parser.add_argument("--text_max_batch", type=int, default=64) parser.add_argument("--img_max_batch", type=int, default=16) parser.add_argument( "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)." ) - parser.add_argument("--pytorch_model", default="/home/lizhongyang/siglip-so400m-patch14-384", type=str) + parser.add_argument("--pytorch_model", default="", type=str) parser.add_argument("--exported_program_path", - default="/home/lizhongyang/code/siglip/exported_siglip.pt2" + default="" ) parser.add_argument("--saved_compile_model_path", - default="/home/lizhongyang/code/siglip/compile_siglip.ts") + default="") parser.add_argument( "--precision", default="fp16", choices=["fp16", "fp32"], help="Specify the architecture (model scale) of Chinese-CLIP model to be converted." ) - parser.add_argument("--images_path", default="/home/lizhongyang/code/3.png") + parser.add_argument("--images_path", default="") parser.add_argument("--text", default=["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"]) args = parser.parse_args() return args -- Gitee From 5cc31cfeb5bf02f123f7d84290e4056cef87ec97 Mon Sep 17 00:00:00 2001 From: lizhongyang10 Date: Tue, 3 Sep 2024 20:57:06 +0800 Subject: [PATCH 3/7] check code. --- MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md index a43c25d7d4..7cde37cfb8 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md @@ -62,7 +62,7 @@ https://huggingface.co/google/siglip-so400m-patch14-384/tree/main 在安装包路径site-packages/transformers/configuration_utils.py中, 将 self.return_dict = kwargs.pop("return_dict", True) 这一行的 True 改为 False - 修改这写代码主要是为了避免mindie-torch编译是受分支语句的影响,由于python的site-packages的位置每台设备并不固定,因此采用手动修改源码的方式。 + 修改这写代码主要是为了避免mindie-torch编译时受到分支语句的影响,导致编译模型加载失败的情况;由于python的site-packages的位置每台设备并不固定,因此采用手动修改源码的方式。 2. 转为MindIETorch模型 -- Gitee From 8de049ec227864ba404519a0f0b1d15eddd56eea Mon Sep 17 00:00:00 2001 From: lizhongyang10 Date: Tue, 3 Sep 2024 21:04:25 +0800 Subject: [PATCH 4/7] codeckeck --- .../built-in/multimodal/Siglip-SO400M/pref_test_gpu.py | 2 +- .../built-in/multimodal/Siglip-SO400M/pref_test_npu.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py index bb6872b63a..98b4d33f90 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py @@ -69,7 +69,7 @@ def test_encoder(aie_path, device_id=0): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--model_path", type=str, default="/root/lizhongyang/data/siglip-so400m-patch14-384") + parser.add_argument("--model_path", type=str, default="") parser.add_argument("--device_id", type=int, help="NPU device id", default=0) args = parser.parse_args() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py index 9b0d8eb08e..c1c27866e1 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py @@ -28,7 +28,7 @@ def test_encoder(aie_path, config_path, device_id=0): device = f'npu:{device_id}' stream = mindietorch.npu.Stream(device) print("Start loading ts module...") - # ts = torch.jit.load(aie_path) + compiled_model = torch.load(aie_path) print("Ts module loaded.") compiled_model.eval() @@ -51,7 +51,7 @@ def test_encoder(aie_path, config_path, device_id=0): print("Start infering...") # warmup - for _ in range(100): + for _ in range(20): with mindietorch.npu.stream(stream): compiled_model(text_npu, img_npu) stream.synchronize() -- Gitee From 04bdba6caa6ce6585ae448755d3dc17b9560d83a Mon Sep 17 00:00:00 2001 From: lizhongyang10 Date: Wed, 4 Sep 2024 10:36:30 +0800 Subject: [PATCH 5/7] add siglip codecheck --- .../multimodal/Siglip-SO400M/precision_test.py | 6 +++--- .../multimodal/Siglip-SO400M/pref_test_gpu.py | 2 +- .../multimodal/Siglip-SO400M/pref_test_npu.py | 4 ++-- .../multimodal/Siglip-SO400M/siglip_export.py | 11 ++++++++--- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py index 4c8c74e9b4..1b0278a9e8 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py @@ -90,13 +90,13 @@ def parse_args(): "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)." ) - parser.add_argument("--pytorch_model", default="", type=str) + parser.add_argument("--pytorch_model", type=str, required=True) parser.add_argument("--exported_program_path", - default="" + type=str, required=True ) parser.add_argument("--saved_compile_model_path", - default="") + type=str, required=True) parser.add_argument( "--precision", default="fp16", diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py index 98b4d33f90..8538c4824d 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py @@ -69,7 +69,7 @@ def test_encoder(aie_path, device_id=0): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--model_path", type=str, default="") + parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--device_id", type=int, help="NPU device id", default=0) args = parser.parse_args() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py index c1c27866e1..78c29a5bf1 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py @@ -71,8 +71,8 @@ def test_encoder(aie_path, config_path, device_id=0): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument("--encoder_aie_path", type=str, default='') - parser.add_argument("--config_path", type=str, default='') + parser.add_argument("--encoder_aie_path", type=str, required=True) + parser.add_argument("--config_path", type=str, required=True) parser.add_argument("--device_id", type=int, help="NPU device id", default=0) args = parser.parse_args() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py index 17d16971c7..612ac2d60e 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py @@ -94,17 +94,22 @@ def parse_args(): parser.add_argument("--pytorch_model", default="", type=str) parser.add_argument("--exported_program_path", - default="" + type=str, + required=True ) parser.add_argument("--saved_compile_model_path", - default="") + type=str, + required=True + ) parser.add_argument( "--precision", default="fp16", choices=["fp16", "fp32"], help="Specify the architecture (model scale) of Chinese-CLIP model to be converted." ) - parser.add_argument("--images_path", default="") + parser.add_argument("--images_path", type=str, + required=True + ) parser.add_argument("--text", default=["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"]) args = parser.parse_args() return args -- Gitee From 29e447fbb84f39e9c40f8ba7bff93f4c66873067 Mon Sep 17 00:00:00 2001 From: lizhongyang10 Date: Wed, 4 Sep 2024 08:05:58 +0000 Subject: [PATCH 6/7] update MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py. Signed-off-by: lizhongyang10 --- .../built-in/multimodal/Siglip-SO400M/siglip_export.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py index 612ac2d60e..233dea5d0c 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py @@ -32,11 +32,10 @@ def export_compile(args): model = AutoModel.from_pretrained(args.pytorch_model) processor = AutoProcessor.from_pretrained(args.pytorch_model) - images = [] image = Image.open(args.images_path) if image.mode != 'RGB': image = image.convert('RGB') - images.append(image) + images = [image, image] inputs = processor(text=args.text, images=images, padding="max_length", return_tensors="pt") @@ -91,7 +90,7 @@ def parse_args(): "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)." ) - parser.add_argument("--pytorch_model", default="", type=str) + parser.add_argument("--pytorch_model", type=str, required=True) parser.add_argument("--exported_program_path", type=str, -- Gitee From 607574d5af58ce971ac7199de7fa169b8c795296 Mon Sep 17 00:00:00 2001 From: lizhongyang10 Date: Wed, 4 Sep 2024 08:09:13 +0000 Subject: [PATCH 7/7] update MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md. Signed-off-by: lizhongyang10 --- MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md index 7cde37cfb8..5e8dee83a9 100644 --- a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md @@ -116,4 +116,4 @@ gpu性能测试,执行pref_test_gpu.py: ## 性能数据 (时延/吞吐率) |Model| MindIE Torch | T4| |------| ----------------- |------| -|encoder| 114.98ms/8.7FPS | 20.53ms/48.70FPS | +|encoder| 114.98ms/8.7FPS | 580.67ms/1.72FPS | -- Gitee