From 25f441261a3bef855bac6d843da6861fb7a31715 Mon Sep 17 00:00:00 2001 From: junqiang521 Date: Wed, 23 Mar 2022 14:30:41 +0800 Subject: [PATCH 1/2] XLM mBART pytorch1.8 adaption --- .../XLM_ID0740_for_PyTorch/xlm/model/transformer.py | 6 ++++-- .../built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py | 4 ++-- .../nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py | 2 ++ .../fairseq/models/transformer.py | 2 ++ .../fairseq/modules/multihead_attention.py | 12 +++++++----- .../mBART_ID2372_for_PyTorch/fairseq/optim/adam.py | 9 ++++++--- .../nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py | 2 ++ 7 files changed, 25 insertions(+), 12 deletions(-) diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/model/transformer.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/model/transformer.py index 623a2b6bc4..5cad6c9d95 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/model/transformer.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/model/transformer.py @@ -24,6 +24,8 @@ import math import itertools import numpy as np import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.nn as nn import torch.nn.functional as F @@ -204,7 +206,7 @@ class MultiHeadAttention(nn.Module): def _transpose_for_scores(self, x, bs, qlen): new_x_shape = (bs, qlen) + (self.n_heads, self.dim // self.n_heads) - return x.npu_confusion_transpose((0, 2, 1, 3), new_x_shape, False) + return torch.npu_confusion_transpose(x, (0, 2, 1, 3), new_x_shape, False) def forward(self, input, bs, qlen, mask, kv=None, cache=None): """ @@ -244,7 +246,7 @@ class MultiHeadAttention(nn.Module): weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) # context = unshape(context) # (bs, qlen, dim) - context = context.npu_confusion_transpose((0, 2, 1, 3), (context.size()[0] * context.size()[2], dim), True) + context = torch.npu_confusion_transpose(context, (0, 2, 1, 3), (context.size()[0] * context.size()[2], dim), True) return self.out_lin(context) diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py index 71f8e3e12d..f23f0c9647 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py @@ -63,9 +63,9 @@ class NpuFusedAdamV2(NpuFusedAdam): step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 combined_param.data, exp_avg, exp_avg_sq = torch.npu_bert_apply_adam( - combined_param.data, exp_avg, exp_avg_sq, group['lr'], beta1, + group['lr'], beta1, beta2, group['eps'], combined_grad, 0.0, 0.0, group['weight_decay'], - step_size, adam_mode=1 + out = (combined_param.data, exp_avg, exp_avg_sq) ) diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py index 0b60a1a09c..718cb7ed14 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/trainer.py @@ -26,6 +26,8 @@ from logging import getLogger from collections import OrderedDict import numpy as np import torch +if torch.__version__ >= "1.8.1": + import torch_npu from torch import nn from torch.nn import functional as F #from torch.nn.utils import clip_grad_norm_ diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py index 5bf2c44f0c..dc0f82a7ec 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/models/transformer.py @@ -7,6 +7,8 @@ import math from typing import Any, Dict, List, Optional, Tuple import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.nn as nn from fairseq import utils from fairseq.models import ( diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py index 5d061ce0ca..253d3e5105 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/modules/multihead_attention.py @@ -7,6 +7,8 @@ import math from typing import Dict, Optional, Tuple import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.nn.functional as F from fairseq import utils from fairseq.incremental_decoding_utils import with_incremental_state @@ -162,7 +164,7 @@ class MultiheadAttention(nn.Module): def transpose_for_scores(self, x): new_x_shape = (self.batch_size, self.squence_length) + (self.num_attention_heads, self.attention_head_size) - return x.npu_confusion_transpose((0, 2, 1, 3), new_x_shape, False) + return torch.npu_confusion_transpose(x, (0, 2, 1, 3), new_x_shape, False) def forward( self, @@ -261,12 +263,12 @@ class MultiheadAttention(nn.Module): new_shape = (bsz, tgt_len) + (self.num_heads, self.head_dim) if k is not None: key_shape = (bsz, k.size(0) // bsz) + (self.num_heads, self.head_dim) - q = q.npu_confusion_transpose((0, 2, 1, 3), new_shape, False) + q = torch.npu_confusion_transpose(q, (0, 2, 1, 3), new_shape, False) if k is not None: - k = k.npu_confusion_transpose((0, 2, 1, 3), key_shape, False) + k = torch.npu_confusion_transpose(k, (0, 2, 1, 3), key_shape, False) if v is not None: - v = v.npu_confusion_transpose((0, 2, 1, 3), key_shape, False) + v = torch.npu_confusion_transpose(v, (0, 2, 1, 3), key_shape, False) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) @@ -369,7 +371,7 @@ class MultiheadAttention(nn.Module): # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: - attn =attn.npu_confusion_transpose((0, 2, 1, 3), + attn = torch.npu_confusion_transpose(attn, (0, 2, 1, 3), (attn.size()[0]* attn.size()[2], embed_dim), True) attn = self.out_proj(attn) diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py index c26cac1fb2..3515945fd8 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py @@ -10,6 +10,8 @@ from dataclasses import dataclass, field from typing import List import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.distributed as dist import torch.optim from fairseq.dataclass import FairseqDataclass @@ -200,10 +202,11 @@ class Adam(torch.optim.Optimizer): step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 # Decay the first and second moment running average coefficient p_data_fp32, exp_avg, exp_avg_sq = \ - torch.npu_bert_apply_adam(p_data_fp32, exp_avg, - exp_avg_sq, group["lr"], beta1, + torch.npu_bert_apply_adam(group["lr"], beta1, beta2, group["eps"], grad, 0.0, 0.0, - group["weight_decay"], step_size, adam_mode=1) + group["weight_decay"], + out = (p_data_fp32, exp_avg,exp_avg_sq) + ) if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py index 25332f9b38..d96cd8ca34 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/trainer.py @@ -15,6 +15,8 @@ from itertools import chain from typing import Any, Dict, List import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.distributed as dist from fairseq import checkpoint_utils, distributed_utils, models, optim, utils from fairseq.file_io import PathManager -- Gitee From e87ab752e6ea73442b1aaba09200d9ef8f785427 Mon Sep 17 00:00:00 2001 From: junqiang521 Date: Thu, 24 Mar 2022 09:45:52 +0800 Subject: [PATCH 2/2] =?UTF-8?q?npu=5Fbert=5Fapply=5Fadam=20=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py | 1 + .../built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py index f23f0c9647..8574079ddb 100644 --- a/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py +++ b/PyTorch/built-in/nlp/XLM_ID0740_for_PyTorch/xlm/optim.py @@ -65,6 +65,7 @@ class NpuFusedAdamV2(NpuFusedAdam): combined_param.data, exp_avg, exp_avg_sq = torch.npu_bert_apply_adam( group['lr'], beta1, beta2, group['eps'], combined_grad, 0.0, 0.0, group['weight_decay'], + step_size, adam_mode=1, out = (combined_param.data, exp_avg, exp_avg_sq) ) diff --git a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py index 3515945fd8..91c62ed5cb 100644 --- a/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py +++ b/PyTorch/built-in/nlp/mBART_ID2372_for_PyTorch/fairseq/optim/adam.py @@ -204,7 +204,7 @@ class Adam(torch.optim.Optimizer): p_data_fp32, exp_avg, exp_avg_sq = \ torch.npu_bert_apply_adam(group["lr"], beta1, beta2, group["eps"], grad, 0.0, 0.0, - group["weight_decay"], + group["weight_decay"], step_size, adam_mode=1, out = (p_data_fp32, exp_avg,exp_avg_sq) ) -- Gitee