From b6f9aa9e3bb32e6dbf822843421aad56ca510875 Mon Sep 17 00:00:00 2001 From: xinxinL Date: Mon, 11 Dec 2023 08:28:39 +0000 Subject: [PATCH 1/2] =?UTF-8?q?update=20PyTorch/built-in/foundation/GPT-Ne?= =?UTF-8?q?oX/requirements/requirements.txt.=20=E4=BE=9D=E8=B5=96=E5=BA=93?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E8=A1=A5=E5=85=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/foundation/GPT-NeoX/requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorch/built-in/foundation/GPT-NeoX/requirements/requirements.txt b/PyTorch/built-in/foundation/GPT-NeoX/requirements/requirements.txt index b1930922cf..9058a148c3 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/requirements/requirements.txt +++ b/PyTorch/built-in/foundation/GPT-NeoX/requirements/requirements.txt @@ -5,7 +5,7 @@ protobuf==3.20.3 pybind11>=2.6.2 regex sentencepiece -best_download +best_download==0.0.9 cloudpickle==2.2.1 decorator==5.1.1 psutil==5.9.5 -- Gitee From 2e844c8288e3db2f15d9082a4609a1c70dbe375e Mon Sep 17 00:00:00 2001 From: xinxinL Date: Wed, 13 Dec 2023 02:28:23 +0000 Subject: [PATCH 2/2] update PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py. FA Signed-off-by: xinxinL --- .../GPT-NeoX/megatron/model/transformer.py | 78 ++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py index 8392c92f2f..528eb76554 100644 --- a/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py +++ b/PyTorch/built-in/foundation/GPT-NeoX/megatron/model/transformer.py @@ -40,6 +40,9 @@ from megatron.model.fused_bias_dropout import ( ) from megatron.model.utils import configure_sparse_attention +import torch_npu +from einops import rearrange + # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) @@ -175,6 +178,58 @@ class ParallelLinear(nn.Module): def forward(self, hidden_states): return self.final_linear(hidden_states) +class FlashSelfAttention(torch.nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=True, seq_len=4096, softmax_scale=1.0, attention_dropout=0., device=None, dtype=None, layerout="SBH"): + super().__init__() + self.causal = causal + self.seq_len = seq_len + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + self.layerout = layerout + + def forward(self, q, k, v, n, attention_mask): + attention_mask_shape = attention_mask.shape + #if attention_mask_shape[0] == 1: + # attention_mask = attention_mask.view((attention_mask_shape[-2], attention_mask_shape[-1])) + # print(f'q : {q.shape} k: {k.shape} v: {v.shape}, mask {attention_mask.shape}') + + if self.causal: + output = torch_npu.npu_fusion_attention( + q, k, v, n, self.layerout, + pse=None, + padding_mask=None, + atten_mask=attention_mask, + scale=self.softmax_scale, + pre_tockens=self.seq_len, + next_tockens=0, + keep_prob=1 - self.dropout_p, + )[0] + else: + output = torch_npu.npu_fusion_attention( + q, k, v, n, self.layerout, + pse=None, + padding_mask=None, + atten_mask=attention_mask, + scale=self.softmax_scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1 - self.dropout_p, + )[0] + return output + + + + class ParallelSelfAttention(nn.Module): """Parallel self-attention layer abstract class. @@ -269,7 +324,9 @@ class ParallelSelfAttention(nn.Module): self.rotary_emb = None self.attention_type = neox_args.attention_config[layer_number] + self.use_npu_flash_attn = False self.use_flash_attention = self.attention_type == "flash" + self.layerout = "BSND" self.sparse = self.attention_type not in ("global", "flash") self.sparse = self.attention_type != "global" and not self.use_flash_attention if self.sparse: @@ -288,7 +345,9 @@ class ParallelSelfAttention(nn.Module): from megatron.model.flash_attention import ( flash_attn_unpadded_qkvpacked_func, ) - + elif self.use_npu_flash_attn: + self.npu_flash_attention = FlashSelfAttention(causal=True, seq_len=2048, softmax_scale=(1.0 / self.norm_factor), + attention_dropout=0, layerout=self.layerout) else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, @@ -318,10 +377,25 @@ class ParallelSelfAttention(nn.Module): self.checkpoint_activations = neox_args.checkpoint_activations self.checkpoint_selective = neox_args.checkpoint_selective + print(f'trian: {self.use_triangle_attn} flash {self.use_npu_flash_attn}') def attention( self, query_layer, key_layer, value_layer, layer_past, attention_mask ): - if self.use_triangle_attn and layer_past is None and query_layer.size( + if self.use_npu_flash_attn: + # context layer shape: [b, np, sq, hn] + + output_size = (value_layer.size(1), + value_layer.size(2), + query_layer.size(0), + value_layer.size(3)) + + sequence_len, bsz, head_num, head_dim = query_layer.shape + + query_layer, key_layer, value_layer = [rearrange(x, 's b n d -> b s n d').contiguous() for x in (query_layer, key_layer, value_layer)] + attn_output = self.npu_flash_attention(query_layer, key_layer, value_layer, head_num, attention_mask) + context_layer = attn_output.view(*output_size) + return context_layer + elif self.use_triangle_attn and layer_past is None and query_layer.size( 0) >= self.block_size * 2 and query_layer.size(0) % self.block_size == 0: # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), -- Gitee