diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md b/AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md index 6c4771cab4b7b6edfcc15476f477705f8d9587d9..1af06cb455d20dce061aa8a5c6472f82ddd19094 100644 --- a/AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md +++ b/AscendIE/TorchAIE/built-in/cv/gan/stargan/README.md @@ -108,12 +108,13 @@ StarGAN是 Yunjey Choi 等人于 17年11月 提出的一个模型。该模型可 调用脚本`StarGAN_pre_processing.py`,可以获得推理结果 ``` - python3 StarGAN_pre_processing.py --result_dir './result_baseline' --attr_path './dataset/celeba/list_attr_celeba.txt' --celeba_image_dir './dataset/celeba/images' --batch_size 16 --ts_model_path "./stargan.ts" + python3 StarGAN_pre_processing.py --result_dir './result_baseline' --attr_path './dataset/celeba/list_attr_celeba.txt' --celeba_image_dir './dataset/celeba/images' --batch_size 16 --ts_model_path "./stargan.ts" --model_save_dir "./" ``` - 参数说明: - ts_model_path:ts模型文件路径 + - model_save_dir: pth模型文件保存路径 3. 性能测试。 diff --git a/AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py b/AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py index 1128689a0b41539e212597b2256304df0dbcbfb6..6912783c9415c154604172df062a72b1d095fed8 100644 --- a/AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py +++ b/AscendIE/TorchAIE/built-in/cv/gan/stargan/solver.py @@ -139,7 +139,7 @@ class Solver(object): os.makedirs(self.result_dir) ts_model = torch.jit.load(ts_model_path) - input_info = [torch_aie.Input((16, 3, 128, 128)), torch_aie.Input((16, 5))] + input_info = [torch_aie.Input((self.batch_size, 3, 128, 128)), torch_aie.Input((self.batch_size, 5))] torch_aie.set_device(0) print("start_compile") torchaie_model = torch_aie.compile(