From 1b94f33a108dbcbd1c44f887c8d2867d0914ebba Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Mon, 11 Dec 2023 19:25:49 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=97=AE=E9=A2=98=E5=8D=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md | 3 ++- AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md b/AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md index 6c4771cab4..1af06cb455 100644 --- a/AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md +++ b/AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md @@ -108,12 +108,13 @@ StarGAN是 Yunjey Choi 等人于 17年11月 提出的一个模型。该模型可 调用脚本`StarGAN_pre_processing.py`,可以获得推理结果 ``` - python3 StarGAN_pre_processing.py --result_dir './result_baseline' --attr_path './dataset/celeba/list_attr_celeba.txt' --celeba_image_dir './dataset/celeba/images' --batch_size 16 --ts_model_path "./stargan.ts" + python3 StarGAN_pre_processing.py --result_dir './result_baseline' --attr_path './dataset/celeba/list_attr_celeba.txt' --celeba_image_dir './dataset/celeba/images' --batch_size 16 --ts_model_path "./stargan.ts" --model_save_dir "./" ``` - 参数说明: - ts_model_path:ts模型文件路径 + - model_save_dir: pth模型文件保存路径 3. 性能测试。 diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py b/AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py index 1128689a0b..6912783c94 100644 --- a/AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py +++ b/AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py @@ -139,7 +139,7 @@ class Solver(object): os.makedirs(self.result_dir) ts_model = torch.jit.load(ts_model_path) - input_info = [torch_aie.Input((16, 3, 128, 128)), torch_aie.Input((16, 5))] + input_info = [torch_aie.Input((self.batch_size, 3, 128, 128)), torch_aie.Input((self.batch_size, 5))] torch_aie.set_device(0) print("start_compile") torchaie_model = torch_aie.compile( -- Gitee