From 452db3779aaf75901a902d86e774457c0e0448ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=89=E5=AE=8F=E6=A2=85?= <591861959@qq.com> Date: Mon, 21 Mar 2022 11:34:12 +0000 Subject: [PATCH 1/2] update --- .../cv/classification/Vgg16_ID1630_for_PyTorch/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py index 7b5057c7d3..cec7c5f1da 100644 --- a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py +++ b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/main.py @@ -236,14 +236,14 @@ def main_worker(gpu, ngpus_per_node, args): model = vgg16() model = model.to(loc) - optimizer = torch.optim.SGD(model.parameters(), args.lr, + optimizer = apex.optimizers.NpuFusedSGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss().to(loc) if args.amp: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale_value) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale_value,combine_grad=True) #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False) # optionally resume from a checkpoint @@ -293,7 +293,7 @@ def main_worker(gpu, ngpus_per_node, args): train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), - num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + num_workers=128, pin_memory=True, sampler=train_sampler, drop_last=True) val_loader = torch.utils.data.DataLoader( val_dataset, -- Gitee From d027ee0166b7442a50de01b0dc2ce2c4c9f484be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=89=E5=AE=8F=E6=A2=85?= <591861959@qq.com> Date: Mon, 21 Mar 2022 11:35:02 +0000 Subject: [PATCH 2/2] update --- .../cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py index b28b64969f..68be33f88a 100644 --- a/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py +++ b/PyTorch/contrib/cv/classification/Vgg16_ID1630_for_PyTorch/vgg.py @@ -55,13 +55,15 @@ class VGG(nn.Module): x = self.fc1(x) x = self.relu(x) if self.training: - x = x.cpu() - x = self.drop(x).npu() + # x = x.cpu() + # x = self.drop(x).npu() + x = self.drop(x) x = self.fc2(x) x = self.relu(x) if self.training: - x = x.cpu() - x = self.drop(x).npu() + # x = x.cpu() + # x = self.drop(x).npu() + x = self.drop(x) x = self.fc3(x) return x -- Gitee