diff --git a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py index d1f7ab636b3653c7c86f4ea1ee2c9de2f6c76790..edb433043c144bf09607015e320b6f8547741d74 100644 --- a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py +++ b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py @@ -236,7 +236,7 @@ def main_worker(gpu, ngpus_per_node, args): model = vgg16() model = model.to(loc) - optimizer = apex.optimizers.NpuFusedSGDSGD(model.parameters(), args.lr, + optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)