From 659e8cb6435a1eadf0435a014175c51984056e8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E4=BC=9F=E6=A0=B9?= <1101204667@qq.com> Date: Mon, 28 Mar 2022 06:03:00 +0000 Subject: [PATCH] update main.py. --- .../DeepLabV3+_ID0458_for_PyTorch/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/PyTorch/dev/cv/image_segmentation/DeepLabV3+_ID0458_for_PyTorch/main.py b/PyTorch/dev/cv/image_segmentation/DeepLabV3+_ID0458_for_PyTorch/main.py index 2ce87b2c20..a73edf6404 100644 --- a/PyTorch/dev/cv/image_segmentation/DeepLabV3+_ID0458_for_PyTorch/main.py +++ b/PyTorch/dev/cv/image_segmentation/DeepLabV3+_ID0458_for_PyTorch/main.py @@ -292,9 +292,9 @@ def main(): train_dst, val_dst = get_dataset(opts) train_loader = data.DataLoader( - train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2) + train_dst, batch_size=opts.batch_size, shuffle=True, pin_memory=True,num_workers=64) val_loader = data.DataLoader( - val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2) + val_dst, batch_size=opts.val_batch_size, shuffle=True, pin_memory=True,num_workers=2) print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(train_dst), len(val_dst))) @@ -400,8 +400,8 @@ def main(): for (images, labels) in train_loader: cur_itrs += 1 - images = images.to(f'npu:{NPU_CALCULATE_DEVICE}', dtype=torch.float32) - labels = labels.to(f'npu:{NPU_CALCULATE_DEVICE}', dtype=torch.long) + images = images.to(f'npu:{NPU_CALCULATE_DEVICE}', dtype=torch.float32, non_blocking=True) + labels = labels.to(f'npu:{NPU_CALCULATE_DEVICE}', dtype=torch.long, non_blocking=True) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) -- Gitee