diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5e8dee83a985ef8a61c498340348e2839e6f4968
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/README.md
@@ -0,0 +1,119 @@
+# Siglip模型-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573)
+
+
+# 概述
+
+CLIP 使用图像-文本对对对比损失进行网络预训练。CLIP 面临两个技术挑战:(1)它需要较大的批大小。例如,CLIP 使用了 32K 的批量大小,这需要大量 GPU;(2)它需要这些 GPU 之间进行大量通信。因此,后续提出了SigLIP,以减少 CLIP 的批量大小要求。SigLip 是Sigmoid Language Image Pre -training 的缩写。SigLIP 的核心思想是使用 Sigmoid运算而不是Softmax运算。
+
+
+# 推理环境准备\[所有版本\]
+
+- 该模型需要以下依赖
+
+ **表 1** 版本配套表
+
+ | 配套 | 版本 |
+ |---------|---------|
+ | CANN | 8.0RC3 |
+ | MindIE | 1.0RC3 |
+ | Python | 3.10.13 |
+ | torch | 2.1.0 |
+ | torch_npu | 2.1.0post6 |
+ | transformers | 4.44.2 |
+ | 处理器架构 | aarch64 |
+
+# 快速上手
+
+
+
+## 模型转换
+1. 下载模型
+
+从以下HuggingFace链接下载所需文件,所需文件为整个文件夹
+https://huggingface.co/google/siglip-so400m-patch14-384/tree/main
+
+新建一个用于保存文件和模型的目录,用变量名model_dir表示
+将siglip-so400m-patch14-384文件夹放在model_dir中
+## 路径变量解释
+| 变量名 | 含义 |
+| ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| working_dir | 新建的用于测试代码的工作目录 |
+| model_dir | 用于保存文件和模型的目录,路径为 `${working_dir}/siglip_model/` |
+
+
+
+1. 导出ts模型
+
+ 1. 修改代码
+
+ 在安装包路径site-packages/mindietorch/_dynamo/lowering/passes/mindie_lowering.py中,
+ 注释掉 from mindietorch._dynamo.lowering.passes.lower_linear import lower_linear 这一行
+ 注释掉 lower_linear, 这一行
+
+ 在安装包路径site-packages/transformers/configuration_utils.py中,
+ 将 self.return_dict = kwargs.pop("return_dict", True) 这一行的 True 改为 False
+
+ 修改这写代码主要是为了避免mindie-torch编译时受到分支语句的影响,导致编译模型加载失败的情况;由于python的site-packages的位置每台设备并不固定,因此采用手动修改源码的方式。
+
+
+ 2. 转为MindIETorch模型
+ 其中test_img为一张测试图片的名称,保存在model_dir下
+
+ 执行siglip_export.py:
+
+ ```sh
+ python siglip_export.py \
+ --pytorch_model ${model_dir}/siglip-so400m-patch14-384 \
+ --exported_program_path ${model_dir}/exported_siglip.pt2 \
+ --saved_compile_model_path ${model_dir}/compile_siglip.ts \
+ --images_path ${model_dir}/{test_img}
+ ```
+
+ 会在model_dir目录下生成compile_siglip.ts文件。
+
+## 精度验证
+
+模型精度验证,屏幕显示Cosine Similarity的Image和Text分数都大于0.9999 为精度正常。
+
+执行precision_test.py:
+
+ ```sh
+ python precision_test.py \
+ --pytorch_model ${model_dir}/siglip-so400m-patch14-384 \
+ --exported_program_path ${model_dir}/exported_siglip.pt2 \
+ --saved_compile_model_path ${model_dir}/compile_siglip.ts
+ ```
+
+## 性能验证
+
+npu性能测试,执行pref_test_npu.py:
+
+ ```sh
+ python pref_test_npu.py \
+ --config_path ${model_dir}/siglip-so400m-patch14-384 \
+ --encoder_aie_path ${model_dir}/compile_siglip.ts
+ ```
+
+gpu性能测试,执行pref_test_gpu.py:
+
+ ```sh
+ python pref_test_gpu.py \
+ --model_path ${model_dir}/siglip-so400m-patch14-384 \
+ ```
+
+屏幕上会打印性能数据,以FPS记
+
+
+## 性能数据 (时延/吞吐率)
+|Model| MindIE Torch | T4|
+|------| ----------------- |------|
+|encoder| 114.98ms/8.7FPS | 580.67ms/1.72FPS |
diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b0278a9e8967f4fd04f6fd81f63d13d82c11b53
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/precision_test.py
@@ -0,0 +1,115 @@
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import os
+import time
+import sys
+import torch
+import mindietorch
+import torch.nn as nn
+import numpy as np
+from PIL import Image
+from mindietorch import _enums
+from torch._export import export, dynamic_dim
+from transformers import AutoProcessor
+from transformers.models.auto.modeling_auto import AutoModel
+
+
+
+def cosine_similarity(gt_tensor, pred_tensor):
+ gt_tensor = gt_tensor.flatten().to(torch.float32)
+ pred_tensor = pred_tensor.flatten().to(torch.float32)
+ if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
+ if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
+ return 1.0
+ res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
+ res = res.cpu().detach().item()
+ return res
+
+
+def inference(args):
+
+ # Load compiled model
+ mindietorch.set_device(args.device_id)
+ stream = mindietorch.npu.Stream()
+ mindietorch.npu.set_stream(stream)
+ compiled_model = torch.load(args.saved_compile_model_path)
+
+
+ # prepare PyTorch implemented model
+ pytorch_model = AutoModel.from_pretrained(args.pytorch_model)
+ processor = AutoProcessor.from_pretrained(args.pytorch_model)
+
+
+ img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8)
+ img_c = img_c.astype(np.float32)
+ img_c = (img_c / 255. - 0.5) / 0.5 # torch style norm
+ img_c = torch.from_numpy(img_c)
+
+ text_c = np.random.randint(0, 25566, size=(3, 64), dtype=np.int64)
+ text_c = torch.from_numpy(text_c)
+
+ text_npu = text_c.to(torch.int64).to(f"npu:{args.device_id}")
+ img_npu = img_c.to(torch.float32).to(f"npu:{args.device_id}")
+
+ pytorch_results = pytorch_model(text_c, img_c)
+ pytorch_img_result, pytorch_text_results = pytorch_results[0], pytorch_results[1]
+ pytorch_text_results = pytorch_text_results / pytorch_text_results.norm(p=2, dim=-1, keepdim=True)
+ pytorch_img_result = pytorch_img_result / pytorch_img_result.norm(p=2, dim=-1, keepdim=True)
+
+ compiled_results = compiled_model(text_npu, img_npu)
+ compiled_img_results, compiled_text_results = compiled_results[0].to('cpu'), compiled_results[1].to('cpu')
+ compiled_text_results = compiled_text_results / compiled_text_results.norm(p=2, dim=-1, keepdim=True)
+ compiled_img_results = compiled_img_results / compiled_img_results.norm(p=2, dim=-1, keepdim=True)
+
+ img_similarity = cosine_similarity(pytorch_img_result, compiled_img_results)
+ text_similarity = cosine_similarity(pytorch_text_results, compiled_text_results)
+ print("Cosine Similarity: ", f"\n Image: {img_similarity}, Text: {text_similarity}")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="mindietorch clip model compilation")
+ parser.add_argument("--soc_version", default="Ascend910A")
+ parser.add_argument("--device_id", type=int, default=0)
+ parser.add_argument("--text_max_batch", type=int, default=64)
+ parser.add_argument("--img_max_batch", type=int, default=16)
+ parser.add_argument(
+ "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)."
+ )
+
+ parser.add_argument("--pytorch_model", type=str, required=True)
+
+ parser.add_argument("--exported_program_path",
+ type=str, required=True
+ )
+ parser.add_argument("--saved_compile_model_path",
+ type=str, required=True)
+ parser.add_argument(
+ "--precision",
+ default="fp16",
+ choices=["fp16", "fp32"],
+ help="Specify the architecture (model scale) of Chinese-CLIP model to be converted."
+ )
+ parser.add_argument("--text", default=["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"])
+ args = parser.parse_args()
+ return args
+
+def main():
+ args = parse_args()
+ inference(args)
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..8538c4824ddb888dcaeeddb2d23ede234f09de3c
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_gpu.py
@@ -0,0 +1,85 @@
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import sys
+import torch
+import numpy as np
+from torch.export import export, dynamic_dim
+from PIL import Image
+import requests
+import argparse
+import time
+from tqdm import tqdm
+from transformers import AutoProcessor, AutoModel
+
+
+def test_encoder(aie_path, device_id=0):
+ batch_size = 1
+ device = f'cuda:{device_id}'
+
+ print("Start loading ts module...")
+ ts = AutoModel.from_pretrained(aie_path)
+
+ ts = ts.to(device)
+ print("Ts module loaded.")
+
+ processor = AutoProcessor.from_pretrained(aie_path)
+
+ img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8)
+ img_c = img_c.astype(np.float32)
+ img_c = torch.from_numpy(img_c)
+
+ texts = ["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"]
+ inputs = processor(text=texts, images=img_c, padding="max_length", return_tensors="pt")
+
+ text = inputs.input_ids
+ img = inputs.pixel_values
+ text = text.to(device)
+ img = img.to(device)
+
+ print("Start infering...")
+ # warmup
+ for _ in range(20):
+ ts(text, img)
+
+ # performance test
+ num_infer = 200
+
+ start = time.time()
+ for _ in tqdm(range(num_infer)):
+ ts(text, img)
+ end = time.time()
+
+ print(f"Encoder latency: {(end - start) / num_infer * 1000:.2f} ms")
+ print(f"Encoder throughput: {num_infer * batch_size / (end - start):.2f} fps")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", type=str, required=True)
+ parser.add_argument("--device_id", type=int, help="NPU device id", default=0)
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+ test_encoder(args.model_path, args.device_id)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py
new file mode 100644
index 0000000000000000000000000000000000000000..78c29a5bf18ee4f7c70ac30459a801051fb67567
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/pref_test_npu.py
@@ -0,0 +1,91 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import time
+import argparse
+import json
+from PIL import Image
+import numpy as np
+import torch
+import mindietorch
+from tqdm import tqdm
+
+
+def test_encoder(aie_path, config_path, device_id=0):
+ batch_size = 1
+ device = f'npu:{device_id}'
+ stream = mindietorch.npu.Stream(device)
+ print("Start loading ts module...")
+
+ compiled_model = torch.load(aie_path)
+ print("Ts module loaded.")
+ compiled_model.eval()
+ from transformers import AutoProcessor
+ processor = AutoProcessor.from_pretrained(config_path)
+ text = ["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"]
+
+ img_c = np.random.randint(0, 255, size=(2, 3, 384, 384), dtype=np.uint8)
+ img_c = img_c.astype(np.float32)
+ # img_c = (img_c / 255. - 0.5) / 0.5 # torch style norm
+ img_c = torch.from_numpy(img_c)
+ image = img_c
+
+ inputs = processor(text=text, images=image, return_tensors="pt", padding="max_length")
+
+ text = inputs.input_ids
+ img = inputs.pixel_values
+ text_npu = text.to(torch.int64).to(device)
+ img_npu = img.to(torch.float32).to(device)
+
+ print("Start infering...")
+ # warmup
+ for _ in range(20):
+ with mindietorch.npu.stream(stream):
+ compiled_model(text_npu, img_npu)
+ stream.synchronize()
+ # performance test
+ num_infer = 200
+ start = time.time()
+ for _ in tqdm(range(num_infer)):
+ with mindietorch.npu.stream(stream):
+ compiled_model(text_npu, img_npu)
+ stream.synchronize()
+ end = time.time()
+
+ print(f"Encoder latency: {(end - start) / num_infer * 1000:.2f} ms")
+ print(f"Encoder throughput: {num_infer * batch_size / (end - start):.2f} FPS")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--encoder_aie_path", type=str, required=True)
+ parser.add_argument("--config_path", type=str, required=True)
+ parser.add_argument("--device_id", type=int, help="NPU device id", default=0)
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+
+ args = parse_args()
+ mindietorch.set_device(args.device_id)
+
+ test_encoder(args.encoder_aie_path, args.config_path, args.device_id)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..233dea5d0ce5881028f818d6813c5ff2716def36
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/multimodal/Siglip-SO400M/siglip_export.py
@@ -0,0 +1,121 @@
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import os
+import time
+import sys
+import torch
+import mindietorch
+import torch.nn as nn
+import numpy as np
+from PIL import Image
+from mindietorch import _enums
+from torch._export import export, dynamic_dim
+from transformers import AutoProcessor
+from transformers.models.auto.modeling_auto import AutoModel
+
+
+def export_compile(args):
+ model = AutoModel.from_pretrained(args.pytorch_model)
+ processor = AutoProcessor.from_pretrained(args.pytorch_model)
+
+ image = Image.open(args.images_path)
+ if image.mode != 'RGB':
+ image = image.convert('RGB')
+ images = [image, image]
+
+ inputs = processor(text=args.text, images=images, padding="max_length", return_tensors="pt")
+
+ input_ids = inputs["input_ids"]
+ pixel_values = inputs["pixel_values"]
+
+ constraints = [
+ dynamic_dim(input_ids, 0) >= 1, dynamic_dim(input_ids, 0) <= args.text_max_batch,
+ dynamic_dim(pixel_values, 0) >= 1, dynamic_dim(pixel_values, 0) <= args.img_max_batch,
+ ]
+
+ if not os.path.exists(args.exported_program_path):
+ encoder_ts_model = export(model, (input_ids, pixel_values), constraints=constraints)
+ torch.export.save(encoder_ts_model, args.exported_program_path)
+
+ img_min_shape = (1, 3, 384, 384)
+ img_max_shape = (16, 3, 384, 384)
+ text_min_shape = (1, 64)
+ text_max_shape = (64, 64)
+ compile_inputs = [
+ mindietorch.Input(min_shape=text_min_shape, max_shape=text_max_shape, dtype=torch.int64,),
+ mindietorch.Input(min_shape=img_min_shape, max_shape=img_max_shape, dtype=torch.float32,)
+ ]
+
+ mindietorch.set_device(args.device_id)
+
+ if not os.path.exists(args.saved_compile_model_path):
+ compiled_encoder_model = mindietorch.compile(
+ encoder_ts_model,
+ inputs=compile_inputs,
+ # precision_policy=mindietorch.PrecisionPolicy.FP16,
+ precision_policy=_enums.PrecisionPolicy.FP16,
+ # truncate_long_and_double=True,
+ soc_version="Ascend910B4"
+ )
+
+ torch.save(compiled_encoder_model, args.saved_compile_model_path, pickle_protocol=4)
+ return compiled_encoder_model
+ else:
+ compiled_model = torch.load(args.saved_compile_model_path)
+ return compiled_model
+
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="mindietorch clip model compilation")
+ parser.add_argument("--soc_version", default="Ascend910B4")
+ parser.add_argument("--device_id", type=int, default=0)
+ parser.add_argument("--text_max_batch", type=int, default=64)
+ parser.add_argument("--img_max_batch", type=int, default=16)
+ parser.add_argument(
+ "--context-length", type=int, default=64, help="The padded length of input text (include [CLS] & [SEP] tokens)."
+ )
+
+ parser.add_argument("--pytorch_model", type=str, required=True)
+
+ parser.add_argument("--exported_program_path",
+ type=str,
+ required=True
+ )
+ parser.add_argument("--saved_compile_model_path",
+ type=str,
+ required=True
+ )
+ parser.add_argument(
+ "--precision",
+ default="fp16",
+ choices=["fp16", "fp32"],
+ help="Specify the architecture (model scale) of Chinese-CLIP model to be converted."
+ )
+ parser.add_argument("--images_path", type=str,
+ required=True
+ )
+ parser.add_argument("--text", default=["a photo of 2 cats", "a photo of 2 dogs", "a photo of 1 cat"])
+ args = parser.parse_args()
+ return args
+
+def main():
+ args = parse_args()
+ export_compile(args)
+
+if __name__ == "__main__":
+ main()