diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md b/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a0f49d274d8db7362b07c92eff8a6eb018caa37
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md
@@ -0,0 +1,177 @@
+# DINOv2-ViT-推理指导
+
+- [概述](#ZH-CN_TOPIC_0000001172161123)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126217823)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126288456)
+
+- [模型推理性能精度](#ZH-CN_TOPIC_0000001172156835)
+
+# 概述
+
+DINOv2是由Meta
+AI开发的一个自监督学习方法,它专注于在无监督学习鲁棒的视觉特征。这个方法基于最初的DINO方法并进行了改进,提升了ViT模型在各种计算机视觉任务上的性能。([来自开源代码仓](https://github.com/facebookresearch/dinov2/tree/main))
+
+# 推理环境准备\[所有版本\]
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+
+ | 配套 | 版本 |
+ |---------| ------- |
+ | 固件与驱动 | - |
+ | CANN | - |
+ | Python | 3.10.13 |
+ | PyTorch | 2.1.0 |
+ | MindIE | - |
+
+ 注意:由于MindIE暂无支持该模型的商发版本,烦请用户联系华为工程师获取对应的固件驱动,CANN,MindIE PoC版本链接。
+ 固件驱动和CANN的安装,请参考昇腾官方文档[环境快速部署](https://www.hiascend.com/document/detail/zh/quick-installation/24.0.RC1/quickinstg/800_3000/quickinstg_800_3000_0001.html)。
+
+ MindIE的安装需要先source toolkit的环境变量,然后直接安装,以默认安装路径`/usr/local/Ascend`为例:
+ ```
+ source /usr/local/Ascend/ascend-tookit/set_env.sh
+ bash Ascend-mindie_*.run --install
+ ```
+
+# 快速上手
+
+1. 安装依赖包
+ ```shell
+ pip install transformers==4.44.1
+ pip install numpy==1.26.4
+ ```
+2. 权重下载
+
+ | 模型 | 下载 |
+ |-------------|--------------------------------------------------------------------------------------------|
+ | ViT-S | [backbone only](https://huggingface.co/facebook/dinov2-small) |
+ | ViT-B | [backbone only](https://huggingface.co/facebook/dinov2-base) |
+ | ViT-L | [backbone only](https://huggingface.co/facebook/dinov2-large) |
+ | ViT-G | [backbone only](https://huggingface.co/facebook/dinov2-giant) |
+
+ 按上述链接下载模型权重,以dinov2-vit-base为例
+ ```shell
+ git lfs install
+ git clone https://huggingface.co/facebook/dinov2-base
+ ```
+2. 参数说明
+ 导出和推理的参数命名有多数重合,公用说明如下:
+
+ | 模型 | 下载 |
+ |---------------|---------------------------------------------|
+ | soc-version | 芯片类型,当前仅在Ascend910B4上调试 |
+ | device | NPU ID 号 |
+ | img-max-batch | 图片输入的最大batch size |
+ | image-path | 输入图片地址 |
+ | model-version | 模型类型("small", "base", "large", "giant")|
+ | hf-model-path | 模型权重路径 |
+ | save-dir | 不同类型模型保存路径 |
+ 更多参数请参考运行不同脚本的`parse_args`部分
+3. ONNX模型导出
+ ```shell
+ python onnx_export.py \
+ --soc-version ${soc_version} \
+ --image-path ${image_path} \
+ --model-version ${model_version} \
+ --hf-model-path ${hf_model_path} \
+ --save-dir ${save_dir}
+ ```
+ 执行完成后将在`save_dir`目录下生成`dinov2-${model_version}-onnx.pt`文件。
+ giant模型由于模型过大,导出时间较长,请耐心等待,并且会保存大量中间计算节点;保存onnx模型的save_dir和保存MindIETorch模型的save_dir必须不同。
+
+
+4. 模型编译
+
+ 由于MindIETorch不支持mode为"bicubic"的nn.functional.interpolate,因此需要将模型中的embedding剥离出来进行在线推理,只编译模型encoder部分,执行以下脚本进行编译:
+ ```shell
+ python dino_compile.py \
+ --soc-version ${soc_version} \
+ --device ${device} \
+ --img-max-batch ${img_max_batch} \
+ --image-path ${image_path} \
+ --model-version ${model_version} \
+ --model-path ${hf_model_path} \
+ --save-dir ${save_dir}
+ ```
+ 执行完成后将在`save_dir`目录下生成`dinov2-${model_version}-MindIETorch.pt`文件。
+
+# 模型推理性能精度
+
+1. 精度验证
+ ```shell
+ dinov2_aie_path="./dinov2-${model_version}-MindIETorch.pt"
+ dinov2_onnx_path="./dinov2-${model_version}-onnx.pt"
+ python precision_test.py \
+ --dinov2-aie-path ${dinov2_aie_path} \
+ --dinov2-onnx-path ${dinov2_onnx_path} \
+ --device ${device} \
+ --image-path ${image_path} \
+ --model-version ${model_version} \
+ --hf-model-path ${hf_model_path} \
+ ```
+ 执行结束后,期望输出如下:
+ ```
+ ----- Compare the outputs of ONNX and AIE dinov2 ${model_version} model -----
+ Number of outputs to compare: 2
+ Number of outputs with cosine similarity > 0.99: 2
+ Number of outputs to compare: 2
+ Number of outputs with cosine similarity > 0.99: 2
+ ```
+
+2. 性能验证
+
+ (a) aie模型性能测试
+ ```shell
+ dinov2_aie_path="./dinov2-${model_version}-MindIETorch.pt"
+ python perf_test_aie.py \
+ --dinov2-aie-path ${dinov2_aie_path} \
+ --device ${device} \
+ --image-path ${image_path} \
+ --img-max-batch ${img_max_batch} \
+ --hf-model-path ${hf_model_path} \
+ ```
+
+ 执行结束后,期望输出如下(base):
+ ```
+ DINOV2 aie latency: 31.11 ms
+ DINOV2 aie throughput: 32.14 fps
+ ```
+
+ (b) onnx模型性能测试
+ (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示:
+ ```shell
+ pip uninstall onnxruntime
+ pip install onnxruntime-gpu
+ ```
+ 验证onnxruntime-gpu是否安装成功:
+ ```python
+ import onnxruntime
+ print(onnxruntime.get_device()) # 若输出为GPU,则说明安装成功
+ ```
+ 执行性能测试(CPU)
+ ```shell
+ dinov2_onnx_path="./dinov2-${model_version}-onnx.pt"
+ python perf_test_onnx.py \
+ --onnx-path ${dinov2_onnx_path} \
+ --image-path ${image_path} \
+ --hf-model-path ${hf_model_path} \
+ ```
+
+ 执行结束后,期望输出如下(base):
+ ```
+ DINOV2 onnx latency: 268.65 ms
+ DINOV2 onnx throughput: 3.72 fps
+ ```
+
+ (c) 性能对比列表(GPU性能待测试):
+
+ | 模型 | MindIE-Torch(Ascend910B4) | ONNX(CPU) |
+ |---------|--------------------------------|---------------------|
+ | small | 3.93 ms / 254.74 fps | 164.50 ms / 6.08 fps |
+ | base | 4.00 ms / 250.14 fps | 535.00 ms / 1.90 fps |
+ | large | 9.59 ms / 104.30 fps | 1086.95 ms / 0.92 fps |
+ | giant | 20.09 ms / 49.78 fps | 3088.63 ms / 0.32 fps |
+ 不同机器的测试出的性能在绝对值上可能有一定差异(特别是CPU性能),但相对值差异是保持一致的。
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py
new file mode 100644
index 0000000000000000000000000000000000000000..71654b09386cfff033e98b7ad5d3ca8fdeded206
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py
@@ -0,0 +1,130 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import os
+import time
+
+import mindietorch
+import torch
+from PIL import Image
+from torch._export import export, dynamic_dim
+from transformers import AutoImageProcessor
+
+from model.dinov2_model import Dinov2Model_WO_Embedding
+
+
+def get_embed_input(args, model):
+ processor = AutoImageProcessor.from_pretrained(args.hf_model_path)
+ inputs = processor(images=Image.open(args.image_path), return_tensors="pt")
+ embeddings = model.embeddings
+ embedding_output = embeddings(inputs.pixel_values)
+ return embedding_output
+
+
+def export_dinov2(args):
+ model = Dinov2Model_WO_Embedding.from_pretrained(args.hf_model_path).float().eval()
+ embedding_output = get_embed_input(args, model)
+ embed_shape = embedding_output.shape
+ emb_input_shape = (args.img_max_batch, embed_shape[-2], embed_shape[-1])
+ input_emb = torch.ones(emb_input_shape, dtype=torch.float32)
+
+ constraints = [
+ dynamic_dim(input_emb, 0) >= 1,
+ dynamic_dim(input_emb, 0) <= args.img_max_batch,
+ ]
+
+ print("----- start exporting dynamic dinov2 -----")
+ intermediate_model = export(
+ model,
+ args=(input_emb,),
+ constraints=constraints
+ )
+ print("----- export dynamic dinov2 success! -----")
+ return embed_shape, intermediate_model
+
+
+def compile_dinov2(args):
+ # export dinov2
+ embed_shape, intermediate_model = export_dinov2(args)
+ # compile dinov2
+ mindietorch.set_device(args.device)
+ compile_inputs = [
+ mindietorch.Input(min_shape=(1, embed_shape[-2], embed_shape[-1]),
+ max_shape=(args.img_max_batch, embed_shape[-2], embed_shape[-1])),
+ ]
+
+ print("----- start mindietorch compile -----")
+ ts = time.time()
+ compiled_model = mindietorch.compile(
+ intermediate_model,
+ inputs=compile_inputs,
+ precision_policy=mindietorch._enums.PrecisionPolicy.FP16,
+ soc_version=args.soc_version,
+ )
+ compile_cost = time.time() - ts
+ print(f"----- compile time cost: {compile_cost} -----")
+ print("----- end mindietorch compile -----")
+
+ print("----- start saving -----")
+ model_save_dir = f"{args.save_dir}"
+ if not os.path.exists(model_save_dir):
+ os.makedirs(model_save_dir)
+ compiled_file_name = f"dinov2-{args.model_version}-MindIETorch.pt"
+ torch.save(compiled_model, model_save_dir + compiled_file_name, pickle_protocol=4)
+ print("----- saving done -----")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Compile DINOv2-Vit model")
+ parser.add_argument(
+ "--soc-version",
+ default="Ascend910B4",
+ help="NPU version"
+ )
+ parser.add_argument(
+ "--device",
+ type=int,
+ default=0
+ )
+ parser.add_argument(
+ "--img-max-batch",
+ type=int,
+ default=8
+ )
+ parser.add_argument(
+ "--image-path",
+ default=""
+ )
+ parser.add_argument(
+ "--model-version",
+ default="base",
+ choices=["small", "base", "large", "giant"],
+ help="Specify the architecture of DINOv2-Vit model to be converted."
+ )
+ parser.add_argument(
+ "--hf-model-path",
+ default="",
+ type=str,
+ help="Path of the Huggingface DINOv2-Vit model."
+ )
+ parser.add_argument(
+ "--save-dir",
+ default="./"
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ input_args = parse_args()
+ compile_dinov2(input_args)
diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..202175318474ecc8aad3f2b9986e828ba6f2bd5f
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py
@@ -0,0 +1,56 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Tuple
+
+import torch
+from transformers.models.dinov2.modeling_dinov2 import Dinov2Model
+
+
+class Dinov2Model_WO_Embedding(Dinov2Model):
+ def forward(
+ self,
+ embedding_output: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Tuple:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if embedding_output is None:
+ raise ValueError("You have to specify embedding_output")
+
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = sequence_output[:, 0, :]
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output)
+ return head_outputs + encoder_outputs[1:]
+
+ return sequence_output, pooled_output
diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7082190276dc234f097d2b0119c42e87fe4fe66
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py
@@ -0,0 +1,81 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+
+import torch
+import torch.onnx
+from PIL import Image
+from transformers import AutoImageProcessor, AutoModel
+
+
+def convert_dinov2(args):
+ processor = AutoImageProcessor.from_pretrained(args.hf_model_path)
+ inputs = processor(images=Image.open(args.image_path), return_tensors="pt")
+ pixel_values = inputs.pixel_values
+ model = AutoModel.from_pretrained(args.hf_model_path).float().eval()
+
+ onnx_path = f"{args.save_dir}dinov2-{args.model_version}-onnx.pt"
+ print("----- Starting to export dynamic onnx -----")
+ torch.onnx.export(
+ model,
+ (pixel_values,),
+ onnx_path,
+ input_names=["pixel_values"],
+ output_names=["sequence_output", "pooled_output"],
+ export_params=True,
+ opset_version=13,
+ verbose=True,
+ dynamic_axes={
+ "pixel_values": {0: "image_batch_size"},
+ "sequence_output": {0: "image_batch_size"},
+ "pooled_output": {0: "image_batch_size"},
+ }
+ )
+ print("----- Successfully exported dynamic onnx! -----")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Compile DINOv2 model")
+ parser.add_argument(
+ "--soc-version",
+ default="Ascend910B4",
+ help="NPU version"
+ )
+ parser.add_argument(
+ "--image-path",
+ type=str,
+ default=""
+ )
+ parser.add_argument(
+ "--hf-model-path",
+ default="",
+ type=str,
+ help="Path of the Huggingface DINOv2 model."
+ )
+ parser.add_argument(
+ "--model-version",
+ default="base",
+ choices=["small", "base", "large", "giant"],
+ help="Specify the architecture of DINOv2-Vit model to be converted."
+ )
+ parser.add_argument(
+ "--save-dir",
+ default="./"
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ input_args = parse_args()
+ convert_dinov2(input_args)