diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/model/transformer.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/model/transformer.py index 623a2b6bc4830942b16c979631c0cbbdacf73dd4..5cad6c9d95623553c3c548983e062c662c372f52 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/model/transformer.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/model/transformer.py @@ -24,6 +24,8 @@ import math import itertools import numpy as np import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.nn as nn import torch.nn.functional as F @@ -204,7 +206,7 @@ class MultiHeadAttention(nn.Module): def _transpose_for_scores(self, x, bs, qlen): new_x_shape = (bs, qlen) + (self.n_heads, self.dim // self.n_heads) - return x.npu_confusion_transpose((0, 2, 1, 3), new_x_shape, False) + return torch.npu_confusion_transpose(x, (0, 2, 1, 3), new_x_shape, False) def forward(self, input, bs, qlen, mask, kv=None, cache=None): """ @@ -244,7 +246,7 @@ class MultiHeadAttention(nn.Module): weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) # context = unshape(context) # (bs, qlen, dim) - context = context.npu_confusion_transpose((0, 2, 1, 3), (context.size()[0] * context.size()[2], dim), True) + context = torch.npu_confusion_transpose(context, (0, 2, 1, 3), (context.size()[0] * context.size()[2], dim), True) return self.out_lin(context) diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py index 71f8e3e12d047008183c3b2364c91593e60ab55d..8574079ddb983696efc22b6a6dc41863ef443927 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py @@ -63,9 +63,10 @@ class NpuFusedAdamV2(NpuFusedAdam): step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 combined_param.data, exp_avg, exp_avg_sq = torch.npu_bert_apply_adam( - combined_param.data, exp_avg, exp_avg_sq, group['lr'], beta1, + group['lr'], beta1, beta2, group['eps'], combined_grad, 0.0, 0.0, group['weight_decay'], - step_size, adam_mode=1 + step_size, adam_mode=1, + out = (combined_param.data, exp_avg, exp_avg_sq) ) diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py index 0b60a1a09c5019da7f89c8989cdf6917e71b58d2..718cb7ed1463fa7a0564570b2002c8050572818a 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py @@ -26,6 +26,8 @@ from logging import getLogger from collections import OrderedDict import numpy as np import torch +if torch.__version__ >= "1.8.1": + import torch_npu from torch import nn from torch.nn import functional as F #from torch.nn.utils import clip_grad_norm_ diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py index 5bf2c44f0c99f73f336b510167f4becb4e109d48..dc0f82a7ec33379b9e94b756520d740f88649666 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py @@ -7,6 +7,8 @@ import math from typing import Any, Dict, List, Optional, Tuple import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.nn as nn from fairseq import utils from fairseq.models import ( diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py index 5d061ce0ca3cec518aa4cb1c209e8c5aca308706..253d3e51056e9e5c1504743ad23a5cfc1e0de3da 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py @@ -7,6 +7,8 @@ import math from typing import Dict, Optional, Tuple import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.nn.functional as F from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state @@ -162,7 +164,7 @@ class MultiheadAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = (self.batch_size, self.squence_length) + (self.num_attention_heads, self.attention_head_size) - return x.npu_confusion_transpose((0, 2, 1, 3), new_x_shape, False) + return torch.npu_confusion_transpose(x, (0, 2, 1, 3), new_x_shape, False) def forward( self, @@ -261,12 +263,12 @@ class MultiheadAttention(nn.Module): new_shape = (bsz, tgt_len) + (self.num_heads, self.head_dim) if k is not None: key_shape = (bsz, k.size(0) // bsz) + (self.num_heads, self.head_dim) - q = q.npu_confusion_transpose((0, 2, 1, 3), new_shape, False) + q = torch.npu_confusion_transpose(q, (0, 2, 1, 3), new_shape, False) if k is not None: - k = k.npu_confusion_transpose((0, 2, 1, 3), key_shape, False) + k = torch.npu_confusion_transpose(k, (0, 2, 1, 3), key_shape, False) if v is not None: - v = v.npu_confusion_transpose((0, 2, 1, 3), key_shape, False) + v = torch.npu_confusion_transpose(v, (0, 2, 1, 3), key_shape, False) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) @@ -369,7 +371,7 @@ class MultiheadAttention(nn.Module): # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: - attn =attn.npu_confusion_transpose((0, 2, 1, 3), + attn = torch.npu_confusion_transpose(attn, (0, 2, 1, 3), (attn.size()[0]* attn.size()[2], embed_dim), True) attn = self.out_proj(attn) diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py index c26cac1fb2a73bf34326416ae502f4b4c141111c..91c62ed5cb1f8dbf8db7e8ee0aafaf26ce91ab3e 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py @@ -10,6 +10,8 @@ from dataclasses import dataclass, field from typing import List import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.distributed as dist import torch.optim from fairseq.dataclass import FairseqDataclass @@ -200,10 +202,11 @@ class Adam(torch.optim.Optimizer): step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 # Decay the first and second moment running average coefficient p_data_fp32, exp_avg, exp_avg_sq = \ - torch.npu_bert_apply_adam(p_data_fp32, exp_avg, - exp_avg_sq, group["lr"], beta1, + torch.npu_bert_apply_adam(group["lr"], beta1, beta2, group["eps"], grad, 0.0, 0.0, - group["weight_decay"], step_size, adam_mode=1) + group["weight_decay"], step_size, adam_mode=1, + out = (p_data_fp32, exp_avg,exp_avg_sq) + ) if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py index 25332f9b38e63ec6d02752759a207596755d3fce..d96cd8ca34c034510d1cb72799351a968eb2496c 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py @@ -15,6 +15,8 @@ from itertools import chain from typing import Any, Dict, List import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.distributed as dist from fairseq import checkpoint_utils, distributed_utils, models, optim, utils from fairseq.file_io import PathManager