diff --git a/MindIE/MindIE-Torch/built-in/MT5/MT5_modeling_patch.py b/MindIE/MindIE-Torch/built-in/MT5/MT5_modeling_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a6ec861384f827a00ab99fb20ac0f24f8bc65d
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/MT5/MT5_modeling_patch.py
@@ -0,0 +1,28 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import transformers
+
+
+def main():
+ transformers_path = transformers.__path__
+ transformers_version = transformers.__version__
+
+ assert transformers_version =='4.42.0', "expectation transformers==4.42.0"
+ os.system(f'patch -p0 {transformers_path[0]}/models/mt5/modeling_mt5.py modeling_mt5.patch')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/MT5/export_mt5.py b/MindIE/MindIE-Torch/built-in/MT5/export_mt5.py
new file mode 100644
index 0000000000000000000000000000000000000000..138728fc16ba0360c6efa07508fc78f6731f5d9f
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/MT5/export_mt5.py
@@ -0,0 +1,192 @@
+
+import torch
+import torch_npu
+import argparse
+import os
+import math
+import mindietorch
+from transformers import MT5ForConditionalGeneration
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./models",
+ help="save dir"
+ )
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="./MT5-Small",
+ help="T5 model path"
+ )
+ parser.add_argument(
+ "--max_batchsize",
+ type=int,
+ default=1,
+ help="max batchsize when running"
+ )
+
+ parser.add_argument(
+ "--max_input_seq_len",
+ type=int,
+ default=256,
+ help="max input_sequence length when running"
+ )
+
+
+ parser.add_argument(
+ "--device_id",
+ type=int,
+ default=0,
+ help="npu device id"
+ )
+ return parser.parse_args()
+
+
+class TextEncoderExport(torch.nn.Module):
+ def __init__(self, textencoder_model):
+ super(TextEncoderExport, self).__init__()
+ self.textencoder_model = textencoder_model
+
+ def forward(self, input_ids):
+ return self.textencoder_model(input_ids=input_ids)
+
+class TextDecoderExport(torch.nn.Module):
+ def __init__(self, textdecoder_model):
+ super(TextDecoderExport, self).__init__()
+ self.textdecoder_model = textdecoder_model
+
+ def forward(self,
+ *args):
+ return self.textdecoder_model(*args)
+
+def export_textencoder(args, model, save_dir, batch_size):
+ encoder_path = os.path.join(save_dir, "encoder")
+ if not os.path.exists(encoder_path):
+ os.makedirs(encoder_path, mode=0o640)
+ traced_path = os.path.join(encoder_path, "encoder.pt")
+ compiled_path = os.path.join(encoder_path, "encoder_compiled.pt")
+ if not os.path.exists(traced_path):
+ text_encoder = model.encoder
+ dummy_input = (
+ torch.ones([1, 128], dtype=torch.int64).npu()
+ )
+ encoder = TextEncoderExport(text_encoder)
+ encoder.eval()
+ torch.jit.trace(encoder, dummy_input, strict=False).save(traced_path)
+ if not os.path.exists(compiled_path):
+ traced_model = torch.jit.load(traced_path).eval()
+
+ inputs0 = []
+ inputs0.append(mindietorch.Input(min_shape = (1,1), max_shape= (args.max_batchsize, args.max_input_seq_len), dtype=torch.int64))
+ print("compiling encoder")
+ compiled_model = mindietorch.compile(
+ traced_model,
+ inputs=inputs0,
+ allow_tensor_replace_int=True,
+ require_full_compilation=False,
+ truncate_long_and_double=True,
+ precision_policy=mindietorch.PrecisionPolicy.FP16,
+ soc_version="Ascend910B4",
+ optimization_level=0
+ )
+ compiled_model.save(compiled_path)
+
+def export_textdecoder(args, model, save_dir, batch_size):
+ decoder_path = os.path.join(save_dir, "decoder")
+ if not os.path.exists(decoder_path):
+ os.makedirs(decoder_path, mode=0o640)
+ traced_path = os.path.join(decoder_path, "decoder.pt")
+ compiled_path = os.path.join(decoder_path, "decoder_compiled.pt")
+ model_path = args.model_path
+ max_lenth = 120
+ if not os.path.exists(traced_path):
+ text_decoder = model
+ all_past_keys = [torch.randn([1, model.config.num_heads, 1, model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_values = [torch.randn([1, model.config.num_heads, 1, model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_cross_keys = [torch.randn([1, 16, model.config.num_heads * model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_cross_values = [torch.randn([1, 16, model.config.num_heads * model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers
+ dummy_input = [torch.randn(1, 16, model.config.d_model).to(torch.float16).npu()]
+ dummy_input.extend(all_past_cross_keys)
+ dummy_input.extend(all_past_cross_values)
+ dummy_input.extend(all_past_keys)
+ dummy_input.extend(all_past_values)
+ dummy_input.append(torch.ones(1,16).npu())
+ dummy_input.append(torch.ones([1, 1], dtype=torch.int64).npu())
+ decoder = TextDecoderExport(text_decoder).npu()
+ decoder.eval()
+ torch.jit.trace(decoder, dummy_input,strict=False).save(traced_path)
+ if not os.path.exists(compiled_path):
+ traced_model = torch.jit.load(traced_path).eval()
+ print("compiling decoder")
+ input_info = [mindietorch.Input(min_shape =(1, 1, model.config.d_model),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model),
+ dtype=mindietorch.dtype.FLOAT16)]
+ past_cross_key_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_model),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_cross_value_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_model),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_key_infos = [mindietorch.Input(min_shape =(1, model.config.num_heads, 0, model.config.d_kv),
+ max_shape=(args.max_batchsize, model.config.num_heads, args.max_input_seq_len, model.config.d_kv),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_value_infos = [mindietorch.Input(min_shape =(1, model.config.num_heads, 0, model.config.d_kv),
+ max_shape=(args.max_batchsize, model.config.num_heads, args.max_input_seq_len, model.config.d_kv),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ decoder_input_ids_info = [mindietorch.Input(min_shape =(1, 1),
+ max_shape = (args.max_batchsize,1),
+ dtype=mindietorch.dtype.INT64)]
+ encoder_attention_mask_info = [mindietorch.Input(min_shape =(1, 1),
+ max_shape = (args.max_batchsize,args.max_input_seq_len),
+ dtype=mindietorch.dtype.INT64)]
+ input_info.extend(past_cross_key_infos)
+ input_info.extend(past_cross_value_infos)
+ input_info.extend(past_key_infos)
+ input_info.extend(past_value_infos)
+ input_info.extend(encoder_attention_mask_info)
+ input_info.extend(decoder_input_ids_info)
+ buffer = []
+ for _ in range(2*model.config.num_layers):
+ buffer.append(math.ceil((args.max_batchsize * args.max_input_seq_len * model.config.d_model * 2) / 1024 / 1024))
+ buffer_size0 = math.ceil((args.max_batchsize * 1 * model.config.vocab_size * 4) / 1024 / 1024)
+ buffer.append(buffer_size0)
+ print("buffer=",buffer)
+ compiled_model = mindietorch.compile(
+ traced_model,
+ inputs=input_info,
+ allow_tensor_replace_int=True,
+ require_full_compilation=False,
+ truncate_long_and_double=True,
+ precision_policy=mindietorch.PrecisionPolicy.FP16,
+ soc_version="Ascend910B4",
+ default_buffer_size_vec=buffer,
+ optimization_level=0
+ )
+ compiled_model.save(compiled_path)
+
+
+def main():
+ args = parse_arguments()
+ device_id = args.device_id
+ save_dir = args.output_dir
+ torch.npu.set_device(device_id)
+ batch_size = 1
+ model = MT5ForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.float).npu()
+ encoder_path = os.path.join(save_dir, "encoder")
+ compiled_path = os.path.join(encoder_path, "encoder_compiled.pt")
+ if not os.path.exists(compiled_path):
+ export_textencoder(args, model, save_dir, batch_size)
+ print("export encoder_model done!")
+
+ decoder_path = os.path.join(save_dir, "decoder")
+ compiled_path = os.path.join(decoder_path, "decoder_compiled.pt")
+ if not os.path.exists(compiled_path):
+ export_textdecoder(args, model, save_dir, batch_size)
+ print("export decoder_model done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/MT5/modeling_mt5.patch b/MindIE/MindIE-Torch/built-in/MT5/modeling_mt5.patch
new file mode 100644
index 0000000000000000000000000000000000000000..a5afef98e2a5de7d41dc57313605becde8835002
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/MT5/modeling_mt5.patch
@@ -0,0 +1,1557 @@
+diff --git a/modeling_mt5.py b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/mt5/modeling_mt5.py
+index 1336b9196..5b94d69c7 100644
+--- a/modeling_mt5.py
++++ b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/mt5/modeling_mt5.py
+@@ -19,22 +19,26 @@ import math
+ import os
+ import warnings
+ from typing import List, Optional, Tuple, Union
+-
++from dataclasses import dataclass
+ import torch
+ from torch import nn
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
++# import torch_npu
++import mindietorch
++
++
++
+
+ from ...activations import ACT2FN
+ from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+- Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+ Seq2SeqQuestionAnsweringModelOutput,
+ Seq2SeqSequenceClassifierOutput,
+ TokenClassifierOutput,
+ )
+-from ...modeling_utils import PreTrainedModel
++from ...modeling_utils import PreTrainedModel,ModuleUtilsMixin
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+ from ...utils import (
+ DUMMY_INPUTS,
+@@ -47,8 +51,44 @@ from ...utils import (
+ )
+ from ...utils.model_parallel_utils import assert_device_map, get_device_map
+ from .configuration_mt5 import MT5Config
++from transformers.generation.logits_process import LogitsProcessorList
++from transformers.generation.stopping_criteria import StoppingCriteriaList
++from transformers.generation.configuration_utils import GenerationMode
++from transformers.utils.generic import ModelOutput
+
+
++@dataclass
++class Seq2SeqLMOutput(ModelOutput):
++ """
++ Base class for model's outputs, with potential hidden states and attentions.
++
++ Args:
++ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
++ Sequence of hidden-states at the output of the last layer of the model.
++ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
++ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
++ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
++
++ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
++ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
++ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
++ sequence_length)`.
++
++ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
++ heads.
++ """
++ loss: Optional[torch.FloatTensor] = None
++ logits: torch.FloatTensor = None
++ past_keys: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ past_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
++ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
++ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
++ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
++ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
++ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
++
+ logger = logging.get_logger(__name__)
+
+ _CONFIG_FOR_DOC = "MT5Config"
+@@ -323,7 +363,10 @@ class MT5Attention(nn.Module):
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
++ past_cross_key=None,
++ past_cross_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+@@ -339,17 +382,15 @@ class MT5Attention(nn.Module):
+
+ real_seq_length = seq_length
+
+- if past_key_value is not None:
+- if len(past_key_value) != 2:
+- raise ValueError(
+- f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+- )
+- real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
++ if past_key is not None:
++ real_seq_length += past_key.shape[2] if query_length is None else query_length
+
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
+
+ def shape(states):
+ """projection"""
++ # import pdb
++ # pdb.set_trace()
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
+
+ def unshape(states):
+@@ -368,16 +409,17 @@ class MT5Attention(nn.Module):
+ hidden_states = shape(proj_layer(key_value_states))
+
+ if past_key_value is not None:
++ past_key_value = shape(past_key_value)
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, key_length, dim_per_head)
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
+- elif past_key_value.shape[2] != key_value_states.shape[1]:
+- # checking that the `sequence_length` of the `past_key_value` is the same as
+- # the provided `key_value_states` to support prefix tuning
+- # cross-attn
+- # (batch_size, n_heads, seq_length, dim_per_head)
+- hidden_states = shape(proj_layer(key_value_states))
++ # elif past_key_value.shape[2] != key_value_states.shape[1]:
++ # # checking that the `sequence_length` of the `past_key_value` is the same as
++ # # the provided `key_value_states` to support prefix tuning
++ # # cross-attn
++ # # (batch_size, n_heads, seq_length, dim_per_head)
++ # hidden_states = shape(proj_layer(key_value_states))
+ else:
+ # cross-attn
+ hidden_states = past_key_value
+@@ -388,10 +430,10 @@ class MT5Attention(nn.Module):
+
+ # get key/value states
+ key_states = project(
+- hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
++ hidden_states, self.k, key_value_states, past_key if past_key is not None else None
+ )
+ value_states = project(
+- hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
++ hidden_states, self.v, key_value_states, past_value if past_value is not None else None
+ )
+
+ # compute scores
+@@ -411,7 +453,7 @@ class MT5Attention(nn.Module):
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+- if past_key_value is not None:
++ if past_key is not None:
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
+
+ if mask is not None:
+@@ -439,14 +481,124 @@ class MT5Attention(nn.Module):
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
+ attn_output = self.o(attn_output)
+
+- present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
+- outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+-
++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None
++ present_value_state = (value_states.half(),) if (self.is_decoder and use_cache) else None
++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,)
++
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
++class MT5SelfAttention(MT5Attention):
++ def __init__(self, config: MT5Config, has_relative_attention_bias=False):
++ super().__init__(config, has_relative_attention_bias)
++
++ def forward(
++ self,
++ hidden_states,
++ mask=None,
++ position_bias=None,
++ past_key=None,
++ past_value=None,
++ layer_head_mask=None,
++ use_cache=False,
++ output_attentions=False,
++ ):
++ """
++ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
++ """
++ # Input is (batch_size, seq_length, dim)
++ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
++ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
++ batch_size, seq_length = hidden_states.shape[:2]
++
++ real_seq_length = seq_length
++
++ if past_key is not None:
++ real_seq_length += past_key.shape[2]
++ key_length = real_seq_length
++ def shape(states):
++ """projection"""
++ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
++
++ def unshape(states):
++ """reshape"""
++ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
++
++ def project(hidden_states, proj_layer, past_key_value):
++ """projects hidden states correctly to key/query states"""
++ if past_key_value is None:
++ # cross-attn
++ # (batch_size, n_heads, seq_length, dim_per_head)
++ hidden_states = shape(proj_layer(hidden_states))
++
++ if past_key_value is not None:
++ hidden_states = shape(proj_layer(hidden_states))
++ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
++ return hidden_states
++
++ # get query states
++ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
++
++ # get key/value states
++ key_states = project(
++ hidden_states, self.k, past_key if past_key is not None else None
++ )
++ value_states = project(
++ hidden_states, self.v, past_value if past_value is not None else None
++ )
++ # compute scores
++ scores = torch.matmul(
++ query_states, key_states.transpose(3, 2)
++ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
++ if position_bias is None:
++ if not self.has_relative_attention_bias:
++ position_bias = torch.zeros(
++ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
++ )
++ if self.gradient_checkpointing and self.training:
++ position_bias.requires_grad = True
++ else:
++ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
++
++ # if key and values are already calculated
++ # we want only the last query position bias
++ if past_key is not None:
++ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
++ if mask is not None:
++ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
++
++ if self.pruned_heads:
++ mask = torch.ones(position_bias.shape[1])
++ mask[list(self.pruned_heads)] = 0
++ position_bias_masked = position_bias[:, mask.bool()]
++ else:
++ position_bias_masked = position_bias
++ scores += position_bias_masked
++ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
++ scores
++ ) # (batch_size, n_heads, seq_length, key_length)
++ attn_weights = nn.functional.dropout(
++ attn_weights, p=self.dropout, training=self.training
++ ) # (batch_size, n_heads, seq_length, key_length)
++
++ # Mask heads if we want to
++ if layer_head_mask is not None:
++ attn_weights = attn_weights * layer_head_mask
++
++ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
++ attn_output = self.o(attn_output)
++
++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None
++ present_value_state = (value_states.half(), ) if (self.is_decoder and use_cache) else None
++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,)
++ if output_attentions:
++ outputs = outputs + (attn_weights,)
++ return outputs
++
+ # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5
+ class MT5LayerSelfAttention(nn.Module):
+ def __init__(self, config, has_relative_attention_bias=False):
+@@ -461,7 +613,8 @@ class MT5LayerSelfAttention(nn.Module):
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+@@ -471,7 +624,8 @@ class MT5LayerSelfAttention(nn.Module):
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+@@ -495,7 +649,8 @@ class MT5LayerCrossAttention(nn.Module):
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
+ use_cache=False,
+ query_length=None,
+ output_attentions=False,
+@@ -507,7 +662,8 @@ class MT5LayerCrossAttention(nn.Module):
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+@@ -539,39 +695,34 @@ class MT5Block(nn.Module):
+ encoder_decoder_position_bias=None,
+ layer_head_mask=None,
+ cross_attn_layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
++ past_cross_key=None,
++ past_cross_value=None,
+ use_cache=False,
+ output_attentions=False,
+ return_dict=True,
+ ):
+- if past_key_value is not None:
+- if not self.is_decoder:
+- logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
+- expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
+-
+- if len(past_key_value) != expected_num_past_key_values:
+- raise ValueError(
+- f"There should be {expected_num_past_key_values} past states. "
+- f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
+- f"Got {len(past_key_value)} past key / value states"
+- )
+-
+- self_attn_past_key_value = past_key_value[:2]
+- cross_attn_past_key_value = past_key_value[2:]
++ if past_key is not None:
++ self_attn_past_key = past_key
++ self_attn_past_value = past_value
++ cross_attn_past_key = past_cross_key
++ cross_attn_past_value = past_cross_value
+ else:
+- self_attn_past_key_value, cross_attn_past_key_value = None, None
++ self_attn_past_key, self_attn_past_value, cross_attn_past_key, cross_attn_past_value = None, None, None, None
+
+ self_attention_outputs = self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=self_attn_past_key_value,
++ past_key=self_attn_past_key,
++ past_value=self_attn_past_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+- hidden_states, present_key_value_state = self_attention_outputs[:2]
+- attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
++ hidden_states, present_key_state, present_value_state = self_attention_outputs[:3]
++ attention_outputs = self_attention_outputs[3:] # Keep self-attention outputs and relative position weights
+
+ # clamp inf values to enable fp16 training
+ if hidden_states.dtype == torch.float16:
+@@ -586,8 +737,8 @@ class MT5Block(nn.Module):
+ if do_cross_attention:
+ # the actual query length is unknown for cross attention
+ # if using past key value states. Need to inject it here
+- if present_key_value_state is not None:
+- query_length = present_key_value_state[0].shape[2]
++ if present_key_state is not None:
++ query_length = present_key_state[0].shape[2]
+ else:
+ query_length = None
+
+@@ -597,7 +748,8 @@ class MT5Block(nn.Module):
+ attention_mask=encoder_attention_mask,
+ position_bias=encoder_decoder_position_bias,
+ layer_head_mask=cross_attn_layer_head_mask,
+- past_key_value=cross_attn_past_key_value,
++ past_key=cross_attn_past_key,
++ past_value=cross_attn_past_value,
+ query_length=query_length,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+@@ -614,11 +766,9 @@ class MT5Block(nn.Module):
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ # Combine self attn and cross attn key value states
+- if present_key_value_state is not None:
+- present_key_value_state = present_key_value_state + cross_attention_outputs[1]
+-
++ # cross_attn_past_key_values = cross_attention_outputs[1]
+ # Keep cross-attention outputs and relative position weights
+- attention_outputs = attention_outputs + cross_attention_outputs[2:]
++ attention_outputs = attention_outputs + cross_attention_outputs[3:]
+
+ # Apply Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states)
+@@ -635,7 +785,7 @@ class MT5Block(nn.Module):
+ outputs = (hidden_states,)
+
+ if use_cache:
+- outputs = outputs + (present_key_value_state,) + attention_outputs
++ outputs = outputs + (present_key_state,) +(present_value_state,)+ attention_outputs
+ else:
+ outputs = outputs + attention_outputs
+
+@@ -884,11 +1034,15 @@ class MT5PreTrainedModel(PreTrainedModel):
+
+ # Copied from transformers.models.t5.modeling_t5.T5Stack with T5->MT5
+ class MT5Stack(MT5PreTrainedModel):
+- def __init__(self, config, embed_tokens=None):
++ def __init__(self, config, embed_tokens=None,lm_head=None, encodecrosskey=None, encodecrossvalue=None):
+ super().__init__(config)
+
+ self.embed_tokens = embed_tokens
+ self.is_decoder = config.is_decoder
++ self.lm_head=lm_head
++ self.encodecrosskey = encodecrosskey
++ self.encodecrossvalue = encodecrossvalue
++ self.model_dim = config.d_model
+
+ self.block = nn.ModuleList(
+ [MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
+@@ -953,20 +1107,63 @@ class MT5Stack(MT5PreTrainedModel):
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens = new_embeddings
+
++ def invert_attention_mask(self, encoder_attention_mask):
++ if encoder_attention_mask.dim() == 3:
++ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
++ if encoder_attention_mask.dim() == 2:
++ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
++ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
++
++ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1000
++
++ return encoder_extended_attention_mask
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, device=None, dtype=None
++ ):
++ if dtype is None:
++ dtype = self.dtype
++
++ if not (attention_mask.dim() == 2 and self.config.is_decoder):
++ if device is not None:
++ warnings.warn(
++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
++ )
++ if attention_mask.dim() == 3:
++ extended_attention_mask = attention_mask[:, None, :, :]
++ elif attention_mask.dim() == 2:
++ if self.config.is_decoder:
++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
++ input_shape, attention_mask, device
++ )
++ else:
++ extended_attention_mask = attention_mask[:, None, None, :]
++ else:
++ raise ValueError(
++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
++ )
++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000
++ return extended_attention_mask
++
+ def forward(
+ self,
+ input_ids=None,
+- attention_mask=None,
+ encoder_hidden_states=None,
++ past_keys=None,
++ past_values=None,
++ past_cross_keys=None,
++ past_cross_values=None,
+ encoder_attention_mask=None,
++ attention_mask=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+- past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
++ **model_kwargs
+ ):
+ # Model parallel
+ if self.model_parallel:
+@@ -985,8 +1182,10 @@ class MT5Stack(MT5PreTrainedModel):
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
++
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
++ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+@@ -999,18 +1198,19 @@ class MT5Stack(MT5PreTrainedModel):
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+-
+ # required mask seq length can be calculated via length of past
+- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
++ mask_seq_length = past_keys[0].shape[2] + seq_length if past_keys is not None else seq_length
+
+ if use_cache is True:
+ if not self.is_decoder:
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
+
+ # initialize past_key_values with `None` if past does not exist
+- if past_key_values is None:
+- past_key_values = [None] * len(self.block)
+-
++ if not self.is_decoder:
++ past_keys = [None] * len(self.block)
++ past_values = [None] * len(self.block)
++ past_cross_keys = [None] * len(self.block)
++ past_cross_values = [None] * len(self.block)
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+
+@@ -1041,7 +1241,8 @@ class MT5Stack(MT5PreTrainedModel):
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+- present_key_value_states = () if use_cache else None
++ present_key_states = () if use_cache else None
++ present_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+@@ -1049,8 +1250,8 @@ class MT5Stack(MT5PreTrainedModel):
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+-
+- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
++ # for i, layer_module in enumerate(self.block):
++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ # Model parallel
+@@ -1099,7 +1300,10 @@ class MT5Stack(MT5PreTrainedModel):
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
++ past_cross_key=past_cross_key,
++ past_cross_value=past_cross_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+@@ -1107,19 +1311,20 @@ class MT5Stack(MT5PreTrainedModel):
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if use_cache is False:
+- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:]
+
+- hidden_states, present_key_value_state = layer_outputs[:2]
++ hidden_states, present_key_state, present_value_state = layer_outputs[:3]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+- position_bias = layer_outputs[2]
++ position_bias = layer_outputs[3]
+ if self.is_decoder and encoder_hidden_states is not None:
+- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4]
+ # append next layer key value states
+ if use_cache:
+- present_key_value_states = present_key_value_states + (present_key_value_state,)
++ present_key_states = present_key_states + present_key_state
++ present_value_states = present_value_states + present_value_state
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+@@ -1133,7 +1338,7 @@ class MT5Stack(MT5PreTrainedModel):
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.final_layer_norm(hidden_states)
+- hidden_states = self.dropout(hidden_states)
++ hidden_states = self.dropout(hidden_states).half()
+
+ # Add last layer
+ if output_hidden_states:
+@@ -1151,13 +1356,216 @@ class MT5Stack(MT5PreTrainedModel):
+ ]
+ if v is not None
+ )
+- return BaseModelOutputWithPastAndCrossAttentions(
+- last_hidden_state=hidden_states,
+- past_key_values=present_key_value_states,
+- hidden_states=all_hidden_states,
+- attentions=all_attentions,
+- cross_attentions=all_cross_attentions,
++ if not self.is_decoder:
++ cross_keys = None
++ cross_values = None
++ if self.encodecrosskey:
++ cross_keys = self.encodecrosskey(hidden_states)
++ if self.encodecrossvalue:
++ cross_values = self.encodecrossvalue(hidden_states)
++ return tuple((hidden_states, cross_keys, cross_values))
++ lm_logits = None
++ if self.is_decoder:
++ if self.config.tie_word_embeddings:
++ hidden_states = hidden_states * (self.model_dim ** -0.5)
++ lm_logits = self.lm_head(hidden_states)
++ return tuple((lm_logits, present_key_states, present_value_states))
++
++
++class MT5Stack_Encoder(MT5PreTrainedModel):
++ def __init__(self, config, embed_tokens=None, encodecrosskey=None, encodecrossvalue=None):
++ super().__init__(config)
++ self.embed_tokens = embed_tokens
++ self.is_decoder = config.is_decoder
++ self.encodecrosskey = encodecrosskey
++ self.encodecrossvalue = encodecrossvalue
++ self.model_dim = config.d_model
++
++ self.block = nn.ModuleList(
++ [MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
+ )
++ self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
++ self.dropout = nn.Dropout(config.dropout_rate)
++
++ # Initialize weights and apply final processing
++ self.post_init()
++ # Model parallel
++ self.model_parallel = False
++ self.device_map = None
++ self.gradient_checkpointing = False
++
++ def get_input_embeddings(self):
++ return self.embed_tokens
++
++ def set_input_embeddings(self, new_embeddings):
++ self.embed_tokens = new_embeddings
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, device=None, dtype=None
++ ):
++ if dtype is None:
++ dtype = self.dtype
++
++ if not (attention_mask.dim() == 2 and self.config.is_decoder):
++ if device is not None:
++ warnings.warn(
++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
++ )
++ if attention_mask.dim() == 3:
++ extended_attention_mask = attention_mask[:, None, :, :]
++ elif attention_mask.dim() == 2:
++ if self.config.is_decoder:
++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
++ input_shape, attention_mask, device
++ )
++ else:
++ extended_attention_mask = attention_mask[:, None, None, :]
++ else:
++ raise ValueError(
++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
++ )
++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000
++ return extended_attention_mask
++
++ def forward(
++ self,
++ input_ids=None,
++ attention_mask=None,
++ head_mask=None,
++ cross_attn_head_mask=None,
++ use_cache=None,
++ output_attentions=None,
++ output_hidden_states=None,
++ return_dict=None,
++ **model_kwargs
++ ):
++ # Model parallel
++ use_cache = use_cache if use_cache is not None else self.config.use_cache
++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
++ output_hidden_states = (
++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
++ )
++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
++
++ input_shape = input_ids.size()
++ input_ids = input_ids.view(-1, input_shape[-1])
++
++ inputs_embeds = self.embed_tokens(input_ids)
++
++ batch_size, seq_length = input_shape
++ # required mask seq length can be calculated via length of past
++ mask_seq_length = seq_length
++
++ # initialize past_key_values with `None` if past does not exist
++ past_keys = [None] * len(self.block)
++ past_values = [None] * len(self.block)
++ past_cross_keys = [None] * len(self.block)
++ past_cross_values = [None] * len(self.block)
++ if attention_mask is None:
++ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
++
++ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
++ # ourselves in which case we just need to make it broadcastable to all heads.
++ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
++
++ # If a 2D or 3D attention mask is provided for the cross-attention
++ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
++
++ encoder_extended_attention_mask = None
++
++ # Prepare head mask if needed
++ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
++ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
++ present_key_states = () if use_cache else None
++ present_value_states = () if use_cache else None
++ all_hidden_states = () if output_hidden_states else None
++ all_attentions = () if output_attentions else None
++ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
++ position_bias = None
++ encoder_decoder_position_bias = None
++
++ hidden_states = self.dropout(inputs_embeds)
++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)):
++ layer_head_mask = head_mask[i]
++ cross_attn_layer_head_mask = cross_attn_head_mask[i]
++ if output_hidden_states:
++ all_hidden_states = all_hidden_states + (hidden_states,)
++
++ layer_outputs = layer_module(
++ hidden_states,
++ attention_mask=extended_attention_mask,
++ position_bias=position_bias,
++ encoder_hidden_states=None,
++ encoder_attention_mask=encoder_extended_attention_mask,
++ encoder_decoder_position_bias=encoder_decoder_position_bias,
++ layer_head_mask=layer_head_mask,
++ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
++ past_key=past_key,
++ past_value=past_value,
++ past_cross_key=past_cross_key,
++ past_cross_value=past_cross_value,
++ use_cache=use_cache,
++ output_attentions=output_attentions,
++ )
++
++ # layer_outputs is a tuple with:
++ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
++ if use_cache is False:
++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:]
++
++ hidden_states, present_key_state, present_value_state = layer_outputs[:3]
++
++ # We share the position biases between the layers - the first layer store them
++ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
++ # (cross-attention position bias), (cross-attention weights)
++ position_bias = layer_outputs[3]
++ if self.is_decoder and encoder_hidden_states is not None:
++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4]
++ # append next layer key value states
++ if use_cache:
++ present_key_states = present_key_states + present_key_state
++ present_value_states = present_value_states + present_value_state
++
++ if output_attentions:
++ all_attentions = all_attentions + (layer_outputs[3],)
++ if self.is_decoder:
++ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
++
++ # Model Parallel: If it's the last layer for that device, put things on the next device
++ if self.model_parallel:
++ for k, v in self.device_map.items():
++ if i == v[-1] and "cuda:" + str(k) != self.last_device:
++ hidden_states = hidden_states.to("cuda:" + str(k + 1))
++
++ hidden_states = self.final_layer_norm(hidden_states)
++ hidden_states = self.dropout(hidden_states).half()
++
++ # Add last layer
++ if output_hidden_states:
++ all_hidden_states = all_hidden_states + (hidden_states,)
++
++ if not return_dict:
++ return tuple(
++ v
++ for v in [
++ hidden_states,
++ present_key_value_states,
++ all_hidden_states,
++ all_attentions,
++ all_cross_attentions,
++ ]
++ if v is not None
++ )
++ # present_key_value_states = torch.concat(present_key_value_states).reshape(len(self.block),2,*present_key_value_states[0].shape).half() if use_cache else None
++ if not self.is_decoder:
++ cross_keys = None
++ cross_values = None
++ if self.encodecrosskey:
++ cross_keys = self.encodecrosskey(hidden_states)
++ if self.encodecrossvalue:
++ cross_values = self.encodecrossvalue(hidden_states)
++ return tuple((hidden_states, cross_keys, cross_values))
+
+
+ MT5_START_DOCSTRING = r"""
+@@ -1549,6 +1957,39 @@ class MT5Model(MT5PreTrainedModel):
+ )
+
+
++class EncoderToCrossKey(nn.Module):
++ def __init__(self, cross_key, num_heads, d_kv):
++ super().__init__()
++ self.cross_key = cross_key
++ self.num_heads = num_heads
++ self.d_kv = d_kv
++
++
++ def forward(self, hidden_states):
++ batch_size = hidden_states.shape[0]
++ past_cross_keys = ()
++ for i in range(len(self.cross_key)):
++ past_cross_keys += (self.cross_key[i](hidden_states),)
++ # import pdb
++ # pdb.set_trace()
++ return past_cross_keys
++
++
++class EncoderToCrossValue(nn.Module):
++ def __init__(self, cross_value, num_heads, d_kv):
++ super().__init__()
++ self.cross_value = cross_value
++ self.num_heads = num_heads
++ self.d_kv = d_kv
++
++
++ def forward(self, hidden_states):
++ batch_size = hidden_states.shape[0]
++ past_cross_values = ()
++ for i in range(len(self.cross_value)):
++ past_cross_values += (self.cross_value[i](hidden_states),)
++ return past_cross_values
++
+ @add_start_docstrings("""MT5 Model with a `language modeling` head on top.""", MT5_START_DOCSTRING)
+ class MT5ForConditionalGeneration(MT5PreTrainedModel):
+ r"""
+@@ -1573,33 +2014,52 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
+- def __init__(self, config: MT5Config):
++ def __init__(self, config: MT5Config, encoder_path=None, decoder_path=None, device_id=0):
+ super().__init__(config)
+- self.model_dim = config.d_model
+-
+- self.shared = nn.Embedding(config.vocab_size, config.d_model)
+-
+- encoder_config = copy.deepcopy(config)
+- encoder_config.is_decoder = False
+- encoder_config.use_cache = False
+- encoder_config.is_encoder_decoder = False
+- self.encoder = MT5Stack(encoder_config, self.shared)
+-
+- decoder_config = copy.deepcopy(config)
+- decoder_config.is_decoder = True
+- decoder_config.is_encoder_decoder = False
+- decoder_config.num_layers = config.num_decoder_layers
+- self.decoder = MT5Stack(decoder_config, self.shared)
+-
+- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
++ self.encoder_path = encoder_path
++ self.decoder_path = decoder_path
++ self.is_mindie = False
++ if not self.encoder_path or not self.decoder_path:
++ self.model_dim = config.d_model
++
++ self.shared = nn.Embedding(config.vocab_size, config.d_model)
++
++ decoder_config = copy.deepcopy(config)
++ decoder_config.is_decoder = True
++ decoder_config.is_encoder_decoder = False
++ decoder_config.num_layers = config.num_decoder_layers
++ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
++ self.decoder = MT5Stack(decoder_config, self.shared, self.lm_head)
++ cross_key = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.k for i in range(config.num_decoder_layers))
++ cross_value = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.v for i in range(config.num_decoder_layers))
++ encodecrosskey = EncoderToCrossKey(cross_key, config.num_heads, config.d_kv)
++ encodecrossvalue = EncoderToCrossValue(cross_value, config.num_heads, config.d_kv)
++ encoder_config = copy.deepcopy(config)
++ encoder_config.is_decoder = False
++ encoder_config.use_cache = False
++ encoder_config.is_encoder_decoder = False
++ self.encoder = MT5Stack_Encoder(encoder_config, self.shared, encodecrosskey=encodecrosskey, encodecrossvalue=encodecrossvalue)
++ self.encoder_mindie = None
++ self.decoder_mindie = None
++ if self.encoder_path:
++ self.encoder_mindie = torch.jit.load(self.encoder_path)
++ self.is_mindie = True
++ if self.decoder_path:
++ self.decoder_mindie = torch.jit.load(self.decoder_path)
++ self.stream = torch.npu.Stream(f"npu:{device_id}")
++ self.device_id = device_id
+
+ # Initialize weights and apply final processing
+- self.post_init()
++ if not self.is_mindie:
++ self.post_init()
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
++ def get_device(self):
++ return f"npu:{self.device_id}"
++
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize
+ def parallelize(self, device_map=None):
+@@ -1666,25 +2126,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
+ @add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward with T5->MT5, t5->mt5
+- def forward(
+- self,
+- input_ids: Optional[torch.LongTensor] = None,
+- attention_mask: Optional[torch.FloatTensor] = None,
+- decoder_input_ids: Optional[torch.LongTensor] = None,
+- decoder_attention_mask: Optional[torch.BoolTensor] = None,
+- head_mask: Optional[torch.FloatTensor] = None,
+- decoder_head_mask: Optional[torch.FloatTensor] = None,
+- cross_attn_head_mask: Optional[torch.Tensor] = None,
+- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+- inputs_embeds: Optional[torch.FloatTensor] = None,
+- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+- labels: Optional[torch.LongTensor] = None,
+- use_cache: Optional[bool] = None,
+- output_attentions: Optional[bool] = None,
+- output_hidden_states: Optional[bool] = None,
+- return_dict: Optional[bool] = None,
+- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
++ def forward(self,*args) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
+@@ -1716,114 +2158,37 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
+ >>> # studies have shown that owning a dog is good for you.
+ ```"""
+- use_cache = use_cache if use_cache is not None else self.config.use_cache
+- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+-
+- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+- if head_mask is not None and decoder_head_mask is None:
+- if self.config.num_layers == self.config.num_decoder_layers:
+- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+- decoder_head_mask = head_mask
+-
+- # Encode if needed (training, first prediction pass)
+- if encoder_outputs is None:
+- # Convert encoder inputs in embeddings if needed
+- encoder_outputs = self.encoder(
+- input_ids=input_ids,
+- attention_mask=attention_mask,
+- inputs_embeds=inputs_embeds,
+- head_mask=head_mask,
+- output_attentions=output_attentions,
+- output_hidden_states=output_hidden_states,
+- return_dict=return_dict,
+- )
+- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+- encoder_outputs = BaseModelOutput(
+- last_hidden_state=encoder_outputs[0],
+- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+- )
+-
+- hidden_states = encoder_outputs[0]
+-
+- if self.model_parallel:
+- torch.cuda.set_device(self.decoder.first_device)
+-
+- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+- # get decoder inputs from shifting lm labels to the right
+- decoder_input_ids = self._shift_right(labels)
+-
+- # Set device for model parallelism
+- if self.model_parallel:
+- torch.cuda.set_device(self.decoder.first_device)
+- hidden_states = hidden_states.to(self.decoder.first_device)
+- if decoder_input_ids is not None:
+- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
+- if attention_mask is not None:
+- attention_mask = attention_mask.to(self.decoder.first_device)
+- if decoder_attention_mask is not None:
+- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
+-
+- # Decode
+- decoder_outputs = self.decoder(
+- input_ids=decoder_input_ids,
+- attention_mask=decoder_attention_mask,
+- inputs_embeds=decoder_inputs_embeds,
+- past_key_values=past_key_values,
+- encoder_hidden_states=hidden_states,
+- encoder_attention_mask=attention_mask,
+- head_mask=decoder_head_mask,
+- cross_attn_head_mask=cross_attn_head_mask,
+- use_cache=use_cache,
+- output_attentions=output_attentions,
+- output_hidden_states=output_hidden_states,
+- return_dict=return_dict,
+- )
+-
+- sequence_output = decoder_outputs[0]
+-
+- # Set device for model parallelism
+- if self.model_parallel:
+- torch.cuda.set_device(self.encoder.first_device)
+- self.lm_head = self.lm_head.to(self.encoder.first_device)
+- sequence_output = sequence_output.to(self.lm_head.weight.device)
+-
+- if self.config.tie_word_embeddings:
+- # Rescale output before projecting on vocab
+- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+- sequence_output = sequence_output * (self.model_dim**-0.5)
+-
+- lm_logits = self.lm_head(sequence_output)
++ if self.is_mindie:
++ with torch.npu.stream(self.stream): # set stream
++ decoder_outputs = self.decoder_mindie.forward(*args)
++ self.stream.synchronize() # synchronize
++ else:
++ hidden_states = args[0]
++ past_cross_keys = args[1:self.config.num_decoder_layers+1]
++ past_cross_values = args[self.config.num_decoder_layers+1:2*self.config.num_decoder_layers+1]
++ past_keys= args[2*self.config.num_decoder_layers+1:3*self.config.num_decoder_layers+1]
++ past_values= args[3*self.config.num_decoder_layers+1:4*self.config.num_decoder_layers+1]
++ encoder_attention_mask = args[-2]
++ decoder_input_ids = args[-1]
++ decoder_outputs = self.decoder(input_ids=decoder_input_ids,
++ encoder_hidden_states=hidden_states,
++ past_keys=past_keys,
++ past_values=past_values,
++ past_cross_keys=past_cross_keys,
++ past_cross_values=past_cross_values,
++ encoder_attention_mask=encoder_attention_mask)
++
+
+ loss = None
+- if labels is not None:
+- loss_fct = CrossEntropyLoss(ignore_index=-100)
+- # move labels to correct device to enable PP
+- labels = labels.to(lm_logits.device)
+- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
+- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
+-
+- if not return_dict:
+- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+- return ((loss,) + output) if loss is not None else output
+-
+- return Seq2SeqLMOutput(
+- loss=loss,
+- logits=lm_logits,
+- past_key_values=decoder_outputs.past_key_values,
+- decoder_hidden_states=decoder_outputs.hidden_states,
+- decoder_attentions=decoder_outputs.attentions,
+- cross_attentions=decoder_outputs.cross_attentions,
+- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+- encoder_hidden_states=encoder_outputs.hidden_states,
+- encoder_attentions=encoder_outputs.attentions,
+- )
++ return (decoder_outputs[0],decoder_outputs[1],decoder_outputs[2])
+
+- # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_inputs_for_generation
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+- past_key_values=None,
++ past_cross_keys=None,
++ past_cross_values=None,
++ past_keys=None,
++ past_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+@@ -1834,8 +2199,8 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+- if past_key_values is not None:
+- past_length = past_key_values[0][0].shape[2]
++ if past_keys is not None:
++ past_length = past_keys[0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+@@ -1848,7 +2213,10 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
+
+ return {
+ "decoder_input_ids": input_ids,
+- "past_key_values": past_key_values,
++ "past_cross_keys":past_cross_keys,
++ "past_cross_values":past_cross_values,
++ "past_keys":past_keys,
++ "past_values":past_values,
+ "encoder_outputs": encoder_outputs,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+@@ -1893,6 +2261,419 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
+ return reordered_decoder_past
+
++ def _prepare_encoder_decoder_kwargs_for_generation(
++ self,
++ inputs_tensor: torch.Tensor,
++ model_kwargs,
++ model_input_name,
++ generation_config,
++ ):
++ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
++ encoder_kwargs = {
++ argument: value
++ for argument, value in model_kwargs.items()
++ if not any(argument.startswith(p) for p in irrelevant_prefix)
++ }
++ encoder_kwargs["output_attentions"] = generation_config.output_attentions
++ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
++ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
++ encoder_kwargs["return_dict"] = True
++ encoder_kwargs[model_input_name] = inputs_tensor
++ import time
++ start_time = time.time()
++ with torch.npu.stream(self.stream): # set stream
++ encoder_outputs=self.encoder_mindie.forward(encoder_kwargs["input_ids"])
++ self.stream.synchronize() # synchronize
++ model_kwargs["encoder_outputs"]={"last_hidden_state":encoder_outputs[0]}
++ model_kwargs["past_cross_keys"] = encoder_outputs[1]
++ model_kwargs["past_cross_values"] =encoder_outputs[2]
++ return model_kwargs
++
++ def _update_model_kwargs_for_generation(
++ self,
++ outputs,
++ model_kwargs,
++ is_encoder_decoder = False,
++ standardize_cache_format = False,
++ num_new_tokens = 1,
++ ):
++ # update past_key_values keeping its naming used in model code
++ cache_name, cache = self._extract_past_from_model_output(
++ outputs, standardize_cache_format=standardize_cache_format
++ )
++ model_kwargs[cache_name] = cache
++ if "past_keys" in outputs:
++ past_keys = outputs.past_keys
++ model_kwargs["past_keys"] = past_keys
++ if "past_values" in outputs:
++ past_values = outputs.past_values
++ model_kwargs["past_values"] = past_values
++ # update decoder attention mask
++ if "decoder_attention_mask" in model_kwargs:
++ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
++ model_kwargs["decoder_attention_mask"] = torch.cat(
++ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
++ dim=-1,
++ )
++ return model_kwargs
++
++ @torch.no_grad()
++ def generate(
++ self,
++ inputs = None,
++ generation_config = None,
++ logits_processor = None,
++ stopping_criteria = None,
++ prefix_allowed_tokens_fn = None,
++ assistant_model = None,
++ negative_prompt_ids = None,
++ negative_prompt_attention_mask = None,
++ **kwargs,
++ ):
++ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
++ import time
++ start_time = time.time()
++ self._validate_model_class()
++ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
++ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
++ self._validate_model_kwargs(model_kwargs.copy())
++
++
++ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
++ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
++
++ accepts_attention_mask = True
++ requires_attention_mask = "encoder_outputs" not in model_kwargs
++ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
++
++ # 3. Define model inputs
++ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
++ inputs, generation_config.bos_token_id, model_kwargs
++ )
++ batch_size = inputs_tensor.shape[0]
++
++ device = inputs_tensor.device
++ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
++
++ # 4. Define other model kwargs
++ # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
++ # generating the first new token or not, and we only want to use the embeddings for the first new token)
++ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
++ model_kwargs["use_cache"] = True
++ else:
++ model_kwargs["use_cache"] = generation_config.use_cache
++ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
++ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
++ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
++ )
++
++ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
++ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
++ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
++ inputs_tensor, model_kwargs, model_input_name, generation_config
++ )
++
++ # 5. Prepare `input_ids` which will be used for auto-regressive generation
++ if self.config.is_encoder_decoder:
++ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
++ batch_size=batch_size,
++ model_input_name=model_input_name,
++ model_kwargs=model_kwargs,
++ decoder_start_token_id=generation_config.decoder_start_token_id,
++ device=inputs_tensor.device,
++ )
++ else:
++ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
++
++ if generation_config.token_healing:
++ input_ids = self.heal_tokens(input_ids, tokenizer)
++
++ # 6. Prepare `max_length` depending on other stopping criteria.
++ input_ids_length = input_ids.shape[-1]
++ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
++ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
++ generation_config = self._prepare_generated_length(
++ generation_config=generation_config,
++ has_default_max_length=has_default_max_length,
++ has_default_min_length=has_default_min_length,
++ model_input_name=model_input_name,
++ inputs_tensor=inputs_tensor,
++ input_ids_length=input_ids_length,
++ )
++
++ use_dynamic_cache_by_default = False
++ if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
++ raise ValueError(
++ "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
++ "Cache object) is unsupported. Please use only one of the two."
++ )
++ elif generation_config.cache_implementation is not None:
++ if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
++ if generation_config.cache_implementation == "static" and not self._supports_static_cache:
++ raise ValueError(
++ "This model does not support `cache_implementation='static'`. Please check the following "
++ "issue: https://github.com/huggingface/transformers/issues/28981"
++ )
++ model_kwargs["past_key_values"] = self._get_cache(
++ generation_config.cache_implementation,
++ getattr(generation_config, "num_beams", 1) * batch_size,
++ generation_config.max_length,
++ )
++ elif generation_config.cache_implementation == "quantized":
++ if not self._supports_quantized_cache:
++ raise ValueError(
++ "This model does not support the quantized cache. If you want your model to support quantized "
++ "cache, please open an issue."
++ )
++
++ cache_config = (
++ generation_config.cache_config
++ if generation_config.cache_config is not None
++ else QuantizedCacheConfig()
++ )
++ cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
++
++ if cache_config.backend == "quanto" and not is_quanto_available():
++ raise ImportError(
++ "You need to install `quanto` in order to use KV cache quantization with quanto backend. "
++ "Please install it via with `pip install quanto`"
++ )
++ elif cache_config.backend == "HQQ" and not is_hqq_available():
++ raise ImportError(
++ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
++ "Please install it via with `pip install hqq`"
++ )
++
++ model_kwargs["past_key_values"] = cache_class(cache_config)
++ # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
++ # keeps copying the cache thus using much more memory
++ elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
++ past = model_kwargs.get("past_key_values", None)
++ if past is None:
++ model_kwargs["past_key_values"] = DynamicCache()
++ use_dynamic_cache_by_default = True
++ elif isinstance(past, tuple):
++ model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
++ use_dynamic_cache_by_default = True
++
++ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
++
++ # 7. determine generation mode
++ generation_mode = generation_config.get_generation_mode(assistant_model)
++ # 8. prepare distribution pre_processing samplers
++ prepared_logits_processor = self._get_logits_processor(
++ generation_config=generation_config,
++ input_ids_seq_length=input_ids_length,
++ encoder_input_ids=inputs_tensor,
++ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
++ logits_processor=logits_processor,
++ device=inputs_tensor.device,
++ model_kwargs=model_kwargs,
++ negative_prompt_ids=negative_prompt_ids,
++ negative_prompt_attention_mask=negative_prompt_attention_mask,
++ )
++
++ # 9. prepare stopping criteria
++ prepared_stopping_criteria = self._get_stopping_criteria(
++ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
++ )
++
++ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
++ # 11. prepare logits warper
++ prepared_logits_warper = (
++ self._get_logits_warper(generation_config, device=input_ids.device)
++ if generation_config.do_sample
++ else None
++ )
++
++ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
++ input_ids, model_kwargs = self._expand_inputs_for_generation(
++ input_ids=input_ids,
++ expand_size=generation_config.num_return_sequences,
++ is_encoder_decoder=self.config.is_encoder_decoder,
++ **model_kwargs,
++ )
++ # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
++ result = self._sample(
++ input_ids,
++ logits_processor=prepared_logits_processor,
++ logits_warper=prepared_logits_warper,
++ stopping_criteria=prepared_stopping_criteria,
++ generation_config=generation_config,
++ **model_kwargs,
++ )
++ return result
++
++ def _sample(
++ self,
++ input_ids,
++ logits_processor,
++ stopping_criteria,
++ generation_config,
++ logits_warper = None,
++ **model_kwargs,
++ ):
++ # init values
++ pad_token_id = generation_config.pad_token_id
++ output_attentions = generation_config.output_attentions
++ output_hidden_states = generation_config.output_hidden_states
++ output_scores = generation_config.output_scores
++ output_logits = generation_config.output_logits
++ return_dict_in_generate = generation_config.return_dict_in_generate
++ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
++ do_sample = generation_config.do_sample
++ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
++ raise ValueError(
++ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
++ f"{logits_warper})."
++ )
++
++ # init attention / hidden states / scores tuples
++ scores = () if (return_dict_in_generate and output_scores) else None
++ raw_logits = () if (return_dict_in_generate and output_logits) else None
++ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
++ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
++ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
++
++ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
++ if return_dict_in_generate and self.config.is_encoder_decoder:
++ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
++ encoder_hidden_states = (
++ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
++ )
++
++ this_peer_finished = False
++ batch_size = input_ids.shape[0]
++ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
++ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
++
++ # keep track of which sequences are already finished
++ if self.is_mindie or self.config.architectures[0]=="MT5ForConditionalGeneration":
++ num_layers = self.config.num_layers
++ num_heads = self.config.num_heads
++ d_kv = self.config.d_kv
++ model_kwargs["past_keys"] = [torch.randn(batch_size, num_heads, 0, d_kv).half().npu() for _ in range(num_layers)]
++ model_kwargs["past_values"] = [torch.randn(batch_size, num_heads, 0, d_kv).half().npu() for _ in range(num_layers)]
++
++
++ while self._has_unfinished_sequences(this_peer_finished, False, device=input_ids.device):
++ # prepare model inputs
++ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
++ model_args = [model_kwargs["encoder_outputs"]["last_hidden_state"]]
++ model_args.extend(model_kwargs["past_cross_keys"])
++ model_args.extend(model_kwargs["past_cross_values"])
++ model_args.extend(model_inputs["past_keys"])
++ model_args.extend(model_inputs["past_values"])
++ model_args.append(model_inputs["attention_mask"])
++ model_args.append(model_inputs["decoder_input_ids"])
++
++ # forward pass to get next token
++ outputs = self(*model_args)
++ outputs = Seq2SeqLMOutput(logits=outputs[0],
++ past_keys=outputs[1],
++ past_values=outputs[2])
++
++ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
++ # (the clone itself is always small)
++ next_token_logits = outputs.logits[:, -1, :].clone()
++
++ # pre-process distribution
++ next_token_scores = logits_processor(input_ids, next_token_logits)
++ if do_sample:
++ next_token_scores = logits_warper(input_ids, next_token_scores)
++
++ # Store scores, attentions and hidden_states when required
++ if return_dict_in_generate:
++ if output_scores:
++ scores += (next_token_scores,)
++ if output_logits:
++ raw_logits += (next_token_logits,)
++ if output_attentions:
++ decoder_attentions += (
++ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
++ )
++ if self.config.is_encoder_decoder:
++ cross_attentions += (outputs.cross_attentions,)
++
++ if output_hidden_states:
++ decoder_hidden_states += (
++ (outputs.decoder_hidden_states,)
++ if self.config.is_encoder_decoder
++ else (outputs.hidden_states,)
++ )
++
++ # token selection
++ if do_sample:
++ probs = nn.functional.softmax(next_token_scores, dim=-1)
++ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
++ else:
++ next_tokens = torch.argmax(next_token_scores, dim=-1)
++
++ # finished sentences should have their next token be a padding token
++ if has_eos_stopping_criteria:
++ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
++
++ # update generated ids, model inputs, and length for next step
++ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
++ model_kwargs = self._update_model_kwargs_for_generation(
++ outputs,
++ model_kwargs,
++ is_encoder_decoder=self.config.is_encoder_decoder,
++ )
++ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
++ this_peer_finished = unfinished_sequences.max() == 0
++ # This is needed to properly delete outputs.logits which may be very large for first iteration
++ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
++ del outputs
++ return input_ids
++
++ def invert_attention_mask(self, encoder_attention_mask):
++ if encoder_attention_mask.dim() == 3:
++ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
++ if encoder_attention_mask.dim() == 2:
++ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
++ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
++
++ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1000
++
++ return encoder_extended_attention_mask
++
++ @property
++ def device(self) -> torch.device:
++ """
++ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
++ device).
++ """
++ return self.get_device()
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, deviceNone, dtype=None
++ ):
++ if dtype is None:
++ dtype = self.dtype
++
++ if not (attention_mask.dim() == 2 and self.config.is_decoder):
++ if device is not None:
++ warnings.warn(
++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
++ )
++ if attention_mask.dim() == 3:
++ extended_attention_mask = attention_mask[:, None, :, :]
++ elif attention_mask.dim() == 2:
++ if self.config.is_decoder:
++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
++ input_shape, attention_mask, device
++ )
++ else:
++ extended_attention_mask = attention_mask[:, None, None, :]
++ else:
++ raise ValueError(
++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
++ )
++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000
++ return extended_attention_mask
++
+
+ @add_start_docstrings(
+ "The bare MT5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
diff --git a/MindIE/MindIE-Torch/built-in/MT5/readme.md b/MindIE/MindIE-Torch/built-in/MT5/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..4862823913f94b04d6601e366d109b757012f524
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/MT5/readme.md
@@ -0,0 +1,96 @@
+# MT5模型-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+ - [输入输出数据](#section540883920406)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+ - [模型推理](#section741711594517)
+
+
+
+# 概述
+
+ T5全称是Text-to-Text Transfer Transformer,是一种模型架构或者说是一种解决NLP任务的一种范式。把所有任务,如分类、相似度计算、文本生成都用一个Text-to-text(文本到文本)的框架里进行解决。
+ 权重下载:https://huggingface.co/collections/google/mt5-release-65005f1a520f8d7b4d039509
+
+## 输入输出数据
+
+- 输入数据
+
+ | 输入数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | input | batchsize x input_seq_len | FLOAT16 | NHWC |
+
+
+- 输出数据
+
+ | 输出数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | output | batchsize x input_seq_len | INT32 | NTHWC |
+
+
+# 推理环境准备
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+-
+ | 配套 | 版本 | 备注 |
+ | ------------------------------------------------------------ |--------| ------------------------------------------------------------ |
+ | Python | 3.10.2 | - |
+ | torch | 2.1.0 | 导出pt模型所需版本 |
+ | torch_npu | 2.1.0 | 模型编译和推理所需版本 |
+
+
+# 快速上手
+
+
+1. 安装transformers4.42.0版本。
+ ```bash
+ pip3 install transformers==4.42.0
+ ```
+
+2. 安装mindie包,需要与torch_npu配合使用,请参考mindietorch配套torch_npu配置环境
+
+ ```bash
+ # 安装mindie
+ chmod +x ./Ascend-mindie_xxx.run
+ ./Ascend-mindie_xxx.run --install
+ source /usr/local/Ascend/mindie/set_env.sh
+ ```
+
+3. 代码修改,在MT5目录下
+
+ 执行命令:
+ ```bash
+ python MT5_modeling_patch.py
+ ```
+4.导出mindietorch模型
+ ```bash
+ python export_mt5.py --output_dir {output_path} --model_path {model_path} --max_batchsize {max_batchsize} --max_input_seq_len {max_input_seq_len} --device_id {device_id}
+ ```
+参数说明:
+{output_path}是输出的目录
+{model_path}模型所在目录
+{max_batchsize}推理过程中最大的batchsize
+{max_input_seq_len}推理过程中最大输入长度
+{device_id} 用哪个npu device
+
+运行该命令后会自动生成encoder和decoder优化后的模型
+
+5.精度测试
+
+ ```bash
+python test_mt5.py --hf_model_path {model_path} --encoder_aie_path {encoder_aie_path} --decoder_aie_path {decoder_aie_path} --device_id device_id
+```
+
+参数说明:
+{model_path}模型所在目录
+{encoder_aie_path}优化后的encoder的模型路径,要具体到.pt文件
+{decoder_aie_path}优化后的decoder的模型路径,要具体到.pt文件
+{device_id} 用哪个npu device
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/MT5/test_mt5.py b/MindIE/MindIE-Torch/built-in/MT5/test_mt5.py
new file mode 100644
index 0000000000000000000000000000000000000000..92717df66f9e0df305d6a1fd9df3a71dcb9fb2f5
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/MT5/test_mt5.py
@@ -0,0 +1,48 @@
+import torch
+import time
+import argparse
+import torch_npu
+from transformers import MT5ForConditionalGeneration, AutoTokenizer, MT5Config
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--hf_model_path", type=str, required=True)
+
+ parser.add_argument("--encoder_aie_path", type=str, required=True)
+ parser.add_argument("--decoder_aie_path", type=str, required=True)
+
+ parser.add_argument("--device_id", type=int, help="NPU device id", default=0)
+
+ args = parser.parse_args()
+ return args
+
+def main():
+ args = parse_args()
+ torch.npu.set_device(args.device_id)
+ model = MT5ForConditionalGeneration.from_pretrained(args.hf_model_path, torch_dtype=torch.float16).npu()
+ encoder = model.encoder
+ decoder = model.decoder
+ encoder_input = torch.randint(0,2000,(8,10), dtype=torch.int64).npu()
+ t5_config = MT5Config.from_pretrained(args.hf_model_path)
+
+ encoder_output = encoder(encoder_input)[0]
+ model = MT5ForConditionalGeneration(config=t5_config,
+ encoder_path=args.encoder_aie_path,
+ decoder_path=args.decoder_aie_path,
+ device_id=args.device_id).half().npu()
+
+ encoder_mindie = model.encoder_mindie
+ decoder_mindie = model.decoder_mindie
+ mindie_stream = model.stream
+ with torch.npu.stream(mindie_stream): # set stream
+ mindie_encoder_output = encoder_mindie(encoder_input)[0]
+ mindie_stream.synchronize() # synchronize
+ if (torch.cosine_similarity(encoder_output.cpu().flatten(), mindie_encoder_output.cpu().flatten(),dim=0)) < 0.99:
+ print("encoder precision failed")
+ else:
+ print("test OK")
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/MindIE/MindIE-Torch/built-in/T5/T5_modeling_t5_patch.py b/MindIE/MindIE-Torch/built-in/T5/T5_modeling_t5_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6733e6904cdc4acd5ef4589a380fee0f71c9447
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/T5/T5_modeling_t5_patch.py
@@ -0,0 +1,40 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import transformers
+import argparse
+
+
+def main(args):
+ transformers_path = transformers.__path__
+ transformers_version = transformers.__version__
+
+ assert transformers_version =='4.42.0', "expectation transformers==4.42.0"
+ if args.ascend_soc == "Ascend910B4":
+ os.system(f'patch -p0 {transformers_path[0]}/models/t5/modeling_t5.py modeling_t5_800IA2.patch')
+ elif args.ascend_soc == "Ascend310P3":
+ os.system(f'patch -p0 {transformers_path[0]}/models/t5/modeling_t5.py modeling_t5.patch')
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ascend_soc", type=str, default="Ascend910B4",required=True)
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args)
diff --git a/MindIE/MindIE-Torch/built-in/T5/export_t5.py b/MindIE/MindIE-Torch/built-in/T5/export_t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c67b7c7efa4bc037fcf5622c8c4610944a5b088
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/T5/export_t5.py
@@ -0,0 +1,194 @@
+
+import torch
+import torch_npu
+import argparse
+import os
+import math
+import mindietorch
+from transformers import T5ForConditionalGeneration
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./models",
+ help="save dir"
+ )
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="./T5-Small",
+ help="T5 model path"
+ )
+ parser.add_argument(
+ "--max_batchsize",
+ type=int,
+ default=1,
+ help="max batchsize when running"
+ )
+
+ parser.add_argument(
+ "--max_input_seq_len",
+ type=int,
+ default=256,
+ help="max input_sequence length when running"
+ )
+
+
+ parser.add_argument(
+ "--device_id",
+ type=int,
+ default=0,
+ help="npu device id"
+ )
+ return parser.parse_args()
+
+
+class TextEncoderExport(torch.nn.Module):
+ def __init__(self, textencoder_model):
+ super(TextEncoderExport, self).__init__()
+ self.textencoder_model = textencoder_model
+
+ def forward(self, input_ids,attention_mask):
+ return self.textencoder_model(input_ids=input_ids, attention_mask=attention_mask)
+
+class TextDecoderExport(torch.nn.Module):
+ def __init__(self, textdecoder_model):
+ super(TextDecoderExport, self).__init__()
+ self.textdecoder_model = textdecoder_model
+
+ def forward(self,
+ *args):
+ return self.textdecoder_model(*args)
+
+def export_textencoder(args, model, save_dir, batch_size):
+ encoder_path = os.path.join(save_dir, "encoder")
+ if not os.path.exists(encoder_path):
+ os.makedirs(encoder_path, mode=0o640)
+ traced_path = os.path.join(encoder_path, "encoder.pt")
+ compiled_path = os.path.join(encoder_path, "encoder_compiled.pt")
+ if not os.path.exists(traced_path):
+ text_encoder = model.encoder
+ dummy_input = (
+ torch.ones([1, 128], dtype=torch.int64).npu(),
+ torch.ones([1, 128], dtype=torch.int64).npu()
+ )
+ encoder = TextEncoderExport(text_encoder)
+ encoder.eval()
+ torch.jit.trace(encoder, dummy_input, strict=False).save(traced_path)
+ if not os.path.exists(compiled_path):
+ traced_model = torch.jit.load(traced_path).eval()
+
+ inputs0 = []
+ inputs0.append(mindietorch.Input(min_shape = (1,1), max_shape= (args.max_batchsize, args.max_input_seq_len), dtype=torch.int64))
+ inputs0.append(mindietorch.Input(min_shape = (1,1), max_shape= (args.max_batchsize, args.max_input_seq_len), dtype=torch.int64))
+ print("compiling encoder")
+ compiled_model = mindietorch.compile(
+ traced_model,
+ inputs=inputs0,
+ allow_tensor_replace_int=True,
+ require_full_compilation=False,
+ truncate_long_and_double=True,
+ precision_policy=mindietorch.PrecisionPolicy.FP16,
+ soc_version="Ascend310P3",
+ optimization_level=0
+ )
+ compiled_model.save(compiled_path)
+
+def export_textdecoder(args, model, save_dir, batch_size):
+ decoder_path = os.path.join(save_dir, "decoder")
+ if not os.path.exists(decoder_path):
+ os.makedirs(decoder_path, mode=0o640)
+ traced_path = os.path.join(decoder_path, "decoder.pt")
+ compiled_path = os.path.join(decoder_path, "decoder_compiled.pt")
+ model_path = args.model_path
+ max_lenth = 120
+ if not os.path.exists(traced_path):
+ text_decoder = model
+ all_past_keys = [torch.randn([1, model.config.num_heads, 1, model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_values = [torch.randn([1, model.config.num_heads, 1, model.config.d_kv]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_cross_keys = [torch.randn([1, 16, model.config.d_model]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_cross_values = [torch.randn([1, 16, model.config.d_model]).to(torch.float16).npu()] * model.config.num_layers
+ dummy_input = [torch.randn(1, 16, model.config.d_model).to(torch.float16).npu()]
+ dummy_input.extend(all_past_cross_keys)
+ dummy_input.extend(all_past_cross_values)
+ dummy_input.extend(all_past_keys)
+ dummy_input.extend(all_past_values)
+ dummy_input.append(torch.ones(1,16).npu())
+ dummy_input.append(torch.ones([1, 1], dtype=torch.int64).npu())
+ decoder = TextDecoderExport(text_decoder).npu()
+ decoder.eval()
+ torch.jit.trace(decoder, dummy_input,strict=False).save(traced_path)
+ if not os.path.exists(compiled_path):
+ traced_model = torch.jit.load(traced_path).eval()
+ print("compiling decoder")
+ input_info = [mindietorch.Input(min_shape =(1, 1, model.config.d_model),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model),
+ dtype=mindietorch.dtype.FLOAT16)]
+ past_cross_key_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_model),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_cross_value_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_model),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_key_infos = [mindietorch.Input(min_shape =(1, model.config.num_heads, 0, model.config.d_kv),
+ max_shape=(args.max_batchsize, model.config.num_heads, args.max_input_seq_len, model.config.d_kv),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_value_infos = [mindietorch.Input(min_shape =(1, model.config.num_heads, 0, model.config.d_kv),
+ max_shape=(args.max_batchsize, model.config.num_heads, args.max_input_seq_len, model.config.d_kv),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ decoder_input_ids_info = [mindietorch.Input(min_shape =(1, 1),
+ max_shape = (args.max_batchsize,1),
+ dtype=mindietorch.dtype.INT64)]
+ encoder_attention_mask_info = [mindietorch.Input(min_shape =(1, 1),
+ max_shape = (args.max_batchsize,args.max_input_seq_len),
+ dtype=mindietorch.dtype.INT64)]
+ input_info.extend(past_cross_key_infos)
+ input_info.extend(past_cross_value_infos)
+ input_info.extend(past_key_infos)
+ input_info.extend(past_value_infos)
+ input_info.extend(encoder_attention_mask_info)
+ input_info.extend(decoder_input_ids_info)
+ buffer = []
+ for _ in range(2*model.config.num_layers):
+ buffer.append(math.ceil((args.max_batchsize * args.max_input_seq_len * model.config.d_model * 2) / 1024 / 1024))
+ buffer_size0 = math.ceil((args.max_batchsize * 1 * model.config.vocab_size * 4) / 1024 / 1024)
+ buffer.append(buffer_size0)
+ print("buffer=",buffer)
+ compiled_model = mindietorch.compile(
+ traced_model,
+ inputs=input_info,
+ allow_tensor_replace_int=True,
+ require_full_compilation=False,
+ truncate_long_and_double=True,
+ precision_policy=mindietorch.PrecisionPolicy.FP16,
+ soc_version="Ascend310P3",
+ default_buffer_size_vec=buffer,
+ optimization_level=0
+ )
+ compiled_model.save(compiled_path)
+
+
+def main():
+ args = parse_arguments()
+ device_id = args.device_id
+ save_dir = args.output_dir
+ torch.npu.set_device(device_id)
+ batch_size = 1
+ model = T5ForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.float).npu()
+ encoder_path = os.path.join(save_dir, "encoder")
+ compiled_path = os.path.join(encoder_path, "encoder_compiled.pt")
+ if not os.path.exists(compiled_path):
+ export_textencoder(args, model, save_dir, batch_size)
+ print("export encoder_model done!")
+
+ decoder_path = os.path.join(save_dir, "decoder")
+ compiled_path = os.path.join(decoder_path, "decoder_compiled.pt")
+ if not os.path.exists(compiled_path):
+ export_textdecoder(args, model, save_dir, batch_size)
+ print("export decoder_model done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/T5/export_t5_800IA2.py b/MindIE/MindIE-Torch/built-in/T5/export_t5_800IA2.py
new file mode 100644
index 0000000000000000000000000000000000000000..e150e8e93a644539df463321f279fdd6929d7312
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/T5/export_t5_800IA2.py
@@ -0,0 +1,202 @@
+
+import torch
+import torch_npu
+import argparse
+import os
+import math
+import mindietorch
+from transformers import T5ForConditionalGeneration
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./models",
+ help="save dir"
+ )
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default="./T5-Small",
+ help="T5 model path"
+ )
+ parser.add_argument(
+ "--max_batchsize",
+ type=int,
+ default=1,
+ help="max batchsize when running"
+ )
+
+ parser.add_argument(
+ "--max_input_seq_len",
+ type=int,
+ default=256,
+ help="max input_sequence length when running"
+ )
+
+
+ parser.add_argument(
+ "--device_id",
+ type=int,
+ default=0,
+ help="npu device id"
+ )
+ return parser.parse_args()
+
+
+class TextEncoderExport(torch.nn.Module):
+ def __init__(self, textencoder_model):
+ super(TextEncoderExport, self).__init__()
+ self.textencoder_model = textencoder_model
+
+ def forward(self, input_ids,attention_mask):
+ return self.textencoder_model(input_ids=input_ids,attention_mask=attention_mask)
+
+class TextDecoderExport(torch.nn.Module):
+ def __init__(self, textdecoder_model):
+ super(TextDecoderExport, self).__init__()
+ self.textdecoder_model = textdecoder_model
+
+ def forward(self,
+ *args):
+ return self.textdecoder_model(*args)
+
+def export_textencoder(args, model, save_dir, batch_size):
+ encoder_path = os.path.join(save_dir, "encoder")
+ if not os.path.exists(encoder_path):
+ os.makedirs(encoder_path, mode=0o640)
+ traced_path = os.path.join(encoder_path, "encoder.pt")
+ compiled_path = os.path.join(encoder_path, "encoder_compiled.pt")
+ if not os.path.exists(traced_path):
+ text_encoder = model.encoder
+ dummy_input = (
+ torch.ones([1, 128], dtype=torch.int64).npu(),
+ torch.ones([1, 1,128,128], dtype=torch.bool).npu()
+ )
+ encoder = TextEncoderExport(text_encoder)
+ encoder.eval()
+ torch.jit.trace(encoder, dummy_input, strict=False).save(traced_path)
+ if not os.path.exists(compiled_path):
+ traced_model = torch.jit.load(traced_path).eval()
+
+ inputs0 = []
+ inputs0.append(mindietorch.Input(min_shape = (1,1), max_shape= (args.max_batchsize, args.max_input_seq_len), dtype=torch.int64))
+ inputs0.append(mindietorch.Input(min_shape = (1,1,1,1), max_shape= (args.max_batchsize, 1,args.max_input_seq_len,args.max_input_seq_len), dtype=torch.bool))
+ print("compiling encoder")
+ compiled_model = mindietorch.compile(
+ traced_model,
+ inputs=inputs0,
+ allow_tensor_replace_int=True,
+ require_full_compilation=False,
+ truncate_long_and_double=True,
+ precision_policy=mindietorch.PrecisionPolicy.FP16,
+ soc_version="Ascend910B4",
+ optimization_level=0
+ )
+ compiled_model.save(compiled_path)
+
+def export_textdecoder(args, model, save_dir, batch_size):
+ decoder_path = os.path.join(save_dir, "decoder")
+ if not os.path.exists(decoder_path):
+ os.makedirs(decoder_path, mode=0o640)
+ traced_path = os.path.join(decoder_path, "decoder.pt")
+ compiled_path = os.path.join(decoder_path, "decoder_compiled.pt")
+ model_path = args.model_path
+ max_lenth = 120
+ if not os.path.exists(traced_path):
+ text_decoder = model
+ all_past_keys = [torch.randn([1, 1, model.config.d_kv*model.config.num_heads]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_values = [torch.randn([1, 1, model.config.d_kv*model.config.num_heads]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_cross_keys = [torch.randn([1, 16, model.config.d_kv*model.config.num_heads]).to(torch.float16).npu()] * model.config.num_layers
+ all_past_cross_values = [torch.randn([1, 16, model.config.d_kv*model.config.num_heads]).to(torch.float16).npu()] * model.config.num_layers
+ dummy_input = [torch.randn(1, 16, model.config.d_kv*model.config.num_heads).to(torch.float16).npu()]
+ dummy_input.extend(all_past_cross_keys)
+ dummy_input.extend(all_past_cross_values)
+ dummy_input.extend(all_past_keys)
+ dummy_input.extend(all_past_values)
+ # encoder_attention_mask
+ dummy_input.append(torch.ones((1,1,16,16),dtype=torch.bool).npu())
+ # decoder_input_ids
+ dummy_input.append(torch.ones([1, 1], dtype=torch.int64).npu())
+ dummy_input.append(torch.ones([1, 1, 1, 1], dtype=torch.bool).npu())
+ # decoder_attention_mask
+ decoder = TextDecoderExport(text_decoder).npu()
+ decoder.eval()
+ torch.jit.trace(decoder, dummy_input,strict=False).save(traced_path)
+ if not os.path.exists(compiled_path):
+ traced_model = torch.jit.load(traced_path).eval()
+ print("compiling decoder")
+ input_info = [mindietorch.Input(min_shape =(1, 1, model.config.d_model),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_model),
+ dtype=mindietorch.dtype.FLOAT16)]
+ past_cross_key_infos = [mindietorch.Input(min_shape =(1, 1, model.config.num_heads*model.config.d_kv),
+ max_shape=(args.max_batchsize,args.max_input_seq_len, model.config.num_heads*model.config.d_kv),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_cross_value_infos = [mindietorch.Input(min_shape =(1, 1, model.config.d_kv*model.config.num_heads),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_kv*model.config.num_heads),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_key_infos = [mindietorch.Input(min_shape =(1, 0, model.config.d_kv*model.config.num_heads),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_kv*model.config.num_heads),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ past_value_infos = [mindietorch.Input(min_shape =(1, 0, model.config.d_kv*model.config.num_heads),
+ max_shape=(args.max_batchsize, args.max_input_seq_len, model.config.d_kv*model.config.num_heads),
+ dtype=mindietorch.dtype.FLOAT16)] * model.config.num_layers
+ decoder_input_ids_info = [mindietorch.Input(min_shape =(1, 1),
+ max_shape = (args.max_batchsize,1),
+ dtype=mindietorch.dtype.INT64)]
+ encoder_attention_mask_info = [mindietorch.Input(min_shape =(1, 1,1, 1),
+ max_shape = (args.max_batchsize, 1, args.max_input_seq_len,args.max_input_seq_len),
+ dtype=mindietorch.dtype.BOOL)]
+ decoder_attention_mask_info = [mindietorch.Input(min_shape =(1, 1,1,1),
+ max_shape = (args.max_batchsize,1,args.max_input_seq_len,args.max_input_seq_len),
+ dtype=mindietorch.dtype.BOOL)]
+ input_info.extend(past_cross_key_infos)
+ input_info.extend(past_cross_value_infos)
+ input_info.extend(past_key_infos)
+ input_info.extend(past_value_infos)
+ input_info.extend(encoder_attention_mask_info)
+ input_info.extend(decoder_input_ids_info)
+ input_info.extend(decoder_attention_mask_info)
+ buffer = []
+ for _ in range(2*model.config.num_layers):
+ buffer.append(math.ceil((args.max_batchsize * args.max_input_seq_len * model.config.d_model * 2) / 1024 / 1024))
+ buffer_size0 = math.ceil((args.max_batchsize * 1 * model.config.vocab_size * 4) / 1024 / 1024)
+ buffer.append(buffer_size0)
+ print("buffer=",buffer)
+ compiled_model = mindietorch.compile(
+ traced_model,
+ inputs=input_info,
+ allow_tensor_replace_int=True,
+ require_full_compilation=False,
+ truncate_long_and_double=True,
+ precision_policy=mindietorch.PrecisionPolicy.FP16,
+ soc_version="Ascend910B4",
+ default_buffer_size_vec=buffer,
+ optimization_level=0
+ )
+ compiled_model.save(compiled_path)
+
+
+def main():
+ args = parse_arguments()
+ device_id = args.device_id
+ save_dir = args.output_dir
+ torch.npu.set_device(device_id)
+ batch_size = 1
+ model = T5ForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.float).npu()
+ encoder_path = os.path.join(save_dir, "encoder")
+ compiled_path = os.path.join(encoder_path, "encoder_compiled.pt")
+ if not os.path.exists(compiled_path):
+ export_textencoder(args, model, save_dir, batch_size)
+ print("export encoder_model done!")
+
+ decoder_path = os.path.join(save_dir, "decoder")
+ compiled_path = os.path.join(decoder_path, "decoder_compiled.pt")
+ if not os.path.exists(compiled_path):
+ export_textdecoder(args, model, save_dir, batch_size)
+ print("export decoder_model done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/MindIE/MindIE-Torch/built-in/T5/main.py b/MindIE/MindIE-Torch/built-in/T5/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e20f1e05ed7c8458fcdebb882105277d445c809
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/T5/main.py
@@ -0,0 +1,51 @@
+import torch
+import torch_npu
+import mindietorch
+import time
+import argparse
+from transformers import T5ForConditionalGeneration, AutoTokenizer, T5Config
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--hf_model_path", type=str, required=True)
+
+ parser.add_argument("--encoder_aie_path", type=str, required=True)
+ parser.add_argument("--decoder_aie_path", type=str, required=True)
+
+ parser.add_argument("--device_id", type=int, help="NPU device id", default=0)
+
+ parser.add_argument("--performance", action="store_true")
+
+ args = parser.parse_args()
+ return args
+
+def main():
+ args = parse_args()
+
+ torch.npu.set_device(args.device_id)
+ tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path)
+ text = ["今年2月26日,阿富汗塔里班的最高领秀下令销毁全国范围内所有“非伊斯兰“的古文化遗产,其中包括矗立于巴米扬的世高(大界最约58米)的立式佛像。"]
+ t5_config = T5Config.from_pretrained(args.hf_model_path)
+ # model = T5ForConditionalGeneration.from_pretrained(args.hf_model_path).half().npu()
+ model = T5ForConditionalGeneration(config=t5_config,
+ encoder_path=args.encoder_aie_path,
+ decoder_path=args.decoder_aie_path,
+ device_id=args.device_id).half().npu()
+ input_ids = tokenizer(text, return_tensors = "pt", padding=True).input_ids
+ if args.performance:
+ input_ids = torch.randint(0,32000,(1,512))
+ outputs = model.generate(input_ids.npu(),max_new_tokens=512)
+ print("token length : ", input_ids.shape)
+ start_time = time.time()
+
+ outputs = model.generate(input_ids.npu(),max_new_tokens=512)
+ inference_time = time.time()-start_time
+ print("time_cost=", inference_time)
+ print("output token length : ", outputs[0].shape[0])
+ print("throught output is : ", outputs[0].shape[0] / inference_time)
+ if not args.performance:
+ print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/MindIE/MindIE-Torch/built-in/T5/modeling_t5.patch b/MindIE/MindIE-Torch/built-in/T5/modeling_t5.patch
new file mode 100644
index 0000000000000000000000000000000000000000..15f81df2a4a590bc64353f172b06f03a9371377f
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/T5/modeling_t5.patch
@@ -0,0 +1,1635 @@
+diff --git a/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5_origin.py b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py
+index 224769f..24f868b 100644
+--- a/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5_origin.py
++++ b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py
+@@ -19,8 +19,10 @@ import math
+ import os
+ import warnings
+ from typing import List, Optional, Tuple, Union
+-
++from dataclasses import dataclass
+ import torch
++import torch_npu
++import mindietorch
+ from torch import nn
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+@@ -28,13 +30,12 @@ from ...activations import ACT2FN
+ from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+- Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+ Seq2SeqQuestionAnsweringModelOutput,
+ Seq2SeqSequenceClassifierOutput,
+ TokenClassifierOutput,
+ )
+-from ...modeling_utils import PreTrainedModel
++from ...modeling_utils import PreTrainedModel,ModuleUtilsMixin
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
+ from ...utils import (
+ DUMMY_INPUTS,
+@@ -47,7 +48,43 @@ from ...utils import (
+ )
+ from ...utils.model_parallel_utils import assert_device_map, get_device_map
+ from .configuration_t5 import T5Config
++from transformers.generation.logits_process import LogitsProcessorList
++from transformers.generation.stopping_criteria import StoppingCriteriaList
++from transformers.generation.configuration_utils import GenerationMode
++from transformers.utils.generic import ModelOutput
++
++
++@dataclass
++class Seq2SeqLMOutput(ModelOutput):
++ """
++ Base class for model's outputs, with potential hidden states and attentions.
+
++ Args:
++ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
++ Sequence of hidden-states at the output of the last layer of the model.
++ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
++ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
++ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
++
++ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
++ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
++ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
++ sequence_length)`.
++
++ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
++ heads.
++ """
++ loss: Optional[torch.FloatTensor] = None
++ logits: torch.FloatTensor = None
++ past_keys: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ past_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
++ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
++ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
++ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
++ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
++ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+ logger = logging.get_logger(__name__)
+
+@@ -448,7 +485,10 @@ class T5Attention(nn.Module):
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
++ past_cross_key=None,
++ past_cross_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+@@ -464,12 +504,8 @@ class T5Attention(nn.Module):
+
+ real_seq_length = seq_length
+
+- if past_key_value is not None:
+- if len(past_key_value) != 2:
+- raise ValueError(
+- f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+- )
+- real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
++ if past_key is not None:
++ real_seq_length += past_key.shape[2] if query_length is None else query_length
+
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
+
+@@ -493,16 +529,17 @@ class T5Attention(nn.Module):
+ hidden_states = shape(proj_layer(key_value_states))
+
+ if past_key_value is not None:
++ past_key_value = shape(past_key_value)
+ if key_value_states is None:
+ # self-attn
+ # (batch_size, n_heads, key_length, dim_per_head)
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
+- elif past_key_value.shape[2] != key_value_states.shape[1]:
+- # checking that the `sequence_length` of the `past_key_value` is the same as
+- # the provided `key_value_states` to support prefix tuning
+- # cross-attn
+- # (batch_size, n_heads, seq_length, dim_per_head)
+- hidden_states = shape(proj_layer(key_value_states))
++ # elif past_key_value.shape[2] != key_value_states.shape[1]:
++ # # checking that the `sequence_length` of the `past_key_value` is the same as
++ # # the provided `key_value_states` to support prefix tuning
++ # # cross-attn
++ # # (batch_size, n_heads, seq_length, dim_per_head)
++ # hidden_states = shape(proj_layer(key_value_states))
+ else:
+ # cross-attn
+ hidden_states = past_key_value
+@@ -513,17 +550,16 @@ class T5Attention(nn.Module):
+
+ # get key/value states
+ key_states = project(
+- hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
++ hidden_states, self.k, key_value_states, past_key if past_key is not None else None
+ )
+ value_states = project(
+- hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
++ hidden_states, self.v, key_value_states, past_value if past_value is not None else None
+ )
+-
++ # torch.ops.mindie.flash_attention_plugin(query_states, key_states, value_states,)
+ # compute scores
+ scores = torch.matmul(
+ query_states, key_states.transpose(3, 2)
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+-
+ if position_bias is None:
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+@@ -536,7 +572,7 @@ class T5Attention(nn.Module):
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+- if past_key_value is not None:
++ if past_key is not None:
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
+
+ if mask is not None:
+@@ -548,7 +584,6 @@ class T5Attention(nn.Module):
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+-
+ scores += position_bias_masked
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
+ scores
+@@ -564,18 +599,131 @@ class T5Attention(nn.Module):
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
+ attn_output = self.o(attn_output)
+
+- present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
+- outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None
++ present_value_state = (value_states.half(),) if (self.is_decoder and use_cache) else None
++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,)
++
++ if output_attentions:
++ outputs = outputs + (attn_weights,)
++ return outputs
++
++
++class T5SelfAttention(T5Attention):
++ def __init__(self, config: T5Config, has_relative_attention_bias=False):
++ super().__init__(config, has_relative_attention_bias)
++
++ def forward(
++ self,
++ hidden_states,
++ mask=None,
++ position_bias=None,
++ past_key=None,
++ past_value=None,
++ layer_head_mask=None,
++ use_cache=False,
++ output_attentions=False,
++ ):
++ """
++ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
++ """
++ # Input is (batch_size, seq_length, dim)
++ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
++ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
++ batch_size, seq_length = hidden_states.shape[:2]
++
++ real_seq_length = seq_length
++
++ if past_key is not None:
++ real_seq_length += past_key.shape[2]
++ key_length = real_seq_length
++ def shape(states):
++ """projection"""
++ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
++
++ def unshape(states):
++ """reshape"""
++ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
++
++ def project(hidden_states, proj_layer, past_key_value):
++ """projects hidden states correctly to key/query states"""
++ if past_key_value is None:
++ # cross-attn
++ # (batch_size, n_heads, seq_length, dim_per_head)
++ hidden_states = shape(proj_layer(hidden_states))
+
++ if past_key_value is not None:
++ hidden_states = shape(proj_layer(hidden_states))
++ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
++ return hidden_states
++
++ # get query states
++ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
++
++ # get key/value states
++ key_states = project(
++ hidden_states, self.k, past_key if past_key is not None else None
++ )
++ value_states = project(
++ hidden_states, self.v, past_value if past_value is not None else None
++ )
++ # compute scores
++ scores = torch.matmul(
++ query_states, key_states.transpose(3, 2)
++ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
++ if position_bias is None:
++ if not self.has_relative_attention_bias:
++ position_bias = torch.zeros(
++ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
++ )
++ if self.gradient_checkpointing and self.training:
++ position_bias.requires_grad = True
++ else:
++ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
++
++ # if key and values are already calculated
++ # we want only the last query position bias
++ if past_key is not None:
++ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
++ if mask is not None:
++ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
++
++ if self.pruned_heads:
++ mask = torch.ones(position_bias.shape[1])
++ mask[list(self.pruned_heads)] = 0
++ position_bias_masked = position_bias[:, mask.bool()]
++ else:
++ position_bias_masked = position_bias
++ scores += position_bias_masked
++ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
++ scores
++ ) # (batch_size, n_heads, seq_length, key_length)
++ attn_weights = nn.functional.dropout(
++ attn_weights, p=self.dropout, training=self.training
++ ) # (batch_size, n_heads, seq_length, key_length)
++
++ # Mask heads if we want to
++ if layer_head_mask is not None:
++ attn_weights = attn_weights * layer_head_mask
++
++ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
++ attn_output = self.o(attn_output)
++
++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None
++ present_value_state = (value_states.half(), ) if (self.is_decoder and use_cache) else None
++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,)
+ if output_attentions:
+ outputs = outputs + (attn_weights,)
+ return outputs
+
+
++
++
+ class T5LayerSelfAttention(nn.Module):
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+- self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
++ self.SelfAttention = T5SelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+@@ -585,7 +733,8 @@ class T5LayerSelfAttention(nn.Module):
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+@@ -595,7 +744,8 @@ class T5LayerSelfAttention(nn.Module):
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+@@ -618,7 +768,8 @@ class T5LayerCrossAttention(nn.Module):
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
+ use_cache=False,
+ query_length=None,
+ output_attentions=False,
+@@ -630,7 +781,8 @@ class T5LayerCrossAttention(nn.Module):
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+@@ -661,39 +813,34 @@ class T5Block(nn.Module):
+ encoder_decoder_position_bias=None,
+ layer_head_mask=None,
+ cross_attn_layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
++ past_cross_key=None,
++ past_cross_value=None,
+ use_cache=False,
+ output_attentions=False,
+ return_dict=True,
+ ):
+- if past_key_value is not None:
+- if not self.is_decoder:
+- logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
+- expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
+-
+- if len(past_key_value) != expected_num_past_key_values:
+- raise ValueError(
+- f"There should be {expected_num_past_key_values} past states. "
+- f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
+- f"Got {len(past_key_value)} past key / value states"
+- )
+-
+- self_attn_past_key_value = past_key_value[:2]
+- cross_attn_past_key_value = past_key_value[2:]
++ if past_key is not None:
++ self_attn_past_key = past_key
++ self_attn_past_value = past_value
++ cross_attn_past_key = past_cross_key
++ cross_attn_past_value = past_cross_value
+ else:
+- self_attn_past_key_value, cross_attn_past_key_value = None, None
++ self_attn_past_key, self_attn_past_value, cross_attn_past_key, cross_attn_past_value = None, None, None, None
+
+ self_attention_outputs = self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=self_attn_past_key_value,
++ past_key=self_attn_past_key,
++ past_value=self_attn_past_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+- hidden_states, present_key_value_state = self_attention_outputs[:2]
+- attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
++ hidden_states, present_key_state, present_value_state = self_attention_outputs[:3]
++ attention_outputs = self_attention_outputs[3:] # Keep self-attention outputs and relative position weights
+
+ # clamp inf values to enable fp16 training
+ if hidden_states.dtype == torch.float16:
+@@ -706,22 +853,23 @@ class T5Block(nn.Module):
+
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
+ if do_cross_attention:
++
+ # the actual query length is unknown for cross attention
+ # if using past key value states. Need to inject it here
+- if present_key_value_state is not None:
+- query_length = present_key_value_state[0].shape[2]
++ if present_key_state is not None:
++ query_length = present_key_state[0].shape[2]
+ else:
+ query_length = None
+-
+ cross_attention_outputs = self.layer[1](
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_bias=encoder_decoder_position_bias,
+ layer_head_mask=cross_attn_layer_head_mask,
+- past_key_value=cross_attn_past_key_value,
++ past_key=cross_attn_past_key,
++ past_value=cross_attn_past_value,
+ query_length=query_length,
+- use_cache=use_cache,
++ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = cross_attention_outputs[0]
+@@ -736,11 +884,9 @@ class T5Block(nn.Module):
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ # Combine self attn and cross attn key value states
+- if present_key_value_state is not None:
+- present_key_value_state = present_key_value_state + cross_attention_outputs[1]
+-
++ # cross_attn_past_key_values = cross_attention_outputs[1]
+ # Keep cross-attention outputs and relative position weights
+- attention_outputs = attention_outputs + cross_attention_outputs[2:]
++ attention_outputs = attention_outputs + cross_attention_outputs[3:]
+
+ # Apply Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states)
+@@ -757,7 +903,7 @@ class T5Block(nn.Module):
+ outputs = (hidden_states,)
+
+ if use_cache:
+- outputs = outputs + (present_key_value_state,) + attention_outputs
++ outputs = outputs + (present_key_state,) +(present_value_state,)+ attention_outputs
+ else:
+ outputs = outputs + attention_outputs
+
+@@ -897,11 +1043,15 @@ class T5PreTrainedModel(PreTrainedModel):
+
+
+ class T5Stack(T5PreTrainedModel):
+- def __init__(self, config, embed_tokens=None):
++ def __init__(self, config, embed_tokens=None,lm_head=None, encodecrosskey=None, encodecrossvalue=None):
+ super().__init__(config)
+
+ self.embed_tokens = embed_tokens
+ self.is_decoder = config.is_decoder
++ self.lm_head=lm_head
++ self.encodecrosskey = encodecrosskey
++ self.encodecrossvalue = encodecrossvalue
++ self.model_dim = config.d_model
+
+ self.block = nn.ModuleList(
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
+@@ -966,20 +1116,63 @@ class T5Stack(T5PreTrainedModel):
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens = new_embeddings
+
++ def invert_attention_mask(self, encoder_attention_mask):
++ if encoder_attention_mask.dim() == 3:
++ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
++ if encoder_attention_mask.dim() == 2:
++ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
++ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
++
++ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1000
++
++ return encoder_extended_attention_mask
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, device=None, dtype=None
++ ):
++ if dtype is None:
++ dtype = self.dtype
++
++ if not (attention_mask.dim() == 2 and self.config.is_decoder):
++ if device is not None:
++ warnings.warn(
++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
++ )
++ if attention_mask.dim() == 3:
++ extended_attention_mask = attention_mask[:, None, :, :]
++ elif attention_mask.dim() == 2:
++ if self.config.is_decoder:
++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
++ input_shape, attention_mask, device
++ )
++ else:
++ extended_attention_mask = attention_mask[:, None, None, :]
++ else:
++ raise ValueError(
++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
++ )
++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000
++ return extended_attention_mask
++
+ def forward(
+ self,
+ input_ids=None,
+- attention_mask=None,
+ encoder_hidden_states=None,
++ past_keys=None,
++ past_values=None,
++ past_cross_keys=None,
++ past_cross_values=None,
+ encoder_attention_mask=None,
++ attention_mask=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+- past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
++ **model_kwargs
+ ):
+ # Model parallel
+ if self.model_parallel:
+@@ -998,8 +1191,10 @@ class T5Stack(T5PreTrainedModel):
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
++
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
++ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+@@ -1012,18 +1207,19 @@ class T5Stack(T5PreTrainedModel):
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+-
+ # required mask seq length can be calculated via length of past
+- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
++ mask_seq_length = past_keys[0].shape[2] + seq_length if past_keys is not None else seq_length
+
+ if use_cache is True:
+ if not self.is_decoder:
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
+
+ # initialize past_key_values with `None` if past does not exist
+- if past_key_values is None:
+- past_key_values = [None] * len(self.block)
+-
++ if not self.is_decoder:
++ past_keys = [None] * len(self.block)
++ past_values = [None] * len(self.block)
++ past_cross_keys = [None] * len(self.block)
++ past_cross_values = [None] * len(self.block)
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+
+@@ -1054,7 +1250,8 @@ class T5Stack(T5PreTrainedModel):
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+- present_key_value_states = () if use_cache else None
++ present_key_states = () if use_cache else None
++ present_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+@@ -1062,8 +1259,8 @@ class T5Stack(T5PreTrainedModel):
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+-
+- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
++ # for i, layer_module in enumerate(self.block):
++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ # Model parallel
+@@ -1112,7 +1309,10 @@ class T5Stack(T5PreTrainedModel):
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
++ past_cross_key=past_cross_key,
++ past_cross_value=past_cross_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+@@ -1120,19 +1320,20 @@ class T5Stack(T5PreTrainedModel):
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if use_cache is False:
+- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:]
+
+- hidden_states, present_key_value_state = layer_outputs[:2]
++ hidden_states, present_key_state, present_value_state = layer_outputs[:3]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+- position_bias = layer_outputs[2]
++ position_bias = layer_outputs[3]
+ if self.is_decoder and encoder_hidden_states is not None:
+- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4]
+ # append next layer key value states
+ if use_cache:
+- present_key_value_states = present_key_value_states + (present_key_value_state,)
++ present_key_states = present_key_states + present_key_state
++ present_value_states = present_value_states + present_value_state
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+@@ -1146,7 +1347,7 @@ class T5Stack(T5PreTrainedModel):
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.final_layer_norm(hidden_states)
+- hidden_states = self.dropout(hidden_states)
++ hidden_states = self.dropout(hidden_states).half()
+
+ # Add last layer
+ if output_hidden_states:
+@@ -1164,13 +1365,216 @@ class T5Stack(T5PreTrainedModel):
+ ]
+ if v is not None
+ )
+- return BaseModelOutputWithPastAndCrossAttentions(
+- last_hidden_state=hidden_states,
+- past_key_values=present_key_value_states,
+- hidden_states=all_hidden_states,
+- attentions=all_attentions,
+- cross_attentions=all_cross_attentions,
++ if not self.is_decoder:
++ cross_keys = None
++ cross_values = None
++ if self.encodecrosskey:
++ cross_keys = self.encodecrosskey(hidden_states)
++ if self.encodecrossvalue:
++ cross_values = self.encodecrossvalue(hidden_states)
++ return tuple((hidden_states, cross_keys, cross_values))
++ lm_logits = None
++ if self.is_decoder:
++ if self.config.tie_word_embeddings:
++ hidden_states = hidden_states * (self.model_dim ** -0.5)
++ lm_logits = self.lm_head(hidden_states)
++ return tuple((lm_logits, present_key_states, present_value_states))
++
++
++class T5Stack_Encoder(T5PreTrainedModel):
++ def __init__(self, config, embed_tokens=None, encodecrosskey=None, encodecrossvalue=None):
++ super().__init__(config)
++ self.embed_tokens = embed_tokens
++ self.is_decoder = config.is_decoder
++ self.encodecrosskey = encodecrosskey
++ self.encodecrossvalue = encodecrossvalue
++ self.model_dim = config.d_model
++
++ self.block = nn.ModuleList(
++ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
+ )
++ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
++ self.dropout = nn.Dropout(config.dropout_rate)
++
++ # Initialize weights and apply final processing
++ self.post_init()
++ # Model parallel
++ self.model_parallel = False
++ self.device_map = None
++ self.gradient_checkpointing = False
++
++ def get_input_embeddings(self):
++ return self.embed_tokens
++
++ def set_input_embeddings(self, new_embeddings):
++ self.embed_tokens = new_embeddings
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, device=None, dtype=None
++ ):
++ if dtype is None:
++ dtype = self.dtype
++
++ if not (attention_mask.dim() == 2 and self.config.is_decoder):
++ if device is not None:
++ warnings.warn(
++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
++ )
++ if attention_mask.dim() == 3:
++ extended_attention_mask = attention_mask[:, None, :, :]
++ elif attention_mask.dim() == 2:
++ if self.config.is_decoder:
++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
++ input_shape, attention_mask, device
++ )
++ else:
++ extended_attention_mask = attention_mask[:, None, None, :]
++ else:
++ raise ValueError(
++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
++ )
++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000
++ return extended_attention_mask
++
++ def forward(
++ self,
++ input_ids=None,
++ attention_mask=None,
++ head_mask=None,
++ cross_attn_head_mask=None,
++ use_cache=None,
++ output_attentions=None,
++ output_hidden_states=None,
++ return_dict=None,
++ **model_kwargs
++ ):
++ # Model parallel
++ use_cache = use_cache if use_cache is not None else self.config.use_cache
++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
++ output_hidden_states = (
++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
++ )
++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
++
++ input_shape = input_ids.size()
++ input_ids = input_ids.view(-1, input_shape[-1])
++
++ inputs_embeds = self.embed_tokens(input_ids)
++
++ batch_size, seq_length = input_shape
++ # required mask seq length can be calculated via length of past
++ mask_seq_length = seq_length
++
++ # initialize past_key_values with `None` if past does not exist
++ past_keys = [None] * len(self.block)
++ past_values = [None] * len(self.block)
++ past_cross_keys = [None] * len(self.block)
++ past_cross_values = [None] * len(self.block)
++ if attention_mask is None:
++ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
++
++ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
++ # ourselves in which case we just need to make it broadcastable to all heads.
++ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
++
++ # If a 2D or 3D attention mask is provided for the cross-attention
++ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
++
++ encoder_extended_attention_mask = None
++
++ # Prepare head mask if needed
++ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
++ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
++ present_key_states = () if use_cache else None
++ present_value_states = () if use_cache else None
++ all_hidden_states = () if output_hidden_states else None
++ all_attentions = () if output_attentions else None
++ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
++ position_bias = None
++ encoder_decoder_position_bias = None
++
++ hidden_states = self.dropout(inputs_embeds)
++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)):
++ layer_head_mask = head_mask[i]
++ cross_attn_layer_head_mask = cross_attn_head_mask[i]
++ if output_hidden_states:
++ all_hidden_states = all_hidden_states + (hidden_states,)
++
++ layer_outputs = layer_module(
++ hidden_states,
++ attention_mask=extended_attention_mask,
++ position_bias=position_bias,
++ encoder_hidden_states=None,
++ encoder_attention_mask=encoder_extended_attention_mask,
++ encoder_decoder_position_bias=encoder_decoder_position_bias,
++ layer_head_mask=layer_head_mask,
++ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
++ past_key=past_key,
++ past_value=past_value,
++ past_cross_key=past_cross_key,
++ past_cross_value=past_cross_value,
++ use_cache=use_cache,
++ output_attentions=output_attentions,
++ )
++
++ # layer_outputs is a tuple with:
++ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
++ if use_cache is False:
++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:]
++
++ hidden_states, present_key_state, present_value_state = layer_outputs[:3]
++
++ # We share the position biases between the layers - the first layer store them
++ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
++ # (cross-attention position bias), (cross-attention weights)
++ position_bias = layer_outputs[3]
++ if self.is_decoder and encoder_hidden_states is not None:
++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4]
++ # append next layer key value states
++ if use_cache:
++ present_key_states = present_key_states + present_key_state
++ present_value_states = present_value_states + present_value_state
++
++ if output_attentions:
++ all_attentions = all_attentions + (layer_outputs[3],)
++ if self.is_decoder:
++ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
++
++ # Model Parallel: If it's the last layer for that device, put things on the next device
++ if self.model_parallel:
++ for k, v in self.device_map.items():
++ if i == v[-1] and "cuda:" + str(k) != self.last_device:
++ hidden_states = hidden_states.to("cuda:" + str(k + 1))
++
++ hidden_states = self.final_layer_norm(hidden_states)
++ hidden_states = self.dropout(hidden_states).half()
++
++ # Add last layer
++ if output_hidden_states:
++ all_hidden_states = all_hidden_states + (hidden_states,)
++
++ if not return_dict:
++ return tuple(
++ v
++ for v in [
++ hidden_states,
++ present_key_value_states,
++ all_hidden_states,
++ all_attentions,
++ all_cross_attentions,
++ ]
++ if v is not None
++ )
++ # present_key_value_states = torch.concat(present_key_value_states).reshape(len(self.block),2,*present_key_value_states[0].shape).half() if use_cache else None
++ if not self.is_decoder:
++ cross_keys = None
++ cross_values = None
++ if self.encodecrosskey:
++ cross_keys = self.encodecrosskey(hidden_states)
++ if self.encodecrossvalue:
++ cross_values = self.encodecrossvalue(hidden_states)
++ return tuple((hidden_states, cross_keys, cross_values))
+
+
+ T5_START_DOCSTRING = r"""
+@@ -1541,6 +1945,38 @@ class T5Model(T5PreTrainedModel):
+ )
+
+
++class EncoderToCrossKey(nn.Module):
++ def __init__(self, cross_key, num_heads, d_kv):
++ super().__init__()
++ self.cross_key = cross_key
++ self.num_heads = num_heads
++ self.d_kv = d_kv
++
++
++ def forward(self, hidden_states):
++ batch_size = hidden_states.shape[0]
++ past_cross_keys = ()
++ for i in range(len(self.cross_key)):
++ past_cross_keys += (self.cross_key[i](hidden_states),)
++ return past_cross_keys
++
++
++class EncoderToCrossValue(nn.Module):
++ def __init__(self, cross_value, num_heads, d_kv):
++ super().__init__()
++ self.cross_value = cross_value
++ self.num_heads = num_heads
++ self.d_kv = d_kv
++
++
++ def forward(self, hidden_states):
++ batch_size = hidden_states.shape[0]
++ past_cross_values = ()
++ for i in range(len(self.cross_value)):
++ past_cross_values += (self.cross_value[i](hidden_states),)
++ return past_cross_values
++
++
+ @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
+ class T5ForConditionalGeneration(T5PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [
+@@ -1548,28 +1984,51 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ ]
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+- def __init__(self, config: T5Config):
++ def __init__(self, config: T5Config, encoder_path=None, decoder_path=None, device_id=0):
+ super().__init__(config)
+- self.model_dim = config.d_model
+-
+- self.shared = nn.Embedding(config.vocab_size, config.d_model)
+-
+- encoder_config = copy.deepcopy(config)
+- encoder_config.is_decoder = False
+- encoder_config.use_cache = False
+- encoder_config.is_encoder_decoder = False
+- self.encoder = T5Stack(encoder_config, self.shared)
+-
+- decoder_config = copy.deepcopy(config)
+- decoder_config.is_decoder = True
+- decoder_config.is_encoder_decoder = False
+- decoder_config.num_layers = config.num_decoder_layers
+- self.decoder = T5Stack(decoder_config, self.shared)
+-
+- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
++ self.encoder_path = encoder_path
++ self.decoder_path = decoder_path
++ self.is_mindie = False
++ if not self.encoder_path or not self.decoder_path:
++ self.model_dim = config.d_model
++
++ self.shared = nn.Embedding(config.vocab_size, config.d_model)
++
++ decoder_config = copy.deepcopy(config)
++ decoder_config.is_decoder = True
++ decoder_config.is_encoder_decoder = False
++ decoder_config.num_layers = config.num_decoder_layers
++
++ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
++ self.decoder = T5Stack(decoder_config, self.shared, self.lm_head)
++
++ cross_key = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.k for i in range(config.num_decoder_layers))
++ cross_value = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.v for i in range(config.num_decoder_layers))
++ encodecrosskey = EncoderToCrossKey(cross_key, config.num_heads, config.d_kv)
++ encodecrossvalue = EncoderToCrossValue(cross_value, config.num_heads, config.d_kv)
++
++ encoder_config = copy.deepcopy(config)
++ encoder_config.is_decoder = False
++ encoder_config.use_cache = False
++ encoder_config.is_encoder_decoder = False
++ self.encoder = T5Stack_Encoder(encoder_config, self.shared, encodecrosskey=encodecrosskey, encodecrossvalue=encodecrossvalue)
++ self.encoder_mindie = None
++ self.decoder_mindie = None
++ if self.encoder_path:
++ self.encoder_mindie = torch.jit.load(self.encoder_path)
++ self.is_mindie = True
++ if self.decoder_path:
++ self.decoder_mindie = torch.jit.load(self.decoder_path)
++
++ self.stream = torch.npu.Stream(f"npu:{device_id}")
++ self.device_id = device_id
++
++
++ def get_device(self):
++ return f"npu:{self.device_id}"
+
+ # Initialize weights and apply final processing
+- self.post_init()
++ # self.post_init()
+
+ # Model parallel
+ self.model_parallel = False
+@@ -1637,25 +2096,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+- def forward(
+- self,
+- input_ids: Optional[torch.LongTensor] = None,
+- attention_mask: Optional[torch.FloatTensor] = None,
+- decoder_input_ids: Optional[torch.LongTensor] = None,
+- decoder_attention_mask: Optional[torch.BoolTensor] = None,
+- head_mask: Optional[torch.FloatTensor] = None,
+- decoder_head_mask: Optional[torch.FloatTensor] = None,
+- cross_attn_head_mask: Optional[torch.Tensor] = None,
+- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+- inputs_embeds: Optional[torch.FloatTensor] = None,
+- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+- labels: Optional[torch.LongTensor] = None,
+- use_cache: Optional[bool] = None,
+- output_attentions: Optional[bool] = None,
+- output_hidden_states: Optional[bool] = None,
+- return_dict: Optional[bool] = None,
+- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
++ def forward(self,*args) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
+@@ -1687,113 +2128,37 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
+ >>> # studies have shown that owning a dog is good for you.
+ ```"""
+- use_cache = use_cache if use_cache is not None else self.config.use_cache
+- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+-
+- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+- if head_mask is not None and decoder_head_mask is None:
+- if self.config.num_layers == self.config.num_decoder_layers:
+- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+- decoder_head_mask = head_mask
+-
+- # Encode if needed (training, first prediction pass)
+- if encoder_outputs is None:
+- # Convert encoder inputs in embeddings if needed
+- encoder_outputs = self.encoder(
+- input_ids=input_ids,
+- attention_mask=attention_mask,
+- inputs_embeds=inputs_embeds,
+- head_mask=head_mask,
+- output_attentions=output_attentions,
+- output_hidden_states=output_hidden_states,
+- return_dict=return_dict,
+- )
+- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+- encoder_outputs = BaseModelOutput(
+- last_hidden_state=encoder_outputs[0],
+- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+- )
+-
+- hidden_states = encoder_outputs[0]
+-
+- if self.model_parallel:
+- torch.cuda.set_device(self.decoder.first_device)
+-
+- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+- # get decoder inputs from shifting lm labels to the right
+- decoder_input_ids = self._shift_right(labels)
+-
+- # Set device for model parallelism
+- if self.model_parallel:
+- torch.cuda.set_device(self.decoder.first_device)
+- hidden_states = hidden_states.to(self.decoder.first_device)
+- if decoder_input_ids is not None:
+- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
+- if attention_mask is not None:
+- attention_mask = attention_mask.to(self.decoder.first_device)
+- if decoder_attention_mask is not None:
+- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
+-
+- # Decode
+- decoder_outputs = self.decoder(
+- input_ids=decoder_input_ids,
+- attention_mask=decoder_attention_mask,
+- inputs_embeds=decoder_inputs_embeds,
+- past_key_values=past_key_values,
+- encoder_hidden_states=hidden_states,
+- encoder_attention_mask=attention_mask,
+- head_mask=decoder_head_mask,
+- cross_attn_head_mask=cross_attn_head_mask,
+- use_cache=use_cache,
+- output_attentions=output_attentions,
+- output_hidden_states=output_hidden_states,
+- return_dict=return_dict,
+- )
+-
+- sequence_output = decoder_outputs[0]
+-
+- # Set device for model parallelism
+- if self.model_parallel:
+- torch.cuda.set_device(self.encoder.first_device)
+- self.lm_head = self.lm_head.to(self.encoder.first_device)
+- sequence_output = sequence_output.to(self.lm_head.weight.device)
+-
+- if self.config.tie_word_embeddings:
+- # Rescale output before projecting on vocab
+- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+- sequence_output = sequence_output * (self.model_dim**-0.5)
+-
+- lm_logits = self.lm_head(sequence_output)
++ if self.is_mindie:
++ with torch.npu.stream(self.stream): # set stream
++ decoder_outputs = self.decoder_mindie.forward(*args)
++ self.stream.synchronize() # synchronize
++ else:
++ hidden_states = args[0]
++ past_cross_keys = args[1:self.config.num_decoder_layers+1]
++ past_cross_values = args[self.config.num_decoder_layers+1:2*self.config.num_decoder_layers+1]
++ past_keys= args[2*self.config.num_decoder_layers+1:3*self.config.num_decoder_layers+1]
++ past_values= args[3*self.config.num_decoder_layers+1:4*self.config.num_decoder_layers+1]
++ encoder_attention_mask = args[-2]
++ decoder_input_ids = args[-1]
++ decoder_outputs = self.decoder(input_ids=decoder_input_ids,
++ encoder_hidden_states=hidden_states,
++ past_keys=past_keys,
++ past_values=past_values,
++ past_cross_keys=past_cross_keys,
++ past_cross_values=past_cross_values,
++ encoder_attention_mask=encoder_attention_mask)
++
+
+ loss = None
+- if labels is not None:
+- loss_fct = CrossEntropyLoss(ignore_index=-100)
+- # move labels to correct device to enable PP
+- labels = labels.to(lm_logits.device)
+- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
+- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
+-
+- if not return_dict:
+- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+- return ((loss,) + output) if loss is not None else output
+-
+- return Seq2SeqLMOutput(
+- loss=loss,
+- logits=lm_logits,
+- past_key_values=decoder_outputs.past_key_values,
+- decoder_hidden_states=decoder_outputs.hidden_states,
+- decoder_attentions=decoder_outputs.attentions,
+- cross_attentions=decoder_outputs.cross_attentions,
+- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+- encoder_hidden_states=encoder_outputs.hidden_states,
+- encoder_attentions=encoder_outputs.attentions,
+- )
++ return (decoder_outputs[0],decoder_outputs[1],decoder_outputs[2])
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+- past_key_values=None,
++ past_cross_keys=None,
++ past_cross_values=None,
++ past_keys=None,
++ past_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+@@ -1804,8 +2169,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+- if past_key_values is not None:
+- past_length = past_key_values[0][0].shape[2]
++ if past_keys is not None:
++ past_length = past_keys[0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+@@ -1813,12 +2178,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+-
+ input_ids = input_ids[:, remove_prefix_length:]
+
+ return {
+ "decoder_input_ids": input_ids,
+- "past_key_values": past_key_values,
++ "past_cross_keys":past_cross_keys,
++ "past_cross_values":past_cross_values,
++ "past_keys":past_keys,
++ "past_values":past_values,
+ "encoder_outputs": encoder_outputs,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+@@ -1826,6 +2193,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ "decoder_attention_mask": decoder_attention_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
++
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+@@ -1861,6 +2229,459 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
+ return reordered_decoder_past
+
++ def _prepare_encoder_decoder_kwargs_for_generation(
++ self,
++ inputs_tensor: torch.Tensor,
++ model_kwargs,
++ model_input_name,
++ generation_config,
++ ):
++ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
++ encoder_kwargs = {
++ argument: value
++ for argument, value in model_kwargs.items()
++ if not any(argument.startswith(p) for p in irrelevant_prefix)
++ }
++ encoder_kwargs["output_attentions"] = generation_config.output_attentions
++ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
++ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
++ encoder_kwargs["return_dict"] = True
++ encoder_kwargs[model_input_name] = inputs_tensor
++ if self.is_mindie:
++ with torch.npu.stream(self.stream): # set stream
++ encoder_outputs=self.encoder_mindie.forward(encoder_kwargs["input_ids"],encoder_kwargs["attention_mask"])
++ self.stream.synchronize() # synchronize
++ else:
++ encoder_outputs=self.encoder.forward(**encoder_kwargs)
++ model_kwargs["encoder_outputs"]={"last_hidden_state":encoder_outputs[0]}
++ model_kwargs["past_cross_keys"] = encoder_outputs[1]
++ model_kwargs["past_cross_values"] =encoder_outputs[2]
++ return model_kwargs
++
++ def _update_model_kwargs_for_generation(
++ self,
++ outputs,
++ model_kwargs,
++ is_encoder_decoder = False,
++ standardize_cache_format = False,
++ num_new_tokens = 1,
++ ):
++ # update past_key_values keeping its naming used in model code
++ cache_name, cache = self._extract_past_from_model_output(
++ outputs, standardize_cache_format=standardize_cache_format
++ )
++ model_kwargs[cache_name] = cache
++ if "past_keys" in outputs:
++ past_keys = outputs.past_keys
++ model_kwargs["past_keys"] = past_keys
++ if "past_values" in outputs:
++ past_values = outputs.past_values
++ model_kwargs["past_values"] = past_values
++ # update decoder attention mask
++ if "decoder_attention_mask" in model_kwargs:
++ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
++ model_kwargs["decoder_attention_mask"] = torch.cat(
++ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
++ dim=-1,
++ )
++ return model_kwargs
++
++ @torch.no_grad()
++ def generate(
++ self,
++ inputs = None,
++ generation_config = None,
++ logits_processor = None,
++ stopping_criteria = None,
++ prefix_allowed_tokens_fn = None,
++ assistant_model = None,
++ negative_prompt_ids = None,
++ negative_prompt_attention_mask = None,
++ **kwargs,
++ ):
++ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
++ self._validate_model_class()
++ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
++ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
++ self._validate_model_kwargs(model_kwargs.copy())
++
++
++ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
++ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
++
++ accepts_attention_mask = True
++ requires_attention_mask = "encoder_outputs" not in model_kwargs
++ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
++
++ # 3. Define model inputs
++ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
++ inputs, generation_config.bos_token_id, model_kwargs
++ )
++ batch_size = inputs_tensor.shape[0]
++
++ device = inputs_tensor.device
++ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
++
++ # 4. Define other model kwargs
++ # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
++ # generating the first new token or not, and we only want to use the embeddings for the first new token)
++ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
++ model_kwargs["use_cache"] = True
++ else:
++ model_kwargs["use_cache"] = generation_config.use_cache
++ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
++ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
++ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
++ )
++
++ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
++ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
++ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
++ inputs_tensor, model_kwargs, model_input_name, generation_config
++ )
++
++ # 5. Prepare `input_ids` which will be used for auto-regressive generation
++ if self.config.is_encoder_decoder:
++ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
++ batch_size=batch_size,
++ model_input_name=model_input_name,
++ model_kwargs=model_kwargs,
++ decoder_start_token_id=generation_config.decoder_start_token_id,
++ device=inputs_tensor.device,
++ )
++ else:
++ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
++
++ if generation_config.token_healing:
++ input_ids = self.heal_tokens(input_ids, tokenizer)
++
++ # 6. Prepare `max_length` depending on other stopping criteria.
++ input_ids_length = input_ids.shape[-1]
++ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
++ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
++ generation_config = self._prepare_generated_length(
++ generation_config=generation_config,
++ has_default_max_length=has_default_max_length,
++ has_default_min_length=has_default_min_length,
++ model_input_name=model_input_name,
++ inputs_tensor=inputs_tensor,
++ input_ids_length=input_ids_length,
++ )
++
++ use_dynamic_cache_by_default = False
++ if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
++ raise ValueError(
++ "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
++ "Cache object) is unsupported. Please use only one of the two."
++ )
++ elif generation_config.cache_implementation is not None:
++ if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
++ if generation_config.cache_implementation == "static" and not self._supports_static_cache:
++ raise ValueError(
++ "This model does not support `cache_implementation='static'`. Please check the following "
++ "issue: https://github.com/huggingface/transformers/issues/28981"
++ )
++ model_kwargs["past_key_values"] = self._get_cache(
++ generation_config.cache_implementation,
++ getattr(generation_config, "num_beams", 1) * batch_size,
++ generation_config.max_length,
++ )
++ elif generation_config.cache_implementation == "quantized":
++ if not self._supports_quantized_cache:
++ raise ValueError(
++ "This model does not support the quantized cache. If you want your model to support quantized "
++ "cache, please open an issue."
++ )
++
++ cache_config = (
++ generation_config.cache_config
++ if generation_config.cache_config is not None
++ else QuantizedCacheConfig()
++ )
++ cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
++
++ if cache_config.backend == "quanto" and not is_quanto_available():
++ raise ImportError(
++ "You need to install `quanto` in order to use KV cache quantization with quanto backend. "
++ "Please install it via with `pip install quanto`"
++ )
++ elif cache_config.backend == "HQQ" and not is_hqq_available():
++ raise ImportError(
++ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
++ "Please install it via with `pip install hqq`"
++ )
++
++ model_kwargs["past_key_values"] = cache_class(cache_config)
++ # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
++ # keeps copying the cache thus using much more memory
++ elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
++ past = model_kwargs.get("past_key_values", None)
++ if past is None:
++ model_kwargs["past_key_values"] = DynamicCache()
++ use_dynamic_cache_by_default = True
++ elif isinstance(past, tuple):
++ model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
++ use_dynamic_cache_by_default = True
++
++ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
++
++ # 7. determine generation mode
++ generation_mode = generation_config.get_generation_mode(assistant_model)
++ # 8. prepare distribution pre_processing samplers
++ prepared_logits_processor = self._get_logits_processor(
++ generation_config=generation_config,
++ input_ids_seq_length=input_ids_length,
++ encoder_input_ids=inputs_tensor,
++ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
++ logits_processor=logits_processor,
++ device=inputs_tensor.device,
++ model_kwargs=model_kwargs,
++ negative_prompt_ids=negative_prompt_ids,
++ negative_prompt_attention_mask=negative_prompt_attention_mask,
++ )
++
++ # 9. prepare stopping criteria
++ prepared_stopping_criteria = self._get_stopping_criteria(
++ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
++ )
++
++ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
++ # 11. prepare logits warper
++ prepared_logits_warper = (
++ self._get_logits_warper(generation_config, device=input_ids.device)
++ if generation_config.do_sample
++ else None
++ )
++
++ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
++ input_ids, model_kwargs = self._expand_inputs_for_generation(
++ input_ids=input_ids,
++ expand_size=generation_config.num_return_sequences,
++ is_encoder_decoder=self.config.is_encoder_decoder,
++ **model_kwargs,
++ )
++ # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
++ result = self._sample(
++ input_ids,
++ logits_processor=prepared_logits_processor,
++ logits_warper=prepared_logits_warper,
++ stopping_criteria=prepared_stopping_criteria,
++ generation_config=generation_config,
++ **model_kwargs,
++ )
++ return result
++
++ def _sample(
++ self,
++ input_ids,
++ logits_processor,
++ stopping_criteria,
++ generation_config,
++ logits_warper = None,
++ **model_kwargs,
++ ):
++ # init values
++ pad_token_id = generation_config.pad_token_id
++ output_attentions = generation_config.output_attentions
++ output_hidden_states = generation_config.output_hidden_states
++ output_scores = generation_config.output_scores
++ output_logits = generation_config.output_logits
++ return_dict_in_generate = generation_config.return_dict_in_generate
++ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
++ do_sample = generation_config.do_sample
++ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
++ raise ValueError(
++ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
++ f"{logits_warper})."
++ )
++
++ # init attention / hidden states / scores tuples
++ scores = () if (return_dict_in_generate and output_scores) else None
++ raw_logits = () if (return_dict_in_generate and output_logits) else None
++ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
++ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
++ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
++
++ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
++ if return_dict_in_generate and self.config.is_encoder_decoder:
++ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
++ encoder_hidden_states = (
++ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
++ )
++
++ this_peer_finished = False
++ batch_size = input_ids.shape[0]
++ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
++ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
++
++ # keep track of which sequences are already finished
++ if self.is_mindie or self.config.architectures[0]=="T5ForConditionalGeneration":
++ num_layers = self.config.num_layers
++ num_heads = self.config.num_heads
++ d_kv = self.config.d_kv
++ model_kwargs["past_keys"] = [torch.randn(batch_size, num_heads, 0, d_kv).half().npu() for _ in range(num_layers)]
++ model_kwargs["past_values"] = [torch.randn(batch_size, num_heads, 0, d_kv).half().npu() for _ in range(num_layers)]
++
++
++ while self._has_unfinished_sequences(this_peer_finished, False, device=input_ids.device):
++ # prepare model inputs
++ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
++ model_args = [model_kwargs["encoder_outputs"]["last_hidden_state"]]
++ model_args.extend(model_kwargs["past_cross_keys"])
++ model_args.extend(model_kwargs["past_cross_values"])
++ model_args.extend(model_inputs["past_keys"])
++ model_args.extend(model_inputs["past_values"])
++ model_args.append(model_inputs["attention_mask"])
++ model_args.append(model_inputs["decoder_input_ids"])
++
++ # forward pass to get next token
++ outputs = self(*model_args)
++ outputs = Seq2SeqLMOutput(logits=outputs[0],
++ past_keys=outputs[1],
++ past_values=outputs[2])
++
++ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
++ # (the clone itself is always small)
++ next_token_logits = outputs.logits[:, -1, :].clone()
++
++ # pre-process distribution
++ next_token_scores = logits_processor(input_ids, next_token_logits)
++ if do_sample:
++ next_token_scores = logits_warper(input_ids, next_token_scores)
++
++ # Store scores, attentions and hidden_states when required
++ if return_dict_in_generate:
++ if output_scores:
++ scores += (next_token_scores,)
++ if output_logits:
++ raw_logits += (next_token_logits,)
++ if output_attentions:
++ decoder_attentions += (
++ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
++ )
++ if self.config.is_encoder_decoder:
++ cross_attentions += (outputs.cross_attentions,)
++
++ if output_hidden_states:
++ decoder_hidden_states += (
++ (outputs.decoder_hidden_states,)
++ if self.config.is_encoder_decoder
++ else (outputs.hidden_states,)
++ )
++
++ # token selection
++ if do_sample:
++ probs = nn.functional.softmax(next_token_scores, dim=-1)
++ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
++ else:
++ next_tokens = torch.argmax(next_token_scores, dim=-1)
++
++ # finished sentences should have their next token be a padding token
++ if has_eos_stopping_criteria:
++ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
++
++ # update generated ids, model inputs, and length for next step
++ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
++ model_kwargs = self._update_model_kwargs_for_generation(
++ outputs,
++ model_kwargs,
++ is_encoder_decoder=self.config.is_encoder_decoder,
++ )
++ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
++ this_peer_finished = unfinished_sequences.max() == 0
++ # This is needed to properly delete outputs.logits which may be very large for first iteration
++ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
++ del outputs
++ return input_ids
++
++ def invert_attention_mask(self, encoder_attention_mask):
++ """
++ Invert an attention mask (e.g., switches 0. and 1.).
++
++ Args:
++ encoder_attention_mask (`torch.Tensor`): An attention mask.
++
++ Returns:
++ `torch.Tensor`: The inverted attention mask.
++ """
++ if encoder_attention_mask.dim() == 3:
++ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
++ if encoder_attention_mask.dim() == 2:
++ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
++ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
++ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
++ # /transformer/transformer_layers.py#L270
++ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
++ # encoder_extended_attention_mask.transpose(-1, -2))
++ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
++ #encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
++ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1000
++
++ return encoder_extended_attention_mask
++
++ @property
++ def device(self) -> torch.device:
++ """
++ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
++ device).
++ """
++ return self.get_device()
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, devic=None, dtype=None
++ ):
++ """
++ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
++
++ Arguments:
++ attention_mask (`torch.Tensor`):
++ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
++ input_shape (`Tuple[int]`):
++ The shape of the input to the model.
++
++ Returns:
++ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
++ """
++ if dtype is None:
++ dtype = self.dtype
++
++ if not (attention_mask.dim() == 2 and self.config.is_decoder):
++ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
++ if device is not None:
++ warnings.warn(
++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
++ )
++ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
++ # ourselves in which case we just need to make it broadcastable to all heads.
++ if attention_mask.dim() == 3:
++ extended_attention_mask = attention_mask[:, None, :, :]
++ elif attention_mask.dim() == 2:
++ # Provided a padding mask of dimensions [batch_size, seq_length]
++ # - if the model is a decoder, apply a causal mask in addition to the padding mask
++ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
++ if self.config.is_decoder:
++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
++ input_shape, attention_mask, device
++ )
++ else:
++ extended_attention_mask = attention_mask[:, None, None, :]
++ else:
++ raise ValueError(
++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
++ )
++
++ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
++ # masked positions, this operation will create a tensor which is 0.0 for
++ # positions we want to attend and the dtype's smallest value for masked positions.
++ # Since we are adding it to the raw scores before the softmax, this is
++ # effectively the same as removing these entirely.
++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
++ #extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000
++ return extended_attention_mask
++
++
++
+
+ @add_start_docstrings(
+ "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
+@@ -1967,7 +2788,6 @@ class T5EncoderModel(T5PreTrainedModel):
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+-
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
diff --git a/MindIE/MindIE-Torch/built-in/T5/modeling_t5_800IA2.patch b/MindIE/MindIE-Torch/built-in/T5/modeling_t5_800IA2.patch
new file mode 100644
index 0000000000000000000000000000000000000000..664b4359cea4c27a50785ab38844644e14a0d3a1
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/T5/modeling_t5_800IA2.patch
@@ -0,0 +1,1594 @@
+diff --git a/modeling_t5_origin.py b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py
+index 224769fdf..4f9ffd74f 100644
+--- a/modeling_t5_origin.py
++++ b/usr/local/python3.10.2/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py
+@@ -19,22 +19,26 @@ import math
+ import os
+ import warnings
+ from typing import List, Optional, Tuple, Union
+-
++from dataclasses import dataclass
+ import torch
+ from torch import nn
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
++# import torch_npu
++# import mindietorch
++
++
++
+
+ from ...activations import ACT2FN
+ from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPastAndCrossAttentions,
+- Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+ Seq2SeqQuestionAnsweringModelOutput,
+ Seq2SeqSequenceClassifierOutput,
+ TokenClassifierOutput,
+ )
+-from ...modeling_utils import PreTrainedModel
++from ...modeling_utils import PreTrainedModel,ModuleUtilsMixin
+ from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
+ from ...utils import (
+ DUMMY_INPUTS,
+@@ -47,7 +51,43 @@ from ...utils import (
+ )
+ from ...utils.model_parallel_utils import assert_device_map, get_device_map
+ from .configuration_t5 import T5Config
++from transformers.generation.logits_process import LogitsProcessorList
++from transformers.generation.stopping_criteria import StoppingCriteriaList
++from transformers.generation.configuration_utils import GenerationMode
++from transformers.utils.generic import ModelOutput
++
++
++@dataclass
++class Seq2SeqLMOutput(ModelOutput):
++ """
++ Base class for model's outputs, with potential hidden states and attentions.
+
++ Args:
++ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
++ Sequence of hidden-states at the output of the last layer of the model.
++ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
++ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
++ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
++
++ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
++ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
++ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
++ sequence_length)`.
++
++ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
++ heads.
++ """
++ loss: Optional[torch.FloatTensor] = None
++ logits: torch.FloatTensor = None
++ past_keys: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ past_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
++ decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
++ decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
++ cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
++ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
++ encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
++ encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+ logger = logging.get_logger(__name__)
+
+@@ -360,6 +400,7 @@ class T5Attention(nn.Module):
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
+ self.pruned_heads = set()
+ self.gradient_checkpointing = False
++ self.lay_out = "BSH"
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+@@ -448,7 +489,10 @@ class T5Attention(nn.Module):
+ mask=None,
+ key_value_states=None,
+ position_bias=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
++ past_cross_key=None,
++ past_cross_value=None,
+ layer_head_mask=None,
+ query_length=None,
+ use_cache=False,
+@@ -464,81 +508,86 @@ class T5Attention(nn.Module):
+
+ real_seq_length = seq_length
+
+- if past_key_value is not None:
+- if len(past_key_value) != 2:
+- raise ValueError(
+- f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
+- )
+- real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
++ if past_key is not None:
++ real_seq_length += past_key.shape[1] if query_length is None else query_length
+
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
++ # BSH
++ query_states = self.q(hidden_states)
++ key_states = past_key
++ value_states = past_value
++ attn_output = torch.ops.aie.flash_attention(query_states,key_states,value_states,self.n_heads,attn_mask=mask)
++ # mask = mask.expand(3,1,16,mask.shape[3]).bool()
++ # attn_output = torch_npu.npu_prompt_flash_attention(query_states,key_states,value_states,atten_mask=mask,num_heads=self.n_heads,input_layout="BSH")
++ attn_output = self.o(attn_output)
++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None
++ present_value_state = (value_states.half(),) if (self.is_decoder and use_cache) else None
++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,)
++ return outputs
++
++
++class T5SelfAttention(T5Attention):
++ def __init__(self, config: T5Config, has_relative_attention_bias=False):
++ super().__init__(config, has_relative_attention_bias)
+
+- def shape(states):
+- """projection"""
+- return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
++ def forward(
++ self,
++ hidden_states,
++ mask=None,
++ position_bias=None,
++ past_key=None,
++ past_value=None,
++ layer_head_mask=None,
++ use_cache=False,
++ output_attentions=False,
++ ):
++ """
++ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
++ """
++ # Input is (batch_size, seq_length, dim)
++ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
++ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
++ batch_size, seq_length = hidden_states.shape[:2]
+
+- def unshape(states):
+- """reshape"""
+- return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
++ real_seq_length = seq_length
++
++ if past_key is not None:
++ real_seq_length += past_key.shape[1]
++ key_length = real_seq_length
+
+- def project(hidden_states, proj_layer, key_value_states, past_key_value):
++ def project(hidden_states, proj_layer, past_key_value):
+ """projects hidden states correctly to key/query states"""
+- if key_value_states is None:
+- # self-attn
+- # (batch_size, n_heads, seq_length, dim_per_head)
+- hidden_states = shape(proj_layer(hidden_states))
+- elif past_key_value is None:
+- # cross-attn
+- # (batch_size, n_heads, seq_length, dim_per_head)
+- hidden_states = shape(proj_layer(key_value_states))
++ if past_key_value is None:
++ hidden_states = proj_layer(hidden_states)
+
+ if past_key_value is not None:
+- if key_value_states is None:
+- # self-attn
+- # (batch_size, n_heads, key_length, dim_per_head)
+- hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
+- elif past_key_value.shape[2] != key_value_states.shape[1]:
+- # checking that the `sequence_length` of the `past_key_value` is the same as
+- # the provided `key_value_states` to support prefix tuning
+- # cross-attn
+- # (batch_size, n_heads, seq_length, dim_per_head)
+- hidden_states = shape(proj_layer(key_value_states))
+- else:
+- # cross-attn
+- hidden_states = past_key_value
++ hidden_states = proj_layer(hidden_states)
++ hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
+ return hidden_states
+
+ # get query states
+- query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
+-
++ query_states = self.q(hidden_states)
+ # get key/value states
+ key_states = project(
+- hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
++ hidden_states, self.k, past_key if past_key is not None else None
+ )
+ value_states = project(
+- hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
++ hidden_states, self.v, past_value if past_value is not None else None
+ )
+-
+- # compute scores
+- scores = torch.matmul(
+- query_states, key_states.transpose(3, 2)
+- ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
+-
+ if position_bias is None:
+ if not self.has_relative_attention_bias:
+ position_bias = torch.zeros(
+- (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
++ (1, self.n_heads, real_seq_length, key_length), device=query_states.device, dtype=query_states.dtype
+ )
+ if self.gradient_checkpointing and self.training:
+ position_bias.requires_grad = True
+ else:
+- position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
++ position_bias = self.compute_bias(real_seq_length, key_length, device=query_states.device)
+
+ # if key and values are already calculated
+ # we want only the last query position bias
+- if past_key_value is not None:
++ if past_key is not None:
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
+-
+ if mask is not None:
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
+
+@@ -548,34 +597,26 @@ class T5Attention(nn.Module):
+ position_bias_masked = position_bias[:, mask.bool()]
+ else:
+ position_bias_masked = position_bias
+-
+- scores += position_bias_masked
+- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
+- scores
+- ) # (batch_size, n_heads, seq_length, key_length)
+- attn_weights = nn.functional.dropout(
+- attn_weights, p=self.dropout, training=self.training
+- ) # (batch_size, n_heads, seq_length, key_length)
+-
+- # Mask heads if we want to
+- if layer_head_mask is not None:
+- attn_weights = attn_weights * layer_head_mask
+-
+- attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
++ # scores += position_bias_masked
++ # attn_output = torch.ops.aie.flash_attention(query_states,key_states,value_states,self.n_heads,pse=position_bias_masked)
++ attn_output = torch.ops.aie.flash_attention(query_states,key_states,value_states,self.n_heads,pse=position_bias_masked,attn_mask=mask)
++ # print("mask=",mask,mask.shape)
++ # mask = mask.expand(3,1,16,mask.shape[3]).bool()
++ # attn_output = torch_npu.npu_prompt_flash_attention(query_states,key_states,value_states,pse_shift=position_bias_masked, atten_mask=mask,num_heads=self.n_heads,input_layout="BSH")
+ attn_output = self.o(attn_output)
++ # present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
++ present_key_state = (key_states.half(), ) if (self.is_decoder and use_cache) else None
++ present_value_state = (value_states.half(), ) if (self.is_decoder and use_cache) else None
++ outputs = (attn_output,) + (present_key_state,) + (present_value_state,) + (position_bias,)
++ return outputs
+
+- present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
+- outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
+
+- if output_attentions:
+- outputs = outputs + (attn_weights,)
+- return outputs
+
+
+ class T5LayerSelfAttention(nn.Module):
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+- self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
++ self.SelfAttention = T5SelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+@@ -585,7 +626,8 @@ class T5LayerSelfAttention(nn.Module):
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+@@ -595,7 +637,8 @@ class T5LayerSelfAttention(nn.Module):
+ mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+@@ -618,7 +661,8 @@ class T5LayerCrossAttention(nn.Module):
+ attention_mask=None,
+ position_bias=None,
+ layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
+ use_cache=False,
+ query_length=None,
+ output_attentions=False,
+@@ -630,7 +674,8 @@ class T5LayerCrossAttention(nn.Module):
+ key_value_states=key_value_states,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
+ use_cache=use_cache,
+ query_length=query_length,
+ output_attentions=output_attentions,
+@@ -661,39 +706,34 @@ class T5Block(nn.Module):
+ encoder_decoder_position_bias=None,
+ layer_head_mask=None,
+ cross_attn_layer_head_mask=None,
+- past_key_value=None,
++ past_key=None,
++ past_value=None,
++ past_cross_key=None,
++ past_cross_value=None,
+ use_cache=False,
+ output_attentions=False,
+ return_dict=True,
+ ):
+- if past_key_value is not None:
+- if not self.is_decoder:
+- logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
+- expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
+-
+- if len(past_key_value) != expected_num_past_key_values:
+- raise ValueError(
+- f"There should be {expected_num_past_key_values} past states. "
+- f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
+- f"Got {len(past_key_value)} past key / value states"
+- )
+-
+- self_attn_past_key_value = past_key_value[:2]
+- cross_attn_past_key_value = past_key_value[2:]
++ if past_key is not None:
++ self_attn_past_key = past_key
++ self_attn_past_value = past_value
++ cross_attn_past_key = past_cross_key
++ cross_attn_past_value = past_cross_value
+ else:
+- self_attn_past_key_value, cross_attn_past_key_value = None, None
++ self_attn_past_key, self_attn_past_value, cross_attn_past_key, cross_attn_past_value = None, None, None, None
+
+ self_attention_outputs = self.layer[0](
+ hidden_states,
+ attention_mask=attention_mask,
+ position_bias=position_bias,
+ layer_head_mask=layer_head_mask,
+- past_key_value=self_attn_past_key_value,
++ past_key=self_attn_past_key,
++ past_value=self_attn_past_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+- hidden_states, present_key_value_state = self_attention_outputs[:2]
+- attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
++ hidden_states, present_key_state, present_value_state = self_attention_outputs[:3]
++ attention_outputs = self_attention_outputs[3:] # Keep self-attention outputs and relative position weights
+
+ # clamp inf values to enable fp16 training
+ if hidden_states.dtype == torch.float16:
+@@ -706,22 +746,23 @@ class T5Block(nn.Module):
+
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
+ if do_cross_attention:
++
+ # the actual query length is unknown for cross attention
+ # if using past key value states. Need to inject it here
+- if present_key_value_state is not None:
+- query_length = present_key_value_state[0].shape[2]
++ if present_key_state is not None:
++ query_length = present_key_state[0].shape[1]
+ else:
+ query_length = None
+-
+ cross_attention_outputs = self.layer[1](
+ hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ position_bias=encoder_decoder_position_bias,
+ layer_head_mask=cross_attn_layer_head_mask,
+- past_key_value=cross_attn_past_key_value,
++ past_key=cross_attn_past_key,
++ past_value=cross_attn_past_value,
+ query_length=query_length,
+- use_cache=use_cache,
++ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = cross_attention_outputs[0]
+@@ -736,11 +777,9 @@ class T5Block(nn.Module):
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ # Combine self attn and cross attn key value states
+- if present_key_value_state is not None:
+- present_key_value_state = present_key_value_state + cross_attention_outputs[1]
+-
++ # cross_attn_past_key_values = cross_attention_outputs[1]
+ # Keep cross-attention outputs and relative position weights
+- attention_outputs = attention_outputs + cross_attention_outputs[2:]
++ attention_outputs = attention_outputs + cross_attention_outputs[3:]
+
+ # Apply Feed Forward layer
+ hidden_states = self.layer[-1](hidden_states)
+@@ -757,7 +796,7 @@ class T5Block(nn.Module):
+ outputs = (hidden_states,)
+
+ if use_cache:
+- outputs = outputs + (present_key_value_state,) + attention_outputs
++ outputs = outputs + (present_key_state,) +(present_value_state,)+ attention_outputs
+ else:
+ outputs = outputs + attention_outputs
+
+@@ -897,11 +936,15 @@ class T5PreTrainedModel(PreTrainedModel):
+
+
+ class T5Stack(T5PreTrainedModel):
+- def __init__(self, config, embed_tokens=None):
++ def __init__(self, config, embed_tokens=None,lm_head=None, encodecrosskey=None, encodecrossvalue=None):
+ super().__init__(config)
+
+ self.embed_tokens = embed_tokens
+ self.is_decoder = config.is_decoder
++ self.lm_head=lm_head
++ self.encodecrosskey = encodecrosskey
++ self.encodecrossvalue = encodecrossvalue
++ self.model_dim = config.d_model
+
+ self.block = nn.ModuleList(
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
+@@ -966,16 +1009,48 @@ class T5Stack(T5PreTrainedModel):
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens = new_embeddings
+
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, device=None, dtype=None
++ ):
++ if dtype is None:
++ dtype = self.dtype
++
++ if not (attention_mask.dim() == 2 and self.config.is_decoder):
++ if device is not None:
++ warnings.warn(
++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
++ )
++ if attention_mask.dim() == 3:
++ extended_attention_mask = attention_mask[:, None, :, :]
++ elif attention_mask.dim() == 2:
++ if self.config.is_decoder:
++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
++ input_shape, attention_mask, device
++ )
++ else:
++ extended_attention_mask = attention_mask[:, None, None, :]
++ else:
++ raise ValueError(
++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
++ )
++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000
++ return extended_attention_mask
++
+ def forward(
+ self,
+ input_ids=None,
+- attention_mask=None,
+ encoder_hidden_states=None,
++ past_keys=None,
++ past_values=None,
++ past_cross_keys=None,
++ past_cross_values=None,
+ encoder_attention_mask=None,
++ attention_mask=None,
+ inputs_embeds=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+- past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+@@ -998,8 +1073,10 @@ class T5Stack(T5PreTrainedModel):
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
++
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
++ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+@@ -1012,25 +1089,29 @@ class T5Stack(T5PreTrainedModel):
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ batch_size, seq_length = input_shape
+-
+ # required mask seq length can be calculated via length of past
+- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
++ mask_seq_length = past_keys[0].shape[1] + seq_length if past_keys is not None else seq_length
+
+ if use_cache is True:
+ if not self.is_decoder:
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
+
+ # initialize past_key_values with `None` if past does not exist
+- if past_key_values is None:
+- past_key_values = [None] * len(self.block)
+-
++ if not self.is_decoder:
++ past_keys = [None] * len(self.block)
++ past_values = [None] * len(self.block)
++ past_cross_keys = [None] * len(self.block)
++ past_cross_values = [None] * len(self.block)
+ if attention_mask is None:
+- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
++ print("aaaaaaaaaaaaaaaaa")
++ attention_mask = torch.zeros(batch_size, mask_seq_length, device=inputs_embeds.device)
++ attention_mask = attention_mask[:,None,None,:].expand(batch_size,1,mask_seq_length,mask_seq_length).bool()
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
+-
++ # extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
++ extended_attention_mask = attention_mask
++ # print("extended_attention_mask=",extended_attention_mask)
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.is_decoder and encoder_hidden_states is not None:
+@@ -1040,7 +1121,7 @@ class T5Stack(T5PreTrainedModel):
+ encoder_attention_mask = torch.ones(
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
+ )
+- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
++ encoder_extended_attention_mask = encoder_attention_mask
+ else:
+ encoder_extended_attention_mask = None
+
+@@ -1054,7 +1135,8 @@ class T5Stack(T5PreTrainedModel):
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
+- present_key_value_states = () if use_cache else None
++ present_key_states = () if use_cache else None
++ present_value_states = () if use_cache else None
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
+@@ -1062,8 +1144,8 @@ class T5Stack(T5PreTrainedModel):
+ encoder_decoder_position_bias = None
+
+ hidden_states = self.dropout(inputs_embeds)
+-
+- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
++ # for i, layer_module in enumerate(self.block):
++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)):
+ layer_head_mask = head_mask[i]
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
+ # Model parallel
+@@ -1112,7 +1194,10 @@ class T5Stack(T5PreTrainedModel):
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
+ layer_head_mask=layer_head_mask,
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
+- past_key_value=past_key_value,
++ past_key=past_key,
++ past_value=past_value,
++ past_cross_key=past_cross_key,
++ past_cross_value=past_cross_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+@@ -1120,19 +1205,20 @@ class T5Stack(T5PreTrainedModel):
+ # layer_outputs is a tuple with:
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
+ if use_cache is False:
+- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:]
+
+- hidden_states, present_key_value_state = layer_outputs[:2]
++ hidden_states, present_key_state, present_value_state = layer_outputs[:3]
+
+ # We share the position biases between the layers - the first layer store them
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
+ # (cross-attention position bias), (cross-attention weights)
+- position_bias = layer_outputs[2]
++ position_bias = layer_outputs[3]
+ if self.is_decoder and encoder_hidden_states is not None:
+- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
++ encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 4]
+ # append next layer key value states
+ if use_cache:
+- present_key_value_states = present_key_value_states + (present_key_value_state,)
++ present_key_states = present_key_states + present_key_state
++ present_value_states = present_value_states + present_value_state
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[3],)
+@@ -1146,31 +1232,158 @@ class T5Stack(T5PreTrainedModel):
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.final_layer_norm(hidden_states)
+- hidden_states = self.dropout(hidden_states)
++ hidden_states = self.dropout(hidden_states).half()
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
++ if self.config.tie_word_embeddings:
++ hidden_states = hidden_states * (self.model_dim ** -0.5)
++ lm_logits = self.lm_head(hidden_states)
++ return tuple((lm_logits, present_key_states, present_value_states))
+
+- if not return_dict:
+- return tuple(
+- v
+- for v in [
+- hidden_states,
+- present_key_value_states,
+- all_hidden_states,
+- all_attentions,
+- all_cross_attentions,
+- ]
+- if v is not None
+- )
+- return BaseModelOutputWithPastAndCrossAttentions(
+- last_hidden_state=hidden_states,
+- past_key_values=present_key_value_states,
+- hidden_states=all_hidden_states,
+- attentions=all_attentions,
+- cross_attentions=all_cross_attentions,
++
++class T5Stack_Encoder(T5PreTrainedModel):
++ def __init__(self, config, embed_tokens=None, encodecrosskey=None, encodecrossvalue=None):
++ super().__init__(config)
++ self.embed_tokens = embed_tokens
++ self.is_decoder = config.is_decoder
++ self.encodecrosskey = encodecrosskey
++ self.encodecrossvalue = encodecrossvalue
++ self.model_dim = config.d_model
++
++ self.block = nn.ModuleList(
++ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
++ )
++ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
++ self.dropout = nn.Dropout(config.dropout_rate)
++
++ # Initialize weights and apply final processing
++ self.post_init()
++ # Model parallel
++ self.model_parallel = False
++ self.device_map = None
++ self.gradient_checkpointing = False
++
++ def get_input_embeddings(self):
++ return self.embed_tokens
++
++ def set_input_embeddings(self, new_embeddings):
++ self.embed_tokens = new_embeddings
++
++
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, device=None, dtype=None
++ ):
++ extended_attention_mask = attention_mask[:,None,None,:].expand(input_shape[0],1,input_shape[1],input_shape[1]).bool()
++ extended_attention_mask = ~extended_attention_mask
++ return extended_attention_mask
++
++ def forward(
++ self,
++ input_ids=None,
++ attention_mask=None,
++ head_mask=None,
++ cross_attn_head_mask=None,
++ use_cache=None,
++ output_attentions=None,
++ output_hidden_states=None,
++ return_dict=None,
++ ):
++ # Model parallel
++ use_cache = use_cache if use_cache is not None else self.config.use_cache
++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
++ output_hidden_states = (
++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
++
++ input_shape = input_ids.size()
++ input_ids = input_ids.view(-1, input_shape[-1])
++
++ inputs_embeds = self.embed_tokens(input_ids)
++
++ batch_size, seq_length = input_shape
++ # required mask seq length can be calculated via length of past
++ mask_seq_length = seq_length
++
++ # initialize past_key_values with `None` if past does not exist
++ past_keys = [None] * len(self.block)
++ past_values = [None] * len(self.block)
++ past_cross_keys = [None] * len(self.block)
++ past_cross_values = [None] * len(self.block)
++ # print("attention_mask=",attention_mask)
++ if attention_mask is None:
++ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
++ encoder_extended_attention_mask = None
++ # Prepare head mask if needed
++ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
++ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
++ present_key_states = () if use_cache else None
++ present_value_states = () if use_cache else None
++ all_hidden_states = () if output_hidden_states else None
++ all_attentions = () if output_attentions else None
++ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
++ position_bias = None
++ encoder_decoder_position_bias = None
++
++ hidden_states = self.dropout(inputs_embeds)
++ for i, (layer_module, past_key, past_value, past_cross_key, past_cross_value) in enumerate(zip(self.block, past_keys, past_values, past_cross_keys, past_cross_values)):
++ layer_head_mask = head_mask[i]
++ cross_attn_layer_head_mask = cross_attn_head_mask[i]
++ if output_hidden_states:
++ all_hidden_states = all_hidden_states + (hidden_states,)
++ layer_outputs = layer_module(
++ hidden_states,
++ attention_mask=attention_mask,
++ position_bias=position_bias,
++ encoder_hidden_states=None,
++ encoder_attention_mask=encoder_extended_attention_mask,
++ encoder_decoder_position_bias=encoder_decoder_position_bias,
++ layer_head_mask=layer_head_mask,
++ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
++ past_key=past_key,
++ past_value=past_value,
++ past_cross_key=past_cross_key,
++ past_cross_value=past_cross_value,
++ use_cache=use_cache,
++ output_attentions=output_attentions,
++ )
++
++ # layer_outputs is a tuple with:
++ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
++ if use_cache is False:
++ layer_outputs = layer_outputs[:1] + (None,) +(None,) + layer_outputs[1:]
++
++ hidden_states, present_key_state, present_value_state = layer_outputs[:3]
++
++ # We share the position biases between the layers - the first layer store them
++ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
++ # (cross-attention position bias), (cross-attention weights)
++ position_bias = layer_outputs[3]
++ # append next layer key value states
++ if use_cache:
++ present_key_states = present_key_states + present_key_state
++ present_value_states = present_value_states + present_value_state
++
++ if output_attentions:
++ all_attentions = all_attentions + (layer_outputs[3],)
++ if self.is_decoder:
++ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
++
++ hidden_states = self.final_layer_norm(hidden_states)
++ hidden_states = self.dropout(hidden_states).half()
++
++ # Add last layer
++ if output_hidden_states:
++ all_hidden_states = all_hidden_states + (hidden_states,)
++
++ if self.encodecrosskey:
++ cross_keys = self.encodecrosskey(hidden_states)
++ if self.encodecrossvalue:
++ cross_values = self.encodecrossvalue(hidden_states)
++ return tuple((hidden_states, cross_keys, cross_values))
+
+
+ T5_START_DOCSTRING = r"""
+@@ -1541,6 +1754,41 @@ class T5Model(T5PreTrainedModel):
+ )
+
+
++class EncoderToCrossKey(nn.Module):
++ def __init__(self, cross_key, num_heads, d_kv):
++ super().__init__()
++ self.cross_key = cross_key
++ self.num_heads = num_heads
++ self.d_kv = d_kv
++
++
++ def forward(self, hidden_states):
++ batch_size = hidden_states.shape[0]
++ past_cross_keys = ()
++ for i in range(len(self.cross_key)):
++ # past_cross_keys += (self.cross_key[i](hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1,2),)
++ past_cross_keys += (self.cross_key[i](hidden_states),)
++ return past_cross_keys
++
++
++class EncoderToCrossValue(nn.Module):
++ def __init__(self, cross_value, num_heads, d_kv):
++ super().__init__()
++ self.cross_value = cross_value
++ self.num_heads = num_heads
++ self.d_kv = d_kv
++
++
++ def forward(self, hidden_states):
++ batch_size = hidden_states.shape[0]
++ past_cross_values = ()
++ for i in range(len(self.cross_value)):
++ # past_cross_values += (self.cross_value[i](hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1,2),)
++ past_cross_values += (self.cross_value[i](hidden_states),)
++ # print("aaa",past_cross_values[0].shape)
++ return past_cross_values
++
++
+ @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
+ class T5ForConditionalGeneration(T5PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [
+@@ -1548,28 +1796,51 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ ]
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+- def __init__(self, config: T5Config):
++ def __init__(self, config: T5Config, encoder_path=None, decoder_path=None, device_id=0):
+ super().__init__(config)
+- self.model_dim = config.d_model
+-
+- self.shared = nn.Embedding(config.vocab_size, config.d_model)
+-
+- encoder_config = copy.deepcopy(config)
+- encoder_config.is_decoder = False
+- encoder_config.use_cache = False
+- encoder_config.is_encoder_decoder = False
+- self.encoder = T5Stack(encoder_config, self.shared)
+-
+- decoder_config = copy.deepcopy(config)
+- decoder_config.is_decoder = True
+- decoder_config.is_encoder_decoder = False
+- decoder_config.num_layers = config.num_decoder_layers
+- self.decoder = T5Stack(decoder_config, self.shared)
+-
+- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
++ self.encoder_path = encoder_path
++ self.decoder_path = decoder_path
++ self.is_mindie = False
++ if not self.encoder_path or not self.decoder_path:
++ self.model_dim = config.d_model
++
++ self.shared = nn.Embedding(config.vocab_size, config.d_model)
++
++ decoder_config = copy.deepcopy(config)
++ decoder_config.is_decoder = True
++ decoder_config.is_encoder_decoder = False
++ decoder_config.num_layers = config.num_decoder_layers
++
++ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
++ self.decoder = T5Stack(decoder_config, self.shared, self.lm_head)
++
++ cross_key = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.k for i in range(config.num_decoder_layers))
++ cross_value = nn.ModuleList(self.decoder.block[i].layer[1].EncDecAttention.v for i in range(config.num_decoder_layers))
++ encodecrosskey = EncoderToCrossKey(cross_key, config.num_heads, config.d_kv)
++ encodecrossvalue = EncoderToCrossValue(cross_value, config.num_heads, config.d_kv)
++
++ encoder_config = copy.deepcopy(config)
++ encoder_config.is_decoder = False
++ encoder_config.use_cache = False
++ encoder_config.is_encoder_decoder = False
++ self.encoder = T5Stack_Encoder(encoder_config, self.shared, encodecrosskey=encodecrosskey, encodecrossvalue=encodecrossvalue)
++ self.encoder_mindie = None
++ self.decoder_mindie = None
++ if self.encoder_path:
++ self.encoder_mindie = torch.jit.load(self.encoder_path)
++ self.is_mindie = True
++ if self.decoder_path:
++ self.decoder_mindie = torch.jit.load(self.decoder_path)
++
++ self.stream = torch.npu.Stream(f"npu:{device_id}")
++ self.device_id = device_id
++
++
++ def get_device(self):
++ return f"npu:{self.device_id}"
+
+ # Initialize weights and apply final processing
+- self.post_init()
++ # self.post_init()
+
+ # Model parallel
+ self.model_parallel = False
+@@ -1637,25 +1908,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+
+ @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+- def forward(
+- self,
+- input_ids: Optional[torch.LongTensor] = None,
+- attention_mask: Optional[torch.FloatTensor] = None,
+- decoder_input_ids: Optional[torch.LongTensor] = None,
+- decoder_attention_mask: Optional[torch.BoolTensor] = None,
+- head_mask: Optional[torch.FloatTensor] = None,
+- decoder_head_mask: Optional[torch.FloatTensor] = None,
+- cross_attn_head_mask: Optional[torch.Tensor] = None,
+- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+- inputs_embeds: Optional[torch.FloatTensor] = None,
+- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+- labels: Optional[torch.LongTensor] = None,
+- use_cache: Optional[bool] = None,
+- output_attentions: Optional[bool] = None,
+- output_hidden_states: Optional[bool] = None,
+- return_dict: Optional[bool] = None,
+- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
++ def forward(self,*args) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
+@@ -1687,113 +1940,36 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
+ >>> # studies have shown that owning a dog is good for you.
+ ```"""
+- use_cache = use_cache if use_cache is not None else self.config.use_cache
+- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+-
+- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
+- if head_mask is not None and decoder_head_mask is None:
+- if self.config.num_layers == self.config.num_decoder_layers:
+- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
+- decoder_head_mask = head_mask
+-
+- # Encode if needed (training, first prediction pass)
+- if encoder_outputs is None:
+- # Convert encoder inputs in embeddings if needed
+- encoder_outputs = self.encoder(
+- input_ids=input_ids,
+- attention_mask=attention_mask,
+- inputs_embeds=inputs_embeds,
+- head_mask=head_mask,
+- output_attentions=output_attentions,
+- output_hidden_states=output_hidden_states,
+- return_dict=return_dict,
+- )
+- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+- encoder_outputs = BaseModelOutput(
+- last_hidden_state=encoder_outputs[0],
+- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+- )
+-
+- hidden_states = encoder_outputs[0]
+-
+- if self.model_parallel:
+- torch.cuda.set_device(self.decoder.first_device)
+-
+- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+- # get decoder inputs from shifting lm labels to the right
+- decoder_input_ids = self._shift_right(labels)
+-
+- # Set device for model parallelism
+- if self.model_parallel:
+- torch.cuda.set_device(self.decoder.first_device)
+- hidden_states = hidden_states.to(self.decoder.first_device)
+- if decoder_input_ids is not None:
+- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
+- if attention_mask is not None:
+- attention_mask = attention_mask.to(self.decoder.first_device)
+- if decoder_attention_mask is not None:
+- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
+-
+- # Decode
+- decoder_outputs = self.decoder(
+- input_ids=decoder_input_ids,
+- attention_mask=decoder_attention_mask,
+- inputs_embeds=decoder_inputs_embeds,
+- past_key_values=past_key_values,
+- encoder_hidden_states=hidden_states,
+- encoder_attention_mask=attention_mask,
+- head_mask=decoder_head_mask,
+- cross_attn_head_mask=cross_attn_head_mask,
+- use_cache=use_cache,
+- output_attentions=output_attentions,
+- output_hidden_states=output_hidden_states,
+- return_dict=return_dict,
+- )
+-
+- sequence_output = decoder_outputs[0]
+-
+- # Set device for model parallelism
+- if self.model_parallel:
+- torch.cuda.set_device(self.encoder.first_device)
+- self.lm_head = self.lm_head.to(self.encoder.first_device)
+- sequence_output = sequence_output.to(self.lm_head.weight.device)
+-
+- if self.config.tie_word_embeddings:
+- # Rescale output before projecting on vocab
+- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+- sequence_output = sequence_output * (self.model_dim**-0.5)
+-
+- lm_logits = self.lm_head(sequence_output)
+-
+- loss = None
+- if labels is not None:
+- loss_fct = CrossEntropyLoss(ignore_index=-100)
+- # move labels to correct device to enable PP
+- labels = labels.to(lm_logits.device)
+- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
+- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
+-
+- if not return_dict:
+- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
+- return ((loss,) + output) if loss is not None else output
+-
+- return Seq2SeqLMOutput(
+- loss=loss,
+- logits=lm_logits,
+- past_key_values=decoder_outputs.past_key_values,
+- decoder_hidden_states=decoder_outputs.hidden_states,
+- decoder_attentions=decoder_outputs.attentions,
+- cross_attentions=decoder_outputs.cross_attentions,
+- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+- encoder_hidden_states=encoder_outputs.hidden_states,
+- encoder_attentions=encoder_outputs.attentions,
+- )
++ if self.is_mindie:
++ with torch.npu.stream(self.stream): # set stream
++ decoder_outputs = self.decoder_mindie.forward(*args)
++ self.stream.synchronize() # synchronize
++ else:
++ hidden_states = args[0]
++ past_cross_keys = args[1:self.config.num_decoder_layers+1]
++ past_cross_values = args[self.config.num_decoder_layers+1:2*self.config.num_decoder_layers+1]
++ past_keys= args[2*self.config.num_decoder_layers+1:3*self.config.num_decoder_layers+1]
++ past_values= args[3*self.config.num_decoder_layers+1:4*self.config.num_decoder_layers+1]
++ encoder_attention_mask = args[-3]
++ decoder_input_ids = args[-2]
++ decoder_attention_mask = args[-1]
++ decoder_outputs = self.decoder(input_ids=decoder_input_ids,
++ encoder_hidden_states=hidden_states,
++ past_keys=past_keys,
++ past_values=past_values,
++ past_cross_keys=past_cross_keys,
++ past_cross_values=past_cross_values,
++ encoder_attention_mask=encoder_attention_mask,
++ attention_mask=decoder_attention_mask)
++ return (decoder_outputs[0],decoder_outputs[1],decoder_outputs[2])
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+- past_key_values=None,
++ past_cross_keys=None,
++ past_cross_values=None,
++ past_keys=None,
++ past_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+@@ -1804,8 +1980,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+- if past_key_values is not None:
+- past_length = past_key_values[0][0].shape[2]
++ if past_keys is not None:
++ past_length = past_keys[0].shape[1]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+@@ -1813,12 +1989,19 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+-
+ input_ids = input_ids[:, remove_prefix_length:]
+
++ batch_size, seq_length = input_ids.shape
++ # required mask seq length can be calculated via length of past
++ mask_seq_length = past_keys[0].shape[1] + seq_length if past_keys is not None else seq_length
++ decoder_attention_mask = torch.zeros(batch_size, mask_seq_length, device=input_ids.device)
++ decoder_attention_mask = decoder_attention_mask[:,None,None,:].expand(batch_size,1,mask_seq_length,mask_seq_length).bool()
+ return {
+ "decoder_input_ids": input_ids,
+- "past_key_values": past_key_values,
++ "past_cross_keys":past_cross_keys,
++ "past_cross_values":past_cross_values,
++ "past_keys":past_keys,
++ "past_values":past_values,
+ "encoder_outputs": encoder_outputs,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+@@ -1826,6 +2009,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ "decoder_attention_mask": decoder_attention_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
++
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+@@ -1861,6 +2045,440 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
+ return reordered_decoder_past
+
++ def _prepare_encoder_decoder_kwargs_for_generation(
++ self,
++ inputs_tensor: torch.Tensor,
++ model_kwargs,
++ model_input_name,
++ generation_config,
++ ):
++ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
++ encoder_kwargs = {
++ argument: value
++ for argument, value in model_kwargs.items()
++ if not any(argument.startswith(p) for p in irrelevant_prefix)
++ }
++ encoder_kwargs["output_attentions"] = generation_config.output_attentions
++ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
++ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
++ encoder_kwargs["return_dict"] = True
++ encoder_kwargs[model_input_name] = inputs_tensor
++ encoder_outputs = None
++ if self.is_mindie:
++ with torch.npu.stream(self.stream): # set stream
++ encoder_outputs=self.encoder_mindie.forward(encoder_kwargs["input_ids"],encoder_kwargs["attention_mask"])
++ self.stream.synchronize() # synchronize
++ else:
++ encoder_outputs=self.encoder.forward(**encoder_kwargs)
++ model_kwargs["encoder_outputs"]={"last_hidden_state":encoder_outputs[0]}
++ model_kwargs["past_cross_keys"] = encoder_outputs[1]
++ model_kwargs["past_cross_values"] =encoder_outputs[2]
++ # print("model_kwargs=",model_kwargs)
++ return model_kwargs
++
++ def _update_model_kwargs_for_generation(
++ self,
++ outputs,
++ model_kwargs,
++ is_encoder_decoder = False,
++ standardize_cache_format = False,
++ num_new_tokens = 1,
++ ):
++ # update past_key_values keeping its naming used in model code
++ cache_name, cache = self._extract_past_from_model_output(
++ outputs, standardize_cache_format=standardize_cache_format
++ )
++ model_kwargs[cache_name] = cache
++ if "past_keys" in outputs:
++ past_keys = outputs.past_keys
++ model_kwargs["past_keys"] = past_keys
++ if "past_values" in outputs:
++ past_values = outputs.past_values
++ model_kwargs["past_values"] = past_values
++ # update decoder attention mask
++ if "decoder_attention_mask" in model_kwargs:
++ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
++ model_kwargs["decoder_attention_mask"] = torch.cat(
++ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
++ dim=-1,
++ )
++ return model_kwargs
++
++ @torch.no_grad()
++ def generate(
++ self,
++ inputs = None,
++ generation_config = None,
++ logits_processor = None,
++ stopping_criteria = None,
++ prefix_allowed_tokens_fn = None,
++ assistant_model = None,
++ negative_prompt_ids = None,
++ negative_prompt_attention_mask = None,
++ **kwargs,
++ ):
++ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
++ self._validate_model_class()
++ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
++ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
++ self._validate_model_kwargs(model_kwargs.copy())
++
++
++ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
++ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
++
++ accepts_attention_mask = True
++ requires_attention_mask = "encoder_outputs" not in model_kwargs
++ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
++
++ # 3. Define model inputs
++ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
++ inputs, generation_config.bos_token_id, model_kwargs
++ )
++ batch_size = inputs_tensor.shape[0]
++ seq_len = inputs_tensor.shape[1]
++ device = inputs_tensor.device
++ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
++
++ # 4. Define other model kwargs
++ # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
++ # generating the first new token or not, and we only want to use the embeddings for the first new token)
++ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
++ model_kwargs["use_cache"] = True
++ else:
++ model_kwargs["use_cache"] = generation_config.use_cache
++ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
++ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
++ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
++ )
++ attention_mask = model_kwargs["attention_mask"]
++ attention_mask = attention_mask[:,None,None,:].expand(batch_size,1,seq_len,seq_len).bool()
++ model_kwargs["attention_mask"] = ~attention_mask
++ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
++ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
++ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
++ inputs_tensor, model_kwargs, model_input_name, generation_config
++ )
++
++ # 5. Prepare `input_ids` which will be used for auto-regressive generation
++ if self.config.is_encoder_decoder:
++ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
++ batch_size=batch_size,
++ model_input_name=model_input_name,
++ model_kwargs=model_kwargs,
++ decoder_start_token_id=generation_config.decoder_start_token_id,
++ device=inputs_tensor.device,
++ )
++ else:
++ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
++
++ if generation_config.token_healing:
++ input_ids = self.heal_tokens(input_ids, tokenizer)
++
++ # 6. Prepare `max_length` depending on other stopping criteria.
++ input_ids_length = input_ids.shape[-1]
++ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
++ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
++ generation_config = self._prepare_generated_length(
++ generation_config=generation_config,
++ has_default_max_length=has_default_max_length,
++ has_default_min_length=has_default_min_length,
++ model_input_name=model_input_name,
++ inputs_tensor=inputs_tensor,
++ input_ids_length=input_ids_length,
++ )
++
++ use_dynamic_cache_by_default = False
++ if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
++ raise ValueError(
++ "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
++ "Cache object) is unsupported. Please use only one of the two."
++ )
++ elif generation_config.cache_implementation is not None:
++ if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
++ if generation_config.cache_implementation == "static" and not self._supports_static_cache:
++ raise ValueError(
++ "This model does not support `cache_implementation='static'`. Please check the following "
++ "issue: https://github.com/huggingface/transformers/issues/28981"
++ )
++ model_kwargs["past_key_values"] = self._get_cache(
++ generation_config.cache_implementation,
++ getattr(generation_config, "num_beams", 1) * batch_size,
++ generation_config.max_length,
++ )
++ elif generation_config.cache_implementation == "quantized":
++ if not self._supports_quantized_cache:
++ raise ValueError(
++ "This model does not support the quantized cache. If you want your model to support quantized "
++ "cache, please open an issue."
++ )
++
++ cache_config = (
++ generation_config.cache_config
++ if generation_config.cache_config is not None
++ else QuantizedCacheConfig()
++ )
++ cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
++
++ if cache_config.backend == "quanto" and not is_quanto_available():
++ raise ImportError(
++ "You need to install `quanto` in order to use KV cache quantization with quanto backend. "
++ "Please install it via with `pip install quanto`"
++ )
++ elif cache_config.backend == "HQQ" and not is_hqq_available():
++ raise ImportError(
++ "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
++ "Please install it via with `pip install hqq`"
++ )
++
++ model_kwargs["past_key_values"] = cache_class(cache_config)
++ # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
++ # keeps copying the cache thus using much more memory
++ elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
++ past = model_kwargs.get("past_key_values", None)
++ if past is None:
++ model_kwargs["past_key_values"] = DynamicCache()
++ use_dynamic_cache_by_default = True
++ elif isinstance(past, tuple):
++ model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
++ use_dynamic_cache_by_default = True
++
++ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
++
++ # 7. determine generation mode
++ generation_mode = generation_config.get_generation_mode(assistant_model)
++ # 8. prepare distribution pre_processing samplers
++ prepared_logits_processor = self._get_logits_processor(
++ generation_config=generation_config,
++ input_ids_seq_length=input_ids_length,
++ encoder_input_ids=inputs_tensor,
++ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
++ logits_processor=logits_processor,
++ device=inputs_tensor.device,
++ model_kwargs=model_kwargs,
++ negative_prompt_ids=negative_prompt_ids,
++ negative_prompt_attention_mask=negative_prompt_attention_mask,
++ )
++
++ # 9. prepare stopping criteria
++ prepared_stopping_criteria = self._get_stopping_criteria(
++ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
++ )
++
++ if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
++ # 11. prepare logits warper
++ prepared_logits_warper = (
++ self._get_logits_warper(generation_config, device=input_ids.device)
++ if generation_config.do_sample
++ else None
++ )
++
++ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
++ input_ids, model_kwargs = self._expand_inputs_for_generation(
++ input_ids=input_ids,
++ expand_size=generation_config.num_return_sequences,
++ is_encoder_decoder=self.config.is_encoder_decoder,
++ **model_kwargs,
++ )
++ # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
++ result = self._sample(
++ input_ids,
++ logits_processor=prepared_logits_processor,
++ logits_warper=prepared_logits_warper,
++ stopping_criteria=prepared_stopping_criteria,
++ generation_config=generation_config,
++ **model_kwargs,
++ )
++ return result
++
++ def _sample(
++ self,
++ input_ids,
++ logits_processor,
++ stopping_criteria,
++ generation_config,
++ logits_warper = None,
++ **model_kwargs,
++ ):
++ # init values
++ pad_token_id = generation_config.pad_token_id
++ output_attentions = generation_config.output_attentions
++ output_hidden_states = generation_config.output_hidden_states
++ output_scores = generation_config.output_scores
++ output_logits = generation_config.output_logits
++ return_dict_in_generate = generation_config.return_dict_in_generate
++ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
++ do_sample = generation_config.do_sample
++ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
++ raise ValueError(
++ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
++ f"{logits_warper})."
++ )
++
++ # init attention / hidden states / scores tuples
++ scores = () if (return_dict_in_generate and output_scores) else None
++ raw_logits = () if (return_dict_in_generate and output_logits) else None
++ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
++ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
++ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
++
++ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
++ if return_dict_in_generate and self.config.is_encoder_decoder:
++ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
++ encoder_hidden_states = (
++ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
++ )
++
++ this_peer_finished = False
++ batch_size = input_ids.shape[0]
++ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
++ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
++
++ # keep track of which sequences are already finished
++ if self.is_mindie or self.config.architectures[0]=="T5ForConditionalGeneration":
++ num_layers = self.config.num_layers
++ num_heads = self.config.num_heads
++ d_kv = self.config.d_kv
++ model_kwargs["past_keys"] = [torch.randn(batch_size, 0, num_heads*d_kv).half().npu() for _ in range(num_layers)]
++ model_kwargs["past_values"] = [torch.randn(batch_size, 0, num_heads*d_kv).half().npu() for _ in range(num_layers)]
++
++
++ while self._has_unfinished_sequences(this_peer_finished, False, device=input_ids.device):
++ # prepare model inputs
++ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
++ model_args = [model_kwargs["encoder_outputs"]["last_hidden_state"]]
++ model_args.extend(model_kwargs["past_cross_keys"])
++ model_args.extend(model_kwargs["past_cross_values"])
++ model_args.extend(model_inputs["past_keys"])
++ model_args.extend(model_inputs["past_values"])
++ model_args.append(model_inputs["attention_mask"])
++ model_args.append(model_inputs["decoder_input_ids"])
++ model_args.append(model_inputs["decoder_attention_mask"])
++
++ # forward pass to get next token
++ outputs = self(*model_args)
++ outputs = Seq2SeqLMOutput(logits=outputs[0],
++ past_keys=outputs[1],
++ past_values=outputs[2])
++
++ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
++ # (the clone itself is always small)
++ next_token_logits = outputs.logits[:, -1, :].clone()
++
++ # pre-process distribution
++ next_token_scores = logits_processor(input_ids, next_token_logits)
++ if do_sample:
++ next_token_scores = logits_warper(input_ids, next_token_scores)
++
++ # Store scores, attentions and hidden_states when required
++ if return_dict_in_generate:
++ if output_scores:
++ scores += (next_token_scores,)
++ if output_logits:
++ raw_logits += (next_token_logits,)
++ if output_attentions:
++ decoder_attentions += (
++ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
++ )
++ if self.config.is_encoder_decoder:
++ cross_attentions += (outputs.cross_attentions,)
++
++ if output_hidden_states:
++ decoder_hidden_states += (
++ (outputs.decoder_hidden_states,)
++ if self.config.is_encoder_decoder
++ else (outputs.hidden_states,)
++ )
++
++ # token selection
++ if do_sample:
++ probs = nn.functional.softmax(next_token_scores, dim=-1)
++ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
++ else:
++ next_tokens = torch.argmax(next_token_scores, dim=-1)
++
++ # finished sentences should have their next token be a padding token
++ if has_eos_stopping_criteria:
++ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
++
++ # update generated ids, model inputs, and length for next step
++ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
++ model_kwargs = self._update_model_kwargs_for_generation(
++ outputs,
++ model_kwargs,
++ is_encoder_decoder=self.config.is_encoder_decoder,
++ )
++ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
++ this_peer_finished = unfinished_sequences.max() == 0
++ # This is needed to properly delete outputs.logits which may be very large for first iteration
++ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
++ del outputs
++ return input_ids
++
++
++ @property
++ def device(self) -> torch.device:
++ """
++ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
++ device).
++ """
++ return self.get_device()
++
++ def get_extended_attention_mask(
++ self, attention_mask, input_shape, devic=None, dtype=None
++ ):
++ """
++ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
++
++ Arguments:
++ attention_mask (`torch.Tensor`):
++ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
++ input_shape (`Tuple[int]`):
++ The shape of the input to the model.
++
++ Returns:
++ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
++ """
++ if dtype is None:
++ dtype = self.dtype
++
++ if not (attention_mask.dim() == 2 and self.config.is_decoder):
++ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
++ if device is not None:
++ warnings.warn(
++ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
++ )
++ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
++ # ourselves in which case we just need to make it broadcastable to all heads.
++ if attention_mask.dim() == 3:
++ extended_attention_mask = attention_mask[:, None, :, :]
++ elif attention_mask.dim() == 2:
++ # Provided a padding mask of dimensions [batch_size, seq_length]
++ # - if the model is a decoder, apply a causal mask in addition to the padding mask
++ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
++ if self.config.is_decoder:
++ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
++ input_shape, attention_mask, device
++ )
++ else:
++ extended_attention_mask = attention_mask[:, None, None, :]
++ else:
++ raise ValueError(
++ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
++ )
++
++ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
++ # masked positions, this operation will create a tensor which is 0.0 for
++ # positions we want to attend and the dtype's smallest value for masked positions.
++ # Since we are adding it to the raw scores before the softmax, this is
++ # effectively the same as removing these entirely.
++ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
++ #extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
++ extended_attention_mask = (1.0 - extended_attention_mask) * -1000
++ return extended_attention_mask
++
++
++
+
+ @add_start_docstrings(
+ "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
+@@ -1878,6 +2496,9 @@ class T5EncoderModel(T5PreTrainedModel):
+ encoder_config.use_cache = False
+ encoder_config.is_encoder_decoder = False
+ self.encoder = T5Stack(encoder_config, self.shared)
++ self.decoder_mindie = torch.jit.load("encoder_model_path")
++
++ self.stream = torch.npu.Stream(f"npu:{2}")
+
+ # Initialize weights and apply final processing
+ self.post_init()
+@@ -1966,17 +2587,21 @@ class T5EncoderModel(T5PreTrainedModel):
+ >>> outputs = model(input_ids=input_ids)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+-
+- encoder_outputs = self.encoder(
+- input_ids=input_ids,
+- attention_mask=attention_mask,
+- inputs_embeds=inputs_embeds,
+- head_mask=head_mask,
+- output_attentions=output_attentions,
+- output_hidden_states=output_hidden_states,
+- return_dict=return_dict,
+- )
++ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
++ # encoder_outputs = self.encoder(
++ # input_ids=input_ids,
++ # attention_mask=attention_mask,
++ # inputs_embeds=inputs_embeds,
++ # head_mask=head_mask,
++ # output_attentions=output_attentions,
++ # output_hidden_states=output_hidden_states,
++ # return_dict=return_dict,
++ # )
++ attention_mask = attention_mask[:,None,None,:].expand(attention_mask.shape[0],1,attention_mask.shape[1],attention_mask.shape[1]).bool()
++ attention_mask = ~attention_mask
++ with torch.npu.stream(self.stream): # set stream
++ encoder_outputs = self.decoder_mindie.forward(input_ids,attention_mask)
++ self.stream.synchronize() # synchronize
+
+ return encoder_outputs
+
diff --git a/MindIE/MindIE-Torch/built-in/T5/readme.md b/MindIE/MindIE-Torch/built-in/T5/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..a8cd519940cba37c301b7f3d1212fb7a7cb2ecfc
--- /dev/null
+++ b/MindIE/MindIE-Torch/built-in/T5/readme.md
@@ -0,0 +1,162 @@
+# T5模型-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+ - [输入输出数据](#section540883920406)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+ - [模型推理](#section741711594517)
+
+
+
+# 概述
+
+ T5的全称为Text to Text Transfer Transformer,是谷歌提出的预训练语言模型领域的通用模型,该模型将所有自然语言问题都转化成文本到文本的形式,并用一个统一的模型解决.T5最核心的理念是:使用前缀任务声明及文本答案生成,统一所有自然语言处理任务的输入和输出。在此之前的几乎所有预训练语言模型,在下游任务微调过程中都需要添加非线性层,将模型的输出转化为任务指定的输出格式。T5不需要对模型做任何改动,只需要提供下游任务的微调数据;不需要添加任何非线性层,唯一需要做的就是在输入数据前加上任务声明前缀.T5将自然语言处理任务都转化成几乎一致的格式,即输入是带有任务前缀声明的文本序列,输出的文本序列是相应任务的结果
+权重下载:https://huggingface.co/collections/google/t5-release-65005e7c520f8d7b4d037918
+
+
+## 输入输出数据
+
+- 输入数据
+
+ | 输入数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | input | batchsize x input_seq_len | FLOAT16 | NHWC |
+
+
+- 输出数据
+
+ | 输出数据 | 大小 | 数据类型 | 数据排布格式 |
+ | -------- | -------- | -------- | ------------ |
+ | output | batchsize x input_seq_len | INT32 | NTHWC |
+
+
+# 推理环境准备
+
+- 该模型需要以下插件与驱动
+
+ **表 1** 版本配套表
+-
+ | 配套 | 版本 | 备注 |
+ | ------------------------------------------------------------ |--------| ------------------------------------------------------------ |
+ | Python | 3.10.2 | - |
+ | torch | 2.1.0 | 导出pt模型所需版本 |
+ | torch_npu | 2.1.0 | 模型编译和推理所需版本 |
+
+
+# 快速上手
+
+
+1. 安装transformers4.42.0版本。
+ ```bash
+ pip3 install transformers==4.42.0
+ ```
+
+2. 安装mindie包,需要与torch_npu配合使用,请参考mindietorch配套torch_npu配置环境
+
+ ```bash
+ # 安装mindie
+ chmod +x ./Ascend-mindie_xxx.run
+ ./Ascend-mindie_xxx.run --install
+ source /usr/local/Ascend/mindie/set_env.sh
+ ```
+
+3. 代码修改,在T5目录下
+
+ 执行命令:
+ ```bash
+ python T5_modeling_t5_patch.py --ascend_soc {Ascend910B4 or Ascend310P3}
+ ```
+4.导出mindietorch模型
+300IDUO卡环境下:
+ ```bash
+ python export_t5.py --output_dir {output_path} --model_path {model_path} --max_batchsize {max_batchsize} --max_input_seq_len {max_input_seq_len} --device_id {device_id}
+ ```
+800IA2卡环境下:
+ ```bash
+ python export_t5_800IA2.py --output_dir {output_path} --model_path {model_path} --max_batchsize {max_batchsize} --max_input_seq_len {max_input_seq_len} --device_id {device_id}
+ ```
+参数说明:
+{output_path}是输出的目录
+{model_path}模型所在目录
+{max_batchsize}推理过程中最大的batchsize
+{max_input_seq_len}推理过程中最大输入长度
+{device_id} 用哪个npu device
+
+运行该命令后会自动生成encoder和decoder优化后的模型
+
+5.运行与性能测试
+导入环境变量:export TORCH_AIE_NPU_CACHE_MAX_SIZE=32
+ ```bash
+python main.py --hf_model_path {model_path} --encoder_aie_path {encoder_aie_path} --decoder_aie_path {decoder_aie_path} --device_id 2
+```
+性能测试:
+ ```bash
+python main.py --hf_model_path {model_path} --encoder_aie_path {encoder_aie_path} --decoder_aie_path {decoder_aie_path} --device_id 2 --performance
+```
+打屏可以看到输入长度为512,输出长度为512单batch下的吞吐
+参数说明:
+{model_path}模型所在目录
+{encoder_aie_path}优化后的encoder的模型路径,要具体到.pt文件
+{decoder_aie_path}优化后的decoder的模型路径,要具体到.pt文件
+{device_id} 用哪个npu device
+
+6.精度测试
+
+6.1 精度验收标准
+数据集:(英文数据集选一种测试),精度和GPU推理结果对比误差小于1%
+
+6.2 精度测试方法
+
+6.2.1安装mteb和sentence_transformes
+
+ ```bash
+pip sentence_transformes==3.1.1
+pip install mteb
+```
+6.2.2 下载mteb数据集(如果机器可以连接外部网络可以跳过这步)
+下载链接:https://github.com/embeddings-benchmark/mteb
+
+6.2.3 修改metb的读取数据集的路径地址(如果机器可以连接外部网络可以跳过这步)
+例如如果下载的是Banking77Classification数据集,修改mteb python包里的文件路径,例如
+D:\python3.9\Lib\site-packages\mteb\tasks\Classification\eng\Banking77Classification.py文件里的path路径为6.2.2下载的数据集的路径
+
+6.2.4 修改代码
+
+800IA2卡环境下:
+修改transfoermers包下modeling_t5.py下的T5EncoderModel类,将self.decoder_mindie加载路径修改为编译好的encoder的路径
+
+300IDUO卡环境下:
+修改transfoermers包下modeling_t5.py下的T5EncoderModel类,增加2行,
+```bash
+self.encoder_mindie = torch.jit.load("encoder_model_path")
+self.stream = torch.npu.Stream(f"npu:{device_id}")
+```
+其中encoder_model_path为编译好的encoder的路径,device_id为当前设置的npu卡号,再修改forward接口为
+```bash
+with torch.npu.stream(self.stream): # set stream
+ encoder_outputs = self.encoder_mindie.forward(input_ids,attention_mask)
+self.stream.synchronize() # synchronize
+return encoder_outputs
+```
+6.2.5测试代码
+
+```bash
+import torch
+import torch_npu
+import mindietorch
+import mteb
+from sentence_transformers import SentenceTransformer
+torch.npu.set_device(0)
+model_name = "D:\downloads\T5-v2"
+model = SentenceTransformer(model_name,model_kwargs={"torch_dtype":torch.float16})
+tasks = mteb.get_tasks(tasks=["CLSClusteringP2P"])
+evaluation = mteb.MTEB(tasks=tasks)
+results = evaluation.run(model, output_folder=f"./{model_name}")
+```
+6.2.6 结果输出
+会在当前目录输出结果文件