diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/perf_test_aie.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/perf_test_aie.py new file mode 100644 index 0000000000000000000000000000000000000000..18432a9577adae3cf98e5220f91f151a07b3bc2f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/perf_test_aie.py @@ -0,0 +1,91 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import time + +import mindietorch +import torch +from PIL import Image +from transformers import AutoImageProcessor +from model.dinov2_model import Dinov2Model_WO_Embedding + + +def test(inputs, model, stream, meta=""): + # warmup + for _ in range(10): + with mindietorch.npu.stream(stream): + model(*inputs) + + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + with mindietorch.npu.stream(stream): + model(*inputs) + end = time.time() + + print(f"{meta} aie latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"{meta} aie throughput: {num_infer / (end - start):.2f} fps") + + +def dinov2_test(args): + device = f'npu:{args.device}' + stream = mindietorch.npu.Stream(device) + mindietorch.set_device(args.device) + + model = torch.load(args.dinov2_aie_path).eval().to(device) + hf_model = Dinov2Model_WO_Embedding.from_pretrained(args.hf_model_path).float().eval() + embeddings = hf_model.embeddings.to(device) + + processor = AutoImageProcessor.from_pretrained(args.hf_model_path) + inputs = processor(images=Image.open(args.image_path), return_tensors="pt") + embedding_output = embeddings(inputs.pixel_values.to(device)) + test([embedding_output], model, stream, "DINOV2") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--dinov2-aie-path", + type=str, + default="" + ) + parser.add_argument( + "--device", + type=int, + help="NPU device id", + default=0 + ) + parser.add_argument( + "--hf-model-path", + default="" + ) + parser.add_argument( + "--image-path", + type=str, + default="" + ) + parser.add_argument( + "--img-max-batch", + type=int, + default=8 + ) + return parser.parse_args() + + +if __name__ == "__main__": + input_args = parse_args() + dinov2_test(input_args) diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/perf_test_onnx.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/perf_test_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..151e51f1f567f3a9d113c1cee4299aa583d475b2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/perf_test_onnx.py @@ -0,0 +1,85 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import time + +import onnxruntime as ort +from PIL import Image +from transformers import AutoImageProcessor + + +def test(onnx_model_path, provider, output_names, onnx_inputs, meta=""): + onnx_model = ort.InferenceSession( + onnx_model_path, + providers=[provider] + ) + + # warmup + for _ in range(10): + onnx_model.run(output_names, onnx_inputs) + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + onnx_model.run(output_names, onnx_inputs) + end = time.time() + + print(f"{meta} onnx latency: {(end - start) / num_infer * 1000:.2f} ms") + print(f"{meta} onnx throughput: {num_infer / (end - start):.2f} fps") + + +def test_dinov2(args, provider): + image = Image.open(args.image_path) + processor = AutoImageProcessor.from_pretrained(args.hf_model_path) + + inputs = processor(images=(image,), return_tensors="pt") + pixel_values = inputs.pixel_values.detach().numpy() + + onnx_inputs = {"pixel_values": pixel_values} + output_names = ["sequence_output", "pooled_output"] + + test(args.onnx_path, provider, output_names, onnx_inputs, "DINOV2") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--onnx-path", + type=str, + default="" + ) + parser.add_argument( + "--image-path", + type=str, + default="" + ) + parser.add_argument( + "--hf-model-path", + default="" + ) + parser.add_argument( + "--use-gpu", + action="store_true" + ) + return parser.parse_args() + + +if __name__ == "__main__": + input_args = parse_args() + if input_args.use_gpu: + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" + + test_dinov2(input_args, provider) diff --git a/MindIE/MindIE-Torch/built-in/cv/DINOv2/precision_test.py b/MindIE/MindIE-Torch/built-in/cv/DINOv2/precision_test.py new file mode 100644 index 0000000000000000000000000000000000000000..913e8e8fcb8314721031098ef264ab9bde947abf --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/cv/DINOv2/precision_test.py @@ -0,0 +1,160 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import mindietorch +import numpy as np +import onnxruntime as ort +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoImageProcessor, AutoModel + +from model.dinov2_model import Dinov2Model_WO_Embedding + + +def compare_onnx_aie_output(onnx_out, hf_out, sim_threshold=0.99): + num_sim = 0 + for i, (a, b) in enumerate(zip(onnx_out, hf_out)): + a = a.reshape(1, -1).astype(np.float32) + b = b.reshape(1, -1) + sim = F.cosine_similarity(torch.from_numpy(a), b, dim=1) + if sim > sim_threshold: + num_sim += 1 + else: + print(f'Output {i} similarity: {sim}') + + print(f'Number of outputs to compare: {len(onnx_out)}') + print(f'Number of outputs with cosine similarity > {sim_threshold}: {num_sim}') + + +def compare_hf_aie_output(hf_out, aie_out, sim_threshold=0.99): + num_sim = 0 + for i, (a, b) in enumerate(zip(hf_out, aie_out)): + a = a.reshape(1, -1) + b = b.reshape(1, -1) + sim = F.cosine_similarity(a, b, dim=1) + if sim > sim_threshold: + num_sim += 1 + else: + print(f'Output {i} similarity: {sim}') + + print(f'Number of outputs to compare: {len(aie_out)}') + print(f'Number of outputs with cosine similarity > {sim_threshold}: {num_sim}') + + +def get_embed_input(args): + device = f'npu:{args.device}' + # preprocess + processor = AutoImageProcessor.from_pretrained(args.hf_model_path) + inputs = processor(images=Image.open(args.image_path), return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch.float32).to(device) + model = Dinov2Model_WO_Embedding.from_pretrained(args.hf_model_path).float().eval().to(device) + embeddings = model.embeddings + embed_output = embeddings(pixel_values) + return pixel_values, [embed_output, ] + + +def compare(args): + device = f'npu:{args.device}' + pixel_values, embedding_output = get_embed_input(args) + # torch_npu + with torch.no_grad(): + hf_model = AutoModel.from_pretrained(args.hf_model_path).float().eval().to(device) + hf_outputs = hf_model(pixel_values) + hf_outputs = [hf_outputs.last_hidden_state, hf_outputs.pooler_output] + hf_out = [x.cpu().detach() for x in hf_outputs] + + # MindIETorch + mindietorch.set_device(args.device) + stream = mindietorch.npu.Stream(device) + aie_model = torch.load(args.dinov2_aie_path).to(device) + aie_model.eval() + + with mindietorch.npu.stream(stream): + aie_out = aie_model(*embedding_output) + if isinstance(aie_out, tuple) or isinstance(aie_out, list): + aie_out = [x.cpu() for x in aie_out] + else: + aie_out = aie_out.cpu() + + # ONNX + pixel_values = pixel_values.cpu().detach().numpy() + + if args.use_gpu: + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" + + onnx_model = ort.InferenceSession( + args.dinov2_onnx_path, + providers=[provider] + ) + onnx_inputs = {"pixel_values": pixel_values} + output_names = ["sequence_output", "pooled_output"] + onnx_out = onnx_model.run(output_names, onnx_inputs) + + compare_onnx_aie_output(onnx_out, hf_out, args.sim_threshold) + compare_hf_aie_output(hf_out, aie_out, args.sim_threshold) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--dinov2-aie-path", + type=str, + default="" + ) + parser.add_argument( + "--dinov2-onnx-path", + type=str, + default="" + ) + parser.add_argument( + "--hf-model-path", + default="" + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device id" + ) + parser.add_argument( + "--image-path", + type=str, + default="" + ) + parser.add_argument( + "--model-version", + default="base", + choices=["small", "base", "large", "giant"], + help="Specify the architecture of Dinov2-Vit model to be converted." + ) + parser.add_argument( + "--sim-threshold", + type=float, + default=0.99 + ) + parser.add_argument( + "--use-gpu", + action="store_true" + ) + return parser.parse_args() + + +if __name__ == "__main__": + input_args = parse_args() + print(f'----- Compare the outputs of ONNX and AIE dinov2 {input_args.model_version} model -----') + compare(input_args)