From dfbdc752044eeb16641b7f98c62510de8e59f139 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=89=E5=AE=8F=E6=A2=85?= <591861959@qq.com> Date: Fri, 25 Mar 2022 09:44:27 +0000 Subject: [PATCH 1/2] update --- .../cv/classification/Vgg16_ID1630_for_PyTorch/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py index 7b5057c7d3..d1f7ab636b 100644 --- a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py +++ b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py @@ -236,14 +236,14 @@ def main_worker(gpu, ngpus_per_node, args): model = vgg16() model = model.to(loc) - optimizer = torch.optim.SGD(model.parameters(), args.lr, + optimizer = apex.optimizers.NpuFusedSGDSGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss().to(loc) if args.amp: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale_value) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale_value,combine_grad=True) #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False) # optionally resume from a checkpoint -- Gitee From abde93e6a4498301eaaaf44b0f63b765519d289c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=89=E5=AE=8F=E6=A2=85?= <591861959@qq.com> Date: Fri, 25 Mar 2022 09:45:04 +0000 Subject: [PATCH 2/2] update --- .../cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py index b28b64969f..2f87dff860 100644 --- a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py +++ b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py @@ -55,13 +55,11 @@ class VGG(nn.Module): x = self.fc1(x) x = self.relu(x) if self.training: - x = x.cpu() - x = self.drop(x).npu() + x = self.drop(x) x = self.fc2(x) x = self.relu(x) if self.training: - x = x.cpu() - x = self.drop(x).npu() + x = self.drop(x) x = self.fc3(x) return x -- Gitee