diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/README.md b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/README.md
index f61869b4734f29d983d79e5bfd34a15a67b3765a..82e9c8a1913fa5ce36346a0b96849f93895331aa 100644
--- a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/README.md
+++ b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/README.md
@@ -1,30 +1,151 @@
-## 构建虚拟环境
+# VGG16模型-推理指导
-`conda create --name vgg16 python=3.9`
-激活:`conda activate vgg16`
+- [开发者自测](#ZH-CN_TOPIC_0000001172161500)
+- [概述](#ZH-CN_TOPIC_0000001172161501)
-## 安装依赖
+ - [输入输出数据](#ZH-CN_TOPIC_0000001126281702)
-`pip3 install -r requirements.txt`
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
-编译pt插件,在dist目录下安装torh_aie
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
-## 下载pth模型
+ - [获取源码](#section4622531142816)
+ - [准备数据集](#section183221994411)
+ - [模型推理](#section741711594517)
-自行下载模型pth文件并放置在`vgg16`路径下
-链接:https://gitee.com/link?target=https%3A%2F%2Fdownload.pytorch.org%2Fmodels%2Fvgg16-397923af.pth
+- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573)
-## trace得到ts文件
+ ******
-将model_path改为自己tar模型的路径
-`python3 export.py --model_path=./vgg16-397923af.pth`
+# 开发者自测
-## 模型推理 - 获取精度
+本项目由开发者完成测试用例设计、以及验证,平均测试覆盖率为88%:
-将data_path改为自己目录下数据集label的路径
-`python3 run.py --data_path /home/pttest_models/datasets/ImageNet_50000/val`
+
-## 推理性能 - ts
+# 概述
-将--ts_path改为自己目录下的ts路径
-`python3 perf_right_one.py --mode=ts --ts_path=./shufflenetv1.ts`
\ No newline at end of file
+VGGNet是牛津大学计算机视觉组(Visual Geometry Group)和Google DeepMind公司的研究员一起研发的深度卷积神经网络,它探索了卷积神经网络的深度与其性能之间的关系,通过反复堆叠3*3的小型卷积核和2*2的最大池化层,成功地构筑了16~19层深的卷积神经网络。VGGNet相比之前state-of-the-art的网络结构,错误率大幅下降,VGGNet论文中全部使用了3*3的小型卷积核和2*2的最大池化核,通过不断加深网络结构来提升性能。
+VGG16包含了16个隐藏层(13个卷积层和3个全连接层)
+
+- 参考实现:
+
+ ```shell
+ url=https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
+ branch=master
+ commit_id=78ed10cc51067f1a6bac9352831ef37a3f842784
+ ```
+
+## 输入输出数据
+
+- 输入数据
+
+ | 输入数据 | 数据类型 | 大小 | 数据排布格式 |
+ | -------- | -------- | ------------------------- | ------------ |
+ | input | RGB_FP32 | batchsize x 3 x 224 x 224 | NCHW |
+
+- 输出数据
+
+ | 输出数据 | 数据类型 | 大小 | 数据排布格式 |
+ | -------- | -------- | ---------------- | ------------ |
+ | output1 | FLOAT32 | batchsize x 1000 | ND |
+
+# 推理环境准备
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+
+ | 配套 | 版本 | 环境准备指导 |
+ |---------| ------- | ------------------------------------------------------------ |
+ | 固件与驱动 | 23.0.rc1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) |
+ | CANN | 7.0.RC1.alpha003 | - |
+ | Python | 3.9.11 | - |
+ | PyTorch | 2.0.1 | - |
+ | Torch_AIE | 6.3.rc2 | - |
+
+- 该模型需要以下依赖
+
+ ```
+ pip install -r requirements.txt
+ ```
+
+# 快速上手
+
+## 准备数据集
+
+1. 获取原始数据集。(解压命令参考tar –xvf \*.tar与 unzip \*.zip)
+
+ 本模型支持ImageNet 50000张图片的验证集。以ILSVRC2012为例,请用户需自行获取ILSVRC2012数据集,上传数据集到服务器任意目录并解压(假设 `dataset_dir=/home/HwHiAiUser/dataset`)。
+
+ 数据目录结构请参考:
+ ```
+ |-- dataset
+ |-- imagenet
+ |-- val
+ |-- n01440764
+ |-- n01443537
+ |-- ...
+ ```
+
+## 模型推理
+
+1. 模型转换。
+
+ 使用PyTorch将模型权重文件.pth转换为.ts文件
+
+ 1. 获取权重文件。
+
+ 从源码包中获取权重文件:“vgg16-397923af.pth” 或者通过[下载链接](https://download.pytorch.org/models/vgg16-397923af.pth)
+
+ 2. 导出ts文件。
+
+ 使用 `export.py` 导出文件。
+
+ ```shell
+ python3 export.py --model_path ./vgg16-397923af.pth --ts_save_path vgg16.ts --batch_size 1
+ ```
+
+ - 参数说明:
+
+ - model_path:pth模型文件所在路径
+ - ts_save_path:ts文件保存的路径
+ - batch_size:bs数
+
+2. 开始推理验证。
+
+ 1. 精度验证
+ ```shell
+ python3 run.py --data_path ./datasets/imagenet/val/ --ts_model_path ./vgg16.ts --batch_size 1 --image_size 224
+ ```
+
+ - 参数说明:
+
+ - data_path:数据集所在路径
+ - ts_model_path:ts模型文件所在路径
+ - batch_size:bs数
+ - image_size:图像尺寸
+
+ 2. 性能验证
+ ```shell
+ python3 perf.py --ts_path ./vgg16.ts --batch_size 1 --image_size 224
+ ```
+
+ - 参数说明:
+
+ - ts_path:ts模型文件所在路径
+ - batch_size:bs数
+ - image_size:图像尺寸
+
+# 模型推理性能&精度
+
+调用ACL接口推理计算,性能参考下列数据。
+
+| 芯片型号 | Batch Size | 数据集 | 精度 | 性能 |
+| --------- | ---------------- | ---------- | ---------- | --------------- |
+| Ascend310P3 | 1 | ILSVRC2012 | 95.512% | 796 FPS |
+| Ascend310P3 | 4 | ILSVRC2012 | 95.512% | 877 FPS |
+| Ascend310P3 | 8 | ILSVRC2012 | 95.512% | 900 FPS |
+| Ascend310P3 | 16 | ILSVRC2012 | 95.512% | 922 FPS |
+| Ascend310P3 | 32 | ILSVRC2012 | 95.512% | 941 FPS |
+| Ascend310P3 | 64 | ILSVRC2012 | 95.512% | 925 FPS |
diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/export.py b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/export.py
index d6dfc4340195112af32496d7064a404eb2d9c8af..7639fdfb984222da44191625ad797eb7b949a1dc 100644
--- a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/export.py
+++ b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/export.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import sys
import os
import argparse
@@ -27,22 +26,24 @@ def parse_args():
)
parser.add_argument('--ts_save_path', help='VGG16 torch script model save path', type=str,
default='vgg16.ts')
-
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
args = parser.parse_args()
return args
+
def check_args(args):
if not os.path.exists(args.model_path):
raise FileNotFoundError(f'VGG16 model file {args.model_path} not exists')
-def trace_ts_model(model_path, ts_save_path):
+
+def trace_ts_model(model_path, ts_save_path, batch_size):
# load model
model = models.vgg16(pretrained=False)
model.load_state_dict(torch.load(model_path))
model.eval()
# trace model
- input_data = torch.ones(1, 3, 224, 224)
+ input_data = torch.ones(batch_size, 3, 224, 224)
ts_model = torch.jit.trace(model, input_data)
ts_model.save(ts_save_path)
print(f'VGG16 torch script model saved to {ts_save_path}')
@@ -54,5 +55,5 @@ if __name__ == '__main__':
check_args(opts)
# load & trace model
- trace_ts_model(opts.model_path, opts.ts_save_path)
- print("Finish Tracing VGG16 model")
\ No newline at end of file
+ trace_ts_model(opts.model_path, opts.ts_save_path, opts.batch_size)
+ print("Finish Tracing VGG16 model")
diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/perf.py b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/perf.py
new file mode 100644
index 0000000000000000000000000000000000000000..6834bc3a8937ca6932cd703f939bbd0dee7be759
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/perf.py
@@ -0,0 +1,95 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import time
+from tqdm import tqdm
+
+import torch
+import numpy as np
+
+import torch_aie
+from torch_aie import _enums
+
+
+def parse_args():
+ args = argparse.ArgumentParser(description="A program that operates in 'om' or 'ts' mode.")
+ args.add_argument('--ts_path',help='MobilenetV1 ts file path', type=str,
+ default='/onnx/vgg16-torchaie/vgg16.ts'
+ )
+ args.add_argument("--batch_size", type=int, default=1, help="batch size.")
+ args.add_argument('--image_size', type=int, default=224, help='Image size')
+ return args.parse_args()
+
+
+def perf(torchaie_model, batch_size, image_size):
+ dummy_input = np.random.randn(batch_size, 3, image_size, image_size).astype(np.float32)
+ input_tensor = torch.Tensor(dummy_input)
+ input_tensor = input_tensor.to("npu:0")
+ loops = 100
+ warm_ctr = 10
+
+ default_stream = torch_aie.npu.default_stream()
+ time_cost = 0
+
+ while warm_ctr:
+ _ = torchaie_model(input_tensor)
+ default_stream.synchronize()
+ warm_ctr -= 1
+
+ print("send to npu")
+ input_tensor = input_tensor.to("npu:0")
+ print("finish sent")
+ for i in range(loops):
+ t0 = time.time()
+ _ = torchaie_model(input_tensor)
+ default_stream.synchronize()
+ t1 = time.time()
+ time_cost += (t1 - t0)
+
+ print(f"fps: {loops} * {batch_size} / {time_cost : .3f} samples/s")
+ print("torch_aie fps: ", loops * batch_size / time_cost)
+
+ from datetime import datetime
+ current_time = datetime.now()
+ formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
+ print("Current Time:", formatted_time)
+
+
+
+if __name__ == '__main__':
+ infer_times = 100
+ pt_cost = 0
+ opts = parse_args()
+ ts_path = opts.ts_path
+ batch_size = opts.batch_size
+ image_size = opts.image_size
+
+ ts_model = torch.jit.load(ts_path)
+
+ input_info = [torch_aie.Input((batch_size, 3, image_size, image_size))]
+
+ torch_aie.set_device(0)
+ print("start compile")
+ torchaie_model = torch_aie.compile(
+ ts_model,
+ inputs=input_info,
+ precision_policy=_enums.PrecisionPolicy.FP16,
+ soc_version='Ascend310P3',
+ )
+ print("end compile")
+ torchaie_model.eval()
+
+ perf(torchaie_model, batch_size, image_size)
diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/perf_right_one.py b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/perf_right_one.py
deleted file mode 100644
index 3986dca00fea4c716778f8449e6847ac7d4a4065..0000000000000000000000000000000000000000
--- a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/perf_right_one.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import argparse
-import time
-from tqdm import tqdm
-
-import torch
-import numpy as np
-
-import torch_aie
-from torch_aie import _enums
-
-from ais_bench.infer.interface import InferSession
-
-INPUT_WIDTH = 224
-INPUT_HEIGHT = 224
-
-def parse_args():
- args = argparse.ArgumentParser(description="A program that operates in 'om' or 'ts' mode.")
- args.add_argument("--mode", choices=["om", "ts"], required=True, help="Specify the mode ('om' or 'ts').")
- args.add_argument('--om_path',help='MobilenetV1 om file path', type=str,
- default='/onnx/mobilenetv1/mobilenet-v1_bs1.om'
- )
- args.add_argument('--ts_path',help='MobilenetV1 ts file path', type=str,
- default='/onnx/vgg16-torchaie/vgg16.ts'
- )
- args.add_argument("--batch-size", type=int, default=4, help="batch size.")
- return args.parse_args()
-
-if __name__ == '__main__':
- infer_times = 100
- om_cost = 0
- pt_cost = 0
- opts = parse_args()
- OM_PATH = opts.om_path
- TS_PATH = opts.ts_path
- BATCH_SIZE = opts.batch_size
-
- if opts.mode == "om":
- om_model = InferSession(0, OM_PATH)
- for _ in tqdm(range(0, infer_times)):
- dummy_input = np.random.randn(BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT).astype(np.float32)
- start = time.time()
- output = om_model.infer([dummy_input], 'static', custom_sizes=90000000)
- cost = time.time() - start
- om_cost += cost
-
- if opts.mode == "ts":
- ts_model = torch.jit.load(TS_PATH)
-
- input_info = [torch_aie.Input((BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT))]
-
- torch_aie.set_device(0)
- print("start compile")
- torchaie_model = torch_aie.compile(
- ts_model,
- inputs=input_info,
- precision_policy=_enums.PrecisionPolicy.FP16,
- soc_version='Ascend310P3',
- )
- print("end compile")
- torchaie_model.eval()
-
- dummy_input = np.random.randn(BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT).astype(np.float32)
- input_tensor = torch.Tensor(dummy_input)
- input_tensor = input_tensor.to("npu:0")
- loops = 100
- warm_ctr = 10
-
- default_stream = torch_aie.npu.default_stream()
- time_cost = 0
-
- while warm_ctr:
- _ = torchaie_model(input_tensor)
- default_stream.synchronize()
- warm_ctr -= 1
-
- print("send to npu")
- input_tensor = input_tensor.to("npu:0")
- print("finish sent")
- for i in range(loops):
- t0 = time.time()
- _ = torchaie_model(input_tensor)
- default_stream.synchronize()
- t1 = time.time()
- time_cost += (t1 - t0)
-
- print(f"fps: {loops} * {BATCH_SIZE} / {time_cost : .3f} samples/s")
- print("torch_aie fps: ", loops * BATCH_SIZE / time_cost)
-
- from datetime import datetime
- current_time = datetime.now()
- formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
- print("Current Time:", formatted_time)
-
diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/run.py b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/run.py
index 06eac084a5867ebd39277675a0d4bc9db3a7a1d3..e8fd533da5ad6764072c346fb01cc28fcb33c78d 100644
--- a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/run.py
+++ b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/run.py
@@ -1,23 +1,42 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
import argparse
+from tqdm.auto import tqdm
+
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
-from tqdm.auto import tqdm
import torch_aie
from torch_aie import _enums
+TEST_MODE = False
+
def parse_args():
parser = argparse.ArgumentParser(description='VGG16 Evaluation.')
- parser.add_argument('--data_path', type=str, default='/home/devkit1/xiefeng/datasets/imagenet/val/',
+ parser.add_argument('--data_path', type=str, default='/home/ascend/datasets/imagenet/val/',
help='Evaluation dataset path')
- parser.add_argument('--ts_model_path', type=str, default='./vgg16.ts',
+ parser.add_argument('--ts_model_path', type=str, default='../vgg16-torchaie/vgg16.ts',
help='Original TorchScript model path')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--image_size', type=int, default=224, help='Image size')
return parser.parse_args()
-def compute_acc(y_pred, y_true, topk_list=(1, 5)):
+
+def compute_acc(y_pred, y_true, topk_list=(1, 5)):
maxk = max(topk_list)
batch_size = y_true.size(0)
@@ -31,6 +50,7 @@ def compute_acc(y_pred, y_true, topk_list=(1, 5)):
res.append(correct_k.mul_(100.0 / batch_size))
return res
+
def validate(model, args):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
@@ -60,12 +80,16 @@ def validate(model, args):
step = i + 1
if step % 100 == 0:
print(f'top1 is {avg_top1 / step}, top5 is {avg_top5 / step}, step is {step}')
+
+ if TEST_MODE:
+ if i > 1:
+ break
if __name__ == '__main__':
args = parse_args()
ts_model = torch.jit.load(args.ts_model_path)
- input_info = [torch_aie.Input((1, 3, 224, 224))]
+ input_info = [torch_aie.Input((args.batch_size, 3, args.image_size, args.image_size))]
torchaie_model = torch_aie.compile(
ts_model,
inputs=input_info,
@@ -73,4 +97,4 @@ if __name__ == '__main__':
soc_version='Ascend310P3'
)
torchaie_model.eval()
- validate(torchaie_model, args)
\ No newline at end of file
+ validate(torchaie_model, args)
diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_export.py b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..58edd8ab0942f8bcc6e216adbff11c5923163fff
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_export.py
@@ -0,0 +1,129 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+import argparse
+import os
+import shutil
+from unittest.mock import patch
+
+import torch
+from torchvision import models
+
+from export import check_args
+from export import trace_ts_model
+from export import parse_args
+
+
+class TestCheckArgs(unittest.TestCase):
+ def test_existing_model_path(self):
+ """
+ Existing Model Path
+ Input: A model path that exists
+ Expected Output: The function should not raise an exception.
+ """
+ # Create a temporary file to use as the model path
+ model_path = 'existing_model.pth'
+ # create the file
+ with open(model_path, 'w') as f:
+ f.write('dummy content')
+
+ # Create a namespace with the model_path attribute
+ args = argparse.Namespace(model_path=model_path)
+
+ # Call the function and it should not raise an exception
+ check_args(args)
+
+ # Clean up the temporary file
+ os.remove(model_path)
+
+ def test_nonexistent_model_path(self):
+ """
+ Nonexistent Model Path
+ Input: A model path that does not exist
+ Expected Output: The function should raise a FileNotFoundError with the correct error message.
+ """
+ # Create a namespace with a non-existent model path
+ args = argparse.Namespace(model_path='nonexistent_model.pth')
+
+ # Call the function, and it should raise a FileNotFoundError
+ with self.assertRaises(FileNotFoundError) as context:
+ check_args(args)
+
+ # Check that the correct error message is raised
+ expected_message = f'VGG16 model file {args.model_path} not exists'
+ self.assertEqual(str(context.exception), expected_message)
+
+
+class TestTraceTSModel(unittest.TestCase):
+ def setUp(self):
+ # Create a temporary directory for saving the traced model
+ self.temp_dir = 'temp_dir'
+ os.makedirs(self.temp_dir, exist_ok=True)
+
+ def tearDown(self):
+ # Remove the temporary directory and its contents
+ shutil.rmtree(self.temp_dir)
+
+ def test_trace_ts_model(self):
+ """
+ Successful Tracing
+ Input: A valid model path, valid save path, and a reasonable batch size
+ Expected Output: The function should successfully trace the model, and the traced model file should be created.
+ """
+ # Paths for the model and traced model
+ model_path = 'vgg16.pth'
+ ts_save_path = os.path.join(self.temp_dir, 'traced_model.pt')
+ batch_size = 1
+
+ # Create a dummy VGG16 model file
+ vgg16_model = models.vgg16(pretrained=False)
+ torch.save(vgg16_model.state_dict(), model_path)
+
+ # Call the function
+ trace_ts_model(model_path, ts_save_path, batch_size)
+
+ # Check if the traced model file exists
+ self.assertTrue(os.path.exists(ts_save_path))
+
+ # Clean up the dummy VGG16 model file
+ os.remove(model_path)
+
+class TestParseArgs(unittest.TestCase):
+ @patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ model_path='./test_model.pth',
+ ts_save_path='test_model.ts',
+ batch_size=32
+ ))
+ def test_parse_args_with_arguments(self, mock_parse_args):
+ """
+ Custom Arguments
+ Input: Command-line arguments specifying custom values
+ Expected Output: Arguments should reflect the provided values.
+ """
+ # Call the function
+ args = parse_args()
+
+ # Check if the function returns the expected arguments
+ self.assertEqual(args.model_path, './test_model.pth')
+ self.assertEqual(args.ts_save_path, 'test_model.ts')
+ self.assertEqual(args.batch_size, 32)
+
+ # Verify that argparse.ArgumentParser.parse_args was called
+ mock_parse_args.assert_called_once()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_perf.py b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_perf.py
new file mode 100644
index 0000000000000000000000000000000000000000..3157e2f7604ec660cc39145b9a679e9ae73d7778
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_perf.py
@@ -0,0 +1,70 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+import argparse
+from unittest.mock import patch
+
+import torch
+
+from perf import parse_args, perf
+import torch_aie
+from torch_aie import _enums
+
+class TestVGG16Perf(unittest.TestCase):
+
+ def test_parse_args(self):
+ """
+ Custom Arguments
+ Input: Command-line arguments specifying custom values
+ Expected Output: Arguments should reflect the provided values.
+ """
+ with patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ ts_path='../vgg16-torchaie/vgg16.ts',
+ batch_size=16,
+ image_size=224
+ )):
+ args = parse_args()
+
+ self.assertEqual(args.ts_path, '../vgg16-torchaie/vgg16.ts')
+ self.assertEqual(args.batch_size, 16)
+ self.assertEqual(args.image_size, 224)
+
+ def test_perf(self):
+ """
+ Performance Evaluation
+ Input: A model and its parameters
+ Expected Output: Validate that the model is correctly evaluated on its performance
+ """
+ with patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ ts_path='../vgg16-torchaie/vgg16.ts',
+ batch_size=16,
+ image_size=224
+ )):
+ args = parse_args()
+ ts_model = torch.jit.load(args.ts_path)
+ input_info = [torch_aie.Input((args.batch_size, 3, args.image_size, args.image_size))]
+ torchaie_model = torch_aie.compile(
+ ts_model,
+ inputs=input_info,
+ precision_policy=_enums.PrecisionPolicy.FP16,
+ soc_version='Ascend310P3'
+ )
+ torchaie_model.eval()
+ perf(torchaie_model, args.batch_size, args.image_size)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_run.py b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_run.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ff5dd6eeee9f0edfdda2165d1fc698e9a505d68
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/test_run.py
@@ -0,0 +1,74 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+import argparse
+from unittest.mock import patch
+
+import torch
+
+from run import parse_args, validate
+import torch_aie
+from torch_aie import _enums
+
+
+class TestVGG16Evaluation(unittest.TestCase):
+
+ def test_parse_args(self):
+ """
+ Custom Arguments
+ Input: Command-line arguments specifying custom values
+ Expected Output: Arguments should reflect the provided values.
+ """
+ with patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ data_path='./datasets/imagenet/val/',
+ ts_model_path='./vgg16.ts',
+ batch_size=16,
+ image_size=224
+ )):
+ args = parse_args()
+
+ self.assertEqual(args.data_path, './datasets/imagenet/val/')
+ self.assertEqual(args.ts_model_path, './vgg16.ts')
+ self.assertEqual(args.batch_size, 16)
+ self.assertEqual(args.image_size, 224)
+
+ def test_validate(self):
+ """
+ Model Evaluation
+ Input: A model and dataset argument
+ Expected Output: Validate that the model is correctly evaluated on the dataset.
+ """
+ with patch('argparse.ArgumentParser.parse_args', return_value=argparse.Namespace(
+ data_path='./datasets/imagenet/val/',
+ ts_model_path='./vgg16.ts',
+ batch_size=16,
+ image_size=224
+ )):
+ args = parse_args()
+ ts_model = torch.jit.load(args.ts_model_path)
+ input_info = [torch_aie.Input((args.batch_size, 3, args.image_size, args.image_size))]
+ torchaie_model = torch_aie.compile(
+ ts_model,
+ inputs=input_info,
+ precision_policy=_enums.PrecisionPolicy.FP16,
+ soc_version='Ascend310P3'
+ )
+ torchaie_model.eval()
+ validate(torchaie_model, args)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/vgg16_pth2onnx.py b/AscendIE/TorchAIE/built-in/cv/classification/vgg16/vgg16_pth2onnx.py
deleted file mode 100644
index f2bca78e78252c40c4802111f318364e68304385..0000000000000000000000000000000000000000
--- a/AscendIE/TorchAIE/built-in/cv/classification/vgg16/vgg16_pth2onnx.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Copyright 2022 Huawei Technologies Co., Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-import ssl
-import torchvision.models as models
-import argparse
-
-
-def convert(args):
- model = models.vgg16(pretrained=False)
- ckpt = torch.load(args.pth_path, map_location=torch.device('cpu'))
- model.load_state_dict(ckpt)
- model.eval()
- input_names = ["actual_input_1"]
- output_names = ["output1"]
- dummy_input = torch.randn(16, 3, 224, 224)
- dynamic_axes = {'actual_input_1': {0: '-1'}, 'output1': {0: '-1'}}
- torch.onnx.export(model, dummy_input, args.out,
- input_names=input_names,
- dynamic_axes=dynamic_axes,
- output_names=output_names,
- opset_version=11)
-
-
-if __name__ == "__main__":
- ssl._create_default_https_context = ssl._create_unverified_context
- parser = argparse.ArgumentParser()
- parser.add_argument('--out', help='onnx output name')
- parser.add_argument('--pth_path', help='model pth path')
- args = parser.parse_args()
- convert(args)
\ No newline at end of file