diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5e8dee83a985ef8a61c498340348e2839e6f4968 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md @@ -0,0 +1,119 @@ +# Siglip模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + +- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + +CLIP 使用图像-文本对对对比损失进行网络预训练。CLIP 面临两个技术挑战:(1)它需要较大的批大小。例如,CLIP 使用了 32K 的批量大小,这需要大量 GPU;(2)它需要这些 GPU 之间进行大量通信。因此,后续提出了SigLIP,以减少 CLIP 的批量大小要求。SigLip 是Sigmoid Language Image Pre -training 的缩写。SigLIP 的核心思想是使用 Sigmoid运算而不是Softmax运算。 + + +# 推理环境准备\[所有版本\] + +- 该模型需要以下依赖 + + **表 1** 版本配套表 + + | 配套 | 版本 | + |---------|---------| + | CANN | 8.0RC3 | + | MindIE | 1.0RC3 | + | Python | 3.10.13 | + | torch | 2.1.0 | + | torch_npu | 2.1.0post6 | + | transformers | 4.44.2 | + | 处理器架构 | aarch64 | + +# 快速上手 + + + +## 模型转换 +1. 下载模型 + +从以下HuggingFace链接下载所需文件,所需文件为整个文件夹 +https://huggingface.co/google/siglip-so400m-patch14-384/tree/main + +新建一个用于保存文件和模型的目录,用变量名model_dir表示 +将siglip-so400m-patch14-384文件夹放在model_dir中 +## 路径变量解释 +| 变量名 | 含义 | +| ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| working_dir | 新建的用于测试代码的工作目录 | +| model_dir | 用于保存文件和模型的目录,路径为 `${working_dir}/siglip_model/` | + + + +1. 导出ts模型 + + 1. 修改代码 + + 在安装包路径site-packages/mindietorch/_dynamo/lowering/passes/mindie_lowering.py中, + 注释掉 from mindietorch._dynamo.lowering.passes.lower_linear import lower_linear 这一行 + 注释掉 lower_linear, 这一行 + + 在安装包路径site-packages/transformers/configuration_utils.py中, + 将 self.return_dict = kwargs.pop("return_dict", True) 这一行的 True 改为 False + + 修改这写代码主要是为了避免mindie-torch编译时受到分支语句的影响,导致编译模型加载失败的情况;由于python的site-packages的位置每台设备并不固定,因此采用手动修改源码的方式。 + + + 2. 转为MindIETorch模型 + 其中test_img为一张测试图片的名称,保存在model_dir下 + + 执行siglip_export.py: + + ```sh + python siglip_export.py \ + --pytorch_model ${model_dir}/siglip-so400m-patch14-384 \ + --exported_program_path ${model_dir}/exported_siglip.pt2 \ + --saved_compile_model_path ${model_dir}/compile_siglip.ts \ + --images_path ${model_dir}/{test_img} + ``` + + 会在model_dir目录下生成compile_siglip.ts文件。 + +## 精度验证 + +模型精度验证,屏幕显示Cosine Similarity的Image和Text分数都大于0.9999 为精度正常。 + +执行precision_test.py: + + ```sh + python precision_test.py \ + --pytorch_model ${model_dir}/siglip-so400m-patch14-384 \ + --exported_program_path ${model_dir}/exported_siglip.pt2 \ + --saved_compile_model_path ${model_dir}/compile_siglip.ts + ``` + +## 性能验证 + +npu性能测试,执行pref_test_npu.py: + + ```sh + python pref_test_npu.py \ + --config_path ${model_dir}/siglip-so400m-patch14-384 \ + --encoder_aie_path ${model_dir}/compile_siglip.ts + ``` + +gpu性能测试,执行pref_test_gpu.py: + + ```sh + python pref_test_gpu.py \ + --model_path ${model_dir}/siglip-so400m-patch14-384 \ + ``` + +屏幕上会打印性能数据,以FPS记 + + +## 性能数据 (时延/吞吐率) +|Model| MindIE Torch | T4| +|------| ----------------- |------| +|encoder| 114.98ms/8.7FPS | 580.67ms/1.72FPS | diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0278a9e8967f4fd04f6fd81f63d13d82c11b53 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py @@ -0,0 +1,115 @@ +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import time +import sys +import torch +import mindietorch +import torch.nn as nn +import numpy as np +from PIL import Image +from mindietorch import _enums +from torch._export import export, dynamic_dim +from transformers import AutoProcessor +from transformers.models.auto.modeling_auto import AutoModel + + + +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + return res + + +def inference(args): + + # Load compiled model + mindietorch.set_device(args.device_id) + stream = mindietorch.npu.Stream() + mindietorch.npu.set_stream(stream) + compiled_model = torch.load(args.saved_compile_model_path) + + + # prepare PyTorch implemented model + pytorch_model = AutoModel.from_pretrained(args.pytorch_model) + processor = AutoProcessor.from_pretrained(args.pytorch_model) + + + img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8) + img_c = img_c.astype(np.float32) + img_c = (img_c / 255. - 0.5) / 0.5 # torch style norm + img_c = torch.from_numpy(img_c) + + text_c = np.random.randint(0, 25566, size=(3, 64), dtype=np.int64) + text_c = torch.from_numpy(text_c) + + text_npu = text_c.to(torch.int64).to(f"npu:{args.device_id}") + img_npu = img_c.to(torch.float32).to(f"npu:{args.device_id}") + + pytorch_results = pytorch_model(text_c, img_c) + pytorch_img_result, pytorch_text_results = pytorch_results[0], pytorch_results[1] + pytorch_text_results = pytorch_text_results / pytorch_text_results.norm(p=2, dim=-1, keepdim=True) + pytorch_img_result = pytorch_img_result / pytorch_img_result.norm(p=2, dim=-1, keepdim=True) + + compiled_results = compiled_model(text_npu, img_npu) + compiled_img_results, compiled_text_results = compiled_results[0].to('cpu'), compiled_results[1].to('cpu') + compiled_text_results = compiled_text_results / compiled_text_results.norm(p=2, dim=-1, keepdim=True) + compiled_img_results = compiled_img_results / compiled_img_results.norm(p=2, dim=-1, keepdim=True) + + img_similarity = cosine_similarity(pytorch_img_result, compiled_img_results) + text_similarity = cosine_similarity(pytorch_text_results, compiled_text_results) + print("Cosine Similarity: ", f"\n Image: {img_similarity}, Text: {text_similarity}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="mindietorch clip model compilation") + parser.add_argument("--soc_version", default="Ascend910A") + parser.add_argument("--device_id", type=int, default=0) + parser.add_argument("--text_max_batch", type=int, default=64) + parser.add_argument("--img_max_batch", type=int, default=16) + parser.add_argument( + "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + + parser.add_argument("--pytorch_model", type=str, required=True) + + parser.add_argument("--exported_program_path", + type=str, required=True + ) + parser.add_argument("--saved_compile_model_path", + type=str, required=True) + parser.add_argument( + "--precision", + default="fp16", + choices=["fp16", "fp32"], + help="Specify the architecture (model scale) of Chinese-CLIP model to be converted." + ) + parser.add_argument("--text", default=["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"]) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + inference(args) + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..8538c4824ddb888dcaeeddb2d23ede234f09de3c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py @@ -0,0 +1,85 @@ +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +import torch +import numpy as np +from torch.export import export, dynamic_dim +from PIL import Image +import requests +import argparse +import time +from tqdm import tqdm +from transformers import AutoProcessor, AutoModel + + +def test_encoder(aie_path, device_id=0): + batch_size = 1 + device = f'cuda:{device_id}' + + print("Start loading ts module...") + ts = AutoModel.from_pretrained(aie_path) + + ts = ts.to(device) + print("Ts module loaded.") + + processor = AutoProcessor.from_pretrained(aie_path) + + img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8) + img_c = img_c.astype(np.float32) + img_c = torch.from_numpy(img_c) + + texts = ["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"] + inputs = processor(text=texts, images=img_c, padding="max_length", return_tensors="pt") + + text = inputs.input_ids + img = inputs.pixel_values + text = text.to(device) + img = img.to(device) + + print("Start infering...") + # warmup + for _ in range(20): + ts(text, img) + + # performance test + num_infer = 200 + + start = time.time() + for _ in tqdm(range(num_infer)): + ts(text, img) + end = time.time() + + print(f"Encoder latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"Encoder throughput: {num_infer * batch_size / (end - start):.2f} fps") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--device_id", type=int, help="NPU device id", default=0) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + test_encoder(args.model_path, args.device_id) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..78c29a5bf18ee4f7c70ac30459a801051fb67567 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py @@ -0,0 +1,91 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +import argparse +import json +from PIL import Image +import numpy as np +import torch +import mindietorch +from tqdm import tqdm + + +def test_encoder(aie_path, config_path, device_id=0): + batch_size = 1 + device = f'npu:{device_id}' + stream = mindietorch.npu.Stream(device) + print("Start loading ts module...") + + compiled_model = torch.load(aie_path) + print("Ts module loaded.") + compiled_model.eval() + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained(config_path) + text = ["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"] + + img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8) + img_c = img_c.astype(np.float32) + # img_c = (img_c / 255. - 0.5) / 0.5 # torch style norm + img_c = torch.from_numpy(img_c) + image = img_c + + inputs = processor(text=text, images=image, return_tensors="pt", padding="max_length") + + text = inputs.input_ids + img = inputs.pixel_values + text_npu = text.to(torch.int64).to(device) + img_npu = img.to(torch.float32).to(device) + + print("Start infering...") + # warmup + for _ in range(20): + with mindietorch.npu.stream(stream): + compiled_model(text_npu, img_npu) + stream.synchronize() + # performance test + num_infer = 200 + start = time.time() + for _ in tqdm(range(num_infer)): + with mindietorch.npu.stream(stream): + compiled_model(text_npu, img_npu) + stream.synchronize() + end = time.time() + + print(f"Encoder latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"Encoder throughput: {num_infer * batch_size / (end - start):.2f} FPS") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--encoder_aie_path", type=str, required=True) + parser.add_argument("--config_path", type=str, required=True) + parser.add_argument("--device_id", type=int, help="NPU device id", default=0) + + args = parser.parse_args() + return args + + +def main(): + + args = parse_args() + mindietorch.set_device(args.device_id) + + test_encoder(args.encoder_aie_path, args.config_path, args.device_id) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py new file mode 100644 index 0000000000000000000000000000000000000000..233dea5d0ce5881028f818d6813c5ff2716def36 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py @@ -0,0 +1,121 @@ +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import time +import sys +import torch +import mindietorch +import torch.nn as nn +import numpy as np +from PIL import Image +from mindietorch import _enums +from torch._export import export, dynamic_dim +from transformers import AutoProcessor +from transformers.models.auto.modeling_auto import AutoModel + + +def export_compile(args): + model = AutoModel.from_pretrained(args.pytorch_model) + processor = AutoProcessor.from_pretrained(args.pytorch_model) + + image = Image.open(args.images_path) + if image.mode != 'RGB': + image = image.convert('RGB') + images = [image, image] + + inputs = processor(text=args.text, images=images, padding="max_length", return_tensors="pt") + + input_ids = inputs["input_ids"] + pixel_values = inputs["pixel_values"] + + constraints = [ + dynamic_dim(input_ids, 0) >= 1, dynamic_dim(input_ids, 0) <= args.text_max_batch, + dynamic_dim(pixel_values, 0) >= 1, dynamic_dim(pixel_values, 0) <= args.img_max_batch, + ] + + if not os.path.exists(args.exported_program_path): + encoder_ts_model = export(model, (input_ids, pixel_values), constraints=constraints) + torch.export.save(encoder_ts_model, args.exported_program_path) + + img_min_shape = (1, 3, 384, 384) + img_max_shape = (16, 3, 384, 384) + text_min_shape = (1, 64) + text_max_shape = (64, 64) + compile_inputs = [ + mindietorch.Input(min_shape=text_min_shape, max_shape=text_max_shape, dtype=torch.int64,), + mindietorch.Input(min_shape=img_min_shape, max_shape=img_max_shape, dtype=torch.float32,) + ] + + mindietorch.set_device(args.device_id) + + if not os.path.exists(args.saved_compile_model_path): + compiled_encoder_model = mindietorch.compile( + encoder_ts_model, + inputs=compile_inputs, + # precision_policy=mindietorch.PrecisionPolicy.FP16, + precision_policy=_enums.PrecisionPolicy.FP16, + # truncate_long_and_double=True, + soc_version="Ascend910B4" + ) + + torch.save(compiled_encoder_model, args.saved_compile_model_path, pickle_protocol=4) + return compiled_encoder_model + else: + compiled_model = torch.load(args.saved_compile_model_path) + return compiled_model + + + +def parse_args(): + parser = argparse.ArgumentParser(description="mindietorch clip model compilation") + parser.add_argument("--soc_version", default="Ascend910B4") + parser.add_argument("--device_id", type=int, default=0) + parser.add_argument("--text_max_batch", type=int, default=64) + parser.add_argument("--img_max_batch", type=int, default=16) + parser.add_argument( + "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)." + ) + + parser.add_argument("--pytorch_model", type=str, required=True) + + parser.add_argument("--exported_program_path", + type=str, + required=True + ) + parser.add_argument("--saved_compile_model_path", + type=str, + required=True + ) + parser.add_argument( + "--precision", + default="fp16", + choices=["fp16", "fp32"], + help="Specify the architecture (model scale) of Chinese-CLIP model to be converted." + ) + parser.add_argument("--images_path", type=str, + required=True + ) + parser.add_argument("--text", default=["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"]) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + export_compile(args) + +if __name__ == "__main__": + main()