diff --git a/MindIE/MindIE-Torch/built-in/MT5/MT5_modeling_patch.py b/MindIE/MindIE-Torch/built-in/MT5/MT5_modeling_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..35a6ec861384f827a00ab99fb20ac0f24f8bc65d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/MT5/MT5_modeling_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import transformers + + +def main(): + transformers_path = transformers.__path__ + transformers_version = transformers.__version__ + + assert transformers_version =='4.42.0', "expectation transformers==4.42.0" + os.system(f'patch -p0 {transformers_path[0]}/models/mt5/modeling_mt5.py modeling_mt5.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MindIE-Torch/built-in/MT5/export_mt5.py b/MindIE/MindIE-Torch/built-in/MT5/export_mt5.py new file mode 100644 index 0000000000000000000000000000000000000000..138728fc16ba0360c6efa07508fc78f6731f5d9f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/MT5/export_mt5.py @@ -0,0 +1,192 @@ + +import torch +import torch_npu +import argparse +import os +import math +import mindietorch +from transformers import MT5ForConditionalGeneration + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="save dir" + ) + parser.add_argument( + "--model_path", + type=str, + default="./MT5-Small", + help="T5 model path" + ) + parser.add_argument( + "--max_batchsize", + type=int, + default=1, + help="max batchsize when running" + ) + + parser.add_argument( + "--max_input_seq_len", + type=int, + default=256, + help="max input_sequence length when running" + ) + + + parser.add_argument( + "--device_id", + type=int, + default=0, + help="npu device id" + ) + return parser.parse_args() + + +class TextEncoderExport(torch.nn.Module): + def __init__(self, textencoder_model): + super(TextEncoderExport, self).__init__() + self.textencoder_model = textencoder_model + + def forward(self, input_ids): + return self.textencoder_model(input_ids=input_ids) + +class TextDecoderExport(torch.nn.Module): + def __init__(self, textdecoder_model): + super(TextDecoderExport, self).__init__() + self.textdecoder_model = textdecoder_model + + def forward(self, + *args): + return self.textdecoder_model(*args) + +def export_textencoder(args, model, save_dir, batch_size): + encoder_path = os.path.join(save_dir, "encoder") + if not os.path.exists(encoder_path): + os.makedirs(encoder_path, mode=0o640) + traced_path = os.path.join(encoder_path, "encoder.pt") + compiled_path = os.path.join(encoder_path, "encoder_compiled.pt") + if not os.path.exists(traced_path): + text_encoder = model.encoder + dummy_input = ( + torch.ones([1, 128], dtype=torch.int64).npu() + ) + encoder = TextEncoderExport(text_encoder) + encoder.eval() + torch.jit.trace(encoder, dummy_input, strict=False).save(traced_path) + if not os.path.exists(compiled_path): + traced_model = torch.jit.load(traced_path).eval() + + inputs0 = [] + inputs0.append(mindietorch.Input(min_shape = (1,1), max_shape= (args.max_batchsize, args.max_input_seq_len), dtype=torch.int64)) + print("compiling encoder") + compiled_model = mindietorch.compile( + traced_model, + inputs=inputs0, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version="Ascend910B4", + optimization_level=0 + ) + compiled_model.save(compiled_path) + +def export_textdecoder(args, model, save_dir, batch_size): + decoder_path = os.path.join(save_dir, "decoder") + if not os.path.exists(decoder_path): + os.makedirs(decoder_path, mode=0o640) + traced_path = os.path.join(decoder_path, "decoder.pt") + compiled_path = os.path.join(decoder_path, "decoder_compiled.pt") + model_path = args.model_path + max_lenth = 120 + if not os.path.exists(traced_path): + text_decoder = model + all_past_keys = [torch.randn([1, model.config.num_heads, 1, model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers + all_past_values = [torch.randn([1, model.config.num_heads, 1, model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers + all_past_cross_keys = [torch.randn([1, 16, model.config.num_heads * model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers + all_past_cross_values = [torch.randn([1, 16, model.config.num_heads * model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers + dummy_input = [torch.randn(1, 16, model.config.d_model).to(torch.float16).npu()] + dummy_input.extend(all_past_cross_keys) + dummy_input.extend(all_past_cross_values) + dummy_input.extend(all_past_keys) + dummy_input.extend(all_past_values) + dummy_input.append(torch.ones(1,16).npu()) + dummy_input.append(torch.ones([1, 1], dtype=torch.int64).npu()) + decoder = TextDecoderExport(text_decoder).npu() + decoder.eval() + torch.jit.trace(decoder, dummy_input,strict=False).save(traced_path) + if not os.path.exists(compiled_path): + traced_model = torch.jit.load(traced_path).eval() + print("compiling decoder") + input_info = [mindietorch.Input(min_shape =(1, 1, model.config.d_model), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model), + dtype=mindietorch.dtype.FLOAT16)] + past_cross_key_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_model), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_cross_value_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_model), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_key_infos = [mindietorch.Input(min_shape =(1, model.config.num_heads, 0, model.config.d_kv), + max_shape=(args.max_batchsize, model.config.num_heads, args.max_input_seq_len, model.config.d_kv), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_value_infos = [mindietorch.Input(min_shape =(1, model.config.num_heads, 0, model.config.d_kv), + max_shape=(args.max_batchsize, model.config.num_heads, args.max_input_seq_len, model.config.d_kv), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + decoder_input_ids_info = [mindietorch.Input(min_shape =(1, 1), + max_shape = (args.max_batchsize,1), + dtype=mindietorch.dtype.INT64)] + encoder_attention_mask_info = [mindietorch.Input(min_shape =(1, 1), + max_shape = (args.max_batchsize,args.max_input_seq_len), + dtype=mindietorch.dtype.INT64)] + input_info.extend(past_cross_key_infos) + input_info.extend(past_cross_value_infos) + input_info.extend(past_key_infos) + input_info.extend(past_value_infos) + input_info.extend(encoder_attention_mask_info) + input_info.extend(decoder_input_ids_info) + buffer = [] + for _ in range(2*model.config.num_layers): + buffer.append(math.ceil((args.max_batchsize * args.max_input_seq_len * model.config.d_model * 2) / 1024 / 1024)) + buffer_size0 = math.ceil((args.max_batchsize * 1 * model.config.vocab_size * 4) / 1024 / 1024) + buffer.append(buffer_size0) + print("buffer=",buffer) + compiled_model = mindietorch.compile( + traced_model, + inputs=input_info, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version="Ascend910B4", + default_buffer_size_vec=buffer, + optimization_level=0 + ) + compiled_model.save(compiled_path) + + +def main(): + args = parse_arguments() + device_id = args.device_id + save_dir = args.output_dir + torch.npu.set_device(device_id) + batch_size = 1 + model = MT5ForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.float).npu() + encoder_path = os.path.join(save_dir, "encoder") + compiled_path = os.path.join(encoder_path, "encoder_compiled.pt") + if not os.path.exists(compiled_path): + export_textencoder(args, model, save_dir, batch_size) + print("export encoder_model done!") + + decoder_path = os.path.join(save_dir, "decoder") + compiled_path = os.path.join(decoder_path, "decoder_compiled.pt") + if not os.path.exists(compiled_path): + export_textdecoder(args, model, save_dir, batch_size) + print("export decoder_model done!") + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/MT5/modeling_mt5.patch b/MindIE/MindIE-Torch/built-in/MT5/modeling_mt5.patch new file mode 100644 index 0000000000000000000000000000000000000000..a5afef98e2a5de7d41dc57313605becde8835002 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/MT5/modeling_mt5.patch @@ -0,0 +1,1557 @@ +diff --git a/modeling_mt5.py b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/mt5/modeling_mt5.py +index 1336b9196..5b94d69c7 100644 +--- a/modeling_mt5.py ++++ b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/mt5/modeling_mt5.py +@@ -19,22 +19,26 @@ import math + import os + import warnings + from typing import List, Optional, Tuple, Union +- ++from dataclasses import dataclass + import torch + from torch import nn + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ++# import torch_npu ++import mindietorch ++ ++ ++ + + from ...activations import ACT2FN + from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, +- Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, + ) +-from ...modeling_utils import PreTrainedModel ++from ...modeling_utils import PreTrainedModel,ModuleUtilsMixin + from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer + from ...utils import ( + DUMMY_INPUTS, +@@ -47,8 +51,44 @@ from ...utils import ( + ) + from ...utils.model_parallel_utils import assert_device_map, get_device_map + from .configuration_mt5 import MT5Config ++from transformers.generation.logits_process import LogitsProcessorList ++from transformers.generation.stopping_criteria import StoppingCriteriaList ++from transformers.generation.configuration_utils import GenerationMode ++from transformers.utils.generic import ModelOutput + + ++@dataclass ++class Seq2SeqLMOutput(ModelOutput): ++ """ ++ Base class for model's outputs, with potential hidden states and attentions. ++ ++ Args: ++ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): ++ Sequence of hidden-states at the output of the last layer of the model. ++ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): ++ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + ++ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. ++ ++ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. ++ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): ++ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, ++ sequence_length)`. ++ ++ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention ++ heads. ++ """ ++ loss: Optional[torch.FloatTensor] = None ++ logits: torch.FloatTensor = None ++ past_keys: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ past_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None ++ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None ++ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None ++ encoder_last_hidden_state: Optional[torch.FloatTensor] = None ++ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None ++ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None ++ + logger = logging.get_logger(__name__) + + _CONFIG_FOR_DOC = "MT5Config" +@@ -323,7 +363,10 @@ class MT5Attention(nn.Module): + mask=None, + key_value_states=None, + position_bias=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, ++ past_cross_key=None, ++ past_cross_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, +@@ -339,17 +382,15 @@ class MT5Attention(nn.Module): + + real_seq_length = seq_length + +- if past_key_value is not None: +- if len(past_key_value) != 2: +- raise ValueError( +- f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" +- ) +- real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length ++ if past_key is not None: ++ real_seq_length += past_key.shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" ++ # import pdb ++ # pdb.set_trace() + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): +@@ -368,16 +409,17 @@ class MT5Attention(nn.Module): + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: ++ past_key_value = shape(past_key_value) + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) +- elif past_key_value.shape[2] != key_value_states.shape[1]: +- # checking that the `sequence_length` of the `past_key_value` is the same as +- # the provided `key_value_states` to support prefix tuning +- # cross-attn +- # (batch_size, n_heads, seq_length, dim_per_head) +- hidden_states = shape(proj_layer(key_value_states)) ++ # elif past_key_value.shape[2] != key_value_states.shape[1]: ++ # # checking that the `sequence_length` of the `past_key_value` is the same as ++ # # the provided `key_value_states` to support prefix tuning ++ # # cross-attn ++ # # (batch_size, n_heads, seq_length, dim_per_head) ++ # hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value +@@ -388,10 +430,10 @@ class MT5Attention(nn.Module): + + # get key/value states + key_states = project( +- hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ++ hidden_states, self.k, key_value_states, past_key if past_key is not None else None + ) + value_states = project( +- hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None ++ hidden_states, self.v, key_value_states, past_value if past_value is not None else None + ) + + # compute scores +@@ -411,7 +453,7 @@ class MT5Attention(nn.Module): + + # if key and values are already calculated + # we want only the last query position bias +- if past_key_value is not None: ++ if past_key is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: +@@ -439,14 +481,124 @@ class MT5Attention(nn.Module): + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + +- present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None +- outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) +- ++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None ++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None ++ present_value_state = (value_states.half(),) if (self.is_decoder and use_cache) else None ++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,) ++ + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + ++class MT5SelfAttention(MT5Attention): ++ def __init__(self, config: MT5Config, has_relative_attention_bias=False): ++ super().__init__(config, has_relative_attention_bias) ++ ++ def forward( ++ self, ++ hidden_states, ++ mask=None, ++ position_bias=None, ++ past_key=None, ++ past_value=None, ++ layer_head_mask=None, ++ use_cache=False, ++ output_attentions=False, ++ ): ++ """ ++ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). ++ """ ++ # Input is (batch_size, seq_length, dim) ++ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) ++ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) ++ batch_size, seq_length = hidden_states.shape[:2] ++ ++ real_seq_length = seq_length ++ ++ if past_key is not None: ++ real_seq_length += past_key.shape[2] ++ key_length = real_seq_length ++ def shape(states): ++ """projection""" ++ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) ++ ++ def unshape(states): ++ """reshape""" ++ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) ++ ++ def project(hidden_states, proj_layer, past_key_value): ++ """projects hidden states correctly to key/query states""" ++ if past_key_value is None: ++ # cross-attn ++ # (batch_size, n_heads, seq_length, dim_per_head) ++ hidden_states = shape(proj_layer(hidden_states)) ++ ++ if past_key_value is not None: ++ hidden_states = shape(proj_layer(hidden_states)) ++ hidden_states = torch.cat([past_key_value, hidden_states], dim=2) ++ return hidden_states ++ ++ # get query states ++ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) ++ ++ # get key/value states ++ key_states = project( ++ hidden_states, self.k, past_key if past_key is not None else None ++ ) ++ value_states = project( ++ hidden_states, self.v, past_value if past_value is not None else None ++ ) ++ # compute scores ++ scores = torch.matmul( ++ query_states, key_states.transpose(3, 2) ++ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 ++ if position_bias is None: ++ if not self.has_relative_attention_bias: ++ position_bias = torch.zeros( ++ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ++ ) ++ if self.gradient_checkpointing and self.training: ++ position_bias.requires_grad = True ++ else: ++ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) ++ ++ # if key and values are already calculated ++ # we want only the last query position bias ++ if past_key is not None: ++ position_bias = position_bias[:, :, -hidden_states.size(1) :, :] ++ if mask is not None: ++ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) ++ ++ if self.pruned_heads: ++ mask = torch.ones(position_bias.shape[1]) ++ mask[list(self.pruned_heads)] = 0 ++ position_bias_masked = position_bias[:, mask.bool()] ++ else: ++ position_bias_masked = position_bias ++ scores += position_bias_masked ++ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( ++ scores ++ ) # (batch_size, n_heads, seq_length, key_length) ++ attn_weights = nn.functional.dropout( ++ attn_weights, p=self.dropout, training=self.training ++ ) # (batch_size, n_heads, seq_length, key_length) ++ ++ # Mask heads if we want to ++ if layer_head_mask is not None: ++ attn_weights = attn_weights * layer_head_mask ++ ++ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) ++ attn_output = self.o(attn_output) ++ ++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None ++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None ++ present_value_state = (value_states.half(), ) if (self.is_decoder and use_cache) else None ++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,) ++ if output_attentions: ++ outputs = outputs + (attn_weights,) ++ return outputs ++ + # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5 + class MT5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): +@@ -461,7 +613,8 @@ class MT5LayerSelfAttention(nn.Module): + attention_mask=None, + position_bias=None, + layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, + use_cache=False, + output_attentions=False, + ): +@@ -471,7 +624,8 @@ class MT5LayerSelfAttention(nn.Module): + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +@@ -495,7 +649,8 @@ class MT5LayerCrossAttention(nn.Module): + attention_mask=None, + position_bias=None, + layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, + use_cache=False, + query_length=None, + output_attentions=False, +@@ -507,7 +662,8 @@ class MT5LayerCrossAttention(nn.Module): + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, +@@ -539,39 +695,34 @@ class MT5Block(nn.Module): + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, ++ past_cross_key=None, ++ past_cross_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): +- if past_key_value is not None: +- if not self.is_decoder: +- logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") +- expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 +- +- if len(past_key_value) != expected_num_past_key_values: +- raise ValueError( +- f"There should be {expected_num_past_key_values} past states. " +- f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" +- f"Got {len(past_key_value)} past key / value states" +- ) +- +- self_attn_past_key_value = past_key_value[:2] +- cross_attn_past_key_value = past_key_value[2:] ++ if past_key is not None: ++ self_attn_past_key = past_key ++ self_attn_past_value = past_value ++ cross_attn_past_key = past_cross_key ++ cross_attn_past_value = past_cross_value + else: +- self_attn_past_key_value, cross_attn_past_key_value = None, None ++ self_attn_past_key, self_attn_past_value, cross_attn_past_key, cross_attn_past_value = None, None, None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=self_attn_past_key_value, ++ past_key=self_attn_past_key, ++ past_value=self_attn_past_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +- hidden_states, present_key_value_state = self_attention_outputs[:2] +- attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights ++ hidden_states, present_key_state, present_value_state = self_attention_outputs[:3] ++ attention_outputs = self_attention_outputs[3:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: +@@ -586,8 +737,8 @@ class MT5Block(nn.Module): + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here +- if present_key_value_state is not None: +- query_length = present_key_value_state[0].shape[2] ++ if present_key_state is not None: ++ query_length = present_key_state[0].shape[2] + else: + query_length = None + +@@ -597,7 +748,8 @@ class MT5Block(nn.Module): + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, +- past_key_value=cross_attn_past_key_value, ++ past_key=cross_attn_past_key, ++ past_value=cross_attn_past_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, +@@ -614,11 +766,9 @@ class MT5Block(nn.Module): + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states +- if present_key_value_state is not None: +- present_key_value_state = present_key_value_state + cross_attention_outputs[1] +- ++ # cross_attn_past_key_values = cross_attention_outputs[1] + # Keep cross-attention outputs and relative position weights +- attention_outputs = attention_outputs + cross_attention_outputs[2:] ++ attention_outputs = attention_outputs + cross_attention_outputs[3:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) +@@ -635,7 +785,7 @@ class MT5Block(nn.Module): + outputs = (hidden_states,) + + if use_cache: +- outputs = outputs + (present_key_value_state,) + attention_outputs ++ outputs = outputs + (present_key_state,) +(present_value_state,)+ attention_outputs + else: + outputs = outputs + attention_outputs + +@@ -884,11 +1034,15 @@ class MT5PreTrainedModel(PreTrainedModel): + + # Copied from transformers.models.t5.modeling_t5.T5Stack with T5->MT5 + class MT5Stack(MT5PreTrainedModel): +- def __init__(self, config, embed_tokens=None): ++ def __init__(self, config, embed_tokens=None,lm_head=None, encodecrosskey=None, encodecrossvalue=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder ++ self.lm_head=lm_head ++ self.encodecrosskey = encodecrosskey ++ self.encodecrossvalue = encodecrossvalue ++ self.model_dim = config.d_model + + self.block = nn.ModuleList( + [MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] +@@ -953,20 +1107,63 @@ class MT5Stack(MT5PreTrainedModel): + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + ++ def invert_attention_mask(self, encoder_attention_mask): ++ if encoder_attention_mask.dim() == 3: ++ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] ++ if encoder_attention_mask.dim() == 2: ++ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] ++ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility ++ ++ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1000 ++ ++ return encoder_extended_attention_mask ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, device=None, dtype=None ++ ): ++ if dtype is None: ++ dtype = self.dtype ++ ++ if not (attention_mask.dim() == 2 and self.config.is_decoder): ++ if device is not None: ++ warnings.warn( ++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ++ ) ++ if attention_mask.dim() == 3: ++ extended_attention_mask = attention_mask[:, None, :, :] ++ elif attention_mask.dim() == 2: ++ if self.config.is_decoder: ++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( ++ input_shape, attention_mask, device ++ ) ++ else: ++ extended_attention_mask = attention_mask[:, None, None, :] ++ else: ++ raise ValueError( ++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ++ ) ++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility ++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000 ++ return extended_attention_mask ++ + def forward( + self, + input_ids=None, +- attention_mask=None, + encoder_hidden_states=None, ++ past_keys=None, ++ past_values=None, ++ past_cross_keys=None, ++ past_cross_values=None, + encoder_attention_mask=None, ++ attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, +- past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, ++ **model_kwargs + ): + # Model parallel + if self.model_parallel: +@@ -985,8 +1182,10 @@ class MT5Stack(MT5PreTrainedModel): + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: ++ + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) ++ input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: +@@ -999,18 +1198,19 @@ class MT5Stack(MT5PreTrainedModel): + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape +- + # required mask seq length can be calculated via length of past +- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length ++ mask_seq_length = past_keys[0].shape[2] + seq_length if past_keys is not None else seq_length + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values with `None` if past does not exist +- if past_key_values is None: +- past_key_values = [None] * len(self.block) +- ++ if not self.is_decoder: ++ past_keys = [None] * len(self.block) ++ past_values = [None] * len(self.block) ++ past_cross_keys = [None] * len(self.block) ++ past_cross_values = [None] * len(self.block) + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + +@@ -1041,7 +1241,8 @@ class MT5Stack(MT5PreTrainedModel): + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) +- present_key_value_states = () if use_cache else None ++ present_key_states = () if use_cache else None ++ present_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None +@@ -1049,8 +1250,8 @@ class MT5Stack(MT5PreTrainedModel): + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) +- +- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): ++ # for i, layer_module in enumerate(self.block): ++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel +@@ -1099,7 +1300,10 @@ class MT5Stack(MT5PreTrainedModel): + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, ++ past_cross_key=past_cross_key, ++ past_cross_value=past_cross_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +@@ -1107,19 +1311,20 @@ class MT5Stack(MT5PreTrainedModel): + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: +- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] ++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:] + +- hidden_states, present_key_value_state = layer_outputs[:2] ++ hidden_states, present_key_state, present_value_state = layer_outputs[:3] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) +- position_bias = layer_outputs[2] ++ position_bias = layer_outputs[3] + if self.is_decoder and encoder_hidden_states is not None: +- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] ++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4] + # append next layer key value states + if use_cache: +- present_key_value_states = present_key_value_states + (present_key_value_state,) ++ present_key_states = present_key_states + present_key_state ++ present_value_states = present_value_states + present_value_state + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) +@@ -1133,7 +1338,7 @@ class MT5Stack(MT5PreTrainedModel): + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) +- hidden_states = self.dropout(hidden_states) ++ hidden_states = self.dropout(hidden_states).half() + + # Add last layer + if output_hidden_states: +@@ -1151,13 +1356,216 @@ class MT5Stack(MT5PreTrainedModel): + ] + if v is not None + ) +- return BaseModelOutputWithPastAndCrossAttentions( +- last_hidden_state=hidden_states, +- past_key_values=present_key_value_states, +- hidden_states=all_hidden_states, +- attentions=all_attentions, +- cross_attentions=all_cross_attentions, ++ if not self.is_decoder: ++ cross_keys = None ++ cross_values = None ++ if self.encodecrosskey: ++ cross_keys = self.encodecrosskey(hidden_states) ++ if self.encodecrossvalue: ++ cross_values = self.encodecrossvalue(hidden_states) ++ return tuple((hidden_states, cross_keys, cross_values)) ++ lm_logits = None ++ if self.is_decoder: ++ if self.config.tie_word_embeddings: ++ hidden_states = hidden_states * (self.model_dim ** -0.5) ++ lm_logits = self.lm_head(hidden_states) ++ return tuple((lm_logits, present_key_states, present_value_states)) ++ ++ ++class MT5Stack_Encoder(MT5PreTrainedModel): ++ def __init__(self, config, embed_tokens=None, encodecrosskey=None, encodecrossvalue=None): ++ super().__init__(config) ++ self.embed_tokens = embed_tokens ++ self.is_decoder = config.is_decoder ++ self.encodecrosskey = encodecrosskey ++ self.encodecrossvalue = encodecrossvalue ++ self.model_dim = config.d_model ++ ++ self.block = nn.ModuleList( ++ [MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) ++ self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) ++ self.dropout = nn.Dropout(config.dropout_rate) ++ ++ # Initialize weights and apply final processing ++ self.post_init() ++ # Model parallel ++ self.model_parallel = False ++ self.device_map = None ++ self.gradient_checkpointing = False ++ ++ def get_input_embeddings(self): ++ return self.embed_tokens ++ ++ def set_input_embeddings(self, new_embeddings): ++ self.embed_tokens = new_embeddings ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, device=None, dtype=None ++ ): ++ if dtype is None: ++ dtype = self.dtype ++ ++ if not (attention_mask.dim() == 2 and self.config.is_decoder): ++ if device is not None: ++ warnings.warn( ++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ++ ) ++ if attention_mask.dim() == 3: ++ extended_attention_mask = attention_mask[:, None, :, :] ++ elif attention_mask.dim() == 2: ++ if self.config.is_decoder: ++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( ++ input_shape, attention_mask, device ++ ) ++ else: ++ extended_attention_mask = attention_mask[:, None, None, :] ++ else: ++ raise ValueError( ++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ++ ) ++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility ++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000 ++ return extended_attention_mask ++ ++ def forward( ++ self, ++ input_ids=None, ++ attention_mask=None, ++ head_mask=None, ++ cross_attn_head_mask=None, ++ use_cache=None, ++ output_attentions=None, ++ output_hidden_states=None, ++ return_dict=None, ++ **model_kwargs ++ ): ++ # Model parallel ++ use_cache = use_cache if use_cache is not None else self.config.use_cache ++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions ++ output_hidden_states = ( ++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ++ ) ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ input_shape = input_ids.size() ++ input_ids = input_ids.view(-1, input_shape[-1]) ++ ++ inputs_embeds = self.embed_tokens(input_ids) ++ ++ batch_size, seq_length = input_shape ++ # required mask seq length can be calculated via length of past ++ mask_seq_length = seq_length ++ ++ # initialize past_key_values with `None` if past does not exist ++ past_keys = [None] * len(self.block) ++ past_values = [None] * len(self.block) ++ past_cross_keys = [None] * len(self.block) ++ past_cross_values = [None] * len(self.block) ++ if attention_mask is None: ++ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) ++ ++ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] ++ # ourselves in which case we just need to make it broadcastable to all heads. ++ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) ++ ++ # If a 2D or 3D attention mask is provided for the cross-attention ++ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] ++ ++ encoder_extended_attention_mask = None ++ ++ # Prepare head mask if needed ++ head_mask = self.get_head_mask(head_mask, self.config.num_layers) ++ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) ++ present_key_states = () if use_cache else None ++ present_value_states = () if use_cache else None ++ all_hidden_states = () if output_hidden_states else None ++ all_attentions = () if output_attentions else None ++ all_cross_attentions = () if (output_attentions and self.is_decoder) else None ++ position_bias = None ++ encoder_decoder_position_bias = None ++ ++ hidden_states = self.dropout(inputs_embeds) ++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)): ++ layer_head_mask = head_mask[i] ++ cross_attn_layer_head_mask = cross_attn_head_mask[i] ++ if output_hidden_states: ++ all_hidden_states = all_hidden_states + (hidden_states,) ++ ++ layer_outputs = layer_module( ++ hidden_states, ++ attention_mask=extended_attention_mask, ++ position_bias=position_bias, ++ encoder_hidden_states=None, ++ encoder_attention_mask=encoder_extended_attention_mask, ++ encoder_decoder_position_bias=encoder_decoder_position_bias, ++ layer_head_mask=layer_head_mask, ++ cross_attn_layer_head_mask=cross_attn_layer_head_mask, ++ past_key=past_key, ++ past_value=past_value, ++ past_cross_key=past_cross_key, ++ past_cross_value=past_cross_value, ++ use_cache=use_cache, ++ output_attentions=output_attentions, ++ ) ++ ++ # layer_outputs is a tuple with: ++ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) ++ if use_cache is False: ++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:] ++ ++ hidden_states, present_key_state, present_value_state = layer_outputs[:3] ++ ++ # We share the position biases between the layers - the first layer store them ++ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), ++ # (cross-attention position bias), (cross-attention weights) ++ position_bias = layer_outputs[3] ++ if self.is_decoder and encoder_hidden_states is not None: ++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4] ++ # append next layer key value states ++ if use_cache: ++ present_key_states = present_key_states + present_key_state ++ present_value_states = present_value_states + present_value_state ++ ++ if output_attentions: ++ all_attentions = all_attentions + (layer_outputs[3],) ++ if self.is_decoder: ++ all_cross_attentions = all_cross_attentions + (layer_outputs[5],) ++ ++ # Model Parallel: If it's the last layer for that device, put things on the next device ++ if self.model_parallel: ++ for k, v in self.device_map.items(): ++ if i == v[-1] and "cuda:" + str(k) != self.last_device: ++ hidden_states = hidden_states.to("cuda:" + str(k + 1)) ++ ++ hidden_states = self.final_layer_norm(hidden_states) ++ hidden_states = self.dropout(hidden_states).half() ++ ++ # Add last layer ++ if output_hidden_states: ++ all_hidden_states = all_hidden_states + (hidden_states,) ++ ++ if not return_dict: ++ return tuple( ++ v ++ for v in [ ++ hidden_states, ++ present_key_value_states, ++ all_hidden_states, ++ all_attentions, ++ all_cross_attentions, ++ ] ++ if v is not None ++ ) ++ # present_key_value_states = torch.concat(present_key_value_states).reshape(len(self.block),2,*present_key_value_states[0].shape).half() if use_cache else None ++ if not self.is_decoder: ++ cross_keys = None ++ cross_values = None ++ if self.encodecrosskey: ++ cross_keys = self.encodecrosskey(hidden_states) ++ if self.encodecrossvalue: ++ cross_values = self.encodecrossvalue(hidden_states) ++ return tuple((hidden_states, cross_keys, cross_values)) + + + MT5_START_DOCSTRING = r""" +@@ -1549,6 +1957,39 @@ class MT5Model(MT5PreTrainedModel): + ) + + ++class EncoderToCrossKey(nn.Module): ++ def __init__(self, cross_key, num_heads, d_kv): ++ super().__init__() ++ self.cross_key = cross_key ++ self.num_heads = num_heads ++ self.d_kv = d_kv ++ ++ ++ def forward(self, hidden_states): ++ batch_size = hidden_states.shape[0] ++ past_cross_keys = () ++ for i in range(len(self.cross_key)): ++ past_cross_keys += (self.cross_key[i](hidden_states),) ++ # import pdb ++ # pdb.set_trace() ++ return past_cross_keys ++ ++ ++class EncoderToCrossValue(nn.Module): ++ def __init__(self, cross_value, num_heads, d_kv): ++ super().__init__() ++ self.cross_value = cross_value ++ self.num_heads = num_heads ++ self.d_kv = d_kv ++ ++ ++ def forward(self, hidden_states): ++ batch_size = hidden_states.shape[0] ++ past_cross_values = () ++ for i in range(len(self.cross_value)): ++ past_cross_values += (self.cross_value[i](hidden_states),) ++ return past_cross_values ++ + @add_start_docstrings("""MT5 Model with a `language modeling` head on top.""", MT5_START_DOCSTRING) + class MT5ForConditionalGeneration(MT5PreTrainedModel): + r""" +@@ -1573,33 +2014,52 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5 +- def __init__(self, config: MT5Config): ++ def __init__(self, config: MT5Config, encoder_path=None, decoder_path=None, device_id=0): + super().__init__(config) +- self.model_dim = config.d_model +- +- self.shared = nn.Embedding(config.vocab_size, config.d_model) +- +- encoder_config = copy.deepcopy(config) +- encoder_config.is_decoder = False +- encoder_config.use_cache = False +- encoder_config.is_encoder_decoder = False +- self.encoder = MT5Stack(encoder_config, self.shared) +- +- decoder_config = copy.deepcopy(config) +- decoder_config.is_decoder = True +- decoder_config.is_encoder_decoder = False +- decoder_config.num_layers = config.num_decoder_layers +- self.decoder = MT5Stack(decoder_config, self.shared) +- +- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) ++ self.encoder_path = encoder_path ++ self.decoder_path = decoder_path ++ self.is_mindie = False ++ if not self.encoder_path or not self.decoder_path: ++ self.model_dim = config.d_model ++ ++ self.shared = nn.Embedding(config.vocab_size, config.d_model) ++ ++ decoder_config = copy.deepcopy(config) ++ decoder_config.is_decoder = True ++ decoder_config.is_encoder_decoder = False ++ decoder_config.num_layers = config.num_decoder_layers ++ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) ++ self.decoder = MT5Stack(decoder_config, self.shared, self.lm_head) ++ cross_key = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.k for i in range(config.num_decoder_layers)) ++ cross_value = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.v for i in range(config.num_decoder_layers)) ++ encodecrosskey = EncoderToCrossKey(cross_key, config.num_heads, config.d_kv) ++ encodecrossvalue = EncoderToCrossValue(cross_value, config.num_heads, config.d_kv) ++ encoder_config = copy.deepcopy(config) ++ encoder_config.is_decoder = False ++ encoder_config.use_cache = False ++ encoder_config.is_encoder_decoder = False ++ self.encoder = MT5Stack_Encoder(encoder_config, self.shared, encodecrosskey=encodecrosskey, encodecrossvalue=encodecrossvalue) ++ self.encoder_mindie = None ++ self.decoder_mindie = None ++ if self.encoder_path: ++ self.encoder_mindie = torch.jit.load(self.encoder_path) ++ self.is_mindie = True ++ if self.decoder_path: ++ self.decoder_mindie = torch.jit.load(self.decoder_path) ++ self.stream = torch.npu.Stream(f"npu:{device_id}") ++ self.device_id = device_id + + # Initialize weights and apply final processing +- self.post_init() ++ if not self.is_mindie: ++ self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + ++ def get_device(self): ++ return f"npu:{self.device_id}" ++ + @add_start_docstrings(PARALLELIZE_DOCSTRING) + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize + def parallelize(self, device_map=None): +@@ -1666,25 +2126,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): + @add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward with T5->MT5, t5->mt5 +- def forward( +- self, +- input_ids: Optional[torch.LongTensor] = None, +- attention_mask: Optional[torch.FloatTensor] = None, +- decoder_input_ids: Optional[torch.LongTensor] = None, +- decoder_attention_mask: Optional[torch.BoolTensor] = None, +- head_mask: Optional[torch.FloatTensor] = None, +- decoder_head_mask: Optional[torch.FloatTensor] = None, +- cross_attn_head_mask: Optional[torch.Tensor] = None, +- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, +- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, +- inputs_embeds: Optional[torch.FloatTensor] = None, +- decoder_inputs_embeds: Optional[torch.FloatTensor] = None, +- labels: Optional[torch.LongTensor] = None, +- use_cache: Optional[bool] = None, +- output_attentions: Optional[bool] = None, +- output_hidden_states: Optional[bool] = None, +- return_dict: Optional[bool] = None, +- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: ++ def forward(self,*args) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., +@@ -1716,114 +2158,37 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" +- use_cache = use_cache if use_cache is not None else self.config.use_cache +- return_dict = return_dict if return_dict is not None else self.config.use_return_dict +- +- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +- if head_mask is not None and decoder_head_mask is None: +- if self.config.num_layers == self.config.num_decoder_layers: +- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) +- decoder_head_mask = head_mask +- +- # Encode if needed (training, first prediction pass) +- if encoder_outputs is None: +- # Convert encoder inputs in embeddings if needed +- encoder_outputs = self.encoder( +- input_ids=input_ids, +- attention_mask=attention_mask, +- inputs_embeds=inputs_embeds, +- head_mask=head_mask, +- output_attentions=output_attentions, +- output_hidden_states=output_hidden_states, +- return_dict=return_dict, +- ) +- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): +- encoder_outputs = BaseModelOutput( +- last_hidden_state=encoder_outputs[0], +- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, +- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, +- ) +- +- hidden_states = encoder_outputs[0] +- +- if self.model_parallel: +- torch.cuda.set_device(self.decoder.first_device) +- +- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: +- # get decoder inputs from shifting lm labels to the right +- decoder_input_ids = self._shift_right(labels) +- +- # Set device for model parallelism +- if self.model_parallel: +- torch.cuda.set_device(self.decoder.first_device) +- hidden_states = hidden_states.to(self.decoder.first_device) +- if decoder_input_ids is not None: +- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) +- if attention_mask is not None: +- attention_mask = attention_mask.to(self.decoder.first_device) +- if decoder_attention_mask is not None: +- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) +- +- # Decode +- decoder_outputs = self.decoder( +- input_ids=decoder_input_ids, +- attention_mask=decoder_attention_mask, +- inputs_embeds=decoder_inputs_embeds, +- past_key_values=past_key_values, +- encoder_hidden_states=hidden_states, +- encoder_attention_mask=attention_mask, +- head_mask=decoder_head_mask, +- cross_attn_head_mask=cross_attn_head_mask, +- use_cache=use_cache, +- output_attentions=output_attentions, +- output_hidden_states=output_hidden_states, +- return_dict=return_dict, +- ) +- +- sequence_output = decoder_outputs[0] +- +- # Set device for model parallelism +- if self.model_parallel: +- torch.cuda.set_device(self.encoder.first_device) +- self.lm_head = self.lm_head.to(self.encoder.first_device) +- sequence_output = sequence_output.to(self.lm_head.weight.device) +- +- if self.config.tie_word_embeddings: +- # Rescale output before projecting on vocab +- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 +- sequence_output = sequence_output * (self.model_dim**-0.5) +- +- lm_logits = self.lm_head(sequence_output) ++ if self.is_mindie: ++ with torch.npu.stream(self.stream): # set stream ++ decoder_outputs = self.decoder_mindie.forward(*args) ++ self.stream.synchronize() # synchronize ++ else: ++ hidden_states = args[0] ++ past_cross_keys = args[1:self.config.num_decoder_layers+1] ++ past_cross_values = args[self.config.num_decoder_layers+1:2*self.config.num_decoder_layers+1] ++ past_keys= args[2*self.config.num_decoder_layers+1:3*self.config.num_decoder_layers+1] ++ past_values= args[3*self.config.num_decoder_layers+1:4*self.config.num_decoder_layers+1] ++ encoder_attention_mask = args[-2] ++ decoder_input_ids = args[-1] ++ decoder_outputs = self.decoder(input_ids=decoder_input_ids, ++ encoder_hidden_states=hidden_states, ++ past_keys=past_keys, ++ past_values=past_values, ++ past_cross_keys=past_cross_keys, ++ past_cross_values=past_cross_values, ++ encoder_attention_mask=encoder_attention_mask) ++ + + loss = None +- if labels is not None: +- loss_fct = CrossEntropyLoss(ignore_index=-100) +- # move labels to correct device to enable PP +- labels = labels.to(lm_logits.device) +- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) +- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 +- +- if not return_dict: +- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs +- return ((loss,) + output) if loss is not None else output +- +- return Seq2SeqLMOutput( +- loss=loss, +- logits=lm_logits, +- past_key_values=decoder_outputs.past_key_values, +- decoder_hidden_states=decoder_outputs.hidden_states, +- decoder_attentions=decoder_outputs.attentions, +- cross_attentions=decoder_outputs.cross_attentions, +- encoder_last_hidden_state=encoder_outputs.last_hidden_state, +- encoder_hidden_states=encoder_outputs.hidden_states, +- encoder_attentions=encoder_outputs.attentions, +- ) ++ return (decoder_outputs[0],decoder_outputs[1],decoder_outputs[2]) + +- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, +- past_key_values=None, ++ past_cross_keys=None, ++ past_cross_values=None, ++ past_keys=None, ++ past_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, +@@ -1834,8 +2199,8 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used +- if past_key_values is not None: +- past_length = past_key_values[0][0].shape[2] ++ if past_keys is not None: ++ past_length = past_keys[0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: +@@ -1848,7 +2213,10 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): + + return { + "decoder_input_ids": input_ids, +- "past_key_values": past_key_values, ++ "past_cross_keys":past_cross_keys, ++ "past_cross_values":past_cross_values, ++ "past_keys":past_keys, ++ "past_values":past_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, +@@ -1893,6 +2261,419 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + ++ def _prepare_encoder_decoder_kwargs_for_generation( ++ self, ++ inputs_tensor: torch.Tensor, ++ model_kwargs, ++ model_input_name, ++ generation_config, ++ ): ++ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] ++ encoder_kwargs = { ++ argument: value ++ for argument, value in model_kwargs.items() ++ if not any(argument.startswith(p) for p in irrelevant_prefix) ++ } ++ encoder_kwargs["output_attentions"] = generation_config.output_attentions ++ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states ++ model_input_name = model_input_name if model_input_name is not None else self.main_input_name ++ encoder_kwargs["return_dict"] = True ++ encoder_kwargs[model_input_name] = inputs_tensor ++ import time ++ start_time = time.time() ++ with torch.npu.stream(self.stream): # set stream ++ encoder_outputs=self.encoder_mindie.forward(encoder_kwargs["input_ids"]) ++ self.stream.synchronize() # synchronize ++ model_kwargs["encoder_outputs"]={"last_hidden_state":encoder_outputs[0]} ++ model_kwargs["past_cross_keys"] = encoder_outputs[1] ++ model_kwargs["past_cross_values"] =encoder_outputs[2] ++ return model_kwargs ++ ++ def _update_model_kwargs_for_generation( ++ self, ++ outputs, ++ model_kwargs, ++ is_encoder_decoder = False, ++ standardize_cache_format = False, ++ num_new_tokens = 1, ++ ): ++ # update past_key_values keeping its naming used in model code ++ cache_name, cache = self._extract_past_from_model_output( ++ outputs, standardize_cache_format=standardize_cache_format ++ ) ++ model_kwargs[cache_name] = cache ++ if "past_keys" in outputs: ++ past_keys = outputs.past_keys ++ model_kwargs["past_keys"] = past_keys ++ if "past_values" in outputs: ++ past_values = outputs.past_values ++ model_kwargs["past_values"] = past_values ++ # update decoder attention mask ++ if "decoder_attention_mask" in model_kwargs: ++ decoder_attention_mask = model_kwargs["decoder_attention_mask"] ++ model_kwargs["decoder_attention_mask"] = torch.cat( ++ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], ++ dim=-1, ++ ) ++ return model_kwargs ++ ++ @torch.no_grad() ++ def generate( ++ self, ++ inputs = None, ++ generation_config = None, ++ logits_processor = None, ++ stopping_criteria = None, ++ prefix_allowed_tokens_fn = None, ++ assistant_model = None, ++ negative_prompt_ids = None, ++ negative_prompt_attention_mask = None, ++ **kwargs, ++ ): ++ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call ++ import time ++ start_time = time.time() ++ self._validate_model_class() ++ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria ++ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) ++ self._validate_model_kwargs(model_kwargs.copy()) ++ ++ ++ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() ++ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() ++ ++ accepts_attention_mask = True ++ requires_attention_mask = "encoder_outputs" not in model_kwargs ++ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None ++ ++ # 3. Define model inputs ++ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( ++ inputs, generation_config.bos_token_id, model_kwargs ++ ) ++ batch_size = inputs_tensor.shape[0] ++ ++ device = inputs_tensor.device ++ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) ++ ++ # 4. Define other model kwargs ++ # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are ++ # generating the first new token or not, and we only want to use the embeddings for the first new token) ++ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": ++ model_kwargs["use_cache"] = True ++ else: ++ model_kwargs["use_cache"] = generation_config.use_cache ++ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: ++ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( ++ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ++ ) ++ ++ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: ++ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` ++ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( ++ inputs_tensor, model_kwargs, model_input_name, generation_config ++ ) ++ ++ # 5. Prepare `input_ids` which will be used for auto-regressive generation ++ if self.config.is_encoder_decoder: ++ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( ++ batch_size=batch_size, ++ model_input_name=model_input_name, ++ model_kwargs=model_kwargs, ++ decoder_start_token_id=generation_config.decoder_start_token_id, ++ device=inputs_tensor.device, ++ ) ++ else: ++ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") ++ ++ if generation_config.token_healing: ++ input_ids = self.heal_tokens(input_ids, tokenizer) ++ ++ # 6. Prepare `max_length` depending on other stopping criteria. ++ input_ids_length = input_ids.shape[-1] ++ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None ++ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None ++ generation_config = self._prepare_generated_length( ++ generation_config=generation_config, ++ has_default_max_length=has_default_max_length, ++ has_default_min_length=has_default_min_length, ++ model_input_name=model_input_name, ++ inputs_tensor=inputs_tensor, ++ input_ids_length=input_ids_length, ++ ) ++ ++ use_dynamic_cache_by_default = False ++ if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: ++ raise ValueError( ++ "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " ++ "Cache object) is unsupported. Please use only one of the two." ++ ) ++ elif generation_config.cache_implementation is not None: ++ if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: ++ if generation_config.cache_implementation == "static" and not self._supports_static_cache: ++ raise ValueError( ++ "This model does not support `cache_implementation='static'`. Please check the following " ++ "issue: https://github.com/huggingface/transformers/issues/28981" ++ ) ++ model_kwargs["past_key_values"] = self._get_cache( ++ generation_config.cache_implementation, ++ getattr(generation_config, "num_beams", 1) * batch_size, ++ generation_config.max_length, ++ ) ++ elif generation_config.cache_implementation == "quantized": ++ if not self._supports_quantized_cache: ++ raise ValueError( ++ "This model does not support the quantized cache. If you want your model to support quantized " ++ "cache, please open an issue." ++ ) ++ ++ cache_config = ( ++ generation_config.cache_config ++ if generation_config.cache_config is not None ++ else QuantizedCacheConfig() ++ ) ++ cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] ++ ++ if cache_config.backend == "quanto" and not is_quanto_available(): ++ raise ImportError( ++ "You need to install `quanto` in order to use KV cache quantization with quanto backend. " ++ "Please install it via with `pip install quanto`" ++ ) ++ elif cache_config.backend == "HQQ" and not is_hqq_available(): ++ raise ImportError( ++ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " ++ "Please install it via with `pip install hqq`" ++ ) ++ ++ model_kwargs["past_key_values"] = cache_class(cache_config) ++ # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that ++ # keeps copying the cache thus using much more memory ++ elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): ++ past = model_kwargs.get("past_key_values", None) ++ if past is None: ++ model_kwargs["past_key_values"] = DynamicCache() ++ use_dynamic_cache_by_default = True ++ elif isinstance(past, tuple): ++ model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past) ++ use_dynamic_cache_by_default = True ++ ++ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) ++ ++ # 7. determine generation mode ++ generation_mode = generation_config.get_generation_mode(assistant_model) ++ # 8. prepare distribution pre_processing samplers ++ prepared_logits_processor = self._get_logits_processor( ++ generation_config=generation_config, ++ input_ids_seq_length=input_ids_length, ++ encoder_input_ids=inputs_tensor, ++ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ++ logits_processor=logits_processor, ++ device=inputs_tensor.device, ++ model_kwargs=model_kwargs, ++ negative_prompt_ids=negative_prompt_ids, ++ negative_prompt_attention_mask=negative_prompt_attention_mask, ++ ) ++ ++ # 9. prepare stopping criteria ++ prepared_stopping_criteria = self._get_stopping_criteria( ++ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ++ ) ++ ++ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): ++ # 11. prepare logits warper ++ prepared_logits_warper = ( ++ self._get_logits_warper(generation_config, device=input_ids.device) ++ if generation_config.do_sample ++ else None ++ ) ++ ++ # 12. expand input_ids with `num_return_sequences` additional sequences per batch ++ input_ids, model_kwargs = self._expand_inputs_for_generation( ++ input_ids=input_ids, ++ expand_size=generation_config.num_return_sequences, ++ is_encoder_decoder=self.config.is_encoder_decoder, ++ **model_kwargs, ++ ) ++ # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) ++ result = self._sample( ++ input_ids, ++ logits_processor=prepared_logits_processor, ++ logits_warper=prepared_logits_warper, ++ stopping_criteria=prepared_stopping_criteria, ++ generation_config=generation_config, ++ **model_kwargs, ++ ) ++ return result ++ ++ def _sample( ++ self, ++ input_ids, ++ logits_processor, ++ stopping_criteria, ++ generation_config, ++ logits_warper = None, ++ **model_kwargs, ++ ): ++ # init values ++ pad_token_id = generation_config.pad_token_id ++ output_attentions = generation_config.output_attentions ++ output_hidden_states = generation_config.output_hidden_states ++ output_scores = generation_config.output_scores ++ output_logits = generation_config.output_logits ++ return_dict_in_generate = generation_config.return_dict_in_generate ++ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) ++ do_sample = generation_config.do_sample ++ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): ++ raise ValueError( ++ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " ++ f"{logits_warper})." ++ ) ++ ++ # init attention / hidden states / scores tuples ++ scores = () if (return_dict_in_generate and output_scores) else None ++ raw_logits = () if (return_dict_in_generate and output_logits) else None ++ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None ++ cross_attentions = () if (return_dict_in_generate and output_attentions) else None ++ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None ++ ++ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states ++ if return_dict_in_generate and self.config.is_encoder_decoder: ++ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None ++ encoder_hidden_states = ( ++ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ++ ) ++ ++ this_peer_finished = False ++ batch_size = input_ids.shape[0] ++ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) ++ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) ++ ++ # keep track of which sequences are already finished ++ if self.is_mindie or self.config.architectures[0]=="MT5ForConditionalGeneration": ++ num_layers = self.config.num_layers ++ num_heads = self.config.num_heads ++ d_kv = self.config.d_kv ++ model_kwargs["past_keys"] = [torch.randn(batch_size, num_heads, 0, d_kv).half().npu() for _ in range(num_layers)] ++ model_kwargs["past_values"] = [torch.randn(batch_size, num_heads, 0, d_kv).half().npu() for _ in range(num_layers)] ++ ++ ++ while self._has_unfinished_sequences(this_peer_finished, False, device=input_ids.device): ++ # prepare model inputs ++ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) ++ model_args = [model_kwargs["encoder_outputs"]["last_hidden_state"]] ++ model_args.extend(model_kwargs["past_cross_keys"]) ++ model_args.extend(model_kwargs["past_cross_values"]) ++ model_args.extend(model_inputs["past_keys"]) ++ model_args.extend(model_inputs["past_values"]) ++ model_args.append(model_inputs["attention_mask"]) ++ model_args.append(model_inputs["decoder_input_ids"]) ++ ++ # forward pass to get next token ++ outputs = self(*model_args) ++ outputs = Seq2SeqLMOutput(logits=outputs[0], ++ past_keys=outputs[1], ++ past_values=outputs[2]) ++ ++ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration ++ # (the clone itself is always small) ++ next_token_logits = outputs.logits[:, -1, :].clone() ++ ++ # pre-process distribution ++ next_token_scores = logits_processor(input_ids, next_token_logits) ++ if do_sample: ++ next_token_scores = logits_warper(input_ids, next_token_scores) ++ ++ # Store scores, attentions and hidden_states when required ++ if return_dict_in_generate: ++ if output_scores: ++ scores += (next_token_scores,) ++ if output_logits: ++ raw_logits += (next_token_logits,) ++ if output_attentions: ++ decoder_attentions += ( ++ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ++ ) ++ if self.config.is_encoder_decoder: ++ cross_attentions += (outputs.cross_attentions,) ++ ++ if output_hidden_states: ++ decoder_hidden_states += ( ++ (outputs.decoder_hidden_states,) ++ if self.config.is_encoder_decoder ++ else (outputs.hidden_states,) ++ ) ++ ++ # token selection ++ if do_sample: ++ probs = nn.functional.softmax(next_token_scores, dim=-1) ++ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) ++ else: ++ next_tokens = torch.argmax(next_token_scores, dim=-1) ++ ++ # finished sentences should have their next token be a padding token ++ if has_eos_stopping_criteria: ++ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) ++ ++ # update generated ids, model inputs, and length for next step ++ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) ++ model_kwargs = self._update_model_kwargs_for_generation( ++ outputs, ++ model_kwargs, ++ is_encoder_decoder=self.config.is_encoder_decoder, ++ ) ++ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) ++ this_peer_finished = unfinished_sequences.max() == 0 ++ # This is needed to properly delete outputs.logits which may be very large for first iteration ++ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration ++ del outputs ++ return input_ids ++ ++ def invert_attention_mask(self, encoder_attention_mask): ++ if encoder_attention_mask.dim() == 3: ++ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] ++ if encoder_attention_mask.dim() == 2: ++ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] ++ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility ++ ++ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1000 ++ ++ return encoder_extended_attention_mask ++ ++ @property ++ def device(self) -> torch.device: ++ """ ++ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same ++ device). ++ """ ++ return self.get_device() ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, deviceNone, dtype=None ++ ): ++ if dtype is None: ++ dtype = self.dtype ++ ++ if not (attention_mask.dim() == 2 and self.config.is_decoder): ++ if device is not None: ++ warnings.warn( ++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ++ ) ++ if attention_mask.dim() == 3: ++ extended_attention_mask = attention_mask[:, None, :, :] ++ elif attention_mask.dim() == 2: ++ if self.config.is_decoder: ++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( ++ input_shape, attention_mask, device ++ ) ++ else: ++ extended_attention_mask = attention_mask[:, None, None, :] ++ else: ++ raise ValueError( ++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ++ ) ++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility ++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000 ++ return extended_attention_mask ++ + + @add_start_docstrings( + "The bare MT5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", diff --git a/MindIE/MindIE-Torch/built-in/MT5/readme.md b/MindIE/MindIE-Torch/built-in/MT5/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..4862823913f94b04d6601e366d109b757012f524 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/MT5/readme.md @@ -0,0 +1,96 @@ +# MT5模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [模型推理](#section741711594517) + + + +# 概述 + + T5全称是Text-to-Text Transfer Transformer,是一种模型架构或者说是一种解决NLP任务的一种范式。把所有任务,如分类、相似度计算、文本生成都用一个Text-to-text(文本到文本)的框架里进行解决。 + 权重下载:https://huggingface.co/collections/google/mt5-release-65005f1a520f8d7b4d039509 + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | input | batchsize x input_seq_len | FLOAT16 | NHWC | + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | output | batchsize x input_seq_len | INT32 | NTHWC | + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 +- + | 配套 | 版本 | 备注 | + | ------------------------------------------------------------ |--------| ------------------------------------------------------------ | + | Python | 3.10.2 | - | + | torch | 2.1.0 | 导出pt模型所需版本 | + | torch_npu | 2.1.0 | 模型编译和推理所需版本 | + + +# 快速上手 + + +1. 安装transformers4.42.0版本。 + ```bash + pip3 install transformers==4.42.0 + ``` + +2. 安装mindie包,需要与torch_npu配合使用,请参考mindietorch配套torch_npu配置环境 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改,在MT5目录下 + + 执行命令: + ```bash + python MT5_modeling_patch.py + ``` +4.导出mindietorch模型 + ```bash + python export_mt5.py --output_dir {output_path} --model_path {model_path} --max_batchsize {max_batchsize} --max_input_seq_len {max_input_seq_len} --device_id {device_id} + ``` +参数说明: +{output_path}是输出的目录 +{model_path}模型所在目录 +{max_batchsize}推理过程中最大的batchsize +{max_input_seq_len}推理过程中最大输入长度 +{device_id} 用哪个npu device + +运行该命令后会自动生成encoder和decoder优化后的模型 + +5.精度测试 + + ```bash +python test_mt5.py --hf_model_path {model_path} --encoder_aie_path {encoder_aie_path} --decoder_aie_path {decoder_aie_path} --device_id device_id +``` + +参数说明: +{model_path}模型所在目录 +{encoder_aie_path}优化后的encoder的模型路径,要具体到.pt文件 +{decoder_aie_path}优化后的decoder的模型路径,要具体到.pt文件 +{device_id} 用哪个npu device \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/MT5/test_mt5.py b/MindIE/MindIE-Torch/built-in/MT5/test_mt5.py new file mode 100644 index 0000000000000000000000000000000000000000..92717df66f9e0df305d6a1fd9df3a71dcb9fb2f5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/MT5/test_mt5.py @@ -0,0 +1,48 @@ +import torch +import time +import argparse +import torch_npu +from transformers import MT5ForConditionalGeneration, AutoTokenizer, MT5Config + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--hf_model_path", type=str, required=True) + + parser.add_argument("--encoder_aie_path", type=str, required=True) + parser.add_argument("--decoder_aie_path", type=str, required=True) + + parser.add_argument("--device_id", type=int, help="NPU device id", default=0) + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + torch.npu.set_device(args.device_id) + model = MT5ForConditionalGeneration.from_pretrained(args.hf_model_path, torch_dtype=torch.float16).npu() + encoder = model.encoder + decoder = model.decoder + encoder_input = torch.randint(0,2000,(8,10), dtype=torch.int64).npu() + t5_config = MT5Config.from_pretrained(args.hf_model_path) + + encoder_output = encoder(encoder_input)[0] + model = MT5ForConditionalGeneration(config=t5_config, + encoder_path=args.encoder_aie_path, + decoder_path=args.decoder_aie_path, + device_id=args.device_id).half().npu() + + encoder_mindie = model.encoder_mindie + decoder_mindie = model.decoder_mindie + mindie_stream = model.stream + with torch.npu.stream(mindie_stream): # set stream + mindie_encoder_output = encoder_mindie(encoder_input)[0] + mindie_stream.synchronize() # synchronize + if (torch.cosine_similarity(encoder_output.cpu().flatten(), mindie_encoder_output.cpu().flatten(),dim=0)) < 0.99: + print("encoder precision failed") + else: + print("test OK") + + +if __name__ == "__main__": + main() + diff --git a/MindIE/MindIE-Torch/built-in/T5/T5_modeling_t5_patch.py b/MindIE/MindIE-Torch/built-in/T5/T5_modeling_t5_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..c6733e6904cdc4acd5ef4589a380fee0f71c9447 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/T5/T5_modeling_t5_patch.py @@ -0,0 +1,40 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import transformers +import argparse + + +def main(args): + transformers_path = transformers.__path__ + transformers_version = transformers.__version__ + + assert transformers_version =='4.42.0', "expectation transformers==4.42.0" + if args.ascend_soc == "Ascend910B4": + os.system(f'patch -p0 {transformers_path[0]}/models/t5/modeling_t5.py modeling_t5_800IA2.patch') + elif args.ascend_soc == "Ascend310P3": + os.system(f'patch -p0 {transformers_path[0]}/models/t5/modeling_t5.py modeling_t5.patch') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ascend_soc", type=str, default="Ascend910B4",required=True) + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/MindIE/MindIE-Torch/built-in/T5/export_t5.py b/MindIE/MindIE-Torch/built-in/T5/export_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..9c67b7c7efa4bc037fcf5622c8c4610944a5b088 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/T5/export_t5.py @@ -0,0 +1,194 @@ + +import torch +import torch_npu +import argparse +import os +import math +import mindietorch +from transformers import T5ForConditionalGeneration + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="save dir" + ) + parser.add_argument( + "--model_path", + type=str, + default="./T5-Small", + help="T5 model path" + ) + parser.add_argument( + "--max_batchsize", + type=int, + default=1, + help="max batchsize when running" + ) + + parser.add_argument( + "--max_input_seq_len", + type=int, + default=256, + help="max input_sequence length when running" + ) + + + parser.add_argument( + "--device_id", + type=int, + default=0, + help="npu device id" + ) + return parser.parse_args() + + +class TextEncoderExport(torch.nn.Module): + def __init__(self, textencoder_model): + super(TextEncoderExport, self).__init__() + self.textencoder_model = textencoder_model + + def forward(self, input_ids,attention_mask): + return self.textencoder_model(input_ids=input_ids, attention_mask=attention_mask) + +class TextDecoderExport(torch.nn.Module): + def __init__(self, textdecoder_model): + super(TextDecoderExport, self).__init__() + self.textdecoder_model = textdecoder_model + + def forward(self, + *args): + return self.textdecoder_model(*args) + +def export_textencoder(args, model, save_dir, batch_size): + encoder_path = os.path.join(save_dir, "encoder") + if not os.path.exists(encoder_path): + os.makedirs(encoder_path, mode=0o640) + traced_path = os.path.join(encoder_path, "encoder.pt") + compiled_path = os.path.join(encoder_path, "encoder_compiled.pt") + if not os.path.exists(traced_path): + text_encoder = model.encoder + dummy_input = ( + torch.ones([1, 128], dtype=torch.int64).npu(), + torch.ones([1, 128], dtype=torch.int64).npu() + ) + encoder = TextEncoderExport(text_encoder) + encoder.eval() + torch.jit.trace(encoder, dummy_input, strict=False).save(traced_path) + if not os.path.exists(compiled_path): + traced_model = torch.jit.load(traced_path).eval() + + inputs0 = [] + inputs0.append(mindietorch.Input(min_shape = (1,1), max_shape= (args.max_batchsize, args.max_input_seq_len), dtype=torch.int64)) + inputs0.append(mindietorch.Input(min_shape = (1,1), max_shape= (args.max_batchsize, args.max_input_seq_len), dtype=torch.int64)) + print("compiling encoder") + compiled_model = mindietorch.compile( + traced_model, + inputs=inputs0, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version="Ascend310P3", + optimization_level=0 + ) + compiled_model.save(compiled_path) + +def export_textdecoder(args, model, save_dir, batch_size): + decoder_path = os.path.join(save_dir, "decoder") + if not os.path.exists(decoder_path): + os.makedirs(decoder_path, mode=0o640) + traced_path = os.path.join(decoder_path, "decoder.pt") + compiled_path = os.path.join(decoder_path, "decoder_compiled.pt") + model_path = args.model_path + max_lenth = 120 + if not os.path.exists(traced_path): + text_decoder = model + all_past_keys = [torch.randn([1, model.config.num_heads, 1, model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers + all_past_values = [torch.randn([1, model.config.num_heads, 1, model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers + all_past_cross_keys = [torch.randn([1, 16, model.config.d_model]).to(torch.float16).npu()] * model.config.num_layers + all_past_cross_values = [torch.randn([1, 16, model.config.d_model]).to(torch.float16).npu()] * model.config.num_layers + dummy_input = [torch.randn(1, 16, model.config.d_model).to(torch.float16).npu()] + dummy_input.extend(all_past_cross_keys) + dummy_input.extend(all_past_cross_values) + dummy_input.extend(all_past_keys) + dummy_input.extend(all_past_values) + dummy_input.append(torch.ones(1,16).npu()) + dummy_input.append(torch.ones([1, 1], dtype=torch.int64).npu()) + decoder = TextDecoderExport(text_decoder).npu() + decoder.eval() + torch.jit.trace(decoder, dummy_input,strict=False).save(traced_path) + if not os.path.exists(compiled_path): + traced_model = torch.jit.load(traced_path).eval() + print("compiling decoder") + input_info = [mindietorch.Input(min_shape =(1, 1, model.config.d_model), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model), + dtype=mindietorch.dtype.FLOAT16)] + past_cross_key_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_model), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_cross_value_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_model), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_key_infos = [mindietorch.Input(min_shape =(1, model.config.num_heads, 0, model.config.d_kv), + max_shape=(args.max_batchsize, model.config.num_heads, args.max_input_seq_len, model.config.d_kv), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_value_infos = [mindietorch.Input(min_shape =(1, model.config.num_heads, 0, model.config.d_kv), + max_shape=(args.max_batchsize, model.config.num_heads, args.max_input_seq_len, model.config.d_kv), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + decoder_input_ids_info = [mindietorch.Input(min_shape =(1, 1), + max_shape = (args.max_batchsize,1), + dtype=mindietorch.dtype.INT64)] + encoder_attention_mask_info = [mindietorch.Input(min_shape =(1, 1), + max_shape = (args.max_batchsize,args.max_input_seq_len), + dtype=mindietorch.dtype.INT64)] + input_info.extend(past_cross_key_infos) + input_info.extend(past_cross_value_infos) + input_info.extend(past_key_infos) + input_info.extend(past_value_infos) + input_info.extend(encoder_attention_mask_info) + input_info.extend(decoder_input_ids_info) + buffer = [] + for _ in range(2*model.config.num_layers): + buffer.append(math.ceil((args.max_batchsize * args.max_input_seq_len * model.config.d_model * 2) / 1024 / 1024)) + buffer_size0 = math.ceil((args.max_batchsize * 1 * model.config.vocab_size * 4) / 1024 / 1024) + buffer.append(buffer_size0) + print("buffer=",buffer) + compiled_model = mindietorch.compile( + traced_model, + inputs=input_info, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version="Ascend310P3", + default_buffer_size_vec=buffer, + optimization_level=0 + ) + compiled_model.save(compiled_path) + + +def main(): + args = parse_arguments() + device_id = args.device_id + save_dir = args.output_dir + torch.npu.set_device(device_id) + batch_size = 1 + model = T5ForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.float).npu() + encoder_path = os.path.join(save_dir, "encoder") + compiled_path = os.path.join(encoder_path, "encoder_compiled.pt") + if not os.path.exists(compiled_path): + export_textencoder(args, model, save_dir, batch_size) + print("export encoder_model done!") + + decoder_path = os.path.join(save_dir, "decoder") + compiled_path = os.path.join(decoder_path, "decoder_compiled.pt") + if not os.path.exists(compiled_path): + export_textdecoder(args, model, save_dir, batch_size) + print("export decoder_model done!") + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/T5/export_t5_800IA2.py b/MindIE/MindIE-Torch/built-in/T5/export_t5_800IA2.py new file mode 100644 index 0000000000000000000000000000000000000000..e150e8e93a644539df463321f279fdd6929d7312 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/T5/export_t5_800IA2.py @@ -0,0 +1,202 @@ + +import torch +import torch_npu +import argparse +import os +import math +import mindietorch +from transformers import T5ForConditionalGeneration + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="save dir" + ) + parser.add_argument( + "--model_path", + type=str, + default="./T5-Small", + help="T5 model path" + ) + parser.add_argument( + "--max_batchsize", + type=int, + default=1, + help="max batchsize when running" + ) + + parser.add_argument( + "--max_input_seq_len", + type=int, + default=256, + help="max input_sequence length when running" + ) + + + parser.add_argument( + "--device_id", + type=int, + default=0, + help="npu device id" + ) + return parser.parse_args() + + +class TextEncoderExport(torch.nn.Module): + def __init__(self, textencoder_model): + super(TextEncoderExport, self).__init__() + self.textencoder_model = textencoder_model + + def forward(self, input_ids,attention_mask): + return self.textencoder_model(input_ids=input_ids,attention_mask=attention_mask) + +class TextDecoderExport(torch.nn.Module): + def __init__(self, textdecoder_model): + super(TextDecoderExport, self).__init__() + self.textdecoder_model = textdecoder_model + + def forward(self, + *args): + return self.textdecoder_model(*args) + +def export_textencoder(args, model, save_dir, batch_size): + encoder_path = os.path.join(save_dir, "encoder") + if not os.path.exists(encoder_path): + os.makedirs(encoder_path, mode=0o640) + traced_path = os.path.join(encoder_path, "encoder.pt") + compiled_path = os.path.join(encoder_path, "encoder_compiled.pt") + if not os.path.exists(traced_path): + text_encoder = model.encoder + dummy_input = ( + torch.ones([1, 128], dtype=torch.int64).npu(), + torch.ones([1, 1,128,128], dtype=torch.bool).npu() + ) + encoder = TextEncoderExport(text_encoder) + encoder.eval() + torch.jit.trace(encoder, dummy_input, strict=False).save(traced_path) + if not os.path.exists(compiled_path): + traced_model = torch.jit.load(traced_path).eval() + + inputs0 = [] + inputs0.append(mindietorch.Input(min_shape = (1,1), max_shape= (args.max_batchsize, args.max_input_seq_len), dtype=torch.int64)) + inputs0.append(mindietorch.Input(min_shape = (1,1,1,1), max_shape= (args.max_batchsize, 1,args.max_input_seq_len,args.max_input_seq_len), dtype=torch.bool)) + print("compiling encoder") + compiled_model = mindietorch.compile( + traced_model, + inputs=inputs0, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version="Ascend910B4", + optimization_level=0 + ) + compiled_model.save(compiled_path) + +def export_textdecoder(args, model, save_dir, batch_size): + decoder_path = os.path.join(save_dir, "decoder") + if not os.path.exists(decoder_path): + os.makedirs(decoder_path, mode=0o640) + traced_path = os.path.join(decoder_path, "decoder.pt") + compiled_path = os.path.join(decoder_path, "decoder_compiled.pt") + model_path = args.model_path + max_lenth = 120 + if not os.path.exists(traced_path): + text_decoder = model + all_past_keys = [torch.randn([1, 1, model.config.d_kv*model.config.num_heads]).to(torch.float16).npu()] * model.config.num_layers + all_past_values = [torch.randn([1, 1, model.config.d_kv*model.config.num_heads]).to(torch.float16).npu()] * model.config.num_layers + all_past_cross_keys = [torch.randn([1, 16, model.config.d_kv*model.config.num_heads]).to(torch.float16).npu()] * model.config.num_layers + all_past_cross_values = [torch.randn([1, 16, model.config.d_kv*model.config.num_heads]).to(torch.float16).npu()] * model.config.num_layers + dummy_input = [torch.randn(1, 16, model.config.d_kv*model.config.num_heads).to(torch.float16).npu()] + dummy_input.extend(all_past_cross_keys) + dummy_input.extend(all_past_cross_values) + dummy_input.extend(all_past_keys) + dummy_input.extend(all_past_values) + # encoder_attention_mask + dummy_input.append(torch.ones((1,1,16,16),dtype=torch.bool).npu()) + # decoder_input_ids + dummy_input.append(torch.ones([1, 1], dtype=torch.int64).npu()) + dummy_input.append(torch.ones([1, 1, 1, 1], dtype=torch.bool).npu()) + # decoder_attention_mask + decoder = TextDecoderExport(text_decoder).npu() + decoder.eval() + torch.jit.trace(decoder, dummy_input,strict=False).save(traced_path) + if not os.path.exists(compiled_path): + traced_model = torch.jit.load(traced_path).eval() + print("compiling decoder") + input_info = [mindietorch.Input(min_shape =(1, 1, model.config.d_model), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model), + dtype=mindietorch.dtype.FLOAT16)] + past_cross_key_infos = [mindietorch.Input(min_shape =(1, 1, model.config.num_heads*model.config.d_kv), + max_shape=(args.max_batchsize,args.max_input_seq_len, model.config.num_heads*model.config.d_kv), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_cross_value_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_kv*model.config.num_heads), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_kv*model.config.num_heads), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_key_infos = [mindietorch.Input(min_shape =(1, 0, model.config.d_kv*model.config.num_heads), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_kv*model.config.num_heads), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + past_value_infos = [mindietorch.Input(min_shape =(1, 0, model.config.d_kv*model.config.num_heads), + max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_kv*model.config.num_heads), + dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers + decoder_input_ids_info = [mindietorch.Input(min_shape =(1, 1), + max_shape = (args.max_batchsize,1), + dtype=mindietorch.dtype.INT64)] + encoder_attention_mask_info = [mindietorch.Input(min_shape =(1, 1,1, 1), + max_shape = (args.max_batchsize, 1, args.max_input_seq_len,args.max_input_seq_len), + dtype=mindietorch.dtype.BOOL)] + decoder_attention_mask_info = [mindietorch.Input(min_shape =(1, 1,1,1), + max_shape = (args.max_batchsize,1,args.max_input_seq_len,args.max_input_seq_len), + dtype=mindietorch.dtype.BOOL)] + input_info.extend(past_cross_key_infos) + input_info.extend(past_cross_value_infos) + input_info.extend(past_key_infos) + input_info.extend(past_value_infos) + input_info.extend(encoder_attention_mask_info) + input_info.extend(decoder_input_ids_info) + input_info.extend(decoder_attention_mask_info) + buffer = [] + for _ in range(2*model.config.num_layers): + buffer.append(math.ceil((args.max_batchsize * args.max_input_seq_len * model.config.d_model * 2) / 1024 / 1024)) + buffer_size0 = math.ceil((args.max_batchsize * 1 * model.config.vocab_size * 4) / 1024 / 1024) + buffer.append(buffer_size0) + print("buffer=",buffer) + compiled_model = mindietorch.compile( + traced_model, + inputs=input_info, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version="Ascend910B4", + default_buffer_size_vec=buffer, + optimization_level=0 + ) + compiled_model.save(compiled_path) + + +def main(): + args = parse_arguments() + device_id = args.device_id + save_dir = args.output_dir + torch.npu.set_device(device_id) + batch_size = 1 + model = T5ForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.float).npu() + encoder_path = os.path.join(save_dir, "encoder") + compiled_path = os.path.join(encoder_path, "encoder_compiled.pt") + if not os.path.exists(compiled_path): + export_textencoder(args, model, save_dir, batch_size) + print("export encoder_model done!") + + decoder_path = os.path.join(save_dir, "decoder") + compiled_path = os.path.join(decoder_path, "decoder_compiled.pt") + if not os.path.exists(compiled_path): + export_textdecoder(args, model, save_dir, batch_size) + print("export decoder_model done!") + + +if __name__ == "__main__": + main() diff --git a/MindIE/MindIE-Torch/built-in/T5/main.py b/MindIE/MindIE-Torch/built-in/T5/main.py new file mode 100644 index 0000000000000000000000000000000000000000..6e20f1e05ed7c8458fcdebb882105277d445c809 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/T5/main.py @@ -0,0 +1,51 @@ +import torch +import torch_npu +import mindietorch +import time +import argparse +from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--hf_model_path", type=str, required=True) + + parser.add_argument("--encoder_aie_path", type=str, required=True) + parser.add_argument("--decoder_aie_path", type=str, required=True) + + parser.add_argument("--device_id", type=int, help="NPU device id", default=0) + + parser.add_argument("--performance", action="store_true") + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + + torch.npu.set_device(args.device_id) + tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path) + text = ["今年2月26日,阿富汗塔里班的最高领秀下令销毁全国范围内所有“非伊斯兰“的古文化遗产,其中包括矗立于巴米扬的世高(大界最约58米)的立式佛像。"] + t5_config = T5Config.from_pretrained(args.hf_model_path) + # model = T5ForConditionalGeneration.from_pretrained(args.hf_model_path).half().npu() + model = T5ForConditionalGeneration(config=t5_config, + encoder_path=args.encoder_aie_path, + decoder_path=args.decoder_aie_path, + device_id=args.device_id).half().npu() + input_ids = tokenizer(text, return_tensors = "pt", padding=True).input_ids + if args.performance: + input_ids = torch.randint(0,32000,(1,512)) + outputs = model.generate(input_ids.npu(),max_new_tokens=512) + print("token length : ", input_ids.shape) + start_time = time.time() + + outputs = model.generate(input_ids.npu(),max_new_tokens=512) + inference_time = time.time()-start_time + print("time_cost=", inference_time) + print("output token length : ", outputs[0].shape[0]) + print("throught output is : ", outputs[0].shape[0] / inference_time) + if not args.performance: + print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/T5/modeling_t5.patch b/MindIE/MindIE-Torch/built-in/T5/modeling_t5.patch new file mode 100644 index 0000000000000000000000000000000000000000..15f81df2a4a590bc64353f172b06f03a9371377f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/T5/modeling_t5.patch @@ -0,0 +1,1635 @@ +diff --git a/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5_origin.py b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py +index 224769f..24f868b 100644 +--- a/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5_origin.py ++++ b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py +@@ -19,8 +19,10 @@ import math + import os + import warnings + from typing import List, Optional, Tuple, Union +- ++from dataclasses import dataclass + import torch ++import torch_npu ++import mindietorch + from torch import nn + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +@@ -28,13 +30,12 @@ from ...activations import ACT2FN + from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, +- Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, + ) +-from ...modeling_utils import PreTrainedModel ++from ...modeling_utils import PreTrainedModel,ModuleUtilsMixin + from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer + from ...utils import ( + DUMMY_INPUTS, +@@ -47,7 +48,43 @@ from ...utils import ( + ) + from ...utils.model_parallel_utils import assert_device_map, get_device_map + from .configuration_t5 import T5Config ++from transformers.generation.logits_process import LogitsProcessorList ++from transformers.generation.stopping_criteria import StoppingCriteriaList ++from transformers.generation.configuration_utils import GenerationMode ++from transformers.utils.generic import ModelOutput ++ ++ ++@dataclass ++class Seq2SeqLMOutput(ModelOutput): ++ """ ++ Base class for model's outputs, with potential hidden states and attentions. + ++ Args: ++ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): ++ Sequence of hidden-states at the output of the last layer of the model. ++ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): ++ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + ++ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. ++ ++ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. ++ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): ++ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, ++ sequence_length)`. ++ ++ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention ++ heads. ++ """ ++ loss: Optional[torch.FloatTensor] = None ++ logits: torch.FloatTensor = None ++ past_keys: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ past_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None ++ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None ++ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None ++ encoder_last_hidden_state: Optional[torch.FloatTensor] = None ++ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None ++ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + logger = logging.get_logger(__name__) + +@@ -448,7 +485,10 @@ class T5Attention(nn.Module): + mask=None, + key_value_states=None, + position_bias=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, ++ past_cross_key=None, ++ past_cross_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, +@@ -464,12 +504,8 @@ class T5Attention(nn.Module): + + real_seq_length = seq_length + +- if past_key_value is not None: +- if len(past_key_value) != 2: +- raise ValueError( +- f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" +- ) +- real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length ++ if past_key is not None: ++ real_seq_length += past_key.shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + +@@ -493,16 +529,17 @@ class T5Attention(nn.Module): + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: ++ past_key_value = shape(past_key_value) + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) +- elif past_key_value.shape[2] != key_value_states.shape[1]: +- # checking that the `sequence_length` of the `past_key_value` is the same as +- # the provided `key_value_states` to support prefix tuning +- # cross-attn +- # (batch_size, n_heads, seq_length, dim_per_head) +- hidden_states = shape(proj_layer(key_value_states)) ++ # elif past_key_value.shape[2] != key_value_states.shape[1]: ++ # # checking that the `sequence_length` of the `past_key_value` is the same as ++ # # the provided `key_value_states` to support prefix tuning ++ # # cross-attn ++ # # (batch_size, n_heads, seq_length, dim_per_head) ++ # hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value +@@ -513,17 +550,16 @@ class T5Attention(nn.Module): + + # get key/value states + key_states = project( +- hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ++ hidden_states, self.k, key_value_states, past_key if past_key is not None else None + ) + value_states = project( +- hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None ++ hidden_states, self.v, key_value_states, past_value if past_value is not None else None + ) +- ++ # torch.ops.mindie.flash_attention_plugin(query_states, key_states, value_states,) + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 +- + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( +@@ -536,7 +572,7 @@ class T5Attention(nn.Module): + + # if key and values are already calculated + # we want only the last query position bias +- if past_key_value is not None: ++ if past_key is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: +@@ -548,7 +584,6 @@ class T5Attention(nn.Module): + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias +- + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores +@@ -564,18 +599,131 @@ class T5Attention(nn.Module): + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + +- present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None +- outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) ++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None ++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None ++ present_value_state = (value_states.half(),) if (self.is_decoder and use_cache) else None ++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,) ++ ++ if output_attentions: ++ outputs = outputs + (attn_weights,) ++ return outputs ++ ++ ++class T5SelfAttention(T5Attention): ++ def __init__(self, config: T5Config, has_relative_attention_bias=False): ++ super().__init__(config, has_relative_attention_bias) ++ ++ def forward( ++ self, ++ hidden_states, ++ mask=None, ++ position_bias=None, ++ past_key=None, ++ past_value=None, ++ layer_head_mask=None, ++ use_cache=False, ++ output_attentions=False, ++ ): ++ """ ++ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). ++ """ ++ # Input is (batch_size, seq_length, dim) ++ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) ++ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) ++ batch_size, seq_length = hidden_states.shape[:2] ++ ++ real_seq_length = seq_length ++ ++ if past_key is not None: ++ real_seq_length += past_key.shape[2] ++ key_length = real_seq_length ++ def shape(states): ++ """projection""" ++ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) ++ ++ def unshape(states): ++ """reshape""" ++ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) ++ ++ def project(hidden_states, proj_layer, past_key_value): ++ """projects hidden states correctly to key/query states""" ++ if past_key_value is None: ++ # cross-attn ++ # (batch_size, n_heads, seq_length, dim_per_head) ++ hidden_states = shape(proj_layer(hidden_states)) + ++ if past_key_value is not None: ++ hidden_states = shape(proj_layer(hidden_states)) ++ hidden_states = torch.cat([past_key_value, hidden_states], dim=2) ++ return hidden_states ++ ++ # get query states ++ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) ++ ++ # get key/value states ++ key_states = project( ++ hidden_states, self.k, past_key if past_key is not None else None ++ ) ++ value_states = project( ++ hidden_states, self.v, past_value if past_value is not None else None ++ ) ++ # compute scores ++ scores = torch.matmul( ++ query_states, key_states.transpose(3, 2) ++ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 ++ if position_bias is None: ++ if not self.has_relative_attention_bias: ++ position_bias = torch.zeros( ++ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ++ ) ++ if self.gradient_checkpointing and self.training: ++ position_bias.requires_grad = True ++ else: ++ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) ++ ++ # if key and values are already calculated ++ # we want only the last query position bias ++ if past_key is not None: ++ position_bias = position_bias[:, :, -hidden_states.size(1) :, :] ++ if mask is not None: ++ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) ++ ++ if self.pruned_heads: ++ mask = torch.ones(position_bias.shape[1]) ++ mask[list(self.pruned_heads)] = 0 ++ position_bias_masked = position_bias[:, mask.bool()] ++ else: ++ position_bias_masked = position_bias ++ scores += position_bias_masked ++ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( ++ scores ++ ) # (batch_size, n_heads, seq_length, key_length) ++ attn_weights = nn.functional.dropout( ++ attn_weights, p=self.dropout, training=self.training ++ ) # (batch_size, n_heads, seq_length, key_length) ++ ++ # Mask heads if we want to ++ if layer_head_mask is not None: ++ attn_weights = attn_weights * layer_head_mask ++ ++ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) ++ attn_output = self.o(attn_output) ++ ++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None ++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None ++ present_value_state = (value_states.half(), ) if (self.is_decoder and use_cache) else None ++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,) + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + ++ ++ + class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() +- self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) ++ self.SelfAttention = T5SelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + +@@ -585,7 +733,8 @@ class T5LayerSelfAttention(nn.Module): + attention_mask=None, + position_bias=None, + layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, + use_cache=False, + output_attentions=False, + ): +@@ -595,7 +744,8 @@ class T5LayerSelfAttention(nn.Module): + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +@@ -618,7 +768,8 @@ class T5LayerCrossAttention(nn.Module): + attention_mask=None, + position_bias=None, + layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, + use_cache=False, + query_length=None, + output_attentions=False, +@@ -630,7 +781,8 @@ class T5LayerCrossAttention(nn.Module): + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, +@@ -661,39 +813,34 @@ class T5Block(nn.Module): + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, ++ past_cross_key=None, ++ past_cross_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): +- if past_key_value is not None: +- if not self.is_decoder: +- logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") +- expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 +- +- if len(past_key_value) != expected_num_past_key_values: +- raise ValueError( +- f"There should be {expected_num_past_key_values} past states. " +- f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" +- f"Got {len(past_key_value)} past key / value states" +- ) +- +- self_attn_past_key_value = past_key_value[:2] +- cross_attn_past_key_value = past_key_value[2:] ++ if past_key is not None: ++ self_attn_past_key = past_key ++ self_attn_past_value = past_value ++ cross_attn_past_key = past_cross_key ++ cross_attn_past_value = past_cross_value + else: +- self_attn_past_key_value, cross_attn_past_key_value = None, None ++ self_attn_past_key, self_attn_past_value, cross_attn_past_key, cross_attn_past_value = None, None, None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=self_attn_past_key_value, ++ past_key=self_attn_past_key, ++ past_value=self_attn_past_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +- hidden_states, present_key_value_state = self_attention_outputs[:2] +- attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights ++ hidden_states, present_key_state, present_value_state = self_attention_outputs[:3] ++ attention_outputs = self_attention_outputs[3:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: +@@ -706,22 +853,23 @@ class T5Block(nn.Module): + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: ++ + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here +- if present_key_value_state is not None: +- query_length = present_key_value_state[0].shape[2] ++ if present_key_state is not None: ++ query_length = present_key_state[0].shape[2] + else: + query_length = None +- + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, +- past_key_value=cross_attn_past_key_value, ++ past_key=cross_attn_past_key, ++ past_value=cross_attn_past_value, + query_length=query_length, +- use_cache=use_cache, ++ use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] +@@ -736,11 +884,9 @@ class T5Block(nn.Module): + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states +- if present_key_value_state is not None: +- present_key_value_state = present_key_value_state + cross_attention_outputs[1] +- ++ # cross_attn_past_key_values = cross_attention_outputs[1] + # Keep cross-attention outputs and relative position weights +- attention_outputs = attention_outputs + cross_attention_outputs[2:] ++ attention_outputs = attention_outputs + cross_attention_outputs[3:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) +@@ -757,7 +903,7 @@ class T5Block(nn.Module): + outputs = (hidden_states,) + + if use_cache: +- outputs = outputs + (present_key_value_state,) + attention_outputs ++ outputs = outputs + (present_key_state,) +(present_value_state,)+ attention_outputs + else: + outputs = outputs + attention_outputs + +@@ -897,11 +1043,15 @@ class T5PreTrainedModel(PreTrainedModel): + + + class T5Stack(T5PreTrainedModel): +- def __init__(self, config, embed_tokens=None): ++ def __init__(self, config, embed_tokens=None,lm_head=None, encodecrosskey=None, encodecrossvalue=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder ++ self.lm_head=lm_head ++ self.encodecrosskey = encodecrosskey ++ self.encodecrossvalue = encodecrossvalue ++ self.model_dim = config.d_model + + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] +@@ -966,20 +1116,63 @@ class T5Stack(T5PreTrainedModel): + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + ++ def invert_attention_mask(self, encoder_attention_mask): ++ if encoder_attention_mask.dim() == 3: ++ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] ++ if encoder_attention_mask.dim() == 2: ++ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] ++ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility ++ ++ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1000 ++ ++ return encoder_extended_attention_mask ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, device=None, dtype=None ++ ): ++ if dtype is None: ++ dtype = self.dtype ++ ++ if not (attention_mask.dim() == 2 and self.config.is_decoder): ++ if device is not None: ++ warnings.warn( ++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ++ ) ++ if attention_mask.dim() == 3: ++ extended_attention_mask = attention_mask[:, None, :, :] ++ elif attention_mask.dim() == 2: ++ if self.config.is_decoder: ++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( ++ input_shape, attention_mask, device ++ ) ++ else: ++ extended_attention_mask = attention_mask[:, None, None, :] ++ else: ++ raise ValueError( ++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ++ ) ++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility ++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000 ++ return extended_attention_mask ++ + def forward( + self, + input_ids=None, +- attention_mask=None, + encoder_hidden_states=None, ++ past_keys=None, ++ past_values=None, ++ past_cross_keys=None, ++ past_cross_values=None, + encoder_attention_mask=None, ++ attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, +- past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, ++ **model_kwargs + ): + # Model parallel + if self.model_parallel: +@@ -998,8 +1191,10 @@ class T5Stack(T5PreTrainedModel): + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: ++ + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) ++ input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: +@@ -1012,18 +1207,19 @@ class T5Stack(T5PreTrainedModel): + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape +- + # required mask seq length can be calculated via length of past +- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length ++ mask_seq_length = past_keys[0].shape[2] + seq_length if past_keys is not None else seq_length + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values with `None` if past does not exist +- if past_key_values is None: +- past_key_values = [None] * len(self.block) +- ++ if not self.is_decoder: ++ past_keys = [None] * len(self.block) ++ past_values = [None] * len(self.block) ++ past_cross_keys = [None] * len(self.block) ++ past_cross_values = [None] * len(self.block) + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + +@@ -1054,7 +1250,8 @@ class T5Stack(T5PreTrainedModel): + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) +- present_key_value_states = () if use_cache else None ++ present_key_states = () if use_cache else None ++ present_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None +@@ -1062,8 +1259,8 @@ class T5Stack(T5PreTrainedModel): + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) +- +- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): ++ # for i, layer_module in enumerate(self.block): ++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel +@@ -1112,7 +1309,10 @@ class T5Stack(T5PreTrainedModel): + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, ++ past_cross_key=past_cross_key, ++ past_cross_value=past_cross_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +@@ -1120,19 +1320,20 @@ class T5Stack(T5PreTrainedModel): + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: +- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] ++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:] + +- hidden_states, present_key_value_state = layer_outputs[:2] ++ hidden_states, present_key_state, present_value_state = layer_outputs[:3] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) +- position_bias = layer_outputs[2] ++ position_bias = layer_outputs[3] + if self.is_decoder and encoder_hidden_states is not None: +- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] ++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4] + # append next layer key value states + if use_cache: +- present_key_value_states = present_key_value_states + (present_key_value_state,) ++ present_key_states = present_key_states + present_key_state ++ present_value_states = present_value_states + present_value_state + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) +@@ -1146,7 +1347,7 @@ class T5Stack(T5PreTrainedModel): + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) +- hidden_states = self.dropout(hidden_states) ++ hidden_states = self.dropout(hidden_states).half() + + # Add last layer + if output_hidden_states: +@@ -1164,13 +1365,216 @@ class T5Stack(T5PreTrainedModel): + ] + if v is not None + ) +- return BaseModelOutputWithPastAndCrossAttentions( +- last_hidden_state=hidden_states, +- past_key_values=present_key_value_states, +- hidden_states=all_hidden_states, +- attentions=all_attentions, +- cross_attentions=all_cross_attentions, ++ if not self.is_decoder: ++ cross_keys = None ++ cross_values = None ++ if self.encodecrosskey: ++ cross_keys = self.encodecrosskey(hidden_states) ++ if self.encodecrossvalue: ++ cross_values = self.encodecrossvalue(hidden_states) ++ return tuple((hidden_states, cross_keys, cross_values)) ++ lm_logits = None ++ if self.is_decoder: ++ if self.config.tie_word_embeddings: ++ hidden_states = hidden_states * (self.model_dim ** -0.5) ++ lm_logits = self.lm_head(hidden_states) ++ return tuple((lm_logits, present_key_states, present_value_states)) ++ ++ ++class T5Stack_Encoder(T5PreTrainedModel): ++ def __init__(self, config, embed_tokens=None, encodecrosskey=None, encodecrossvalue=None): ++ super().__init__(config) ++ self.embed_tokens = embed_tokens ++ self.is_decoder = config.is_decoder ++ self.encodecrosskey = encodecrosskey ++ self.encodecrossvalue = encodecrossvalue ++ self.model_dim = config.d_model ++ ++ self.block = nn.ModuleList( ++ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + ) ++ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) ++ self.dropout = nn.Dropout(config.dropout_rate) ++ ++ # Initialize weights and apply final processing ++ self.post_init() ++ # Model parallel ++ self.model_parallel = False ++ self.device_map = None ++ self.gradient_checkpointing = False ++ ++ def get_input_embeddings(self): ++ return self.embed_tokens ++ ++ def set_input_embeddings(self, new_embeddings): ++ self.embed_tokens = new_embeddings ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, device=None, dtype=None ++ ): ++ if dtype is None: ++ dtype = self.dtype ++ ++ if not (attention_mask.dim() == 2 and self.config.is_decoder): ++ if device is not None: ++ warnings.warn( ++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ++ ) ++ if attention_mask.dim() == 3: ++ extended_attention_mask = attention_mask[:, None, :, :] ++ elif attention_mask.dim() == 2: ++ if self.config.is_decoder: ++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( ++ input_shape, attention_mask, device ++ ) ++ else: ++ extended_attention_mask = attention_mask[:, None, None, :] ++ else: ++ raise ValueError( ++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ++ ) ++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility ++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000 ++ return extended_attention_mask ++ ++ def forward( ++ self, ++ input_ids=None, ++ attention_mask=None, ++ head_mask=None, ++ cross_attn_head_mask=None, ++ use_cache=None, ++ output_attentions=None, ++ output_hidden_states=None, ++ return_dict=None, ++ **model_kwargs ++ ): ++ # Model parallel ++ use_cache = use_cache if use_cache is not None else self.config.use_cache ++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions ++ output_hidden_states = ( ++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ++ ) ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ input_shape = input_ids.size() ++ input_ids = input_ids.view(-1, input_shape[-1]) ++ ++ inputs_embeds = self.embed_tokens(input_ids) ++ ++ batch_size, seq_length = input_shape ++ # required mask seq length can be calculated via length of past ++ mask_seq_length = seq_length ++ ++ # initialize past_key_values with `None` if past does not exist ++ past_keys = [None] * len(self.block) ++ past_values = [None] * len(self.block) ++ past_cross_keys = [None] * len(self.block) ++ past_cross_values = [None] * len(self.block) ++ if attention_mask is None: ++ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) ++ ++ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] ++ # ourselves in which case we just need to make it broadcastable to all heads. ++ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) ++ ++ # If a 2D or 3D attention mask is provided for the cross-attention ++ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] ++ ++ encoder_extended_attention_mask = None ++ ++ # Prepare head mask if needed ++ head_mask = self.get_head_mask(head_mask, self.config.num_layers) ++ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) ++ present_key_states = () if use_cache else None ++ present_value_states = () if use_cache else None ++ all_hidden_states = () if output_hidden_states else None ++ all_attentions = () if output_attentions else None ++ all_cross_attentions = () if (output_attentions and self.is_decoder) else None ++ position_bias = None ++ encoder_decoder_position_bias = None ++ ++ hidden_states = self.dropout(inputs_embeds) ++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)): ++ layer_head_mask = head_mask[i] ++ cross_attn_layer_head_mask = cross_attn_head_mask[i] ++ if output_hidden_states: ++ all_hidden_states = all_hidden_states + (hidden_states,) ++ ++ layer_outputs = layer_module( ++ hidden_states, ++ attention_mask=extended_attention_mask, ++ position_bias=position_bias, ++ encoder_hidden_states=None, ++ encoder_attention_mask=encoder_extended_attention_mask, ++ encoder_decoder_position_bias=encoder_decoder_position_bias, ++ layer_head_mask=layer_head_mask, ++ cross_attn_layer_head_mask=cross_attn_layer_head_mask, ++ past_key=past_key, ++ past_value=past_value, ++ past_cross_key=past_cross_key, ++ past_cross_value=past_cross_value, ++ use_cache=use_cache, ++ output_attentions=output_attentions, ++ ) ++ ++ # layer_outputs is a tuple with: ++ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) ++ if use_cache is False: ++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:] ++ ++ hidden_states, present_key_state, present_value_state = layer_outputs[:3] ++ ++ # We share the position biases between the layers - the first layer store them ++ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), ++ # (cross-attention position bias), (cross-attention weights) ++ position_bias = layer_outputs[3] ++ if self.is_decoder and encoder_hidden_states is not None: ++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4] ++ # append next layer key value states ++ if use_cache: ++ present_key_states = present_key_states + present_key_state ++ present_value_states = present_value_states + present_value_state ++ ++ if output_attentions: ++ all_attentions = all_attentions + (layer_outputs[3],) ++ if self.is_decoder: ++ all_cross_attentions = all_cross_attentions + (layer_outputs[5],) ++ ++ # Model Parallel: If it's the last layer for that device, put things on the next device ++ if self.model_parallel: ++ for k, v in self.device_map.items(): ++ if i == v[-1] and "cuda:" + str(k) != self.last_device: ++ hidden_states = hidden_states.to("cuda:" + str(k + 1)) ++ ++ hidden_states = self.final_layer_norm(hidden_states) ++ hidden_states = self.dropout(hidden_states).half() ++ ++ # Add last layer ++ if output_hidden_states: ++ all_hidden_states = all_hidden_states + (hidden_states,) ++ ++ if not return_dict: ++ return tuple( ++ v ++ for v in [ ++ hidden_states, ++ present_key_value_states, ++ all_hidden_states, ++ all_attentions, ++ all_cross_attentions, ++ ] ++ if v is not None ++ ) ++ # present_key_value_states = torch.concat(present_key_value_states).reshape(len(self.block),2,*present_key_value_states[0].shape).half() if use_cache else None ++ if not self.is_decoder: ++ cross_keys = None ++ cross_values = None ++ if self.encodecrosskey: ++ cross_keys = self.encodecrosskey(hidden_states) ++ if self.encodecrossvalue: ++ cross_values = self.encodecrossvalue(hidden_states) ++ return tuple((hidden_states, cross_keys, cross_values)) + + + T5_START_DOCSTRING = r""" +@@ -1541,6 +1945,38 @@ class T5Model(T5PreTrainedModel): + ) + + ++class EncoderToCrossKey(nn.Module): ++ def __init__(self, cross_key, num_heads, d_kv): ++ super().__init__() ++ self.cross_key = cross_key ++ self.num_heads = num_heads ++ self.d_kv = d_kv ++ ++ ++ def forward(self, hidden_states): ++ batch_size = hidden_states.shape[0] ++ past_cross_keys = () ++ for i in range(len(self.cross_key)): ++ past_cross_keys += (self.cross_key[i](hidden_states),) ++ return past_cross_keys ++ ++ ++class EncoderToCrossValue(nn.Module): ++ def __init__(self, cross_value, num_heads, d_kv): ++ super().__init__() ++ self.cross_value = cross_value ++ self.num_heads = num_heads ++ self.d_kv = d_kv ++ ++ ++ def forward(self, hidden_states): ++ batch_size = hidden_states.shape[0] ++ past_cross_values = () ++ for i in range(len(self.cross_value)): ++ past_cross_values += (self.cross_value[i](hidden_states),) ++ return past_cross_values ++ ++ + @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) + class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ +@@ -1548,28 +1984,51 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + +- def __init__(self, config: T5Config): ++ def __init__(self, config: T5Config, encoder_path=None, decoder_path=None, device_id=0): + super().__init__(config) +- self.model_dim = config.d_model +- +- self.shared = nn.Embedding(config.vocab_size, config.d_model) +- +- encoder_config = copy.deepcopy(config) +- encoder_config.is_decoder = False +- encoder_config.use_cache = False +- encoder_config.is_encoder_decoder = False +- self.encoder = T5Stack(encoder_config, self.shared) +- +- decoder_config = copy.deepcopy(config) +- decoder_config.is_decoder = True +- decoder_config.is_encoder_decoder = False +- decoder_config.num_layers = config.num_decoder_layers +- self.decoder = T5Stack(decoder_config, self.shared) +- +- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) ++ self.encoder_path = encoder_path ++ self.decoder_path = decoder_path ++ self.is_mindie = False ++ if not self.encoder_path or not self.decoder_path: ++ self.model_dim = config.d_model ++ ++ self.shared = nn.Embedding(config.vocab_size, config.d_model) ++ ++ decoder_config = copy.deepcopy(config) ++ decoder_config.is_decoder = True ++ decoder_config.is_encoder_decoder = False ++ decoder_config.num_layers = config.num_decoder_layers ++ ++ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) ++ self.decoder = T5Stack(decoder_config, self.shared, self.lm_head) ++ ++ cross_key = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.k for i in range(config.num_decoder_layers)) ++ cross_value = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.v for i in range(config.num_decoder_layers)) ++ encodecrosskey = EncoderToCrossKey(cross_key, config.num_heads, config.d_kv) ++ encodecrossvalue = EncoderToCrossValue(cross_value, config.num_heads, config.d_kv) ++ ++ encoder_config = copy.deepcopy(config) ++ encoder_config.is_decoder = False ++ encoder_config.use_cache = False ++ encoder_config.is_encoder_decoder = False ++ self.encoder = T5Stack_Encoder(encoder_config, self.shared, encodecrosskey=encodecrosskey, encodecrossvalue=encodecrossvalue) ++ self.encoder_mindie = None ++ self.decoder_mindie = None ++ if self.encoder_path: ++ self.encoder_mindie = torch.jit.load(self.encoder_path) ++ self.is_mindie = True ++ if self.decoder_path: ++ self.decoder_mindie = torch.jit.load(self.decoder_path) ++ ++ self.stream = torch.npu.Stream(f"npu:{device_id}") ++ self.device_id = device_id ++ ++ ++ def get_device(self): ++ return f"npu:{self.device_id}" + + # Initialize weights and apply final processing +- self.post_init() ++ # self.post_init() + + # Model parallel + self.model_parallel = False +@@ -1637,25 +2096,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) +- def forward( +- self, +- input_ids: Optional[torch.LongTensor] = None, +- attention_mask: Optional[torch.FloatTensor] = None, +- decoder_input_ids: Optional[torch.LongTensor] = None, +- decoder_attention_mask: Optional[torch.BoolTensor] = None, +- head_mask: Optional[torch.FloatTensor] = None, +- decoder_head_mask: Optional[torch.FloatTensor] = None, +- cross_attn_head_mask: Optional[torch.Tensor] = None, +- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, +- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, +- inputs_embeds: Optional[torch.FloatTensor] = None, +- decoder_inputs_embeds: Optional[torch.FloatTensor] = None, +- labels: Optional[torch.LongTensor] = None, +- use_cache: Optional[bool] = None, +- output_attentions: Optional[bool] = None, +- output_hidden_states: Optional[bool] = None, +- return_dict: Optional[bool] = None, +- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: ++ def forward(self,*args) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., +@@ -1687,113 +2128,37 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" +- use_cache = use_cache if use_cache is not None else self.config.use_cache +- return_dict = return_dict if return_dict is not None else self.config.use_return_dict +- +- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +- if head_mask is not None and decoder_head_mask is None: +- if self.config.num_layers == self.config.num_decoder_layers: +- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) +- decoder_head_mask = head_mask +- +- # Encode if needed (training, first prediction pass) +- if encoder_outputs is None: +- # Convert encoder inputs in embeddings if needed +- encoder_outputs = self.encoder( +- input_ids=input_ids, +- attention_mask=attention_mask, +- inputs_embeds=inputs_embeds, +- head_mask=head_mask, +- output_attentions=output_attentions, +- output_hidden_states=output_hidden_states, +- return_dict=return_dict, +- ) +- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): +- encoder_outputs = BaseModelOutput( +- last_hidden_state=encoder_outputs[0], +- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, +- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, +- ) +- +- hidden_states = encoder_outputs[0] +- +- if self.model_parallel: +- torch.cuda.set_device(self.decoder.first_device) +- +- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: +- # get decoder inputs from shifting lm labels to the right +- decoder_input_ids = self._shift_right(labels) +- +- # Set device for model parallelism +- if self.model_parallel: +- torch.cuda.set_device(self.decoder.first_device) +- hidden_states = hidden_states.to(self.decoder.first_device) +- if decoder_input_ids is not None: +- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) +- if attention_mask is not None: +- attention_mask = attention_mask.to(self.decoder.first_device) +- if decoder_attention_mask is not None: +- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) +- +- # Decode +- decoder_outputs = self.decoder( +- input_ids=decoder_input_ids, +- attention_mask=decoder_attention_mask, +- inputs_embeds=decoder_inputs_embeds, +- past_key_values=past_key_values, +- encoder_hidden_states=hidden_states, +- encoder_attention_mask=attention_mask, +- head_mask=decoder_head_mask, +- cross_attn_head_mask=cross_attn_head_mask, +- use_cache=use_cache, +- output_attentions=output_attentions, +- output_hidden_states=output_hidden_states, +- return_dict=return_dict, +- ) +- +- sequence_output = decoder_outputs[0] +- +- # Set device for model parallelism +- if self.model_parallel: +- torch.cuda.set_device(self.encoder.first_device) +- self.lm_head = self.lm_head.to(self.encoder.first_device) +- sequence_output = sequence_output.to(self.lm_head.weight.device) +- +- if self.config.tie_word_embeddings: +- # Rescale output before projecting on vocab +- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 +- sequence_output = sequence_output * (self.model_dim**-0.5) +- +- lm_logits = self.lm_head(sequence_output) ++ if self.is_mindie: ++ with torch.npu.stream(self.stream): # set stream ++ decoder_outputs = self.decoder_mindie.forward(*args) ++ self.stream.synchronize() # synchronize ++ else: ++ hidden_states = args[0] ++ past_cross_keys = args[1:self.config.num_decoder_layers+1] ++ past_cross_values = args[self.config.num_decoder_layers+1:2*self.config.num_decoder_layers+1] ++ past_keys= args[2*self.config.num_decoder_layers+1:3*self.config.num_decoder_layers+1] ++ past_values= args[3*self.config.num_decoder_layers+1:4*self.config.num_decoder_layers+1] ++ encoder_attention_mask = args[-2] ++ decoder_input_ids = args[-1] ++ decoder_outputs = self.decoder(input_ids=decoder_input_ids, ++ encoder_hidden_states=hidden_states, ++ past_keys=past_keys, ++ past_values=past_values, ++ past_cross_keys=past_cross_keys, ++ past_cross_values=past_cross_values, ++ encoder_attention_mask=encoder_attention_mask) ++ + + loss = None +- if labels is not None: +- loss_fct = CrossEntropyLoss(ignore_index=-100) +- # move labels to correct device to enable PP +- labels = labels.to(lm_logits.device) +- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) +- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 +- +- if not return_dict: +- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs +- return ((loss,) + output) if loss is not None else output +- +- return Seq2SeqLMOutput( +- loss=loss, +- logits=lm_logits, +- past_key_values=decoder_outputs.past_key_values, +- decoder_hidden_states=decoder_outputs.hidden_states, +- decoder_attentions=decoder_outputs.attentions, +- cross_attentions=decoder_outputs.cross_attentions, +- encoder_last_hidden_state=encoder_outputs.last_hidden_state, +- encoder_hidden_states=encoder_outputs.hidden_states, +- encoder_attentions=encoder_outputs.attentions, +- ) ++ return (decoder_outputs[0],decoder_outputs[1],decoder_outputs[2]) + + def prepare_inputs_for_generation( + self, + input_ids, +- past_key_values=None, ++ past_cross_keys=None, ++ past_cross_values=None, ++ past_keys=None, ++ past_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, +@@ -1804,8 +2169,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used +- if past_key_values is not None: +- past_length = past_key_values[0][0].shape[2] ++ if past_keys is not None: ++ past_length = past_keys[0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: +@@ -1813,12 +2178,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 +- + input_ids = input_ids[:, remove_prefix_length:] + + return { + "decoder_input_ids": input_ids, +- "past_key_values": past_key_values, ++ "past_cross_keys":past_cross_keys, ++ "past_cross_values":past_cross_values, ++ "past_keys":past_keys, ++ "past_values":past_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, +@@ -1826,6 +2193,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + "decoder_attention_mask": decoder_attention_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, ++ + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): +@@ -1861,6 +2229,459 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + ++ def _prepare_encoder_decoder_kwargs_for_generation( ++ self, ++ inputs_tensor: torch.Tensor, ++ model_kwargs, ++ model_input_name, ++ generation_config, ++ ): ++ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] ++ encoder_kwargs = { ++ argument: value ++ for argument, value in model_kwargs.items() ++ if not any(argument.startswith(p) for p in irrelevant_prefix) ++ } ++ encoder_kwargs["output_attentions"] = generation_config.output_attentions ++ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states ++ model_input_name = model_input_name if model_input_name is not None else self.main_input_name ++ encoder_kwargs["return_dict"] = True ++ encoder_kwargs[model_input_name] = inputs_tensor ++ if self.is_mindie: ++ with torch.npu.stream(self.stream): # set stream ++ encoder_outputs=self.encoder_mindie.forward(encoder_kwargs["input_ids"],encoder_kwargs["attention_mask"]) ++ self.stream.synchronize() # synchronize ++ else: ++ encoder_outputs=self.encoder.forward(**encoder_kwargs) ++ model_kwargs["encoder_outputs"]={"last_hidden_state":encoder_outputs[0]} ++ model_kwargs["past_cross_keys"] = encoder_outputs[1] ++ model_kwargs["past_cross_values"] =encoder_outputs[2] ++ return model_kwargs ++ ++ def _update_model_kwargs_for_generation( ++ self, ++ outputs, ++ model_kwargs, ++ is_encoder_decoder = False, ++ standardize_cache_format = False, ++ num_new_tokens = 1, ++ ): ++ # update past_key_values keeping its naming used in model code ++ cache_name, cache = self._extract_past_from_model_output( ++ outputs, standardize_cache_format=standardize_cache_format ++ ) ++ model_kwargs[cache_name] = cache ++ if "past_keys" in outputs: ++ past_keys = outputs.past_keys ++ model_kwargs["past_keys"] = past_keys ++ if "past_values" in outputs: ++ past_values = outputs.past_values ++ model_kwargs["past_values"] = past_values ++ # update decoder attention mask ++ if "decoder_attention_mask" in model_kwargs: ++ decoder_attention_mask = model_kwargs["decoder_attention_mask"] ++ model_kwargs["decoder_attention_mask"] = torch.cat( ++ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], ++ dim=-1, ++ ) ++ return model_kwargs ++ ++ @torch.no_grad() ++ def generate( ++ self, ++ inputs = None, ++ generation_config = None, ++ logits_processor = None, ++ stopping_criteria = None, ++ prefix_allowed_tokens_fn = None, ++ assistant_model = None, ++ negative_prompt_ids = None, ++ negative_prompt_attention_mask = None, ++ **kwargs, ++ ): ++ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call ++ self._validate_model_class() ++ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria ++ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) ++ self._validate_model_kwargs(model_kwargs.copy()) ++ ++ ++ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() ++ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() ++ ++ accepts_attention_mask = True ++ requires_attention_mask = "encoder_outputs" not in model_kwargs ++ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None ++ ++ # 3. Define model inputs ++ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( ++ inputs, generation_config.bos_token_id, model_kwargs ++ ) ++ batch_size = inputs_tensor.shape[0] ++ ++ device = inputs_tensor.device ++ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) ++ ++ # 4. Define other model kwargs ++ # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are ++ # generating the first new token or not, and we only want to use the embeddings for the first new token) ++ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": ++ model_kwargs["use_cache"] = True ++ else: ++ model_kwargs["use_cache"] = generation_config.use_cache ++ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: ++ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( ++ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ++ ) ++ ++ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: ++ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` ++ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( ++ inputs_tensor, model_kwargs, model_input_name, generation_config ++ ) ++ ++ # 5. Prepare `input_ids` which will be used for auto-regressive generation ++ if self.config.is_encoder_decoder: ++ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( ++ batch_size=batch_size, ++ model_input_name=model_input_name, ++ model_kwargs=model_kwargs, ++ decoder_start_token_id=generation_config.decoder_start_token_id, ++ device=inputs_tensor.device, ++ ) ++ else: ++ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") ++ ++ if generation_config.token_healing: ++ input_ids = self.heal_tokens(input_ids, tokenizer) ++ ++ # 6. Prepare `max_length` depending on other stopping criteria. ++ input_ids_length = input_ids.shape[-1] ++ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None ++ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None ++ generation_config = self._prepare_generated_length( ++ generation_config=generation_config, ++ has_default_max_length=has_default_max_length, ++ has_default_min_length=has_default_min_length, ++ model_input_name=model_input_name, ++ inputs_tensor=inputs_tensor, ++ input_ids_length=input_ids_length, ++ ) ++ ++ use_dynamic_cache_by_default = False ++ if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: ++ raise ValueError( ++ "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " ++ "Cache object) is unsupported. Please use only one of the two." ++ ) ++ elif generation_config.cache_implementation is not None: ++ if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: ++ if generation_config.cache_implementation == "static" and not self._supports_static_cache: ++ raise ValueError( ++ "This model does not support `cache_implementation='static'`. Please check the following " ++ "issue: https://github.com/huggingface/transformers/issues/28981" ++ ) ++ model_kwargs["past_key_values"] = self._get_cache( ++ generation_config.cache_implementation, ++ getattr(generation_config, "num_beams", 1) * batch_size, ++ generation_config.max_length, ++ ) ++ elif generation_config.cache_implementation == "quantized": ++ if not self._supports_quantized_cache: ++ raise ValueError( ++ "This model does not support the quantized cache. If you want your model to support quantized " ++ "cache, please open an issue." ++ ) ++ ++ cache_config = ( ++ generation_config.cache_config ++ if generation_config.cache_config is not None ++ else QuantizedCacheConfig() ++ ) ++ cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] ++ ++ if cache_config.backend == "quanto" and not is_quanto_available(): ++ raise ImportError( ++ "You need to install `quanto` in order to use KV cache quantization with quanto backend. " ++ "Please install it via with `pip install quanto`" ++ ) ++ elif cache_config.backend == "HQQ" and not is_hqq_available(): ++ raise ImportError( ++ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " ++ "Please install it via with `pip install hqq`" ++ ) ++ ++ model_kwargs["past_key_values"] = cache_class(cache_config) ++ # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that ++ # keeps copying the cache thus using much more memory ++ elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): ++ past = model_kwargs.get("past_key_values", None) ++ if past is None: ++ model_kwargs["past_key_values"] = DynamicCache() ++ use_dynamic_cache_by_default = True ++ elif isinstance(past, tuple): ++ model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past) ++ use_dynamic_cache_by_default = True ++ ++ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) ++ ++ # 7. determine generation mode ++ generation_mode = generation_config.get_generation_mode(assistant_model) ++ # 8. prepare distribution pre_processing samplers ++ prepared_logits_processor = self._get_logits_processor( ++ generation_config=generation_config, ++ input_ids_seq_length=input_ids_length, ++ encoder_input_ids=inputs_tensor, ++ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ++ logits_processor=logits_processor, ++ device=inputs_tensor.device, ++ model_kwargs=model_kwargs, ++ negative_prompt_ids=negative_prompt_ids, ++ negative_prompt_attention_mask=negative_prompt_attention_mask, ++ ) ++ ++ # 9. prepare stopping criteria ++ prepared_stopping_criteria = self._get_stopping_criteria( ++ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ++ ) ++ ++ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): ++ # 11. prepare logits warper ++ prepared_logits_warper = ( ++ self._get_logits_warper(generation_config, device=input_ids.device) ++ if generation_config.do_sample ++ else None ++ ) ++ ++ # 12. expand input_ids with `num_return_sequences` additional sequences per batch ++ input_ids, model_kwargs = self._expand_inputs_for_generation( ++ input_ids=input_ids, ++ expand_size=generation_config.num_return_sequences, ++ is_encoder_decoder=self.config.is_encoder_decoder, ++ **model_kwargs, ++ ) ++ # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) ++ result = self._sample( ++ input_ids, ++ logits_processor=prepared_logits_processor, ++ logits_warper=prepared_logits_warper, ++ stopping_criteria=prepared_stopping_criteria, ++ generation_config=generation_config, ++ **model_kwargs, ++ ) ++ return result ++ ++ def _sample( ++ self, ++ input_ids, ++ logits_processor, ++ stopping_criteria, ++ generation_config, ++ logits_warper = None, ++ **model_kwargs, ++ ): ++ # init values ++ pad_token_id = generation_config.pad_token_id ++ output_attentions = generation_config.output_attentions ++ output_hidden_states = generation_config.output_hidden_states ++ output_scores = generation_config.output_scores ++ output_logits = generation_config.output_logits ++ return_dict_in_generate = generation_config.return_dict_in_generate ++ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) ++ do_sample = generation_config.do_sample ++ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): ++ raise ValueError( ++ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " ++ f"{logits_warper})." ++ ) ++ ++ # init attention / hidden states / scores tuples ++ scores = () if (return_dict_in_generate and output_scores) else None ++ raw_logits = () if (return_dict_in_generate and output_logits) else None ++ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None ++ cross_attentions = () if (return_dict_in_generate and output_attentions) else None ++ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None ++ ++ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states ++ if return_dict_in_generate and self.config.is_encoder_decoder: ++ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None ++ encoder_hidden_states = ( ++ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ++ ) ++ ++ this_peer_finished = False ++ batch_size = input_ids.shape[0] ++ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) ++ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) ++ ++ # keep track of which sequences are already finished ++ if self.is_mindie or self.config.architectures[0]=="T5ForConditionalGeneration": ++ num_layers = self.config.num_layers ++ num_heads = self.config.num_heads ++ d_kv = self.config.d_kv ++ model_kwargs["past_keys"] = [torch.randn(batch_size, num_heads, 0, d_kv).half().npu() for _ in range(num_layers)] ++ model_kwargs["past_values"] = [torch.randn(batch_size, num_heads, 0, d_kv).half().npu() for _ in range(num_layers)] ++ ++ ++ while self._has_unfinished_sequences(this_peer_finished, False, device=input_ids.device): ++ # prepare model inputs ++ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) ++ model_args = [model_kwargs["encoder_outputs"]["last_hidden_state"]] ++ model_args.extend(model_kwargs["past_cross_keys"]) ++ model_args.extend(model_kwargs["past_cross_values"]) ++ model_args.extend(model_inputs["past_keys"]) ++ model_args.extend(model_inputs["past_values"]) ++ model_args.append(model_inputs["attention_mask"]) ++ model_args.append(model_inputs["decoder_input_ids"]) ++ ++ # forward pass to get next token ++ outputs = self(*model_args) ++ outputs = Seq2SeqLMOutput(logits=outputs[0], ++ past_keys=outputs[1], ++ past_values=outputs[2]) ++ ++ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration ++ # (the clone itself is always small) ++ next_token_logits = outputs.logits[:, -1, :].clone() ++ ++ # pre-process distribution ++ next_token_scores = logits_processor(input_ids, next_token_logits) ++ if do_sample: ++ next_token_scores = logits_warper(input_ids, next_token_scores) ++ ++ # Store scores, attentions and hidden_states when required ++ if return_dict_in_generate: ++ if output_scores: ++ scores += (next_token_scores,) ++ if output_logits: ++ raw_logits += (next_token_logits,) ++ if output_attentions: ++ decoder_attentions += ( ++ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ++ ) ++ if self.config.is_encoder_decoder: ++ cross_attentions += (outputs.cross_attentions,) ++ ++ if output_hidden_states: ++ decoder_hidden_states += ( ++ (outputs.decoder_hidden_states,) ++ if self.config.is_encoder_decoder ++ else (outputs.hidden_states,) ++ ) ++ ++ # token selection ++ if do_sample: ++ probs = nn.functional.softmax(next_token_scores, dim=-1) ++ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) ++ else: ++ next_tokens = torch.argmax(next_token_scores, dim=-1) ++ ++ # finished sentences should have their next token be a padding token ++ if has_eos_stopping_criteria: ++ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) ++ ++ # update generated ids, model inputs, and length for next step ++ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) ++ model_kwargs = self._update_model_kwargs_for_generation( ++ outputs, ++ model_kwargs, ++ is_encoder_decoder=self.config.is_encoder_decoder, ++ ) ++ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) ++ this_peer_finished = unfinished_sequences.max() == 0 ++ # This is needed to properly delete outputs.logits which may be very large for first iteration ++ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration ++ del outputs ++ return input_ids ++ ++ def invert_attention_mask(self, encoder_attention_mask): ++ """ ++ Invert an attention mask (e.g., switches 0. and 1.). ++ ++ Args: ++ encoder_attention_mask (`torch.Tensor`): An attention mask. ++ ++ Returns: ++ `torch.Tensor`: The inverted attention mask. ++ """ ++ if encoder_attention_mask.dim() == 3: ++ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] ++ if encoder_attention_mask.dim() == 2: ++ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] ++ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition ++ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow ++ # /transformer/transformer_layers.py#L270 ++ # encoder_extended_attention_mask = (encoder_extended_attention_mask == ++ # encoder_extended_attention_mask.transpose(-1, -2)) ++ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility ++ #encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min ++ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1000 ++ ++ return encoder_extended_attention_mask ++ ++ @property ++ def device(self) -> torch.device: ++ """ ++ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same ++ device). ++ """ ++ return self.get_device() ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, devic=None, dtype=None ++ ): ++ """ ++ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. ++ ++ Arguments: ++ attention_mask (`torch.Tensor`): ++ Mask with ones indicating tokens to attend to, zeros for tokens to ignore. ++ input_shape (`Tuple[int]`): ++ The shape of the input to the model. ++ ++ Returns: ++ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. ++ """ ++ if dtype is None: ++ dtype = self.dtype ++ ++ if not (attention_mask.dim() == 2 and self.config.is_decoder): ++ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` ++ if device is not None: ++ warnings.warn( ++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ++ ) ++ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] ++ # ourselves in which case we just need to make it broadcastable to all heads. ++ if attention_mask.dim() == 3: ++ extended_attention_mask = attention_mask[:, None, :, :] ++ elif attention_mask.dim() == 2: ++ # Provided a padding mask of dimensions [batch_size, seq_length] ++ # - if the model is a decoder, apply a causal mask in addition to the padding mask ++ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] ++ if self.config.is_decoder: ++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( ++ input_shape, attention_mask, device ++ ) ++ else: ++ extended_attention_mask = attention_mask[:, None, None, :] ++ else: ++ raise ValueError( ++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ++ ) ++ ++ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for ++ # masked positions, this operation will create a tensor which is 0.0 for ++ # positions we want to attend and the dtype's smallest value for masked positions. ++ # Since we are adding it to the raw scores before the softmax, this is ++ # effectively the same as removing these entirely. ++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility ++ #extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min ++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000 ++ return extended_attention_mask ++ ++ ++ + + @add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", +@@ -1967,7 +2788,6 @@ class T5EncoderModel(T5PreTrainedModel): + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict +- + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, diff --git a/MindIE/MindIE-Torch/built-in/T5/modeling_t5_800IA2.patch b/MindIE/MindIE-Torch/built-in/T5/modeling_t5_800IA2.patch new file mode 100644 index 0000000000000000000000000000000000000000..664b4359cea4c27a50785ab38844644e14a0d3a1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/T5/modeling_t5_800IA2.patch @@ -0,0 +1,1594 @@ +diff --git a/modeling_t5_origin.py b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py +index 224769fdf..4f9ffd74f 100644 +--- a/modeling_t5_origin.py ++++ b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py +@@ -19,22 +19,26 @@ import math + import os + import warnings + from typing import List, Optional, Tuple, Union +- ++from dataclasses import dataclass + import torch + from torch import nn + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ++# import torch_npu ++# import mindietorch ++ ++ ++ + + from ...activations import ACT2FN + from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, +- Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, + TokenClassifierOutput, + ) +-from ...modeling_utils import PreTrainedModel ++from ...modeling_utils import PreTrainedModel,ModuleUtilsMixin + from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer + from ...utils import ( + DUMMY_INPUTS, +@@ -47,7 +51,43 @@ from ...utils import ( + ) + from ...utils.model_parallel_utils import assert_device_map, get_device_map + from .configuration_t5 import T5Config ++from transformers.generation.logits_process import LogitsProcessorList ++from transformers.generation.stopping_criteria import StoppingCriteriaList ++from transformers.generation.configuration_utils import GenerationMode ++from transformers.utils.generic import ModelOutput ++ ++ ++@dataclass ++class Seq2SeqLMOutput(ModelOutput): ++ """ ++ Base class for model's outputs, with potential hidden states and attentions. + ++ Args: ++ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): ++ Sequence of hidden-states at the output of the last layer of the model. ++ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): ++ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + ++ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. ++ ++ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. ++ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): ++ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, ++ sequence_length)`. ++ ++ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention ++ heads. ++ """ ++ loss: Optional[torch.FloatTensor] = None ++ logits: torch.FloatTensor = None ++ past_keys: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ past_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None ++ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None ++ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None ++ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None ++ encoder_last_hidden_state: Optional[torch.FloatTensor] = None ++ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None ++ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + logger = logging.get_logger(__name__) + +@@ -360,6 +400,7 @@ class T5Attention(nn.Module): + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False ++ self.lay_out = "BSH" + + def prune_heads(self, heads): + if len(heads) == 0: +@@ -448,7 +489,10 @@ class T5Attention(nn.Module): + mask=None, + key_value_states=None, + position_bias=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, ++ past_cross_key=None, ++ past_cross_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, +@@ -464,81 +508,86 @@ class T5Attention(nn.Module): + + real_seq_length = seq_length + +- if past_key_value is not None: +- if len(past_key_value) != 2: +- raise ValueError( +- f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" +- ) +- real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length ++ if past_key is not None: ++ real_seq_length += past_key.shape[1] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] ++ # BSH ++ query_states = self.q(hidden_states) ++ key_states = past_key ++ value_states = past_value ++ attn_output = torch.ops.aie.flash_attention(query_states,key_states,value_states,self.n_heads,attn_mask=mask) ++ # mask = mask.expand(3,1,16,mask.shape[3]).bool() ++ # attn_output = torch_npu.npu_prompt_flash_attention(query_states,key_states,value_states,atten_mask=mask,num_heads=self.n_heads,input_layout="BSH") ++ attn_output = self.o(attn_output) ++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None ++ present_value_state = (value_states.half(),) if (self.is_decoder and use_cache) else None ++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,) ++ return outputs ++ ++ ++class T5SelfAttention(T5Attention): ++ def __init__(self, config: T5Config, has_relative_attention_bias=False): ++ super().__init__(config, has_relative_attention_bias) + +- def shape(states): +- """projection""" +- return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) ++ def forward( ++ self, ++ hidden_states, ++ mask=None, ++ position_bias=None, ++ past_key=None, ++ past_value=None, ++ layer_head_mask=None, ++ use_cache=False, ++ output_attentions=False, ++ ): ++ """ ++ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). ++ """ ++ # Input is (batch_size, seq_length, dim) ++ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) ++ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) ++ batch_size, seq_length = hidden_states.shape[:2] + +- def unshape(states): +- """reshape""" +- return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) ++ real_seq_length = seq_length ++ ++ if past_key is not None: ++ real_seq_length += past_key.shape[1] ++ key_length = real_seq_length + +- def project(hidden_states, proj_layer, key_value_states, past_key_value): ++ def project(hidden_states, proj_layer, past_key_value): + """projects hidden states correctly to key/query states""" +- if key_value_states is None: +- # self-attn +- # (batch_size, n_heads, seq_length, dim_per_head) +- hidden_states = shape(proj_layer(hidden_states)) +- elif past_key_value is None: +- # cross-attn +- # (batch_size, n_heads, seq_length, dim_per_head) +- hidden_states = shape(proj_layer(key_value_states)) ++ if past_key_value is None: ++ hidden_states = proj_layer(hidden_states) + + if past_key_value is not None: +- if key_value_states is None: +- # self-attn +- # (batch_size, n_heads, key_length, dim_per_head) +- hidden_states = torch.cat([past_key_value, hidden_states], dim=2) +- elif past_key_value.shape[2] != key_value_states.shape[1]: +- # checking that the `sequence_length` of the `past_key_value` is the same as +- # the provided `key_value_states` to support prefix tuning +- # cross-attn +- # (batch_size, n_heads, seq_length, dim_per_head) +- hidden_states = shape(proj_layer(key_value_states)) +- else: +- # cross-attn +- hidden_states = past_key_value ++ hidden_states = proj_layer(hidden_states) ++ hidden_states = torch.cat([past_key_value, hidden_states], dim=1) + return hidden_states + + # get query states +- query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) +- ++ query_states = self.q(hidden_states) + # get key/value states + key_states = project( +- hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ++ hidden_states, self.k, past_key if past_key is not None else None + ) + value_states = project( +- hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None ++ hidden_states, self.v, past_value if past_value is not None else None + ) +- +- # compute scores +- scores = torch.matmul( +- query_states, key_states.transpose(3, 2) +- ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 +- + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( +- (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ++ (1, self.n_heads, real_seq_length, key_length), device=query_states.device, dtype=query_states.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: +- position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) ++ position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device) + + # if key and values are already calculated + # we want only the last query position bias +- if past_key_value is not None: ++ if past_key is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] +- + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + +@@ -548,34 +597,26 @@ class T5Attention(nn.Module): + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias +- +- scores += position_bias_masked +- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( +- scores +- ) # (batch_size, n_heads, seq_length, key_length) +- attn_weights = nn.functional.dropout( +- attn_weights, p=self.dropout, training=self.training +- ) # (batch_size, n_heads, seq_length, key_length) +- +- # Mask heads if we want to +- if layer_head_mask is not None: +- attn_weights = attn_weights * layer_head_mask +- +- attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) ++ # scores += position_bias_masked ++ # attn_output = torch.ops.aie.flash_attention(query_states,key_states,value_states,self.n_heads,pse=position_bias_masked) ++ attn_output = torch.ops.aie.flash_attention(query_states,key_states,value_states,self.n_heads,pse=position_bias_masked,attn_mask=mask) ++ # print("mask=",mask,mask.shape) ++ # mask = mask.expand(3,1,16,mask.shape[3]).bool() ++ # attn_output = torch_npu.npu_prompt_flash_attention(query_states,key_states,value_states,pse_shift=position_bias_masked, atten_mask=mask,num_heads=self.n_heads,input_layout="BSH") + attn_output = self.o(attn_output) ++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None ++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None ++ present_value_state = (value_states.half(), ) if (self.is_decoder and use_cache) else None ++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,) ++ return outputs + +- present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None +- outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + +- if output_attentions: +- outputs = outputs + (attn_weights,) +- return outputs + + + class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() +- self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) ++ self.SelfAttention = T5SelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + +@@ -585,7 +626,8 @@ class T5LayerSelfAttention(nn.Module): + attention_mask=None, + position_bias=None, + layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, + use_cache=False, + output_attentions=False, + ): +@@ -595,7 +637,8 @@ class T5LayerSelfAttention(nn.Module): + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +@@ -618,7 +661,8 @@ class T5LayerCrossAttention(nn.Module): + attention_mask=None, + position_bias=None, + layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, + use_cache=False, + query_length=None, + output_attentions=False, +@@ -630,7 +674,8 @@ class T5LayerCrossAttention(nn.Module): + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, +@@ -661,39 +706,34 @@ class T5Block(nn.Module): + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, +- past_key_value=None, ++ past_key=None, ++ past_value=None, ++ past_cross_key=None, ++ past_cross_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): +- if past_key_value is not None: +- if not self.is_decoder: +- logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") +- expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 +- +- if len(past_key_value) != expected_num_past_key_values: +- raise ValueError( +- f"There should be {expected_num_past_key_values} past states. " +- f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" +- f"Got {len(past_key_value)} past key / value states" +- ) +- +- self_attn_past_key_value = past_key_value[:2] +- cross_attn_past_key_value = past_key_value[2:] ++ if past_key is not None: ++ self_attn_past_key = past_key ++ self_attn_past_value = past_value ++ cross_attn_past_key = past_cross_key ++ cross_attn_past_value = past_cross_value + else: +- self_attn_past_key_value, cross_attn_past_key_value = None, None ++ self_attn_past_key, self_attn_past_value, cross_attn_past_key, cross_attn_past_value = None, None, None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, +- past_key_value=self_attn_past_key_value, ++ past_key=self_attn_past_key, ++ past_value=self_attn_past_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +- hidden_states, present_key_value_state = self_attention_outputs[:2] +- attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights ++ hidden_states, present_key_state, present_value_state = self_attention_outputs[:3] ++ attention_outputs = self_attention_outputs[3:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: +@@ -706,22 +746,23 @@ class T5Block(nn.Module): + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: ++ + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here +- if present_key_value_state is not None: +- query_length = present_key_value_state[0].shape[2] ++ if present_key_state is not None: ++ query_length = present_key_state[0].shape[1] + else: + query_length = None +- + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, +- past_key_value=cross_attn_past_key_value, ++ past_key=cross_attn_past_key, ++ past_value=cross_attn_past_value, + query_length=query_length, +- use_cache=use_cache, ++ use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] +@@ -736,11 +777,9 @@ class T5Block(nn.Module): + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states +- if present_key_value_state is not None: +- present_key_value_state = present_key_value_state + cross_attention_outputs[1] +- ++ # cross_attn_past_key_values = cross_attention_outputs[1] + # Keep cross-attention outputs and relative position weights +- attention_outputs = attention_outputs + cross_attention_outputs[2:] ++ attention_outputs = attention_outputs + cross_attention_outputs[3:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) +@@ -757,7 +796,7 @@ class T5Block(nn.Module): + outputs = (hidden_states,) + + if use_cache: +- outputs = outputs + (present_key_value_state,) + attention_outputs ++ outputs = outputs + (present_key_state,) +(present_value_state,)+ attention_outputs + else: + outputs = outputs + attention_outputs + +@@ -897,11 +936,15 @@ class T5PreTrainedModel(PreTrainedModel): + + + class T5Stack(T5PreTrainedModel): +- def __init__(self, config, embed_tokens=None): ++ def __init__(self, config, embed_tokens=None,lm_head=None, encodecrosskey=None, encodecrossvalue=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder ++ self.lm_head=lm_head ++ self.encodecrosskey = encodecrosskey ++ self.encodecrossvalue = encodecrossvalue ++ self.model_dim = config.d_model + + self.block = nn.ModuleList( + [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] +@@ -966,16 +1009,48 @@ class T5Stack(T5PreTrainedModel): + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, device=None, dtype=None ++ ): ++ if dtype is None: ++ dtype = self.dtype ++ ++ if not (attention_mask.dim() == 2 and self.config.is_decoder): ++ if device is not None: ++ warnings.warn( ++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ++ ) ++ if attention_mask.dim() == 3: ++ extended_attention_mask = attention_mask[:, None, :, :] ++ elif attention_mask.dim() == 2: ++ if self.config.is_decoder: ++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( ++ input_shape, attention_mask, device ++ ) ++ else: ++ extended_attention_mask = attention_mask[:, None, None, :] ++ else: ++ raise ValueError( ++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ++ ) ++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility ++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000 ++ return extended_attention_mask ++ + def forward( + self, + input_ids=None, +- attention_mask=None, + encoder_hidden_states=None, ++ past_keys=None, ++ past_values=None, ++ past_cross_keys=None, ++ past_cross_values=None, + encoder_attention_mask=None, ++ attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, +- past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, +@@ -998,8 +1073,10 @@ class T5Stack(T5PreTrainedModel): + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: ++ + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) ++ input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: +@@ -1012,25 +1089,29 @@ class T5Stack(T5PreTrainedModel): + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape +- + # required mask seq length can be calculated via length of past +- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length ++ mask_seq_length = past_keys[0].shape[1] + seq_length if past_keys is not None else seq_length + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + # initialize past_key_values with `None` if past does not exist +- if past_key_values is None: +- past_key_values = [None] * len(self.block) +- ++ if not self.is_decoder: ++ past_keys = [None] * len(self.block) ++ past_values = [None] * len(self.block) ++ past_cross_keys = [None] * len(self.block) ++ past_cross_values = [None] * len(self.block) + if attention_mask is None: +- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) ++ print("aaaaaaaaaaaaaaaaa") ++ attention_mask = torch.zeros(batch_size, mask_seq_length, device=inputs_embeds.device) ++ attention_mask = attention_mask[:,None,None,:].expand(batch_size,1,mask_seq_length,mask_seq_length).bool() + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. +- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) +- ++ # extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) ++ extended_attention_mask = attention_mask ++ # print("extended_attention_mask=",extended_attention_mask) + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: +@@ -1040,7 +1121,7 @@ class T5Stack(T5PreTrainedModel): + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) +- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) ++ encoder_extended_attention_mask = encoder_attention_mask + else: + encoder_extended_attention_mask = None + +@@ -1054,7 +1135,8 @@ class T5Stack(T5PreTrainedModel): + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) +- present_key_value_states = () if use_cache else None ++ present_key_states = () if use_cache else None ++ present_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None +@@ -1062,8 +1144,8 @@ class T5Stack(T5PreTrainedModel): + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) +- +- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): ++ # for i, layer_module in enumerate(self.block): ++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel +@@ -1112,7 +1194,10 @@ class T5Stack(T5PreTrainedModel): + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, +- past_key_value=past_key_value, ++ past_key=past_key, ++ past_value=past_value, ++ past_cross_key=past_cross_key, ++ past_cross_value=past_cross_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) +@@ -1120,19 +1205,20 @@ class T5Stack(T5PreTrainedModel): + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: +- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] ++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:] + +- hidden_states, present_key_value_state = layer_outputs[:2] ++ hidden_states, present_key_state, present_value_state = layer_outputs[:3] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) +- position_bias = layer_outputs[2] ++ position_bias = layer_outputs[3] + if self.is_decoder and encoder_hidden_states is not None: +- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] ++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4] + # append next layer key value states + if use_cache: +- present_key_value_states = present_key_value_states + (present_key_value_state,) ++ present_key_states = present_key_states + present_key_state ++ present_value_states = present_value_states + present_value_state + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) +@@ -1146,31 +1232,158 @@ class T5Stack(T5PreTrainedModel): + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) +- hidden_states = self.dropout(hidden_states) ++ hidden_states = self.dropout(hidden_states).half() + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) ++ if self.config.tie_word_embeddings: ++ hidden_states = hidden_states * (self.model_dim ** -0.5) ++ lm_logits = self.lm_head(hidden_states) ++ return tuple((lm_logits, present_key_states, present_value_states)) + +- if not return_dict: +- return tuple( +- v +- for v in [ +- hidden_states, +- present_key_value_states, +- all_hidden_states, +- all_attentions, +- all_cross_attentions, +- ] +- if v is not None +- ) +- return BaseModelOutputWithPastAndCrossAttentions( +- last_hidden_state=hidden_states, +- past_key_values=present_key_value_states, +- hidden_states=all_hidden_states, +- attentions=all_attentions, +- cross_attentions=all_cross_attentions, ++ ++class T5Stack_Encoder(T5PreTrainedModel): ++ def __init__(self, config, embed_tokens=None, encodecrosskey=None, encodecrossvalue=None): ++ super().__init__(config) ++ self.embed_tokens = embed_tokens ++ self.is_decoder = config.is_decoder ++ self.encodecrosskey = encodecrosskey ++ self.encodecrossvalue = encodecrossvalue ++ self.model_dim = config.d_model ++ ++ self.block = nn.ModuleList( ++ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] ++ ) ++ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) ++ self.dropout = nn.Dropout(config.dropout_rate) ++ ++ # Initialize weights and apply final processing ++ self.post_init() ++ # Model parallel ++ self.model_parallel = False ++ self.device_map = None ++ self.gradient_checkpointing = False ++ ++ def get_input_embeddings(self): ++ return self.embed_tokens ++ ++ def set_input_embeddings(self, new_embeddings): ++ self.embed_tokens = new_embeddings ++ ++ ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, device=None, dtype=None ++ ): ++ extended_attention_mask = attention_mask[:,None,None,:].expand(input_shape[0],1,input_shape[1],input_shape[1]).bool() ++ extended_attention_mask = ~extended_attention_mask ++ return extended_attention_mask ++ ++ def forward( ++ self, ++ input_ids=None, ++ attention_mask=None, ++ head_mask=None, ++ cross_attn_head_mask=None, ++ use_cache=None, ++ output_attentions=None, ++ output_hidden_states=None, ++ return_dict=None, ++ ): ++ # Model parallel ++ use_cache = use_cache if use_cache is not None else self.config.use_cache ++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions ++ output_hidden_states = ( ++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ input_shape = input_ids.size() ++ input_ids = input_ids.view(-1, input_shape[-1]) ++ ++ inputs_embeds = self.embed_tokens(input_ids) ++ ++ batch_size, seq_length = input_shape ++ # required mask seq length can be calculated via length of past ++ mask_seq_length = seq_length ++ ++ # initialize past_key_values with `None` if past does not exist ++ past_keys = [None] * len(self.block) ++ past_values = [None] * len(self.block) ++ past_cross_keys = [None] * len(self.block) ++ past_cross_values = [None] * len(self.block) ++ # print("attention_mask=",attention_mask) ++ if attention_mask is None: ++ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) ++ encoder_extended_attention_mask = None ++ # Prepare head mask if needed ++ head_mask = self.get_head_mask(head_mask, self.config.num_layers) ++ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) ++ present_key_states = () if use_cache else None ++ present_value_states = () if use_cache else None ++ all_hidden_states = () if output_hidden_states else None ++ all_attentions = () if output_attentions else None ++ all_cross_attentions = () if (output_attentions and self.is_decoder) else None ++ position_bias = None ++ encoder_decoder_position_bias = None ++ ++ hidden_states = self.dropout(inputs_embeds) ++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)): ++ layer_head_mask = head_mask[i] ++ cross_attn_layer_head_mask = cross_attn_head_mask[i] ++ if output_hidden_states: ++ all_hidden_states = all_hidden_states + (hidden_states,) ++ layer_outputs = layer_module( ++ hidden_states, ++ attention_mask=attention_mask, ++ position_bias=position_bias, ++ encoder_hidden_states=None, ++ encoder_attention_mask=encoder_extended_attention_mask, ++ encoder_decoder_position_bias=encoder_decoder_position_bias, ++ layer_head_mask=layer_head_mask, ++ cross_attn_layer_head_mask=cross_attn_layer_head_mask, ++ past_key=past_key, ++ past_value=past_value, ++ past_cross_key=past_cross_key, ++ past_cross_value=past_cross_value, ++ use_cache=use_cache, ++ output_attentions=output_attentions, ++ ) ++ ++ # layer_outputs is a tuple with: ++ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) ++ if use_cache is False: ++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:] ++ ++ hidden_states, present_key_state, present_value_state = layer_outputs[:3] ++ ++ # We share the position biases between the layers - the first layer store them ++ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), ++ # (cross-attention position bias), (cross-attention weights) ++ position_bias = layer_outputs[3] ++ # append next layer key value states ++ if use_cache: ++ present_key_states = present_key_states + present_key_state ++ present_value_states = present_value_states + present_value_state ++ ++ if output_attentions: ++ all_attentions = all_attentions + (layer_outputs[3],) ++ if self.is_decoder: ++ all_cross_attentions = all_cross_attentions + (layer_outputs[5],) ++ ++ hidden_states = self.final_layer_norm(hidden_states) ++ hidden_states = self.dropout(hidden_states).half() ++ ++ # Add last layer ++ if output_hidden_states: ++ all_hidden_states = all_hidden_states + (hidden_states,) ++ ++ if self.encodecrosskey: ++ cross_keys = self.encodecrosskey(hidden_states) ++ if self.encodecrossvalue: ++ cross_values = self.encodecrossvalue(hidden_states) ++ return tuple((hidden_states, cross_keys, cross_values)) + + + T5_START_DOCSTRING = r""" +@@ -1541,6 +1754,41 @@ class T5Model(T5PreTrainedModel): + ) + + ++class EncoderToCrossKey(nn.Module): ++ def __init__(self, cross_key, num_heads, d_kv): ++ super().__init__() ++ self.cross_key = cross_key ++ self.num_heads = num_heads ++ self.d_kv = d_kv ++ ++ ++ def forward(self, hidden_states): ++ batch_size = hidden_states.shape[0] ++ past_cross_keys = () ++ for i in range(len(self.cross_key)): ++ # past_cross_keys += (self.cross_key[i](hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1,2),) ++ past_cross_keys += (self.cross_key[i](hidden_states),) ++ return past_cross_keys ++ ++ ++class EncoderToCrossValue(nn.Module): ++ def __init__(self, cross_value, num_heads, d_kv): ++ super().__init__() ++ self.cross_value = cross_value ++ self.num_heads = num_heads ++ self.d_kv = d_kv ++ ++ ++ def forward(self, hidden_states): ++ batch_size = hidden_states.shape[0] ++ past_cross_values = () ++ for i in range(len(self.cross_value)): ++ # past_cross_values += (self.cross_value[i](hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1,2),) ++ past_cross_values += (self.cross_value[i](hidden_states),) ++ # print("aaa",past_cross_values[0].shape) ++ return past_cross_values ++ ++ + @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) + class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [ +@@ -1548,28 +1796,51 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + ] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + +- def __init__(self, config: T5Config): ++ def __init__(self, config: T5Config, encoder_path=None, decoder_path=None, device_id=0): + super().__init__(config) +- self.model_dim = config.d_model +- +- self.shared = nn.Embedding(config.vocab_size, config.d_model) +- +- encoder_config = copy.deepcopy(config) +- encoder_config.is_decoder = False +- encoder_config.use_cache = False +- encoder_config.is_encoder_decoder = False +- self.encoder = T5Stack(encoder_config, self.shared) +- +- decoder_config = copy.deepcopy(config) +- decoder_config.is_decoder = True +- decoder_config.is_encoder_decoder = False +- decoder_config.num_layers = config.num_decoder_layers +- self.decoder = T5Stack(decoder_config, self.shared) +- +- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) ++ self.encoder_path = encoder_path ++ self.decoder_path = decoder_path ++ self.is_mindie = False ++ if not self.encoder_path or not self.decoder_path: ++ self.model_dim = config.d_model ++ ++ self.shared = nn.Embedding(config.vocab_size, config.d_model) ++ ++ decoder_config = copy.deepcopy(config) ++ decoder_config.is_decoder = True ++ decoder_config.is_encoder_decoder = False ++ decoder_config.num_layers = config.num_decoder_layers ++ ++ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) ++ self.decoder = T5Stack(decoder_config, self.shared, self.lm_head) ++ ++ cross_key = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.k for i in range(config.num_decoder_layers)) ++ cross_value = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.v for i in range(config.num_decoder_layers)) ++ encodecrosskey = EncoderToCrossKey(cross_key, config.num_heads, config.d_kv) ++ encodecrossvalue = EncoderToCrossValue(cross_value, config.num_heads, config.d_kv) ++ ++ encoder_config = copy.deepcopy(config) ++ encoder_config.is_decoder = False ++ encoder_config.use_cache = False ++ encoder_config.is_encoder_decoder = False ++ self.encoder = T5Stack_Encoder(encoder_config, self.shared, encodecrosskey=encodecrosskey, encodecrossvalue=encodecrossvalue) ++ self.encoder_mindie = None ++ self.decoder_mindie = None ++ if self.encoder_path: ++ self.encoder_mindie = torch.jit.load(self.encoder_path) ++ self.is_mindie = True ++ if self.decoder_path: ++ self.decoder_mindie = torch.jit.load(self.decoder_path) ++ ++ self.stream = torch.npu.Stream(f"npu:{device_id}") ++ self.device_id = device_id ++ ++ ++ def get_device(self): ++ return f"npu:{self.device_id}" + + # Initialize weights and apply final processing +- self.post_init() ++ # self.post_init() + + # Model parallel + self.model_parallel = False +@@ -1637,25 +1908,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) +- def forward( +- self, +- input_ids: Optional[torch.LongTensor] = None, +- attention_mask: Optional[torch.FloatTensor] = None, +- decoder_input_ids: Optional[torch.LongTensor] = None, +- decoder_attention_mask: Optional[torch.BoolTensor] = None, +- head_mask: Optional[torch.FloatTensor] = None, +- decoder_head_mask: Optional[torch.FloatTensor] = None, +- cross_attn_head_mask: Optional[torch.Tensor] = None, +- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, +- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, +- inputs_embeds: Optional[torch.FloatTensor] = None, +- decoder_inputs_embeds: Optional[torch.FloatTensor] = None, +- labels: Optional[torch.LongTensor] = None, +- use_cache: Optional[bool] = None, +- output_attentions: Optional[bool] = None, +- output_hidden_states: Optional[bool] = None, +- return_dict: Optional[bool] = None, +- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: ++ def forward(self,*args) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., +@@ -1687,113 +1940,36 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" +- use_cache = use_cache if use_cache is not None else self.config.use_cache +- return_dict = return_dict if return_dict is not None else self.config.use_return_dict +- +- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +- if head_mask is not None and decoder_head_mask is None: +- if self.config.num_layers == self.config.num_decoder_layers: +- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) +- decoder_head_mask = head_mask +- +- # Encode if needed (training, first prediction pass) +- if encoder_outputs is None: +- # Convert encoder inputs in embeddings if needed +- encoder_outputs = self.encoder( +- input_ids=input_ids, +- attention_mask=attention_mask, +- inputs_embeds=inputs_embeds, +- head_mask=head_mask, +- output_attentions=output_attentions, +- output_hidden_states=output_hidden_states, +- return_dict=return_dict, +- ) +- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): +- encoder_outputs = BaseModelOutput( +- last_hidden_state=encoder_outputs[0], +- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, +- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, +- ) +- +- hidden_states = encoder_outputs[0] +- +- if self.model_parallel: +- torch.cuda.set_device(self.decoder.first_device) +- +- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: +- # get decoder inputs from shifting lm labels to the right +- decoder_input_ids = self._shift_right(labels) +- +- # Set device for model parallelism +- if self.model_parallel: +- torch.cuda.set_device(self.decoder.first_device) +- hidden_states = hidden_states.to(self.decoder.first_device) +- if decoder_input_ids is not None: +- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) +- if attention_mask is not None: +- attention_mask = attention_mask.to(self.decoder.first_device) +- if decoder_attention_mask is not None: +- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) +- +- # Decode +- decoder_outputs = self.decoder( +- input_ids=decoder_input_ids, +- attention_mask=decoder_attention_mask, +- inputs_embeds=decoder_inputs_embeds, +- past_key_values=past_key_values, +- encoder_hidden_states=hidden_states, +- encoder_attention_mask=attention_mask, +- head_mask=decoder_head_mask, +- cross_attn_head_mask=cross_attn_head_mask, +- use_cache=use_cache, +- output_attentions=output_attentions, +- output_hidden_states=output_hidden_states, +- return_dict=return_dict, +- ) +- +- sequence_output = decoder_outputs[0] +- +- # Set device for model parallelism +- if self.model_parallel: +- torch.cuda.set_device(self.encoder.first_device) +- self.lm_head = self.lm_head.to(self.encoder.first_device) +- sequence_output = sequence_output.to(self.lm_head.weight.device) +- +- if self.config.tie_word_embeddings: +- # Rescale output before projecting on vocab +- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 +- sequence_output = sequence_output * (self.model_dim**-0.5) +- +- lm_logits = self.lm_head(sequence_output) +- +- loss = None +- if labels is not None: +- loss_fct = CrossEntropyLoss(ignore_index=-100) +- # move labels to correct device to enable PP +- labels = labels.to(lm_logits.device) +- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) +- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 +- +- if not return_dict: +- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs +- return ((loss,) + output) if loss is not None else output +- +- return Seq2SeqLMOutput( +- loss=loss, +- logits=lm_logits, +- past_key_values=decoder_outputs.past_key_values, +- decoder_hidden_states=decoder_outputs.hidden_states, +- decoder_attentions=decoder_outputs.attentions, +- cross_attentions=decoder_outputs.cross_attentions, +- encoder_last_hidden_state=encoder_outputs.last_hidden_state, +- encoder_hidden_states=encoder_outputs.hidden_states, +- encoder_attentions=encoder_outputs.attentions, +- ) ++ if self.is_mindie: ++ with torch.npu.stream(self.stream): # set stream ++ decoder_outputs = self.decoder_mindie.forward(*args) ++ self.stream.synchronize() # synchronize ++ else: ++ hidden_states = args[0] ++ past_cross_keys = args[1:self.config.num_decoder_layers+1] ++ past_cross_values = args[self.config.num_decoder_layers+1:2*self.config.num_decoder_layers+1] ++ past_keys= args[2*self.config.num_decoder_layers+1:3*self.config.num_decoder_layers+1] ++ past_values= args[3*self.config.num_decoder_layers+1:4*self.config.num_decoder_layers+1] ++ encoder_attention_mask = args[-3] ++ decoder_input_ids = args[-2] ++ decoder_attention_mask = args[-1] ++ decoder_outputs = self.decoder(input_ids=decoder_input_ids, ++ encoder_hidden_states=hidden_states, ++ past_keys=past_keys, ++ past_values=past_values, ++ past_cross_keys=past_cross_keys, ++ past_cross_values=past_cross_values, ++ encoder_attention_mask=encoder_attention_mask, ++ attention_mask=decoder_attention_mask) ++ return (decoder_outputs[0],decoder_outputs[1],decoder_outputs[2]) + + def prepare_inputs_for_generation( + self, + input_ids, +- past_key_values=None, ++ past_cross_keys=None, ++ past_cross_values=None, ++ past_keys=None, ++ past_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, +@@ -1804,8 +1980,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + **kwargs, + ): + # cut decoder_input_ids if past_key_values is used +- if past_key_values is not None: +- past_length = past_key_values[0][0].shape[2] ++ if past_keys is not None: ++ past_length = past_keys[0].shape[1] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: +@@ -1813,12 +1989,19 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 +- + input_ids = input_ids[:, remove_prefix_length:] + ++ batch_size, seq_length = input_ids.shape ++ # required mask seq length can be calculated via length of past ++ mask_seq_length = past_keys[0].shape[1] + seq_length if past_keys is not None else seq_length ++ decoder_attention_mask = torch.zeros(batch_size, mask_seq_length, device=input_ids.device) ++ decoder_attention_mask = decoder_attention_mask[:,None,None,:].expand(batch_size,1,mask_seq_length,mask_seq_length).bool() + return { + "decoder_input_ids": input_ids, +- "past_key_values": past_key_values, ++ "past_cross_keys":past_cross_keys, ++ "past_cross_values":past_cross_values, ++ "past_keys":past_keys, ++ "past_values":past_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, +@@ -1826,6 +2009,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + "decoder_attention_mask": decoder_attention_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, ++ + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): +@@ -1861,6 +2045,440 @@ class T5ForConditionalGeneration(T5PreTrainedModel): + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + ++ def _prepare_encoder_decoder_kwargs_for_generation( ++ self, ++ inputs_tensor: torch.Tensor, ++ model_kwargs, ++ model_input_name, ++ generation_config, ++ ): ++ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] ++ encoder_kwargs = { ++ argument: value ++ for argument, value in model_kwargs.items() ++ if not any(argument.startswith(p) for p in irrelevant_prefix) ++ } ++ encoder_kwargs["output_attentions"] = generation_config.output_attentions ++ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states ++ model_input_name = model_input_name if model_input_name is not None else self.main_input_name ++ encoder_kwargs["return_dict"] = True ++ encoder_kwargs[model_input_name] = inputs_tensor ++ encoder_outputs = None ++ if self.is_mindie: ++ with torch.npu.stream(self.stream): # set stream ++ encoder_outputs=self.encoder_mindie.forward(encoder_kwargs["input_ids"],encoder_kwargs["attention_mask"]) ++ self.stream.synchronize() # synchronize ++ else: ++ encoder_outputs=self.encoder.forward(**encoder_kwargs) ++ model_kwargs["encoder_outputs"]={"last_hidden_state":encoder_outputs[0]} ++ model_kwargs["past_cross_keys"] = encoder_outputs[1] ++ model_kwargs["past_cross_values"] =encoder_outputs[2] ++ # print("model_kwargs=",model_kwargs) ++ return model_kwargs ++ ++ def _update_model_kwargs_for_generation( ++ self, ++ outputs, ++ model_kwargs, ++ is_encoder_decoder = False, ++ standardize_cache_format = False, ++ num_new_tokens = 1, ++ ): ++ # update past_key_values keeping its naming used in model code ++ cache_name, cache = self._extract_past_from_model_output( ++ outputs, standardize_cache_format=standardize_cache_format ++ ) ++ model_kwargs[cache_name] = cache ++ if "past_keys" in outputs: ++ past_keys = outputs.past_keys ++ model_kwargs["past_keys"] = past_keys ++ if "past_values" in outputs: ++ past_values = outputs.past_values ++ model_kwargs["past_values"] = past_values ++ # update decoder attention mask ++ if "decoder_attention_mask" in model_kwargs: ++ decoder_attention_mask = model_kwargs["decoder_attention_mask"] ++ model_kwargs["decoder_attention_mask"] = torch.cat( ++ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], ++ dim=-1, ++ ) ++ return model_kwargs ++ ++ @torch.no_grad() ++ def generate( ++ self, ++ inputs = None, ++ generation_config = None, ++ logits_processor = None, ++ stopping_criteria = None, ++ prefix_allowed_tokens_fn = None, ++ assistant_model = None, ++ negative_prompt_ids = None, ++ negative_prompt_attention_mask = None, ++ **kwargs, ++ ): ++ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call ++ self._validate_model_class() ++ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria ++ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) ++ self._validate_model_kwargs(model_kwargs.copy()) ++ ++ ++ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() ++ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() ++ ++ accepts_attention_mask = True ++ requires_attention_mask = "encoder_outputs" not in model_kwargs ++ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None ++ ++ # 3. Define model inputs ++ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( ++ inputs, generation_config.bos_token_id, model_kwargs ++ ) ++ batch_size = inputs_tensor.shape[0] ++ seq_len = inputs_tensor.shape[1] ++ device = inputs_tensor.device ++ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) ++ ++ # 4. Define other model kwargs ++ # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are ++ # generating the first new token or not, and we only want to use the embeddings for the first new token) ++ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": ++ model_kwargs["use_cache"] = True ++ else: ++ model_kwargs["use_cache"] = generation_config.use_cache ++ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: ++ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( ++ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id ++ ) ++ attention_mask = model_kwargs["attention_mask"] ++ attention_mask = attention_mask[:,None,None,:].expand(batch_size,1,seq_len,seq_len).bool() ++ model_kwargs["attention_mask"] = ~attention_mask ++ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: ++ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` ++ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( ++ inputs_tensor, model_kwargs, model_input_name, generation_config ++ ) ++ ++ # 5. Prepare `input_ids` which will be used for auto-regressive generation ++ if self.config.is_encoder_decoder: ++ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( ++ batch_size=batch_size, ++ model_input_name=model_input_name, ++ model_kwargs=model_kwargs, ++ decoder_start_token_id=generation_config.decoder_start_token_id, ++ device=inputs_tensor.device, ++ ) ++ else: ++ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") ++ ++ if generation_config.token_healing: ++ input_ids = self.heal_tokens(input_ids, tokenizer) ++ ++ # 6. Prepare `max_length` depending on other stopping criteria. ++ input_ids_length = input_ids.shape[-1] ++ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None ++ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None ++ generation_config = self._prepare_generated_length( ++ generation_config=generation_config, ++ has_default_max_length=has_default_max_length, ++ has_default_min_length=has_default_min_length, ++ model_input_name=model_input_name, ++ inputs_tensor=inputs_tensor, ++ input_ids_length=input_ids_length, ++ ) ++ ++ use_dynamic_cache_by_default = False ++ if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: ++ raise ValueError( ++ "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " ++ "Cache object) is unsupported. Please use only one of the two." ++ ) ++ elif generation_config.cache_implementation is not None: ++ if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: ++ if generation_config.cache_implementation == "static" and not self._supports_static_cache: ++ raise ValueError( ++ "This model does not support `cache_implementation='static'`. Please check the following " ++ "issue: https://github.com/huggingface/transformers/issues/28981" ++ ) ++ model_kwargs["past_key_values"] = self._get_cache( ++ generation_config.cache_implementation, ++ getattr(generation_config, "num_beams", 1) * batch_size, ++ generation_config.max_length, ++ ) ++ elif generation_config.cache_implementation == "quantized": ++ if not self._supports_quantized_cache: ++ raise ValueError( ++ "This model does not support the quantized cache. If you want your model to support quantized " ++ "cache, please open an issue." ++ ) ++ ++ cache_config = ( ++ generation_config.cache_config ++ if generation_config.cache_config is not None ++ else QuantizedCacheConfig() ++ ) ++ cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] ++ ++ if cache_config.backend == "quanto" and not is_quanto_available(): ++ raise ImportError( ++ "You need to install `quanto` in order to use KV cache quantization with quanto backend. " ++ "Please install it via with `pip install quanto`" ++ ) ++ elif cache_config.backend == "HQQ" and not is_hqq_available(): ++ raise ImportError( ++ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " ++ "Please install it via with `pip install hqq`" ++ ) ++ ++ model_kwargs["past_key_values"] = cache_class(cache_config) ++ # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that ++ # keeps copying the cache thus using much more memory ++ elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): ++ past = model_kwargs.get("past_key_values", None) ++ if past is None: ++ model_kwargs["past_key_values"] = DynamicCache() ++ use_dynamic_cache_by_default = True ++ elif isinstance(past, tuple): ++ model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past) ++ use_dynamic_cache_by_default = True ++ ++ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) ++ ++ # 7. determine generation mode ++ generation_mode = generation_config.get_generation_mode(assistant_model) ++ # 8. prepare distribution pre_processing samplers ++ prepared_logits_processor = self._get_logits_processor( ++ generation_config=generation_config, ++ input_ids_seq_length=input_ids_length, ++ encoder_input_ids=inputs_tensor, ++ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ++ logits_processor=logits_processor, ++ device=inputs_tensor.device, ++ model_kwargs=model_kwargs, ++ negative_prompt_ids=negative_prompt_ids, ++ negative_prompt_attention_mask=negative_prompt_attention_mask, ++ ) ++ ++ # 9. prepare stopping criteria ++ prepared_stopping_criteria = self._get_stopping_criteria( ++ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ++ ) ++ ++ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): ++ # 11. prepare logits warper ++ prepared_logits_warper = ( ++ self._get_logits_warper(generation_config, device=input_ids.device) ++ if generation_config.do_sample ++ else None ++ ) ++ ++ # 12. expand input_ids with `num_return_sequences` additional sequences per batch ++ input_ids, model_kwargs = self._expand_inputs_for_generation( ++ input_ids=input_ids, ++ expand_size=generation_config.num_return_sequences, ++ is_encoder_decoder=self.config.is_encoder_decoder, ++ **model_kwargs, ++ ) ++ # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) ++ result = self._sample( ++ input_ids, ++ logits_processor=prepared_logits_processor, ++ logits_warper=prepared_logits_warper, ++ stopping_criteria=prepared_stopping_criteria, ++ generation_config=generation_config, ++ **model_kwargs, ++ ) ++ return result ++ ++ def _sample( ++ self, ++ input_ids, ++ logits_processor, ++ stopping_criteria, ++ generation_config, ++ logits_warper = None, ++ **model_kwargs, ++ ): ++ # init values ++ pad_token_id = generation_config.pad_token_id ++ output_attentions = generation_config.output_attentions ++ output_hidden_states = generation_config.output_hidden_states ++ output_scores = generation_config.output_scores ++ output_logits = generation_config.output_logits ++ return_dict_in_generate = generation_config.return_dict_in_generate ++ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) ++ do_sample = generation_config.do_sample ++ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): ++ raise ValueError( ++ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " ++ f"{logits_warper})." ++ ) ++ ++ # init attention / hidden states / scores tuples ++ scores = () if (return_dict_in_generate and output_scores) else None ++ raw_logits = () if (return_dict_in_generate and output_logits) else None ++ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None ++ cross_attentions = () if (return_dict_in_generate and output_attentions) else None ++ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None ++ ++ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states ++ if return_dict_in_generate and self.config.is_encoder_decoder: ++ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None ++ encoder_hidden_states = ( ++ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ++ ) ++ ++ this_peer_finished = False ++ batch_size = input_ids.shape[0] ++ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) ++ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) ++ ++ # keep track of which sequences are already finished ++ if self.is_mindie or self.config.architectures[0]=="T5ForConditionalGeneration": ++ num_layers = self.config.num_layers ++ num_heads = self.config.num_heads ++ d_kv = self.config.d_kv ++ model_kwargs["past_keys"] = [torch.randn(batch_size, 0, num_heads*d_kv).half().npu() for _ in range(num_layers)] ++ model_kwargs["past_values"] = [torch.randn(batch_size, 0, num_heads*d_kv).half().npu() for _ in range(num_layers)] ++ ++ ++ while self._has_unfinished_sequences(this_peer_finished, False, device=input_ids.device): ++ # prepare model inputs ++ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) ++ model_args = [model_kwargs["encoder_outputs"]["last_hidden_state"]] ++ model_args.extend(model_kwargs["past_cross_keys"]) ++ model_args.extend(model_kwargs["past_cross_values"]) ++ model_args.extend(model_inputs["past_keys"]) ++ model_args.extend(model_inputs["past_values"]) ++ model_args.append(model_inputs["attention_mask"]) ++ model_args.append(model_inputs["decoder_input_ids"]) ++ model_args.append(model_inputs["decoder_attention_mask"]) ++ ++ # forward pass to get next token ++ outputs = self(*model_args) ++ outputs = Seq2SeqLMOutput(logits=outputs[0], ++ past_keys=outputs[1], ++ past_values=outputs[2]) ++ ++ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration ++ # (the clone itself is always small) ++ next_token_logits = outputs.logits[:, -1, :].clone() ++ ++ # pre-process distribution ++ next_token_scores = logits_processor(input_ids, next_token_logits) ++ if do_sample: ++ next_token_scores = logits_warper(input_ids, next_token_scores) ++ ++ # Store scores, attentions and hidden_states when required ++ if return_dict_in_generate: ++ if output_scores: ++ scores += (next_token_scores,) ++ if output_logits: ++ raw_logits += (next_token_logits,) ++ if output_attentions: ++ decoder_attentions += ( ++ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ++ ) ++ if self.config.is_encoder_decoder: ++ cross_attentions += (outputs.cross_attentions,) ++ ++ if output_hidden_states: ++ decoder_hidden_states += ( ++ (outputs.decoder_hidden_states,) ++ if self.config.is_encoder_decoder ++ else (outputs.hidden_states,) ++ ) ++ ++ # token selection ++ if do_sample: ++ probs = nn.functional.softmax(next_token_scores, dim=-1) ++ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) ++ else: ++ next_tokens = torch.argmax(next_token_scores, dim=-1) ++ ++ # finished sentences should have their next token be a padding token ++ if has_eos_stopping_criteria: ++ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) ++ ++ # update generated ids, model inputs, and length for next step ++ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) ++ model_kwargs = self._update_model_kwargs_for_generation( ++ outputs, ++ model_kwargs, ++ is_encoder_decoder=self.config.is_encoder_decoder, ++ ) ++ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) ++ this_peer_finished = unfinished_sequences.max() == 0 ++ # This is needed to properly delete outputs.logits which may be very large for first iteration ++ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration ++ del outputs ++ return input_ids ++ ++ ++ @property ++ def device(self) -> torch.device: ++ """ ++ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same ++ device). ++ """ ++ return self.get_device() ++ ++ def get_extended_attention_mask( ++ self, attention_mask, input_shape, devic=None, dtype=None ++ ): ++ """ ++ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. ++ ++ Arguments: ++ attention_mask (`torch.Tensor`): ++ Mask with ones indicating tokens to attend to, zeros for tokens to ignore. ++ input_shape (`Tuple[int]`): ++ The shape of the input to the model. ++ ++ Returns: ++ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. ++ """ ++ if dtype is None: ++ dtype = self.dtype ++ ++ if not (attention_mask.dim() == 2 and self.config.is_decoder): ++ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` ++ if device is not None: ++ warnings.warn( ++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning ++ ) ++ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] ++ # ourselves in which case we just need to make it broadcastable to all heads. ++ if attention_mask.dim() == 3: ++ extended_attention_mask = attention_mask[:, None, :, :] ++ elif attention_mask.dim() == 2: ++ # Provided a padding mask of dimensions [batch_size, seq_length] ++ # - if the model is a decoder, apply a causal mask in addition to the padding mask ++ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] ++ if self.config.is_decoder: ++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( ++ input_shape, attention_mask, device ++ ) ++ else: ++ extended_attention_mask = attention_mask[:, None, None, :] ++ else: ++ raise ValueError( ++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" ++ ) ++ ++ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for ++ # masked positions, this operation will create a tensor which is 0.0 for ++ # positions we want to attend and the dtype's smallest value for masked positions. ++ # Since we are adding it to the raw scores before the softmax, this is ++ # effectively the same as removing these entirely. ++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility ++ #extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min ++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000 ++ return extended_attention_mask ++ ++ ++ + + @add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", +@@ -1878,6 +2496,9 @@ class T5EncoderModel(T5PreTrainedModel): + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) ++ self.decoder_mindie = torch.jit.load("encoder_model_path") ++ ++ self.stream = torch.npu.Stream(f"npu:{2}") + + # Initialize weights and apply final processing + self.post_init() +@@ -1966,17 +2587,21 @@ class T5EncoderModel(T5PreTrainedModel): + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" +- return_dict = return_dict if return_dict is not None else self.config.use_return_dict +- +- encoder_outputs = self.encoder( +- input_ids=input_ids, +- attention_mask=attention_mask, +- inputs_embeds=inputs_embeds, +- head_mask=head_mask, +- output_attentions=output_attentions, +- output_hidden_states=output_hidden_states, +- return_dict=return_dict, +- ) ++ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ # encoder_outputs = self.encoder( ++ # input_ids=input_ids, ++ # attention_mask=attention_mask, ++ # inputs_embeds=inputs_embeds, ++ # head_mask=head_mask, ++ # output_attentions=output_attentions, ++ # output_hidden_states=output_hidden_states, ++ # return_dict=return_dict, ++ # ) ++ attention_mask = attention_mask[:,None,None,:].expand(attention_mask.shape[0],1,attention_mask.shape[1],attention_mask.shape[1]).bool() ++ attention_mask = ~attention_mask ++ with torch.npu.stream(self.stream): # set stream ++ encoder_outputs = self.decoder_mindie.forward(input_ids,attention_mask) ++ self.stream.synchronize() # synchronize + + return encoder_outputs + diff --git a/MindIE/MindIE-Torch/built-in/T5/readme.md b/MindIE/MindIE-Torch/built-in/T5/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..a8cd519940cba37c301b7f3d1212fb7a7cb2ecfc --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/T5/readme.md @@ -0,0 +1,162 @@ +# T5模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [模型推理](#section741711594517) + + + +# 概述 + + T5的全称为Text to Text Transfer Transformer,是谷歌提出的预训练语言模型领域的通用模型,该模型将所有自然语言问题都转化成文本到文本的形式,并用一个统一的模型解决.T5最核心的理念是:使用前缀任务声明及文本答案生成,统一所有自然语言处理任务的输入和输出。在此之前的几乎所有预训练语言模型,在下游任务微调过程中都需要添加非线性层,将模型的输出转化为任务指定的输出格式。T5不需要对模型做任何改动,只需要提供下游任务的微调数据;不需要添加任何非线性层,唯一需要做的就是在输入数据前加上任务声明前缀.T5将自然语言处理任务都转化成几乎一致的格式,即输入是带有任务前缀声明的文本序列,输出的文本序列是相应任务的结果 +权重下载:https://huggingface.co/collections/google/t5-release-65005e7c520f8d7b4d037918 + + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | input | batchsize x input_seq_len | FLOAT16 | NHWC | + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | output | batchsize x input_seq_len | INT32 | NTHWC | + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 +- + | 配套 | 版本 | 备注 | + | ------------------------------------------------------------ |--------| ------------------------------------------------------------ | + | Python | 3.10.2 | - | + | torch | 2.1.0 | 导出pt模型所需版本 | + | torch_npu | 2.1.0 | 模型编译和推理所需版本 | + + +# 快速上手 + + +1. 安装transformers4.42.0版本。 + ```bash + pip3 install transformers==4.42.0 + ``` + +2. 安装mindie包,需要与torch_npu配合使用,请参考mindietorch配套torch_npu配置环境 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改,在T5目录下 + + 执行命令: + ```bash + python T5_modeling_t5_patch.py --ascend_soc {Ascend910B4 or Ascend310P3} + ``` +4.导出mindietorch模型 +300IDUO卡环境下: + ```bash + python export_t5.py --output_dir {output_path} --model_path {model_path} --max_batchsize {max_batchsize} --max_input_seq_len {max_input_seq_len} --device_id {device_id} + ``` +800IA2卡环境下: + ```bash + python export_t5_800IA2.py --output_dir {output_path} --model_path {model_path} --max_batchsize {max_batchsize} --max_input_seq_len {max_input_seq_len} --device_id {device_id} + ``` +参数说明: +{output_path}是输出的目录 +{model_path}模型所在目录 +{max_batchsize}推理过程中最大的batchsize +{max_input_seq_len}推理过程中最大输入长度 +{device_id} 用哪个npu device + +运行该命令后会自动生成encoder和decoder优化后的模型 + +5.运行与性能测试 +导入环境变量:export TORCH_AIE_NPU_CACHE_MAX_SIZE=32 + ```bash +python main.py --hf_model_path {model_path} --encoder_aie_path {encoder_aie_path} --decoder_aie_path {decoder_aie_path} --device_id 2 +``` +性能测试: + ```bash +python main.py --hf_model_path {model_path} --encoder_aie_path {encoder_aie_path} --decoder_aie_path {decoder_aie_path} --device_id 2 --performance +``` +打屏可以看到输入长度为512,输出长度为512单batch下的吞吐 +参数说明: +{model_path}模型所在目录 +{encoder_aie_path}优化后的encoder的模型路径,要具体到.pt文件 +{decoder_aie_path}优化后的decoder的模型路径,要具体到.pt文件 +{device_id} 用哪个npu device + +6.精度测试 + +6.1 精度验收标准 +数据集:(英文数据集选一种测试),精度和GPU推理结果对比误差小于1% + +6.2 精度测试方法 + +6.2.1安装mteb和sentence_transformes + + ```bash +pip sentence_transformes==3.1.1 +pip install mteb +``` +6.2.2 下载mteb数据集(如果机器可以连接外部网络可以跳过这步) +下载链接:https://github.com/embeddings-benchmark/mteb + +6.2.3 修改metb的读取数据集的路径地址(如果机器可以连接外部网络可以跳过这步) +例如如果下载的是Banking77Classification数据集,修改mteb python包里的文件路径,例如 +D:\python3.9\Lib\site-packages\mteb\tasks\Classification\eng\Banking77Classification.py文件里的path路径为6.2.2下载的数据集的路径 + +6.2.4 修改代码 + +800IA2卡环境下: +修改transfoermers包下modeling_t5.py下的T5EncoderModel类,将self.decoder_mindie加载路径修改为编译好的encoder的路径 + +300IDUO卡环境下: +修改transfoermers包下modeling_t5.py下的T5EncoderModel类,增加2行, +```bash +self.encoder_mindie = torch.jit.load("encoder_model_path") +self.stream = torch.npu.Stream(f"npu:{device_id}") +``` +其中encoder_model_path为编译好的encoder的路径,device_id为当前设置的npu卡号,再修改forward接口为 +```bash +with torch.npu.stream(self.stream): # set stream + encoder_outputs = self.encoder_mindie.forward(input_ids,attention_mask) +self.stream.synchronize() # synchronize +return encoder_outputs +``` +6.2.5测试代码 + +```bash +import torch +import torch_npu +import mindietorch +import mteb +from sentence_transformers import SentenceTransformer +torch.npu.set_device(0) +model_name = "D:\downloads\T5-v2" +model = SentenceTransformer(model_name,model_kwargs={"torch_dtype":torch.float16}) +tasks = mteb.get_tasks(tasks=["CLSClusteringP2P"]) +evaluation = mteb.MTEB(tasks=tasks) +results = evaluation.run(model, output_folder=f"./{model_name}") +``` +6.2.6 结果输出 +会在当前目录输出结果文件