diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md b/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3a0f49d274d8db7362b07c92eff8a6eb018caa37 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/README.md @@ -0,0 +1,177 @@ +# DINOv2-ViT-推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161123) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126217823) + +- [快速上手](#ZH-CN_TOPIC_0000001126288456) + +- [模型推理性能精度](#ZH-CN_TOPIC_0000001172156835) + +# 概述 + +DINOv2是由Meta +AI开发的一个自监督学习方法,它专注于在无监督学习鲁棒的视觉特征。这个方法基于最初的DINO方法并进行了改进,提升了ViT模型在各种计算机视觉任务上的性能。([来自开源代码仓](https://github.com/facebookresearch/dinov2/tree/main)) + +# 推理环境准备\[所有版本\] + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | + |---------| ------- | + | 固件与驱动 | - | + | CANN | - | + | Python | 3.10.13 | + | PyTorch | 2.1.0 | + | MindIE | - | + + 注意:由于MindIE暂无支持该模型的商发版本,烦请用户联系华为工程师获取对应的固件驱动,CANN,MindIE PoC版本链接。 + 固件驱动和CANN的安装,请参考昇腾官方文档[环境快速部署](https://www.hiascend.com/document/detail/zh/quick-installation/24.0.RC1/quickinstg/800_3000/quickinstg_800_3000_0001.html)。 + + MindIE的安装需要先source toolkit的环境变量,然后直接安装,以默认安装路径`/usr/local/Ascend`为例: + ``` + source /usr/local/Ascend/ascend-tookit/set_env.sh + bash Ascend-mindie_*.run --install + ``` + +# 快速上手 + +1. 安装依赖包 + ```shell + pip install transformers==4.44.1 + pip install numpy==1.26.4 + ``` +2. 权重下载 + + | 模型 | 下载 | + |-------------|--------------------------------------------------------------------------------------------| + | ViT-S | [backbone only](https://huggingface.co/facebook/dinov2-small) | + | ViT-B | [backbone only](https://huggingface.co/facebook/dinov2-base) | + | ViT-L | [backbone only](https://huggingface.co/facebook/dinov2-large) | + | ViT-G | [backbone only](https://huggingface.co/facebook/dinov2-giant) | + + 按上述链接下载模型权重,以dinov2-vit-base为例 + ```shell + git lfs install + git clone https://huggingface.co/facebook/dinov2-base + ``` +2. 参数说明 + 导出和推理的参数命名有多数重合,公用说明如下: + + | 模型 | 下载 | + |---------------|---------------------------------------------| + | soc-version | 芯片类型,当前仅在Ascend910B4上调试 | + | device | NPU ID 号 | + | img-max-batch | 图片输入的最大batch size | + | image-path | 输入图片地址 | + | model-version | 模型类型("small", "base", "large", "giant")| + | hf-model-path | 模型权重路径 | + | save-dir | 不同类型模型保存路径 | + 更多参数请参考运行不同脚本的`parse_args`部分 +3. ONNX模型导出 + ```shell + python onnx_export.py \ + --soc-version ${soc_version} \ + --image-path ${image_path} \ + --model-version ${model_version} \ + --hf-model-path ${hf_model_path} \ + --save-dir ${save_dir} + ``` + 执行完成后将在`save_dir`目录下生成`dinov2-${model_version}-onnx.pt`文件。 + giant模型由于模型过大,导出时间较长,请耐心等待,并且会保存大量中间计算节点;保存onnx模型的save_dir和保存MindIETorch模型的save_dir必须不同。 + + +4. 模型编译 + + 由于MindIETorch不支持mode为"bicubic"的nn.functional.interpolate,因此需要将模型中的embedding剥离出来进行在线推理,只编译模型encoder部分,执行以下脚本进行编译: + ```shell + python dino_compile.py \ + --soc-version ${soc_version} \ + --device ${device} \ + --img-max-batch ${img_max_batch} \ + --image-path ${image_path} \ + --model-version ${model_version} \ + --model-path ${hf_model_path} \ + --save-dir ${save_dir} + ``` + 执行完成后将在`save_dir`目录下生成`dinov2-${model_version}-MindIETorch.pt`文件。 + +# 模型推理性能精度 + +1. 精度验证 + ```shell + dinov2_aie_path="./dinov2-${model_version}-MindIETorch.pt" + dinov2_onnx_path="./dinov2-${model_version}-onnx.pt" + python precision_test.py \ + --dinov2-aie-path ${dinov2_aie_path} \ + --dinov2-onnx-path ${dinov2_onnx_path} \ + --device ${device} \ + --image-path ${image_path} \ + --model-version ${model_version} \ + --hf-model-path ${hf_model_path} \ + ``` + 执行结束后,期望输出如下: + ``` + ----- Compare the outputs of ONNX and AIE dinov2 ${model_version} model ----- + Number of outputs to compare: 2 + Number of outputs with cosine similarity > 0.99: 2 + Number of outputs to compare: 2 + Number of outputs with cosine similarity > 0.99: 2 + ``` + +2. 性能验证 + + (a) aie模型性能测试 + ```shell + dinov2_aie_path="./dinov2-${model_version}-MindIETorch.pt" + python perf_test_aie.py \ + --dinov2-aie-path ${dinov2_aie_path} \ + --device ${device} \ + --image-path ${image_path} \ + --img-max-batch ${img_max_batch} \ + --hf-model-path ${hf_model_path} \ + ``` + + 执行结束后,期望输出如下(base): + ``` + DINOV2 aie latency: 31.11 ms + DINOV2 aie throughput: 32.14 fps + ``` + + (b) onnx模型性能测试 + (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示: + ```shell + pip uninstall onnxruntime + pip install onnxruntime-gpu + ``` + 验证onnxruntime-gpu是否安装成功: + ```python + import onnxruntime + print(onnxruntime.get_device()) # 若输出为GPU,则说明安装成功 + ``` + 执行性能测试(CPU) + ```shell + dinov2_onnx_path="./dinov2-${model_version}-onnx.pt" + python perf_test_onnx.py \ + --onnx-path ${dinov2_onnx_path} \ + --image-path ${image_path} \ + --hf-model-path ${hf_model_path} \ + ``` + + 执行结束后,期望输出如下(base): + ``` + DINOV2 onnx latency: 268.65 ms + DINOV2 onnx throughput: 3.72 fps + ``` + + (c) 性能对比列表(GPU性能待测试): + + | 模型 | MindIE-Torch(Ascend910B4) | ONNX(CPU) | + |---------|--------------------------------|---------------------| + | small | 3.93 ms / 254.74 fps | 164.50 ms / 6.08 fps | + | base | 4.00 ms / 250.14 fps | 535.00 ms / 1.90 fps | + | large | 9.59 ms / 104.30 fps | 1086.95 ms / 0.92 fps | + | giant | 20.09 ms / 49.78 fps | 3088.63 ms / 0.32 fps | + 不同机器的测试出的性能在绝对值上可能有一定差异(特别是CPU性能),但相对值差异是保持一致的。 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..71654b09386cfff033e98b7ad5d3ca8fdeded206 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/dino_compile.py @@ -0,0 +1,130 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import time + +import mindietorch +import torch +from PIL import Image +from torch._export import export, dynamic_dim +from transformers import AutoImageProcessor + +from model.dinov2_model import Dinov2Model_WO_Embedding + + +def get_embed_input(args, model): + processor = AutoImageProcessor.from_pretrained(args.hf_model_path) + inputs = processor(images=Image.open(args.image_path), return_tensors="pt") + embeddings = model.embeddings + embedding_output = embeddings(inputs.pixel_values) + return embedding_output + + +def export_dinov2(args): + model = Dinov2Model_WO_Embedding.from_pretrained(args.hf_model_path).float().eval() + embedding_output = get_embed_input(args, model) + embed_shape = embedding_output.shape + emb_input_shape = (args.img_max_batch, embed_shape[-2], embed_shape[-1]) + input_emb = torch.ones(emb_input_shape, dtype=torch.float32) + + constraints = [ + dynamic_dim(input_emb, 0) >= 1, + dynamic_dim(input_emb, 0) <= args.img_max_batch, + ] + + print("----- start exporting dynamic dinov2 -----") + intermediate_model = export( + model, + args=(input_emb,), + constraints=constraints + ) + print("----- export dynamic dinov2 success! -----") + return embed_shape, intermediate_model + + +def compile_dinov2(args): + # export dinov2 + embed_shape, intermediate_model = export_dinov2(args) + # compile dinov2 + mindietorch.set_device(args.device) + compile_inputs = [ + mindietorch.Input(min_shape=(1, embed_shape[-2], embed_shape[-1]), + max_shape=(args.img_max_batch, embed_shape[-2], embed_shape[-1])), + ] + + print("----- start mindietorch compile -----") + ts = time.time() + compiled_model = mindietorch.compile( + intermediate_model, + inputs=compile_inputs, + precision_policy=mindietorch._enums.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + compile_cost = time.time() - ts + print(f"----- compile time cost: {compile_cost} -----") + print("----- end mindietorch compile -----") + + print("----- start saving -----") + model_save_dir = f"{args.save_dir}" + if not os.path.exists(model_save_dir): + os.makedirs(model_save_dir) + compiled_file_name = f"dinov2-{args.model_version}-MindIETorch.pt" + torch.save(compiled_model, model_save_dir + compiled_file_name, pickle_protocol=4) + print("----- saving done -----") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile DINOv2-Vit model") + parser.add_argument( + "--soc-version", + default="Ascend910B4", + help="NPU version" + ) + parser.add_argument( + "--device", + type=int, + default=0 + ) + parser.add_argument( + "--img-max-batch", + type=int, + default=8 + ) + parser.add_argument( + "--image-path", + default="" + ) + parser.add_argument( + "--model-version", + default="base", + choices=["small", "base", "large", "giant"], + help="Specify the architecture of DINOv2-Vit model to be converted." + ) + parser.add_argument( + "--hf-model-path", + default="", + type=str, + help="Path of the Huggingface DINOv2-Vit model." + ) + parser.add_argument( + "--save-dir", + default="./" + ) + return parser.parse_args() + + +if __name__ == "__main__": + input_args = parse_args() + compile_dinov2(input_args) diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py new file mode 100644 index 0000000000000000000000000000000000000000..202175318474ecc8aad3f2b9986e828ba6f2bd5f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/model/dinov2_model.py @@ -0,0 +1,56 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple + +import torch +from transformers.models.dinov2.modeling_dinov2 import Dinov2Model + + +class Dinov2Model_WO_Embedding(Dinov2Model): + def forward( + self, + embedding_output: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Tuple: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if embedding_output is None: + raise ValueError("You have to specify embedding_output") + + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return sequence_output, pooled_output diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py new file mode 100644 index 0000000000000000000000000000000000000000..d7082190276dc234f097d2b0119c42e87fe4fe66 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/onnx_export.py @@ -0,0 +1,81 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import torch +import torch.onnx +from PIL import Image +from transformers import AutoImageProcessor, AutoModel + + +def convert_dinov2(args): + processor = AutoImageProcessor.from_pretrained(args.hf_model_path) + inputs = processor(images=Image.open(args.image_path), return_tensors="pt") + pixel_values = inputs.pixel_values + model = AutoModel.from_pretrained(args.hf_model_path).float().eval() + + onnx_path = f"{args.save_dir}dinov2-{args.model_version}-onnx.pt" + print("----- Starting to export dynamic onnx -----") + torch.onnx.export( + model, + (pixel_values,), + onnx_path, + input_names=["pixel_values"], + output_names=["sequence_output", "pooled_output"], + export_params=True, + opset_version=13, + verbose=True, + dynamic_axes={ + "pixel_values": {0: "image_batch_size"}, + "sequence_output": {0: "image_batch_size"}, + "pooled_output": {0: "image_batch_size"}, + } + ) + print("----- Successfully exported dynamic onnx! -----") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Compile DINOv2 model") + parser.add_argument( + "--soc-version", + default="Ascend910B4", + help="NPU version" + ) + parser.add_argument( + "--image-path", + type=str, + default="" + ) + parser.add_argument( + "--hf-model-path", + default="", + type=str, + help="Path of the Huggingface DINOv2 model." + ) + parser.add_argument( + "--model-version", + default="base", + choices=["small", "base", "large", "giant"], + help="Specify the architecture of DINOv2-Vit model to be converted." + ) + parser.add_argument( + "--save-dir", + default="./" + ) + return parser.parse_args() + + +if __name__ == "__main__": + input_args = parse_args() + convert_dinov2(input_args)