diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_aie.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_aie.py new file mode 100644 index 0000000000000000000000000000000000000000..ceff8b864867ace4a0c4f20b0c964a6993c779e9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_aie.py @@ -0,0 +1,95 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import time +import torch +import mindietorch +from transformers import AutoConfig + +logging.basicConfig(level=logging.INFO) + + +def test(inputs, model, stream, meta=""): + # warmup + for _ in range(10): + with mindietorch.npu.stream(stream): + model(*inputs) + stream.synchronize() + + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + with mindietorch.npu.stream(stream): + model(*inputs) + stream.synchronize() + end = time.time() + + logging.info("%s latency: %.2f ms", meta, (end - start) / num_infer * 1000) + logging.info("%s throughput: %.2f fps", meta, num_infer / (end - start)) + + +def test_clip(args): + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + if args.clip_aie_path.endswith(".ts"): + model = torch.jit.load(args.clip_aie_path) + else: + model = torch.load(args.clip_aie_path) + model.eval().to(device) + config = AutoConfig.from_pretrained(args.hf_model_path) + + image_width = config.vision_config.image_size + img_input_shape = (args.image_batchsize, 3, image_width, image_width) + text_input_shape = (args.text_batchsize, args.token_len) + input_img = torch.randn(img_input_shape, dtype=torch.float32).to(device) + input_ids = torch.randint(high=1000, size=text_input_shape, dtype=torch.int32).to(device) + attention_mask = torch.ones(text_input_shape, dtype=torch.int32).to(device) + inputs = [input_ids, input_img, attention_mask] + + test(inputs, model, stream, "CLIP") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--device-id", type=int, help="NPU device id", default=0) + parser.add_argument( + "--clip-aie-path", + type=str, + default="/Path/to/compiled/aie_or_ts_model" + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--text-batchsize", type=int, default=80) + parser.add_argument("--image-batchsize", type=int, default=1) + parser.add_argument("--token-len", type=int, default=52) + + return parser.parse_args() + + +def main(): + perf_args = parse_args() + mindietorch.set_device(perf_args.device_id) + test_clip(perf_args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_onnx.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbc7f9d67fc591eafe14d96bf0080e23fe1db7c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/perf_test_onnx.py @@ -0,0 +1,94 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import time +import torch +import onnxruntime as ort +from transformers import AutoConfig + +logging.basicConfig(level=logging.INFO) + + +def test(encoder_path, provider, output_names, onnx_inputs, meta=""): + onnx_model = ort.InferenceSession( + encoder_path, + providers=[provider] + ) + + # warmup + for _ in range(10): + onnx_model.run(output_names, onnx_inputs) + # performance test + num_infer = 100 + start = time.time() + for _ in range(num_infer): + onnx_model.run(output_names, onnx_inputs) + end = time.time() + + logging.info("%s latency: %.2f ms", meta, (end - start) / num_infer * 1000) + logging.info("%s throughput: %.2f fps", meta, num_infer / (end - start)) + + +def test_clip(args, provider): + config = AutoConfig.from_pretrained(args.hf_model_path) + + image_width = config.vision_config.image_size + img_input_shape = (args.image_batchsize, 3, image_width, image_width) + text_input_shape = (args.text_batchsize, args.token_len) + input_img = torch.randn(img_input_shape, dtype=torch.float32).detach().numpy() + input_ids = torch.randint(high=1000, size=text_input_shape, dtype=torch.int32).detach().numpy() + attention_mask = torch.ones(text_input_shape, dtype=torch.int32).detach().numpy() + + onnx_inputs = {"input_ids": input_ids, "pixel_values": input_img, "attention_mask": attention_mask} + output_names = ["image_embeds", "text_embeds", "logits_per_text", "logits_per_image"] + + test(args.onnx_path, provider, output_names, onnx_inputs, "CLIP") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--onnx-path", + type=str, + default="/Path/to/onnx_model" + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--text-batchsize", type=int, default=80) + parser.add_argument("--image-batchsize", type=int, default=1) + parser.add_argument("--token-len", type=int, default=52) + parser.add_argument("--use-gpu", action="store_true") + + return parser.parse_args() + + +def main(): + perf_args = parse_args() + if perf_args.use_gpu: + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" + + test_clip(perf_args, provider) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/multimodal/CLIP/precision_test.py b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/precision_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6995367ed97e1b17d300b656c3e5cd0bf2455582 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/multimodal/CLIP/precision_test.py @@ -0,0 +1,130 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import logging +import argparse +import torch +import mindietorch +import torch +import onnxruntime as ort +import numpy as np +import torch.nn.functional as F +from transformers import AutoConfig + +logging.basicConfig(level=logging.INFO) + + +def compare_onnx_aie_output(onnx_out, aie_out, sim_threshold=0.99): + num_sim = 0 + for i, (a, b) in enumerate(zip(onnx_out, aie_out)): + a = a.reshape(1, -1).astype(np.float32) + b = b.reshape(1, -1) + sim = F.cosine_similarity(torch.from_numpy(a), b, dim=1) + if sim > sim_threshold: + num_sim += 1 + else: + logging.info('Output %d similarity: %f', i, sim) + + logging.info('Number of outputs to compare: %d', len(onnx_out)) + logging.info('Number of outputs with cosine similarity > %.2f: %d', sim_threshold, num_sim) + + +def compare(args): + # MindIETorch + device = f'npu:{args.device_id}' + stream = mindietorch.npu.Stream(device) + + if args.clip_aie_path.endswith(".ts"): + aie_model = torch.jit.load(args.clip_aie_path) + else: + aie_model = torch.load(args.clip_aie_path) + aie_model.eval().to(device) + config = AutoConfig.from_pretrained(args.hf_model_path) + + image_width = config.vision_config.image_size + img_input_shape = (args.image_batchsize, 3, image_width, image_width) + text_input_shape = (args.text_batchsize, args.token_len) + input_img = torch.randn(img_input_shape, dtype=torch.float32).to(device) + input_ids = torch.randint(high=1000, size=text_input_shape, dtype=torch.int32).to(device) + attention_mask = torch.ones(text_input_shape, dtype=torch.int32).to(device) + inputs = [input_ids, input_img, attention_mask] + + with mindietorch.npu.stream(stream): + aie_out = aie_model(*inputs) + stream.synchronize() + + if isinstance(aie_out, tuple) or isinstance(aie_out, list): + aie_out = (x.cpu() for x in aie_out) + else: + aie_out = aie_out.cpu() + + # ONNX + input_img = input_img.cpu().detach().numpy() + input_ids = input_ids.cpu().detach().numpy() + attention_mask = attention_mask.cpu().detach().numpy() + + if args.use_gpu: + provider = "CUDAExecutionProvider" + else: + provider = "CPUExecutionProvider" + + onnx_model = ort.InferenceSession( + args.clip_onnx_path, + providers=[provider] + ) + onnx_inputs = {"input_ids": input_ids, "pixel_values": input_img, "attention_mask": attention_mask} + output_names = ["image_embeds", "text_embeds", "logits_per_text", "logits_per_image"] + onnx_out = onnx_model.run(output_names, onnx_inputs) + + compare_onnx_aie_output(onnx_out, aie_out, args.sim_threshold) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--device-id", type=int, default=0, help="NPU device id") + parser.add_argument( + "--clip-aie-path", + type=str, + default="/Path/to/compiled/aie_or_ts_model" + ) + parser.add_argument( + "--clip-onnx-path", + type=str, + default="/Path/to/onnx_model" + ) + parser.add_argument( + "--hf-model-path", + default="/Path/to/Huggingface_model_path", + type=str, + help="Huggingface CLIP Model Path." + ) + parser.add_argument("--text-batchsize", type=int, default=80) + parser.add_argument("--image-batchsize", type=int, default=1) + parser.add_argument("--token-len", type=int, default=52) + parser.add_argument('--sim-threshold', type=float, default=0.99) + parser.add_argument("--use-gpu", action="store_true") + + return parser.parse_args() + + +def main(): + compare_args = parse_args() + mindietorch.set_device(compare_args.device_id) + logging.info('=== Compare the outputs of ONNX and AIE ===') + compare(compare_args) + + +if __name__ == "__main__": + main() \ No newline at end of file