From eb73a44335302f0474cf853ae8039550b1afa3af Mon Sep 17 00:00:00 2001 From: xingjinliang Date: Thu, 24 Mar 2022 07:42:24 +0000 Subject: [PATCH 1/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=A2=84=E5=A4=84=E7=90=86=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../DCAP_ID2836_for_PyTorch/prefetcher.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py diff --git a/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py b/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py new file mode 100644 index 0000000000..be1cfe1118 --- /dev/null +++ b/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class Prefetcher(object): + """Prefetcher using on npu device. + + Origin Code URL: + https://github.com/implus/PytorchInsight/blob/master/classification/imagenet_fast.py#L280 + + Args: + loder (torch.utils.data.DataLoader or DataLoader like iterator): + Using to generate inputs after preprocessing. + stream (torch.npu.Stream): Default None. + Because of the limitation of NPU's memory mechanism, + if prefetcher is initialized repeatedly during training, + a defined stream should be introduced to prevent memory leakage; + if prefetcher is initialized only once during training, + a defined stream is not necessary. + + Returns: + float: tensors of shape (k, 5) and (k, 1). Labels are 0-based. + """ + + def __init__(self, loader, stream=None): + self.loader = iter(loader) + self.stream = stream if stream is not None else torch.npu.Stream() + self.preload() + + def preload(self): + try: + self.next_input, self.next_target = next(self.loader) + except StopIteration: + self.user = None + self.item = None + return + + with torch.npu.stream(self.stream): + self.next_input, self.next_target = self.next_input.to(torch.float), self.next_target.to(torch.float) + self.next_input, self.next_target = self.next_input.to(f'npu:{NPU_CALCULATE_DEVICE}', non_blocking=True), self.next_target.to(f'npu:{NPU_CALCULATE_DEVICE}', non_blocking=True) + + def next(self): + torch.npu.current_stream().wait_stream(self.stream) + next_input = self.next_input + next_target = self.next_target + if next_input is not None: + self.preload() + return next_input, next_target \ No newline at end of file -- Gitee From 30ca1e59af1fb28c153fad3a5ebc8b36ad7c7d46 Mon Sep 17 00:00:00 2001 From: xingjinliang Date: Thu, 24 Mar 2022 07:48:34 +0000 Subject: [PATCH 2/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B7=B7=E5=90=88?= =?UTF-8?q?=E7=B2=BE=E5=BA=A6+=E6=9B=BF=E6=8D=A2=E4=BA=B2=E5=92=8C?= =?UTF-8?q?=E6=80=A7=E4=BC=98=E5=8C=96=E5=99=A8=E6=8E=A5=E5=8F=A3+?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=A2=84=E5=A4=84=E7=90=86=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../DCAP_ID2836_for_PyTorch/DCAP.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py b/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py index d67d54af7d..745bd69fc9 100644 --- a/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py +++ b/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py @@ -178,8 +178,8 @@ class DCAP(nn.Module): self.embedding_dic[feat[0]](X[:, self.feature_index[feat[0]]].to(torch.int32)).reshape(X.shape[0], 1, -1) for feat in self.sparse_feature_columns] sparse_input = torch.cat(sparse_embedding, dim=1) - attn_mask = (torch.triu(torch.ones(len(self.sparse_feature_columns), len(self.sparse_feature_columns))) == 1) - attn_mask = attn_mask.float().masked_fill(attn_mask==0, float('-inf')).masked_fill(attn_mask==1, float(0.0)).to(f'npu:{NPU_CALCULATE_DEVICE}', non_blocking=True) + attn_mask = (torch.triu(torch.ones(len(self.sparse_feature_columns), len(self.sparse_feature_columns)).to(f'npu:{NPU_CALCULATE_DEVICE}')) == 1) + attn_mask = attn_mask.float().masked_fill(attn_mask==0, float('-inf')).masked_fill(attn_mask==1, float(0.0)) X, X_0 = sparse_input, sparse_input output = [] for layer in self.layers: @@ -238,7 +238,7 @@ if __name__ == '__main__': train_label = pd.DataFrame(train['label']) train = train.drop(columns=['label']) train_tensor_data = TensorDataset(torch.from_numpy(np.array(train)), torch.from_numpy(np.array(train_label))) - train_loader = DataLoader(train_tensor_data, shuffle=True, batch_size=batch_size) + train_loader = DataLoader(train_tensor_data, shuffle=True, batch_size=batch_size, pin_memory=True) test_label = pd.DataFrame(test['label']) test = test.drop(columns=['label']) @@ -246,37 +246,38 @@ if __name__ == '__main__': test_loader = DataLoader(test_tensor_data, batch_size=batch_size) loss_func = nn.BCELoss(reduction='mean') - # optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=lr, weight_decay=wd) - optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) - # model, optimizer = amp.initialize(model, optimizer, opt_level = 'O2', loss_scale = 128.0, combine_grad=True) + optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=lr, weight_decay=wd) + model, optimizer = amp.initialize(model, optimizer, opt_level = 'O2', loss_scale = 128.0, combine_grad=True) step = 0 for epoch in range(epoches): total_loss_epoch = 0.0 total_tmp = 0 model.train() - for index, (x, y) in enumerate(train_loader): + step = 0 + from prefetcher import Prefetcher + prefetcher = Prefetcher(train_loader) + x, y = prefetcher.next() + while x is not None: if step > 10: pass start_time = time.time() - #x, y = x.to(f'npu:{NPU_CALCULATE_DEVICE}').float(), y.to(f'npu:{NPU_CALCULATE_DEVICE}').float() - x, y = x.float(), y.float() - x, y = x.to(f'npu:{NPU_CALCULATE_DEVICE}', non_blocking=True), y.to(f'npu:{NPU_CALCULATE_DEVICE}', non_blocking=True) y_hat = model(x) optimizer.zero_grad() loss = loss_func(y_hat, y) - # with amp.scale_loss(loss,optimizer) as scaled_loss: - # scaled_loss.backward() - loss.backward() + with amp.scale_loss(loss,optimizer) as scaled_loss: + scaled_loss.backward() optimizer.step() - total_loss_epoch += loss.item() + total_loss_epoch += loss.detach() total_tmp += 1 step_time = time.time() - start_time FPS = batch_size / step_time step += 1 print("Epoch:{}, step:{}, Loss:{:.4f}, time/step(s):{:.4f}, FPS:{:.3f}".format(epoch,step,loss.item(),step_time,FPS)) - + if step == 79: + break + x, y = prefetcher.next() auc = get_auc(test_loader, model) print('epoch/epoches: {}/{}, train loss: {:.3f}, test auc: {:.3f}'.format(epoch, epoches, total_loss_epoch / total_tmp, auc)) -- Gitee From 6f9ddec394d7596815a5b50070278fac13b97d5a Mon Sep 17 00:00:00 2001 From: xingjinliang Date: Thu, 24 Mar 2022 07:52:51 +0000 Subject: [PATCH 3/4] =?UTF-8?q?=E5=AF=BC=E5=85=A5apex=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py b/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py index 745bd69fc9..f56b2e9b4e 100644 --- a/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py +++ b/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/DCAP.py @@ -45,6 +45,8 @@ from collections import OrderedDict, namedtuple, defaultdict import time import torch.npu import os +import apex +from apex import amp NPU_CALCULATE_DEVICE = 0 if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')): NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE')) -- Gitee From fb638896740f12b8d2574df2ceaa172fcb3dbbc5 Mon Sep 17 00:00:00 2001 From: xingjinliang Date: Thu, 24 Mar 2022 07:56:50 +0000 Subject: [PATCH 4/4] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E4=B8=8B=E6=B2=89NPU?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py b/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py index be1cfe1118..a02e80d661 100644 --- a/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py +++ b/PyTorch/dev/cv/image_classification/DCAP_ID2836_for_PyTorch/prefetcher.py @@ -52,7 +52,7 @@ class Prefetcher(object): with torch.npu.stream(self.stream): self.next_input, self.next_target = self.next_input.to(torch.float), self.next_target.to(torch.float) - self.next_input, self.next_target = self.next_input.to(f'npu:{NPU_CALCULATE_DEVICE}', non_blocking=True), self.next_target.to(f'npu:{NPU_CALCULATE_DEVICE}', non_blocking=True) + self.next_input, self.next_target = self.next_input.npu(non_blocking=True), self.next_target.npu(non_blocking=True) def next(self): torch.npu.current_stream().wait_stream(self.stream) -- Gitee