diff --git a/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/data.py b/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/data.py index ebdafffd1f85ca5e91a6ad0322b0022c742dcb3c..4c7fb24a47e9721405b49bb00ae73b2e7cf1be3f 100644 --- a/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/data.py +++ b/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/data.py @@ -123,8 +123,8 @@ class SampleGenerator(object): users.append(int(row.userId)) items.append(int(row.negatives[i])) ratings.append(float(0)) # negative samples get 0 rating - dataset = UserItemRatingDataset(user_tensor=torch.LongTensor(users), - item_tensor=torch.LongTensor(items), + dataset = UserItemRatingDataset(user_tensor=torch.IntTensor(users), + item_tensor=torch.IntTensor(items), target_tensor=torch.FloatTensor(ratings)) return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=16) diff --git a/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/engine.py b/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/engine.py index 41c71feef54c8aad8721a3fe4b7d4685f9fcbbda..405a722b3ee4b7209b004283ce60ff33eb639fe7 100644 --- a/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/engine.py +++ b/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/engine.py @@ -72,30 +72,30 @@ class Engine(object): def train_single_batch(self, users, items, ratings): assert hasattr(self, 'model'), 'Please specify the exact model !' - #if self.config['use_npu'] is True: - users, items, ratings = users.npu(), items.npu(), ratings.npu() self.opt.zero_grad() ratings_pred = self.model(users, items) loss = self.crit(ratings_pred.view(-1), ratings) - #loss.backward() with amp.scale_loss(loss, self.opt) as scaled_loss: scaled_loss.backward() self.opt.step() - loss = loss.item() + loss = loss.detach() return loss def train_an_epoch(self, train_loader, epoch_id): assert hasattr(self, 'model'), 'Please specify the exact model !' self.model.train() total_loss = 0 - for batch_id, batch in enumerate(train_loader): + batch_id = 0 + from prefetcher import Prefetcher + prefetcher = Prefetcher(train_loader) + user, item, rating = prefetcher.next() + while user is not None: start_time=time.time() - assert isinstance(batch[0], torch.LongTensor) - user, item, rating = batch[0], batch[1], batch[2] - rating = rating.float() loss = self.train_single_batch(user, item, rating) print('[Training Epoch {}] Batch {}/{}, Loss:{:.3f}, Train-time:{:.4f}'.format(epoch_id, batch_id, len(train_loader), loss,time.time()-start_time)) total_loss += loss + batch_id += 1 + user, item, rating = prefetcher.next() self._writer.add_scalar('model/loss', total_loss, epoch_id) def evaluate(self, evaluate_data, epoch_id): diff --git a/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/prefetcher.py b/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/prefetcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5e05ead50092cc817ea11c9a17e4a724d0e0deb8 --- /dev/null +++ b/PyTorch/dev/cv/image_classification/NeuMF_ID0351_for_PyTorch/src/prefetcher.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class Prefetcher(object): + """Prefetcher using on npu device. + + Origin Code URL: + https://github.com/implus/PytorchInsight/blob/master/classification/imagenet_fast.py#L280 + + Args: + loder (torch.utils.data.DataLoader or DataLoader like iterator): + Using to generate inputs after preprocessing. + stream (torch.npu.Stream): Default None. + Because of the limitation of NPU's memory mechanism, + if prefetcher is initialized repeatedly during training, + a defined stream should be introduced to prevent memory leakage; + if prefetcher is initialized only once during training, + a defined stream is not necessary. + + Returns: + float: tensors of shape (k, 5) and (k, 1). Labels are 0-based. + """ + + def __init__(self, loader, stream=None): + self.loader = iter(loader) + self.stream = stream if stream is not None else torch.npu.Stream() + self.preload() + + def preload(self): + try: + self.user, self.item, self.rating = next(self.loader) + assert isinstance(self.user, torch.IntTensor) + self.rating = self.rating.float() + except StopIteration: + self.user = None + self.item = None + return + + with torch.npu.stream(self.stream): + self.user = self.user.npu(non_blocking=True) + self.item = self.item.npu(non_blocking=True) + self.rating = self.rating.npu(non_blocking=True) + + def next(self): + torch.npu.current_stream().wait_stream(self.stream) + user = self.user + item = self.item + rating = self.rating + if user is not None: + self.preload() + return user, item, rating