From 65fa8ed9290aae067d382683572d566f1e5d9b24 Mon Sep 17 00:00:00 2001 From: tanghongyan <1349905607@qq.com> Date: Wed, 30 Mar 2022 16:57:33 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=9B=9E=E5=BD=92=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../MnasNet/modelArts/pth2onnx.py | 25 ++++++++++--------- .../MnasNet/modelArts/train-modelarts.py | 2 +- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/PyTorch/contrib/cv/classification/MnasNet/modelArts/pth2onnx.py b/PyTorch/contrib/cv/classification/MnasNet/modelArts/pth2onnx.py index 5b66c0830c..2dece3cb8d 100644 --- a/PyTorch/contrib/cv/classification/MnasNet/modelArts/pth2onnx.py +++ b/PyTorch/contrib/cv/classification/MnasNet/modelArts/pth2onnx.py @@ -49,20 +49,21 @@ def proc_node_module(checkpoint, AttrName): return new_state_dict -def convert(pth_file, onnx_path, class_num, train_url): +def convert(pth_file, onnx_path, class_num, train_url, npu): - checkpoint = torch.load(pth_file, map_location=None) - + loc = 'npu:{}'.format(npu) + checkpoint = torch.load(pth_file, map_location=loc) + + checkpoint['state_dict'] = proc_node_module(checkpoint, 'state_dict') model = mnasnet.mnasnet1_0(num_classes=class_num) - model.load_state_dict(checkpoint) + model.to(loc) + model.load_state_dict(checkpoint['state_dict']) model.eval() - - input_names = ["image"] - output_names = ["class"] - dynamic_axes = {'image': {0: '-1'}, 'class': {0: '-1'}} - dummy_input = torch.randn(1, 3, 224, 224) - + input_names = ["actual_input_1"] + output_names = ["output1"] + dummy_input = torch.randn(16, 3, 224, 224) + dummy_input = dummy_input.to(loc, non_blocking=False) torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names, opset_version=11) mox.file.copy_parallel(onnx_path, train_url + 'model.onnx') @@ -75,7 +76,7 @@ def convert_pth_to_onnx(config_args): return pth_file = pth_file_list[0] onnx_path = pth_file.split(".")[0] + '.onnx' - convert(pth_file, onnx_path, config_args.class_num, config_args.train_url) + convert(pth_file, onnx_path, config_args.class_num, config_args.train_url, config_args.npu) if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # modelarts @@ -94,4 +95,4 @@ if __name__ == '__main__': print('===========================') print(args) print('===========================') - convert_pth_to_onnx(args) + convert_pth_to_onnx(args) \ No newline at end of file diff --git a/PyTorch/contrib/cv/classification/MnasNet/modelArts/train-modelarts.py b/PyTorch/contrib/cv/classification/MnasNet/modelArts/train-modelarts.py index 77791c2842..26ad841ce2 100644 --- a/PyTorch/contrib/cv/classification/MnasNet/modelArts/train-modelarts.py +++ b/PyTorch/contrib/cv/classification/MnasNet/modelArts/train-modelarts.py @@ -349,7 +349,7 @@ def main_worker(npu, ngpus_per_node, args): acc1 = validate(val_loader, model, criterion, args) # remember best acc@1 and save checkpoint - is_best = acc1 > best_acc1 + is_best = acc1 >= best_acc1 best_acc1 = max(acc1, best_acc1) if not args.multiprocessing_distributed or (args.multiprocessing_distributed -- Gitee From 1c3351fe07df1a18a1c64dbe8aa9c120a124255a Mon Sep 17 00:00:00 2001 From: tanghongyan <1349905607@qq.com> Date: Wed, 30 Mar 2022 17:00:41 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=9B=9E=E5=BD=92=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../MnasNet/infer/convert/onnx2om.sh | 4 +- .../MnasNet/mnasnet_pthtar2onnx.py | 42 ------------------- 2 files changed, 2 insertions(+), 44 deletions(-) delete mode 100644 PyTorch/contrib/cv/classification/MnasNet/mnasnet_pthtar2onnx.py diff --git a/PyTorch/contrib/cv/classification/MnasNet/infer/convert/onnx2om.sh b/PyTorch/contrib/cv/classification/MnasNet/infer/convert/onnx2om.sh index 4d05ff4e7b..c1113f056e 100755 --- a/PyTorch/contrib/cv/classification/MnasNet/infer/convert/onnx2om.sh +++ b/PyTorch/contrib/cv/classification/MnasNet/infer/convert/onnx2om.sh @@ -16,12 +16,12 @@ model_path=$1 output_model_name=$2 aipp_cfg=$3 -/usr/local/Ascend/atc/bin/atc \ +atc \ --model=$model_path \ --framework=5 \ --output=$output_model_name \ --input_format=NCHW \ - --input_shape="image:1,3,256,256" \ + --input_shape="actual_input_1:1,3,256,256" \ --enable_small_channel=1 \ --log=error \ --soc_version=Ascend310 \ diff --git a/PyTorch/contrib/cv/classification/MnasNet/mnasnet_pthtar2onnx.py b/PyTorch/contrib/cv/classification/MnasNet/mnasnet_pthtar2onnx.py deleted file mode 100644 index 92d41e71aa..0000000000 --- a/PyTorch/contrib/cv/classification/MnasNet/mnasnet_pthtar2onnx.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2022 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import sys -import torch -import mnasnet -import torch.onnx - -from collections import OrderedDict - - -def convert(): - checkpoint = torch.load(input_file, map_location=None) - model = mnasnet.mnasnet1_0() - model.load_state_dict(checkpoint) - model.eval() - print(model) - - input_names = ["image"] - output_names = ["class"] - dynamic_axes = {'image': {0: '-1'}, 'class': {0: '-1'}} - dummy_input = torch.randn(1, 3, 224, 224) - - torch.onnx.export(model, dummy_input, "mnasnet1_0.onnx", input_names=input_names, output_names=output_names, - opset_version=11) - - -if __name__ == "__main__": - input_file = sys.argv[1] - convert() -- Gitee