diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/PyanNet.patch b/MindIE/MindIE-Torch/built-in/audio/WhisperX/PyanNet.patch new file mode 100644 index 0000000000000000000000000000000000000000..3d29088aba8809e2960adb45e24e233e10be3b53 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/PyanNet.patch @@ -0,0 +1,87 @@ +--- PyanNet.py 2024-10-28 11:06:38.577263005 +0800 ++++ whisperX_update/PyanNet.py 2024-10-28 10:53:18.000000000 +0800 +@@ -97,27 +97,27 @@ + del multi_layer_lstm["monolithic"] + self.lstm = nn.LSTM(60, **multi_layer_lstm) + +- else: +- num_layers = lstm["num_layers"] +- if num_layers > 1: +- self.dropout = nn.Dropout(p=lstm["dropout"]) +- +- one_layer_lstm = dict(lstm) +- one_layer_lstm["num_layers"] = 1 +- one_layer_lstm["dropout"] = 0.0 +- del one_layer_lstm["monolithic"] +- +- self.lstm = nn.ModuleList( +- [ +- nn.LSTM( +- 60 +- if i == 0 +- else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), +- **one_layer_lstm +- ) +- for i in range(num_layers) +- ] +- ) ++ # else: ++ # num_layers = lstm["num_layers"] ++ # if num_layers > 1: ++ # self.dropout = nn.Dropout(p=lstm["dropout"]) ++ ++ # one_layer_lstm = dict(lstm) ++ # one_layer_lstm["num_layers"] = 1 ++ # one_layer_lstm["dropout"] = 0.0 ++ # del one_layer_lstm["monolithic"] ++ ++ # self.lstm = nn.ModuleList( ++ # [ ++ # nn.LSTM( ++ # 60 ++ # if i == 0 ++ # else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), ++ # **one_layer_lstm ++ # ) ++ # for i in range(num_layers) ++ # ] ++ # ) + + if linear["num_layers"] < 1: + return +@@ -171,19 +171,22 @@ + + outputs = self.sincnet(waveforms) + +- if self.hparams.lstm["monolithic"]: +- outputs, _ = self.lstm( +- rearrange(outputs, "batch feature frame -> batch frame feature") +- ) +- else: +- outputs = rearrange(outputs, "batch feature frame -> batch frame feature") +- for i, lstm in enumerate(self.lstm): +- outputs, _ = lstm(outputs) +- if i + 1 < self.hparams.lstm["num_layers"]: +- outputs = self.dropout(outputs) ++ # if self.hparams.lstm["monolithic"]: ++ outputs, _ = self.lstm( ++ rearrange(outputs, "batch feature frame -> batch frame feature") ++ ) ++ # else: ++ # outputs = rearrange(outputs, "batch feature frame -> batch frame feature") ++ # for i, lstm in enumerate(self.lstm): ++ # outputs, _ = lstm(outputs) ++ # if i + 1 < self.hparams.lstm["num_layers"]: ++ # outputs = self.dropout(outputs) ++ ++ # if self.hparams.linear["num_layers"] > 0: ++ # for linear in self.linear: ++ # outputs = F.leaky_relu(linear(outputs)) + +- if self.hparams.linear["num_layers"] > 0: +- for linear in self.linear: +- outputs = F.leaky_relu(linear(outputs)) ++ outputs = F.leaky_relu(self.linear[0](outputs)) ++ outputs = F.leaky_relu(self.linear[1](outputs)) + + return self.activation(self.classifier(outputs)) diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_300IPro_whisper.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_300IPro_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..b5dbdeb9ee2e8769f6455fad48c93e1fabf0d69f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_300IPro_whisper.py @@ -0,0 +1,162 @@ +import os +import argparse +import time +import math +import torch +import mindietorch +from mindietorch._enums import dtype +from modeling_whisper_300IPro import MindieWhisperForConditionalGeneration +from utils import CompileInfo + +def compile_encoder(model : MindieWhisperForConditionalGeneration, + args : argparse, + compile_info : CompileInfo): + encoder = model.get_encoder() + + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, compile_info.mel_feature_size, compile_info.max_frames))] + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + +def compile_prefill_decoder(model : MindieWhisperForConditionalGeneration, + args : argparse, compile_info : CompileInfo): + print("Start compiling prefill_decoder.") + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + +def compile_incre_decoder(args : argparse, compile_info : CompileInfo): + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, is_use_ifa=True).to("cpu") + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + ] * compile_info.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16 + )] * compile_info.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), + mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + +def compile_scatter_update(args, compile_info): + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + + bs = args.bs + self_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + indices = torch.tensor([0] * bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=update_states.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, encoder_attn_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[4]}{args.bs}.ts") + print("compile scatter success.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=32, help="please provide batch_size, default:32.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-machine_type', type=str, choices=["300IPro", "800IA2"], default="800A2") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path).to("cpu") + device = f"npu:{args.device_id}" + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + if not args.save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + print(f"Directory {args.save_path} created.") + else: + print(f"Directory {args.save_path} already exists.") + mindietorch.set_device(args.device_id) + compile_scatter_update(args, CompileInfo) + compile_encoder(mindie_whisper, args, CompileInfo) + compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_incre_decoder(args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_800IA2_whisper.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_800IA2_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..018199cf89f5b114d0038d37b1bb041cf9257c84 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_800IA2_whisper.py @@ -0,0 +1,158 @@ +import os +import argparse +import time +import math +import torch +import mindietorch +from mindietorch._enums import dtype +from utils import CompileInfo +from modeling_whisper_800IA2 import MindieWhisperForConditionalGeneration + +def compile_encoder(model, args : argparse, compile_info : CompileInfo): + encoder = model.get_encoder() + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, compile_info.mel_feature_size, compile_info.max_frames))] + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + +def compile_prefill_decoder(model, args : argparse, compile_info : CompileInfo): + print("Start compiling prefill_decoder.") + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + +def compile_incre_decoder(args : argparse, compile_info : CompileInfo): + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, is_use_ifa=True).to("cpu") + + + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + ] * compile_info.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16 + )] * compile_info.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), + mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + +def compile_scatter_update(args, compile_info): + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + + bs = args.bs + self_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + indices = torch.tensor([0] * bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=update_states.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, encoder_attn_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[4]}{args.bs}.ts") + print("compile scatter success.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + args = parser.parse_args() + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path).to("cpu") + device = f"npu:{args.device_id}" + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + if not args.save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + print(f"Directory {args.save_path} created.") + else: + print(f"Directory {args.save_path} already exists.") + mindietorch.set_device(args.device_id) + compile_scatter_update(args, CompileInfo) + compile_encoder(mindie_whisper, args, CompileInfo) + compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_incre_decoder(args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_vad.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_vad.py new file mode 100644 index 0000000000000000000000000000000000000000..7baa6f9e2f8eeb79bda979f56b2ebcbe82be3701 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_vad.py @@ -0,0 +1,68 @@ +import os +import argparse + +import torch +import torch.nn as nn +from pyannote.audio import Model + +import mindietorch + +SAMPLE_RATE = 16000 +WIND_SIZE = 5 +MIN_BATCH_SIZE = 1 +MAX_BATCH_SIZE = 32 +CHANNEL = 1 + + +def trace_vad(model_dir, traced_model_dir): + vad_model = Model.from_pretrained(model_dir, use_auth_token=None) + chunks = torch.randn(MAX_BATCH_SIZE, CHANNEL, WIND_SIZE * SAMPLE_RATE) + torch.jit.save(vad_model.to_torchscript(method="trace", example_inputs=chunks), traced_model_dir) + + +def compile_vad(traced_model_dir, compiled_model_dir, soc_version): + traced_model = torch.jit.load(traced_model_dir) + traced_model.eval() + + min_shape = (MIN_BATCH_SIZE, CHANNEL, WIND_SIZE * SAMPLE_RATE) + max_shape = (MAX_BATCH_SIZE, CHANNEL, WIND_SIZE * SAMPLE_RATE) + mie_inputs = [] + mie_inputs.append(mindietorch.Input(min_shape=min_shape, max_shape=max_shape)) + + compiled_module = mindietorch.compile( + traced_model, + inputs=mie_inputs, + precision_policy=mindietorch.PrecisionPolicy.FP16, + truncate_long_and_double=True, + require_full_compilation=False, + allow_tensor_replace_int=True, + torch_executed_ops=[], + soc_version=soc_version, + optimization_level=0) + + torch.jit.save(compiled_module, compiled_model_dir) + print(f"save {compiled_model_dir} success.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-vad_model_path', type=str, required=True, help="please provide vad model path.") + parser.add_argument('-soc_version', type=str, required=True, help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + device = f"npu:{args.device_id}" + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + print(f"Directory {args.save_path} created.") + else: + print(f"Directory {args.save_path} already exists.") + mindietorch.set_device(args.device_id) + + vad_model_dir = os.path.join(args.vad_model_path, "whisperx-vad-segmentation.bin") + vad_traced_model_dir = os.path.join(args.save_path, "vad_traced_model.pt") + vad_compiled_model_dir = os.path.join(args.save_path, "mindie_vad.ts") + + trace_vad(vad_model_dir, vad_traced_model_dir) + + compile_vad(vad_traced_model_dir, vad_compiled_model_dir, args.soc_version) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_whisper.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..64bce685a5bcc1409a8d6cc9b359c1efbe6c4595 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/compile_whisper.py @@ -0,0 +1,169 @@ +import os +import argparse +import time +import math +import torch +import mindietorch +from mindietorch._enums import dtype +from utils import CompileInfo + +def compile_encoder(model, args : argparse, compile_info : CompileInfo): + encoder = model.get_encoder() + class Encoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + return self.model(input_features=input_features, return_dict=False) + + input_features = torch.randn([args.bs, compile_info.mel_feature_size, compile_info.max_frames]) + encoder_traced = torch.jit.trace(Encoder(encoder), (input_features)) + input_info = [mindietorch.Input(shape=(args.bs, compile_info.mel_feature_size, compile_info.max_frames))] + compiled = mindietorch.compile(encoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + ) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[0]}{args.bs}.ts") + torch.jit.save(compiled, save_file) + print(f"Compile encoder success, saved in {save_file}") + +def compile_prefill_decoder(model, args : argparse, compile_info : CompileInfo): + print("Start compiling prefill_decoder.") + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + prefill_decoder_traced = torch.jit.trace(model.eval(), (decoder_input_ids, encoder_outputs)) + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size))] + prefill_decoder_compiled = mindietorch.compile(prefill_decoder_traced, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version) + + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[1]}{args.bs}.ts") + torch.jit.save(prefill_decoder_compiled, save_file) + print(f"Compile prefill_decoder success, saved in {save_file}.") + +def compile_incre_decoder(args : argparse, compile_info : CompileInfo): + class Decoder(torch.nn.Module): + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *args): + return self.model.forward(*args)[0] + + if args.machine_type == "300IPro": + from modeling_whisper_800IA2 import MindieWhisperForConditionalGeneration + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + is_use_ifa=True).to("cpu") + else: + from modeling_whisper_300IPro import MindieWhisperForConditionalGeneration + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, + is_use_ifa=True).to("cpu") + + decoder = Decoder(mindie_whisper) + print("Start compiling decoder.") + encoder_outputs = torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.hidden_size]) + decoder_input_ids = torch.randint(1, 4, (args.bs, 1)) + actual_seq_len = torch.ones((args.bs)) + all_past_key_value = [torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]), + torch.randn([args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + ] * compile_info.layer_nums + traced_args = [decoder_input_ids, encoder_outputs, actual_seq_len] + traced_args.extend(all_past_key_value) + traced_decoder = torch.jit.trace(decoder, traced_args) + # BSND + key_value_infos = [ + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16), + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size), + dtype=dtype.FLOAT16 + )] * compile_info.layer_nums + input_info = [mindietorch.Input(shape=(args.bs, 1), dtype=dtype.INT64), # input ids + mindietorch.Input(shape=(args.bs, compile_info.encoder_seq_len, compile_info.hidden_size)), + mindietorch.Input(shape=(args.bs,), dtype=dtype.INT64)] # actual sq len + + input_info.extend(key_value_infos) + float_size = 4 + voc_size = 51866 + buffer_size = math.ceil((args.bs * 1 * voc_size * float_size) / 1024 / 1024) + print(f"Set {buffer_size}/MB for output.") + compiled_decoder = mindietorch.compile(traced_decoder, + inputs=input_info, + precision_policy=mindietorch.PrecisionPolicy.FP16, + soc_version=args.soc_version, + default_buffer_size_vec=[buffer_size]) + save_file = os.path.join(args.save_path, f"{compile_info.prefix_name[2]}{args.bs}.ts") + torch.jit.save(compiled_decoder, save_file) + + print(f"Compile whisper_decoder success, saved in {save_file}.") + +def compile_scatter_update(args, compile_info): + class MindieScatter(torch.nn.Module): + def forward(self, past_key_value, indices, update_states): + out = torch.ops.aie.scatter_update(past_key_value, indices, update_states, 1) + return out + + bs = args.bs + self_past_key_value = torch.randn([bs, compile_info.max_decode_step, compile_info.head_num, compile_info.head_size]) + encoder_past_key_value = torch.randn([bs, compile_info.encoder_seq_len, compile_info.head_num, compile_info.head_size]) + indices = torch.tensor([0] * bs) + update_states = torch.randn([bs, 1, 20, 64]) + traced = torch.jit.trace(MindieScatter(), (self_past_key_value, indices, update_states)) + + self_attn_info = mindietorch.Input(shape=self_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + encoder_attn_info = mindietorch.Input(shape=encoder_past_key_value.shape, dtype=mindietorch.dtype.FLOAT16) + indices_info = mindietorch.Input(shape=indices.shape, dtype=mindietorch.dtype.INT64) + update_states_info = mindietorch.Input(shape=update_states.shape, dtype=mindietorch.dtype.FLOAT16) + + compile_self = mindietorch.compile(traced, inputs=[self_attn_info, indices_info, update_states_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[3]}{args.bs}.ts") + + compile_self = mindietorch.compile(traced, inputs=[encoder_attn_info, indices_info, encoder_attn_info], + soc_version=args.soc_version) + torch.jit.save(compile_self, f"{args.save_path}/{compile_info.prefix_name[4]}{args.bs}.ts") + print("compile scatter success.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + parser.add_argument('-machine_type', type=str, required=True, choices=["300IPro", "800IA2"]) + + args = parser.parse_args() + if args.machine_type == "300IPro": + from modeling_whisper_800IA2 import MindieWhisperForConditionalGeneration + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path).to("cpu") + else: + from modeling_whisper_300IPro import MindieWhisperForConditionalGeneration + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path).to("cpu") + device = f"npu:{args.device_id}" + print("Start compiling Mindie-Whisper, it will take some time, please wait.") + if not args.save_path: + raise ValueError("Please provide the directory where the compiled model saved.") + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + print(f"Directory {args.save_path} created.") + else: + print(f"Directory {args.save_path} already exists.") + mindietorch.set_device(args.device_id) + compile_scatter_update(args, CompileInfo) + compile_encoder(mindie_whisper, args, CompileInfo) + compile_prefill_decoder(mindie_whisper, args, CompileInfo) + compile_incre_decoder(args, CompileInfo) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/diarize.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/diarize.py new file mode 100644 index 0000000000000000000000000000000000000000..eb85b19d9558f3b848ca6f3308112925bf0e53b5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/diarize.py @@ -0,0 +1,72 @@ +from typing import Optional, Union +import numpy as np +import pandas as pd +from pyannote.audio import Pipeline +import torch + +SAMPLE_RATE = 16000 + + +class DiarizationPipeline: + def __init__( + self, + model_name="pyannote/speaker-diarization-3.1", + use_auth_token=None, + device: Optional[Union[str, torch.device]] = "cpu", + ): + if isinstance(device, str): + device = torch.device(device) + self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device) + + def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None): + audio_data = { + 'waveform': torch.from_numpy(audio[None, :]), + 'sample_rate': SAMPLE_RATE + } + segments = self.model(audio_data, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers) + diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) + diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start) + diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end) + return diarize_df + + +def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): + transcript_segments = transcript_result["segments"] + for seg in transcript_segments: + # assign speaker to segment (if any) + diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], seg['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start']) + # remove no hit, otherwise we look for closest (even negative intersection...) + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] + else: + dia_tmp = diarize_df + if len(dia_tmp) > 0: + # sum over speakers + speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + seg["speaker"] = speaker + + # assign speaker to words + if 'words' in seg: + for word in seg['words']: + if 'start' in word: + diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(diarize_df['start'], word['start']) + diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'], word['start']) + # remove no hit + if not fill_nearest: + dia_tmp = diarize_df[diarize_df['intersection'] > 0] + else: + dia_tmp = diarize_df + if len(dia_tmp) > 0: + # sum over speakers + speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0] + word["speaker"] = speaker + + return transcript_result + + +class Segment: + def __init__(self, start, end, speaker=None): + self.start = start + self.end = end + self.speaker = speaker diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/inference.patch b/MindIE/MindIE-Torch/built-in/audio/WhisperX/inference.patch new file mode 100644 index 0000000000000000000000000000000000000000..fde2cdc390c112ce50e0e169ec6cd37f595eadbd --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/inference.patch @@ -0,0 +1,33 @@ +--- inference.py 2024-10-28 11:12:54.877267615 +0800 ++++ whisperX_update/inference.py 2024-11-11 12:10:57.696130346 +0800 +@@ -94,6 +94,7 @@ + device: torch.device = None, + batch_size: int = 32, + use_auth_token: Union[Text, None] = None, ++ ts_model_path: str = None + ): + # ~~~~ model ~~~~~ + +@@ -115,6 +116,10 @@ + self.model.eval() + self.model.to(self.device) + ++ self.ts_model = None ++ if ts_model_path is not None: ++ self.ts_model = torch.jit.load(ts_model_path) ++ + specifications = self.model.specifications + + # ~~~~ sliding window ~~~~~ +@@ -214,7 +219,10 @@ + + with torch.inference_mode(): + try: +- outputs = self.model(chunks.to(self.device)) ++ if self.ts_model is not None: ++ outputs = self.ts_model(chunks.contiguous().to("npu")).to(self.device) ++ else: ++ outputs = self.model(chunks.to(self.device)) + except RuntimeError as exception: + if is_oom_error(exception): + raise MemoryError( diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/modeling_whisper_300IPro.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/modeling_whisper_300IPro.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b1a30b57350c89e11f38cf05c2b66463a69948 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/modeling_whisper_300IPro.py @@ -0,0 +1,773 @@ +import time +import argparse +import os +import copy +import math +import warnings +import librosa +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers import WhisperForConditionalGeneration +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, WhisperEncoderLayer, \ + WhisperModel, WhisperDecoderLayer, WhisperAttention +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria +from transformers.generation.utils import GenerationMixin +#import torch_npu +import mindietorch +from mindietorch._enums import dtype +from utils import CompileInfo + + +class MindiePFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config + ) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + # cross attn + if ( + is_cross_attention + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous().transpose(1, 2) + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states.transpose(1, 2), + value=value_states.transpose(1, 2), + num_head=self.num_heads, + scale=self.scaling, + layout="BNSD", + type="FA_HIGH_PERF" + ) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + +class MindieIFA(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config=None + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config + ) + + # BSND + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: torch.Tensor, + actual_seq_len: torch.Tensor, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self attn + assert past_key_value is not None, \ + "Current operation is incre_flash_attention, past_key_value is required." + bsz, tgt_len, _ = hidden_states.size() + assert tgt_len == 1, \ + "Current operation is incre_flash_attention, query's seq length should be equal to 1." + query_states = self.q_proj(hidden_states) + + key_states = self._shape(self.k_proj(hidden_states), -1, bsz).to(torch.float16) + values_states = self._shape(self.v_proj(hidden_states), -1, bsz).to(torch.float16) + + indices = actual_seq_len - 1 + past_key_cache, past_value_cache = past_key_value[0], past_key_value[1] + + past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) + past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, values_states, axis=1) + + past_key_value = (past_key_cache, past_value_cache) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + actual_seq_lengths=actual_seq_len, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + + +class MindieIFA2(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config=None + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config + ) + + # BSND + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert key_value_states is not None + assert past_key_value is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + # cross attn + assert past_key_value[0].shape[1] == key_value_states.shape[1] + key_states = past_key_value[0] + value_states = past_key_value[1] + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # B S N D: + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # import pdb; pdb.set_trace() + # + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + + def __init__(self, config, is_use_ifa): + super().__init__(config) + self.embed_dim = config.d_model + if is_use_ifa: + self.self_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieIFA2( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config) + else: + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + + self.encoder_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config) + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states, + layer_head_mask, + cross_attn_layer_head_mask, + past_key_value, + cross_attn_past_key_value, + output_attentions, + use_cache, + actual_seq_len, + encoder_attention_mask=None + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + self_attn_past_key_value = past_key_value if past_key_value is not None else None + + hidden_states, _, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + actual_seq_len=actual_seq_len + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states.reshape(-1, self.embed_dim) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MindieWhisperDecoder(WhisperDecoder): + + def __init__(self, config, is_use_ifa): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperDecoderLayer(config, is_use_ifa) for _ in range(config.decoder_layers)]) + self.config = config + + def forward( + self, + input_ids, + encoder_hidden_states, + past_key_values, + actual_seq_len, + use_cache=True, + attention_mask=None): + if input_ids is None: + raise ValueError("You have to specify either decoder_input_ids") + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + # B S N D todo + past_key_values_length = past_key_values[0].shape[2] if past_key_values is not None else 0 + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + hidden_states = inputs_embeds + positions + + past_key_value_cache = [] + for idx, decoder_layer in enumerate(self.layers): + past_key_value = (past_key_values[4 * idx], past_key_values[4 * idx + 1]) \ + if past_key_values is not None else None + cross_past_key_value = (past_key_values[4 * idx + 2], past_key_values[4 * idx + 3]) \ + if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + actual_seq_len=actual_seq_len, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=past_key_value, + cross_attn_past_key_value=cross_past_key_value, + output_attentions=None, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + past_key_value_cache.extend(layer_outputs[1]) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states, past_key_value_cache + + +class MindieWhisperEncoderLayer(WhisperEncoderLayer): + def __init__(self, config): + super().__init__(config) + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class MindieWhisperEncoder(WhisperEncoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + + +class MindieWhisperModel(WhisperModel): + + def __init__(self, config, is_use_ifa): + super().__init__(config) + self.decoder = MindieWhisperDecoder(config, is_use_ifa) + self.encoder = MindieWhisperEncoder(config) + + def forward( + self, + encoder_outputs, + decoder_input_ids, + past_key_values: Optional[torch.Tensor] = None, + actual_seq_len: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + return_dict: Optional[bool] = True, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None + ) -> List[torch.Tensor]: + if input_features is None and encoder_outputs is None: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `forward`.") + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=return_dict, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + actual_seq_len=actual_seq_len + ) + return decoder_outputs + + +class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): + + def __init__(self, config, is_use_ifa=False): + super().__init__(config) + self.model = MindieWhisperModel(config, is_use_ifa) + self.has_load = False + self.has_compile = False + self.mindie_encoder = None + self.mindie_decoder_prefill = None + self.mindie_decoder = None + self.save_path = None + self.batch_size = -1 + self.encoder_seq_len = 1500 + self.file_prefix_names = CompileInfo.prefix_name + self.past_key_value = [] + + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + self.batch_size = batch_size + + for _ in range(32): + self.past_key_value.append( torch.ones([self.batch_size, 448, 20, 64], dtype=torch.float16).to("npu")) + self.past_key_value.append(torch.ones([self.batch_size, 448, 20, 64], dtype=torch.float16).to("npu")) + self.past_key_value.append(torch.ones([self.batch_size, 1500, 20, 64], dtype=torch.float16).to("npu")) + self.past_key_value.append(torch.ones([self.batch_size, 1500, 20, 64], dtype=torch.float16).to("npu")) + if not self.has_load: + + self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[0]}{batch_size}.ts success.") + + self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[1]}{batch_size}.ts success.") + + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[2]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[3]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[3]}{batch_size}.ts success.") + + self.encoder_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[4]}{batch_size}.ts success.") + self.has_load = True + print("load compile models success.") + + + def forward(self, *args, **kwargs): + if len(args) not in (2, 131): + raise ValueError(f"The args length of forward can only be 2 or 131, but got {len(args)}") + decoder_input_ids = args[0] + encoder_outputs = args[1] + if len(args) == 131: + actual_seq_len = args[2] + past_key_values = args[3:] + else: + past_key_values = None + actual_seq_len = None + outputs = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + actual_seq_len=actual_seq_len, + use_cache=True, + return_dict=False, + input_features=None + ) + lm_logits = self.proj_out(outputs[0]) + return [lm_logits] + outputs[1] + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs): + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + print( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to("cpu") if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device="cpu") + + this_peer_finished = False # used by synced_gpus only + kv_actual_step = 1 + indices = torch.tensor([0] * input_ids.shape[0]).to("npu") + is_first_step = True + while True: + + model_inputs = self.prepare_inputs_for_generation(input_ids, is_first_step, **model_kwargs) + args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] + if is_first_step: + outputs = self.mindie_decoder_prefill(*args) + for idx in range(32): + self.self_attn_scatter(self.past_key_value[4*idx], indices, outputs[1 + 4*idx]) + self.self_attn_scatter(self.past_key_value[4*idx + 1], indices, outputs[1 + 4*idx + 1]) + self.encoder_attn_scatter(self.past_key_value[4*idx + 2], indices, outputs[1 + 4*idx + 2]) + self.encoder_attn_scatter(self.past_key_value[4*idx + 3], indices, outputs[1 + 4*idx + 3]) + is_first_step = False + else: + kv_actual_step += 1 + args.append(torch.tensor([kv_actual_step] * input_ids.shape[0]).to("npu")) + args.extend(self.past_key_value) + outputs = self.mindie_decoder(*args) + if isinstance(outputs, list): + next_token_logits = outputs[0].to("cpu")[:, -1, :] + else: + next_token_logits = outputs.to("cpu")[:, -1, :] + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids.to("cpu"), next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids.to("cpu"), next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + return input_ids + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, + return_timestamps=None, + **kwargs, + ): + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + num_segment_frames = 3000 + assert input_features.shape[-1] == num_segment_frames, "Only support 30s speech." + encoder_outputs = self.mindie_encoder(input_features.to("npu"))[0] + kwargs["encoder_outputs"] = encoder_outputs + outputs = GenerationMixin.generate( + self, + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs + ) + return outputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_inputs, + ): + return model_inputs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.size()[:-1] + input_ids = torch.ones(shape, dtype=torch.long, device="cpu") * -100 + return input_ids.to("npu") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device="cpu") * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.to(self.device) + + return decoder_input_ids.to("npu"), model_kwargs + + def _check_save_path(self, save_path, batch_size): + file_list = os.listdir(save_path) + expected_files = [file + f"{batch_size}.ts" for file in self.file_prefix_names] + for file in expected_files: + if file not in file_list: + raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + if_first_step, + encoder_outputs=None, + **kwargs + ): + if not if_first_step: + decoder_input_ids_shape = decoder_input_ids.shape + remove_prefix_length = decoder_input_ids_shape[1] - 1 + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + return { + "encoder_outputs": encoder_outputs, + "decoder_input_ids": decoder_input_ids + } +if __name__=="__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=8, help="please provide batch_size, default:8.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + args = parser.parse_args() + mindietorch.set_device(args.device_id) + mindie_whipser = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, is_use_ifa=False).to("cpu") + mindie_whipser.compile(args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/modeling_whisper_800IA2.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/modeling_whisper_800IA2.py new file mode 100644 index 0000000000000000000000000000000000000000..d23df369e6a467e15067a90162571a1d54e61005 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/modeling_whisper_800IA2.py @@ -0,0 +1,824 @@ +import time +import argparse +import os +import copy +import math +import warnings +import librosa +from typing import Optional, Tuple, Union, List, Dict +from collections import OrderedDict + +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers import WhisperForConditionalGeneration +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, WhisperEncoderLayer, \ + WhisperModel, WhisperDecoderLayer, WhisperAttention +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria +from transformers.generation.utils import GenerationMixin +import mindietorch +from mindietorch._enums import dtype +from utils import CompileInfo + + +class MindiePFA(WhisperAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: bool = None): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # cross_attn + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + if ( + is_cross_attention + and past_key_value is not None + # past_key_value layout is BSND + and past_key_value[0].shape[1] == key_value_states.shape[1] + ): + # reuse k, v + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + # self attn + elif past_key_value is not None: + # reuse k v + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=1) + value_states = torch.cat([past_key_value[1], value_states], dim=1) + else: + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # B S N D: + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous().transpose(1, 2) + + # B S N D + attn_output = torch.ops.aie.flash_attention( + query=query_states, + key=key_states.transpose(1, 2), + value=value_states.transpose(1, 2), + num_head=self.num_heads, + scale=self.scaling, + layout="BNSD", + type="PFA") + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + + +class MindieIFA(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config=None): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config) + + # BSND + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: torch.Tensor, + actual_seq_len: torch.Tensor, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self attn + assert past_key_value is not None + bsz, tgt_len, _ = hidden_states.size() + assert tgt_len == 1 + query_states = self.q_proj(hidden_states) + + key_states = self._shape(self.k_proj(hidden_states), -1, bsz).to(torch.float16) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz).to(torch.float16) + + indices = actual_seq_len - 1 + past_key_cache, past_value_cache = past_key_value[0], past_key_value[1] + + past_key_cache = torch.ops.aie.scatter_update(past_key_cache, indices, key_states, axis=1) + past_value_cache = torch.ops.aie.scatter_update(past_value_cache, indices, value_states, axis=1) + + past_key_value = (past_key_cache, past_value_cache) + + # B S N D + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=past_key_cache, + value=past_value_cache, + actual_seq_lengths=actual_seq_len, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + + +class MindieIFA2(WhisperAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config=None + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config + ) + + # BSND + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert key_value_states is not None + assert past_key_value is not None + bsz, tgt_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + # cross attn + assert past_key_value[0].shape[1] == key_value_states.shape[1] + key_states = past_key_value[0] + value_states = past_key_value[1] + + if self.is_decoder: + past_key_value = (key_states.to(torch.float16), value_states.to(torch.float16)) + + # B S N D: + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous() + # import pdb; pdb.set_trace() + # + # B S N D + attn_output = torch.ops.aie.incre_flash_attention( + query=query_states, + key=key_states, + value=value_states, + num_head=self.num_heads, + scale=self.scaling, + layout="BSND") + + attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output, _, past_key_value + + +class MindieWhisperDecoderLayer(WhisperDecoderLayer): + + def __init__(self, config, is_use_ifa): + super().__init__(config) + self.embed_dim = config.d_model + if is_use_ifa: + self.self_attn = MindieIFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindieIFA2( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config) + else: + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + is_decoder=True + ) + self.encoder_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + is_decoder=True, + config=config + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states, + layer_head_mask, + cross_attn_layer_head_mask, + past_key_value, + cross_attn_past_key_value, + output_attentions, + use_cache, + actual_seq_len, + encoder_attention_mask=None + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + self_attn_past_key_value = past_key_value if past_key_value is not None else None + hidden_states, _, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + actual_seq_len=actual_seq_len + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = hidden_states.reshape(-1, self.embed_dim) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = hidden_states.reshape(-1, 1, self.embed_dim) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MindieWhisperDecoder(WhisperDecoder): + + def __init__(self, config, is_use_ifa): + super().__init__(config) + self.layers = nn.ModuleList( + [MindieWhisperDecoderLayer(config, is_use_ifa) for _ in range(config.decoder_layers)]) + self.config = config + + def forward( + self, + input_ids, + encoder_hidden_states, + past_key_values, + actual_seq_len, + is_use_ifa=False, + use_cache=True, + attention_mask=None): + if input_ids is None: + raise ValueError("You have to specify either decoder_input_ids") + inputs_embeds = self.embed_tokens(input_ids) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + # B S N D todo + past_key_values_length = past_key_values[0].shape[2] if past_key_values is not None else 0 + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + hidden_states = inputs_embeds + positions + + past_key_value_cache = [] + for idx, decoder_layer in enumerate(self.layers): + past_key_value = (past_key_values[4 * idx], past_key_values[4 * idx + 1]) \ + if past_key_values is not None else None + cross_past_key_value = (past_key_values[4 * idx + 2], past_key_values[4 * idx + 3]) \ + if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + actual_seq_len=actual_seq_len, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=past_key_value, + cross_attn_past_key_value=cross_past_key_value, + output_attentions=None, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + past_key_value_cache.extend(layer_outputs[1]) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states, past_key_value_cache + + +class MindieWhisperEncoderLayer(WhisperEncoderLayer): + def __init__(self, config): + super().__init__(config) + self.self_attn = MindiePFA( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class MindieWhisperEncoder(WhisperEncoder): + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([MindieWhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + + +class MindieWhisperModel(WhisperModel): + + def __init__(self, config, is_use_ifa): + super().__init__(config) + self.decoder = MindieWhisperDecoder(config, is_use_ifa) + self.encoder = MindieWhisperEncoder(config) + + def forward( + self, + encoder_outputs, + decoder_input_ids, + past_key_values, + actual_seq_len, + use_cache: Optional[bool] = False, + return_dict: Optional[bool] = True, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None + ) -> List[torch.Tensor]: + if input_features is None and encoder_outputs is None: + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `forward`.") + + if encoder_outputs is None: + input_features = self._mask_input_features(input_features, attention_mask=attention_mask) + + encoder_outputs = self.encoder( + input_features, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=return_dict, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_outputs, + past_key_values=past_key_values, + use_cache=use_cache, + actual_seq_len=actual_seq_len + ) + return decoder_outputs + + +class MindieWhisperForConditionalGeneration(WhisperForConditionalGeneration, GenerationMixin): + + def __init__(self, config, is_use_ifa=False): + super().__init__(config) + self.model = MindieWhisperModel(config, is_use_ifa) + self.has_load = False + self.has_compile = False + self.mindie_encoder = None + self.mindie_decoder_prefill = None + self.mindie_decoder = None + self.save_path = None + self.batch_size = -1 + self.encoder_seq_len = 1500 + self.init_encoder_seq_len = self.encoder_seq_len - 1 + self.file_prefix_names = CompileInfo.prefix_name + self.generate_cost = 0 + self.encoder_cost = 0 + self.decoder_cost = 0 + self.prefill_decoder_cost = 0 + self.decoder_cnt = 0 + self.inner_cpu = 0 + self.generate_cnt = 0 + self.past_key_value = [] + def load_mindie_models(self, save_path, batch_size): + if not (save_path and batch_size): + raise ValueError(f"Please provide batch_size and the directory where the compiled models saved,\ + but found save_path is {save_path}, batch_size is{batch_size}.") + self._check_save_path(save_path, batch_size) + + for _ in range(32): + self.past_key_value.append( + torch.ones([batch_size, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([batch_size, CompileInfo.max_decode_step, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([batch_size, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") + ) + self.past_key_value.append( + torch.ones([batch_size, CompileInfo.encoder_seq_len, CompileInfo.head_num, CompileInfo.head_size], + dtype=torch.float16).to("npu") + ) + print("init past key value cache success.") + + if not self.has_load: + self.mindie_encoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[0]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[0]}{batch_size}.ts success.") + + self.mindie_decoder_prefill = torch.jit.load(f"{save_path}/{self.file_prefix_names[1]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[1]}{batch_size}.ts success.") + + self.mindie_decoder = torch.jit.load(f"{save_path}/{self.file_prefix_names[2]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[2]}{batch_size}.ts success.") + + self.self_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[3]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[3]}{batch_size}.ts success.") + + self.encoder_attn_scatter = torch.jit.load(f"{save_path}/{self.file_prefix_names[4]}{batch_size}.ts") + print(f"load {save_path}/{self.file_prefix_names[4]}{batch_size}.ts success.") + self.has_load = True + else: + print("Mindie whisper has already load.") + def forward(self, *args, **kwargs): + if len(args) not in (2, 131): + raise ValueError(f"The args length of forward can only be 2 or 131, but got {len(args)}") + decoder_input_ids = args[0] + encoder_outputs = args[1] + if len(args) == 131: + actual_seq_len = args[2] + past_key_values = args[3:] + else: + past_key_values = None + actual_seq_len = None + + outputs = self.model( + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + actual_seq_len=actual_seq_len, + use_cache=True, + return_dict=False, + input_features=None + ) + lm_logits = self.proj_out(outputs[0]) + return [lm_logits] + outputs[1] + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + **model_kwargs): + to = time.time() + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + print( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead." + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to("cpu") if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device="cpu") + + this_peer_finished = False # used by synced_gpus only + + is_first_step = True + decoder_step = 0 + bs = input_ids.shape[0] + indices = torch.tensor([0] * bs, dtype=torch.int64).to("npu") + t1 = time.time() + kv_actual_step = 1 + stream = mindietorch.current_stream() + while True: + + model_inputs = self.prepare_inputs_for_generation(input_ids, is_first_step, **model_kwargs) + args = [model_inputs["decoder_input_ids"].contiguous().to("npu"), model_inputs["encoder_outputs"]] + if is_first_step: + outputs = self.mindie_decoder_prefill(*args) + stream.synchronize() + for idx in range(32): + self.self_attn_scatter(self.past_key_value[4 * idx], indices, outputs[1 + 4 * idx]) + self.self_attn_scatter(self.past_key_value[4 * idx + 1], indices, outputs[1 + 4 * idx + 1]) + self.encoder_attn_scatter(self.past_key_value[4 * idx + 2], indices, outputs[1 + 4 * idx + 2]) + self.encoder_attn_scatter(self.past_key_value[4 * idx + 3], indices, outputs[1 + 4 * idx + 3]) + stream.synchronize() + is_first_step = False + else: + kv_actual_step += 1 + args.append(torch.tensor([kv_actual_step] * bs).to("npu")) + args.extend(self.past_key_value) + outputs = self.mindie_decoder(*args) + stream.synchronize() + decoder_step += 1 + self.decoder_cnt += 1 + if synced_gpus and this_peer_finished: + continue + if isinstance(outputs, list): + next_token_logits = outputs[0].to("cpu")[:, -1, :] + else: + next_token_logits = outputs.to("cpu")[:, -1, :] + stream.synchronize() + # pre-process distribution + next_tokens_scores = logits_processor(input_ids.to("cpu"), next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids.to("cpu"), next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + return input_ids + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, + return_timestamps=None, + return_segments=False, + attention_mask=None, + time_precision=0.02, + return_token_timestamps=None, + **kwargs, + ): + input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + num_segment_frames = 3000 + if input_features is not None: + total_input_frames = input_features.shape[-1] + + is_shortform = total_input_frames <= num_segment_frames + assert is_shortform is True + if return_timestamps is True: + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError("you trying ....") + generation_config.return_timestamps = return_timestamps + elif not is_shortform: + if return_timestamps is False: + raise ValueError("yo has passed more than 3000") + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError("yo has passed more than 3000") + generation_config.return_timestamps = True + else: + generation_config.return_timestamps = False + + if is_shortform: + encoder_outputs = self.mindie_encoder(input_features.to("npu"))[0] + kwargs["encoder_outputs"] = encoder_outputs + outputs = GenerationMixin.generate( + self, + input_features, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_dict_in_generate=return_dict_in_generate, + **kwargs + ).to("cpu") + return outputs + + def _update_model_kwargs_for_generation( + self, + outputs, + model_inputs, + ): + return model_inputs + + def _maybe_initialize_input_ids_for_generation( + self, + inputs: Optional[torch.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.LongTensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + if self.config.is_encoder_decoder and encoder_outputs is not None: + # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding + shape = encoder_outputs.size()[:-1] + input_ids = torch.ones(shape, dtype=torch.long, device="cpu") * -100 + return input_ids.to("npu") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, torch.Tensor): + batch_size = value.shape[0] + break + return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id + + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + model_input_name: str, + model_kwargs: Dict[str, torch.Tensor], + decoder_start_token_id: int = None, + bos_token_id: int = None, + device: torch.device = None, + ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif "input_ids" in model_kwargs and model_input_name != "input_ids": + decoder_input_ids = model_kwargs.pop("input_ids") + else: + decoder_input_ids = None + + # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that. + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + if device is None: + device = self.device + decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device="cpu") * decoder_start_token_id + + # no user input -> use decoder_start_token_id as decoder_input_ids + if decoder_input_ids is None: + decoder_input_ids = decoder_input_ids_start + # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token + elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): + pass + # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust + # decoder_attention_mask if provided) + elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): + decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + decoder_attention_mask = torch.cat( + (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + dim=-1, + ) + model_kwargs["decoder_attention_mask"] = decoder_attention_mask.to(self.device) + + return decoder_input_ids.to("npu"), model_kwargs + + def _check_save_path(self, save_path, batch_size): + file_list = os.listdir(save_path) + expected_files = [file + f"{batch_size}.ts" for file in self.file_prefix_names] + for file in expected_files: + if file not in file_list: + raise ValueError(f"Expected file name is {file}, but can't be found in path: {save_path}") + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + is_first_step, + past_key_values=None, + encoder_outputs=None, + **kwargs + ): + if not is_first_step: + decoder_input_ids_shape = decoder_input_ids.shape + remove_prefix_length = decoder_input_ids_shape[1] - 1 + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-bs', type=int, default=16, help="please provide batch_size, default:16.") + parser.add_argument('-soc_version', type=str, required=True, + help="please provide soc_version.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + args = parser.parse_args() + mindietorch.set_device(args.device_id) + mindie_whisper = MindieWhisperForConditionalGeneration.from_pretrained(args.model_path, is_use_ifa=False).to("cpu") + mindie_whisper.compile(args) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/patch_apply.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/patch_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ad680e7da43e94a742282e2b59c31c4aee985e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/patch_apply.py @@ -0,0 +1,18 @@ +import os +import sys +import pyannote.audio + + +def main(): + pyannote_audio_path = pyannote.audio.__path__ + pyannote_audio_version = pyannote.audio.__version__ + + if pyannote_audio_version != '3.1.1': + sys.exit("Expectation pyannote.audio==3.1.1") + os.system(f'patch -p0 {pyannote_audio_path[0]}/models/segmentation/PyanNet.py PyanNet.patch') + os.system(f'patch -p0 {pyannote_audio_path[0]}/models/blocks/sincnet.py sincnet.patch') + os.system(f'patch -p0 {pyannote_audio_path[0]}/pipelines/voice_activity_detection.py voice_activity_detection.patch') + os.system(f'patch -p0 {pyannote_audio_path[0]}/core/inference.py inference.patch') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/pipeline.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..bb044c01273dafb4f55f6204ca1ad520f00a1253 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/pipeline.py @@ -0,0 +1,369 @@ +import os +import subprocess +from functools import lru_cache +from typing import Union, Optional +import time +import argparse + +import torch +import torch.nn.functional as F +import mindietorch +import librosa +import tokenizers +import numpy as np +from transformers import Pipeline +from transformers.pipelines.pt_utils import PipelineIterator +from transformers import WhisperProcessor + +from vad import load_vad_model, merge_chunks + + +def exact_div(x1, x2): + if x1 % x2 != 0: + raise ValueError("x1 is not divisible by x2") + return x1 // x2 + + +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + try: + # Launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI to be installed. + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", + "0", + "-i", + file, + "-f", + "s16le", + "-ac", + "1", + "-acodec", + "pcm_s16le", + "-ar", + str(sr), + "-", + ] + out = subprocess.run(cmd, capture_output=True, check=True).stdout + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int) -> torch.Tensor: + mel_filters_path = os.path.join(os.path.dirname(__file__), "mel_filters.npz") + if not os.path.exists(mel_filters_path): + np.savez_compressed( + "mel_filters.npz", + mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), + ) + if n_mels not in [80, 128]: + raise ValueError(f"Unsupported n_mels: {n_mels}") + with np.load( + os.path.join(os.path.dirname(__file__), "mel_filters.npz") + ) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters("cpu", n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if log_spec.shape[-1] != 3000: + raise ValueError("log_spec shape is not as expected") + return log_spec + + +from tokenizer import Tokenizer + + +class MindiePipeline(Pipeline): + def __init__(self, whisper_model_path, vad_model_path, device, save_path, batch_size, **kwargs): + self.model = MindieWhisperForConditionalGeneration.from_pretrained(whisper_model_path).to("cpu") + self.device = torch.device("cpu") + if not isinstance(self.model, MindieWhisperForConditionalGeneration): + raise ValueError(f"Please provide MindieWhisperForConditionalGeneration, found {type(self.model)}") + + if not (save_path and batch_size): + raise ValueError(f"Please provide compiled model save path and batch_size.") + + self.model.load_mindie_models(save_path, batch_size) + print("start load vad") + + tokenizer_path = os.path.join(whisper_model_path, "tokenizer.json") + self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_path) + self.tokenizer = Tokenizer(self.hf_tokenizer, multilingual=True, task="transcribe", language="zh") + + default_vad_options = { + "vad_onset": 0.500, + "vad_offset": 0.363 + } + + vad_ts_model_path = os.path.join(save_path, "mindie_vad.ts") + if not os.path.exists(vad_ts_model_path): + raise ValueError(f"Expect file name is {vad_ts_model_path}, but can`t be found in path: {save_path}") + + self.vad_model = load_vad_model(vad_model_path, torch.device("cpu"), vad_ts_model_path, **default_vad_options) + print("load vad success") + + self._batch_size = batch_size + self._num_workers = 1 + self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) + self.call_count = 0 + self.framework = "pt" + + super(Pipeline, self).__init__() + + self._vad_params = { + "vad_onset": 0.500, + "vad_offset": 0.363 + } + self.vad_cost = 0 + + + def _sanitize_parameters(self, **kwargs): + return {}, {}, {} + + def preprocess(self, audio): + model_n_mels = 128 + audio = audio['inputs'] + features = log_mel_spectrogram( + audio, + n_mels=model_n_mels if model_n_mels is not None else 80, + padding=N_SAMPLES - len(audio), + device="cpu" + ) + return {'inputs': features} + + def _forward(self, model_inputs, **generate_kwargs): + generate_kwargs["input_features"] = model_inputs["inputs"] + print("call forward") + tokens = self.model.generate(attention_mask=None, **generate_kwargs) + + tokens_batch = [x for x in tokens] + each_token_num = [len(x) for x in tokens_batch] + + def decode_batch(tokens) -> str: + res = [] + for tk in tokens: + res.append([token for token in tk if token < self.tokenizer.eot]) + return self.tokenizer.tokenizer.decode_batch(res) + + text = decode_batch(tokens_batch) + return {'text': text} + + def forward(self, model_inputs, **forward_params): + model_outputs = self._forward(model_inputs, **forward_params) + return model_outputs + + def postprocess(self, model_outputs): + return model_outputs + + def get_iterator( + self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params + ): + dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) + if "TOKENIZERS_PARALLELISM" not in os.environ: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + def stack(items): + return {'inputs': torch.stack([x['inputs'] for x in items])} + + dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack) + model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) + final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) + + return final_iterator + + def pad_segment(self, vad_segments, batch_size): + need_pad_num = batch_size - len(vad_segments) % batch_size + paded_segments = vad_segments + for _ in range(need_pad_num): + paded_segments.append(vad_segments[-1]) + return paded_segments + + def transcribe( + self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, chunk_size=30 + ): + segments = [] + + def data(total_audio, segments): + for seg in segments: + f1 = int(seg['start'] * SAMPLE_RATE) + f2 = int(seg['end'] * SAMPLE_RATE) + yield {'inputs': total_audio[f1:f2]} + + vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + vad_segments = merge_chunks( + vad_segments, + chunk_size, + onset=self._vad_params.get("vad_onset", 0.500), + offset=self._vad_params.get("vad_offset", 0.363), + ) + total_segments = len(vad_segments) + + if total_segments % batch_size != 0: + vad_segments = self.pad_segment(vad_segments, batch_size) + + for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): + text = out['text'] + if batch_size in [0, 1, None]: + text = text[0] + segments.append( + { + "text": text, + "start": round(vad_segments[idx]['start'], 3), + "end": round(vad_segments[idx]['end'], 3) + } + ) + + return segments + + +if __name__ == "__main__": + print("start here") + + parser = argparse.ArgumentParser() + parser.add_argument('-whisper_model_path', type=str, required=True, help="please provide model path.") + parser.add_argument('-vad_model_path', type=str, required=True) + parser.add_argument('-machine_type', type=str, required=True, choices=["300IPro", "800IA2"]) + parser.add_argument('-audio_path', type=str, required=True) + parser.add_argument('-bs', type=int, default=16, help="please provide batch_size, default:8.") + parser.add_argument('-save_path', type=str, default="compiled_models", help="compiled models save dir.") + parser.add_argument('-device_id', type=int, default=0) + + args = parser.parse_args() + if args.machine_type == "800IA2": + from modeling_whisper_800IA2 import MindieWhisperForConditionalGeneration + elif args.machine_type == "300IPro": + from modeling_whisper_300IPro import MindieWhisperForConditionalGeneration + else: + raise ValueError("machine type is not supported.") + inference_device = f"npu:{args.device_id}" + + mindie_pipe = MindiePipeline(args.whisper_model_path, args.vad_model_path, inference_device, args.save_path, args.bs) + + mindietorch.set_device(args.device_id) + + audio_path = args.audio_path + inp = [] + + infer_audio = load_audio(audio_path) + print(f"load audio success.") + + y, audio_sr = librosa.load(audio_path) + duration_seconds = librosa.get_duration(y=y, sr=audio_sr) + print(f"duration_seconds {duration_seconds}") + + inp.extend(infer_audio) + inp = np.array(inp) + + predicted_ids = mindie_pipe.transcribe(inp, batch_size=args.bs) + + t0 = time.time() + predicted_ids = mindie_pipe.transcribe(inp, batch_size=args.bs) + print(f"trascription {predicted_ids}") + t1 = time.time() + + print(f"speech_duration/s: {duration_seconds}") + print(f"E2E cost {t1 - t0}") + print(f"perfomence {duration_seconds / (t1 - t0)}") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/readme.md b/MindIE/MindIE-Torch/built-in/audio/WhisperX/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..c0d2f120af7ab6d287aa3cbc7a2d0f38a2e37f33 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/readme.md @@ -0,0 +1,146 @@ +# WhisperX推理指导 + +- [WhisperX推理指导](#whisperx推理指导) +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型编译](#模型编译) + - [模型推理](#模型推理) + +# 概述 + +该工程使用mindietorch部署WhisperX模型 + + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |-----------| ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.B023 | - | + | MindIE | 1.0.B030 | - | + +# 快速上手 +## 获取源码 + +1. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + + +2. Whisper large V3模型权重下载路径: + ```bash + https://huggingface.co/openai/whisper-large-v3/tree/main + ``` + 将权重文件存放至当前目录下的model_path文件夹,请先创建改文件夹。 + + +3. WhisperX中VAD模型权重下载路径: + ```bash + https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin + ``` + 并修改文件名为`whisperx-vad-segmentation.bin` + + +4. 安装依赖 + ``` + pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install nltk + pip3 install librosa + pip3 install transformers==4.36.0 + pip3 install numpy==1.24.0 + pip3 install ml-dtypes + pip3 install cloudpickle + pip3 install pyannote.audio==3.1.1 + ``` + 同时需要保证环境安装了libsndfile1, ffmpeg库 + +## 模型编译 +1.1 300IPro环境执行如下命令: +``` +python3 compile_300IPro_whisper.py \ +-model_path ./model_path \ +-soc_version Ascend310P3 \ +``` + 参数说明: + - -model_path:预训练模型路径,必选。 + - -bs:batch_size, 默认值为32, 可选。 + - -save_path: 编译好的模型的保存文件,可选。默认值为"./compiled_models"。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,必选。 + 约束说明: + 1. 当前暂不支持动态batch,batch_size改变后,需要重新编图。 + +1.2 800IA2环境执行如下命令 +``` +python3 compile_800IA2_whisper.py \ +-model_path ./model_path \ +-soc_version Ascend910B4 +``` + 参数说明: + - -model_path:预训练模型路径,必选。 + - -bs:batch_size, 默认值为16, 可选。 + - -save_path: 编译好的模型的保存文件,可选。默认值为"./compiled_models"。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,必选。 + 约束说明: + 1. 当前暂不支持动态batch,batch_size改变后,需要重新编图。 + +2.0 应用VAD模型补丁 +在编译VAD模型前需要先打补丁,使用如下命令 +``` +python3 patch_apply.py +python3 remove_script.py +``` + +2. VAD模型编译 + ``` + python3 compile_vad.py \ + -vad_model_path /vad_model_path \ + -soc_version soc_version + ``` + + 参数说明: + - -vad_model_path:VAD预训练模型路径,必选。 + - -save_path: 编译好的模型的保存文件,可选,默认值"./compiled_models"。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -soc_version: 芯片类型,必选。 + + 注:VAD模型编译的保存路径需要和Whisper-large-V3模型编译保存路径一致 + + +## 模型推理 +1. 设置mindie内存池上限为32,执行如下命令设置环境变量。内存池设置过小,内存重复申请和释放会影响性能。 + ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=32 + ``` + +2. 模型推理 + ``` + python3 pipeline.py \ + -whisper_model_path /whisper_model_path \ + -vad_model_path /vad_model_path \ + -machine_type machine_type \ + -audio_path /audio_path + ``` + + 参数说明: + - -model_path:预训练模型路径,必选。 + - -bs:batch_size, 默认值为16, 可选。针对300IPro需要设置成32。 + - -save_path: 编译好的模型的保存文件,可选,默认值"./compiled_models"。 + - -device_id: 选在模型运行的卡编号,默认值0,可选。 + - -machine_type: 机器类型,必选。支持800IA2和300IPro + + diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/remove_script.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/remove_script.py new file mode 100644 index 0000000000000000000000000000000000000000..559aabdeee3b2e9f8216f4587eefd0bcb629e1db --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/remove_script.py @@ -0,0 +1,26 @@ +import os +import stat + +import asteroid_filterbanks + + +def add_comment_before_function(file_path, target_function): + with open(file_path, 'r') as file: + lines = file.readlines() + + for i, line in enumerate(lines): + if i > 0 and target_function in line: + lines[i - 1] = '#' + lines[i - 1] + + flags = os.O_WRONLY | os.O_TRUNC + mode = stat.S_IWUSR | stat.S_IRUSR + with os.fdopen(os.open(file_path, flags, mode), 'w') as file: + for line in lines: + file.write(line) + + +if __name__ == "__main__": + asteroid_filterbanks_path = asteroid_filterbanks.__path__ + enc_dec_path = f'{asteroid_filterbanks_path[0]}/enc_dec.py' + function_str = 'def multishape_conv1d(' + add_comment_before_function(enc_dec_path, function_str) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/sincnet.patch b/MindIE/MindIE-Torch/built-in/audio/WhisperX/sincnet.patch new file mode 100644 index 0000000000000000000000000000000000000000..5b4d1664cbd1eed84b1a09505d3928aa4fed70ee --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/sincnet.patch @@ -0,0 +1,37 @@ +--- sincnet.py 2024-10-28 11:12:11.613267085 +0800 ++++ whisperX_update/sincnet.py 2024-10-28 10:55:46.000000000 +0800 +@@ -80,16 +80,26 @@ + + outputs = self.wav_norm1d(waveforms) + +- for c, (conv1d, pool1d, norm1d) in enumerate( +- zip(self.conv1d, self.pool1d, self.norm1d) +- ): ++ # for c, (conv1d, pool1d, norm1d) in enumerate( ++ # zip(self.conv1d, self.pool1d, self.norm1d) ++ # ): + +- outputs = conv1d(outputs) ++ # outputs = conv1d(outputs) + +- # https://github.com/mravanelli/SincNet/issues/4 +- if c == 0: +- outputs = torch.abs(outputs) ++ # # https://github.com/mravanelli/SincNet/issues/4 ++ # if c == 0: ++ # outputs = torch.abs(outputs) + +- outputs = F.leaky_relu(norm1d(pool1d(outputs))) ++ # outputs = F.leaky_relu(norm1d(pool1d(outputs))) ++ ++ outputs = self.conv1d[0](outputs) ++ outputs = torch.abs(outputs) ++ outputs = F.leaky_relu(self.norm1d[0](self.pool1d[0](outputs))) ++ ++ outputs = self.conv1d[1](outputs) ++ outputs = F.leaky_relu(self.norm1d[1](self.pool1d[1](outputs))) ++ ++ outputs = self.conv1d[2](outputs) ++ outputs = F.leaky_relu(self.norm1d[2](self.pool1d[2](outputs))) + + return outputs diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/tokenizer.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d6cc08db3887e4127f744df39de1bf94e5bdd9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/tokenizer.py @@ -0,0 +1,317 @@ +import string + +from functools import cached_property +from typing import List, Optional, Tuple + +import tokenizers + + +class Tokenizer: + """Simple wrapper around a tokenizers.Tokenizer.""" + + def __init__( + self, + tokenizer: tokenizers.Tokenizer, + multilingual: bool, + task: Optional[str] = None, + language: Optional[str] = None, + ): + self.tokenizer = tokenizer + + if multilingual: + if task not in _TASKS: + raise ValueError( + "'%s' is not a valid task (accepted tasks: %s)" + % (task, ", ".join(_TASKS)) + ) + + if language not in _LANGUAGE_CODES: + raise ValueError( + "'%s' is not a valid language code (accepted language codes: %s)" + % (language, ", ".join(_LANGUAGE_CODES)) + ) + + self.task = self.tokenizer.token_to_id("<|%s|>" % task) + self.language = self.tokenizer.token_to_id("<|%s|>" % language) + self.language_code = language + else: + self.task = None + self.language = None + self.language_code = "en" + + @cached_property + def transcribe(self) -> int: + return self.tokenizer.token_to_id("<|transcribe|>") + + @cached_property + def translate(self) -> int: + return self.tokenizer.token_to_id("<|translate|>") + + @cached_property + def sot(self) -> int: + return self.tokenizer.token_to_id("<|startoftranscript|>") + + @cached_property + def sot_lm(self) -> int: + return self.tokenizer.token_to_id("<|startoflm|>") + + @cached_property + def sot_prev(self) -> int: + return self.tokenizer.token_to_id("<|startofprev|>") + + @cached_property + def eot(self) -> int: + return self.tokenizer.token_to_id("<|endoftext|>") + + @cached_property + def no_timestamps(self) -> int: + return self.tokenizer.token_to_id("<|notimestamps|>") + + @property + def timestamp_begin(self) -> int: + return self.no_timestamps + 1 + + @property + def sot_sequence(self) -> List[int]: + sequence = [self.sot] + + if self.language is not None: + sequence.append(self.language) + + if self.task is not None: + sequence.append(self.task) + + return sequence + + def encode(self, text: str) -> List[int]: + return self.tokenizer.encode(text, add_special_tokens=False).ids + + def decode(self, tokens: List[int]) -> str: + text_tokens = [token for token in tokens if token < self.eot] + return self.tokenizer.decode(text_tokens) + + def decode_with_timestamps(self, tokens: List[int]) -> str: + outputs = [[]] + + for token in tokens: + if token >= self.timestamp_begin: + timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" + outputs.append(timestamp) + outputs.append([]) + else: + outputs[-1].append(token) + + return "".join( + [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] + ) + + @cached_property + def non_speech_tokens(self) -> Tuple[int]: + """ + Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech + annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. + + - ♪♪♪ + - ( SPEAKING FOREIGN LANGUAGE ) + - [DAVID] Hey there, + + keeping basic punctuations like commas, periods, question marks, exclamation points, etc. + """ + symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') + symbols += ( + "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() + ) + + # symbols that may be a single token or multiple tokens depending on the tokenizer. + # In case they're multiple tokens, suppress the first token, which is safe because: + # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress + # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. + miscellaneous = set("♩♪♫♬♭♮♯") + if not all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous): + raise ValueError("The miscellaneous set contains invalid characters") + + # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + result = {self.encode(" -")[0], self.encode(" '")[0]} + for symbol in symbols + list(miscellaneous): + for tokens in [ + self.encode(symbol), + self.encode(" " + symbol), + ]: + if len(tokens) == 1 or symbol in miscellaneous: + result.add(tokens[0]) + + return tuple(sorted(result)) + + def split_to_word_tokens( + self, tokens: List[int] + ) -> Tuple[List[str], List[List[int]]]: + if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + return self.split_tokens_on_unicode(tokens) + + return self.split_tokens_on_spaces(tokens) + + def split_tokens_on_unicode( + self, tokens: List[int] + ) -> Tuple[List[str], List[List[int]]]: + decoded_full = self.decode_with_timestamps(tokens) + replacement_char = "\ufffd" + + words = [] + word_tokens = [] + current_tokens = [] + unicode_offset = 0 + + for token in tokens: + current_tokens.append(token) + decoded = self.decode_with_timestamps(current_tokens) + + try: + replacement_char_index = decoded.index(replacement_char) + replacement_char_index += unicode_offset + except ValueError: + replacement_char_index = None + + if replacement_char_index is None or ( + replacement_char_index < len(decoded_full) + and decoded_full[replacement_char_index] == replacement_char + ): + words.append(decoded) + word_tokens.append(current_tokens) + current_tokens = [] + unicode_offset += len(decoded) + + return words, word_tokens + + def split_tokens_on_spaces( + self, tokens: List[int] + ) -> Tuple[List[str], List[List[int]]]: + subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) + words = [] + word_tokens = [] + + for subword, subword_tokens in zip(subwords, subword_tokens_list): + special = subword_tokens[0] >= self.eot + with_space = subword.startswith(" ") + punctuation = subword.strip() in string.punctuation + should_append = special or with_space or punctuation or len(words) == 0 + + if should_append: + words.append(subword) + word_tokens.append(subword_tokens) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + + return words, word_tokens + + +_TASKS = ( + "transcribe", + "translate", +) + +_LANGUAGE_CODES = ( + "af", + "am", + "ar", + "as", + "az", + "ba", + "be", + "bg", + "bn", + "bo", + "br", + "bs", + "ca", + "cs", + "cy", + "da", + "de", + "el", + "en", + "es", + "et", + "eu", + "fa", + "fi", + "fo", + "fr", + "gl", + "gu", + "ha", + "haw", + "he", + "hi", + "hr", + "ht", + "hu", + "hy", + "id", + "is", + "it", + "ja", + "jw", + "ka", + "kk", + "km", + "kn", + "ko", + "la", + "lb", + "ln", + "lo", + "lt", + "lv", + "mg", + "mi", + "mk", + "ml", + "mn", + "mr", + "ms", + "mt", + "my", + "ne", + "nl", + "nn", + "no", + "oc", + "pa", + "pl", + "ps", + "pt", + "ro", + "ru", + "sa", + "sd", + "si", + "sk", + "sl", + "sn", + "so", + "sq", + "sr", + "su", + "sv", + "sw", + "ta", + "te", + "tg", + "th", + "tk", + "tl", + "tr", + "tt", + "uk", + "ur", + "uz", + "vi", + "yi", + "yo", + "zh", + "yue", +) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/utils.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d4464195e2997a3dcf4346d51e68405e9be10cd --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/utils.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass, field +from typing import List + +@dataclass +class CompileInfo: + prefix_name = ["mindie_whisper_encoder_bs", + "mindie_decoder_prefill_bs", + "mindie_whisper_decoder_bs", + "mindie_self_scatter_bs", + "mindie_encoder_scatter_bs"] + mel_feature_size = 128 + max_frames = 3000 + max_decode_step = 448 + head_num = 20 + head_size = 64 + encoder_seq_len = 1500 + hidden_size = 1280 + layer_nums = 32 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/vad.py b/MindIE/MindIE-Torch/built-in/audio/WhisperX/vad.py new file mode 100644 index 0000000000000000000000000000000000000000..35fd026ada7673edbdde5415c03c9d3eeb6253c8 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/vad.py @@ -0,0 +1,287 @@ +import hashlib +import os +import urllib +from typing import Callable, Optional, Text, Union + +import numpy as np +import pandas as pd +import torch +from pyannote.audio import Model +from pyannote.audio.core.io import AudioFile +from pyannote.audio.pipelines import VoiceActivityDetection +from pyannote.audio.pipelines.utils import PipelineModel +from pyannote.core import Annotation, Segment, SlidingWindowFeature +from tqdm import tqdm + +from diarize import Segment as SegmentX + + +def load_vad_model(model_dir, device, ts_model_path, vad_onset=0.500, vad_offset=0.363): + model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin") + + vad_model = Model.from_pretrained(model_fp, use_auth_token=None) + hyperparameters = {"onset": vad_onset, + "offset": vad_offset, + "min_duration_on": 0.1, + "min_duration_off": 0.1} + vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device), ts_model_path=ts_model_path) + vad_pipeline.instantiate(hyperparameters) + + return vad_pipeline + + +class Binarize: + """Binarize detection scores using hysteresis thresholding, with min-cut operation + to ensure not segments are longer than max_duration. + + Parameters + ---------- + onset : float, optional + Onset threshold. Defaults to 0.5. + offset : float, optional + Offset threshold. Defaults to `onset`. + min_duration_on : float, optional + Remove active regions shorter than that many seconds. Defaults to 0s. + min_duration_off : float, optional + Fill inactive regions shorter than that many seconds. Defaults to 0s. + pad_onset : float, optional + Extend active regions by moving their start time by that many seconds. + Defaults to 0s. + pad_offset : float, optional + Extend active regions by moving their end time by that many seconds. + Defaults to 0s. + max_duration: float + The maximum length of an active segment, divides segment at timestamp with lowest score. + Reference + --------- + Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of + RNN-based Voice Activity Detection", InterSpeech 2015. + + Modified by Max Bain to include WhisperX's min-cut operation + https://arxiv.org/abs/2303.00747 + + Pyannote-audio + """ + + def __init__( + self, + onset: float = 0.5, + offset: Optional[float] = None, + min_duration_on: float = 0.0, + min_duration_off: float = 0.0, + pad_onset: float = 0.0, + pad_offset: float = 0.0, + max_duration: float = float('inf') + ): + + super().__init__() + + self.onset = onset + self.offset = offset or onset + + self.pad_onset = pad_onset + self.pad_offset = pad_offset + + self.min_duration_on = min_duration_on + self.min_duration_off = min_duration_off + + self.max_duration = max_duration + + def __call__(self, scores: SlidingWindowFeature) -> Annotation: + """Binarize detection scores + Parameters + ---------- + scores : SlidingWindowFeature + Detection scores. + Returns + ------- + active : Annotation + Binarized scores. + """ + + num_frames, num_classes = scores.data.shape + frames = scores.sliding_window + timestamps = [frames[i].middle for i in range(num_frames)] + + # annotation meant to store 'active' regions + active = Annotation() + for k, k_scores in enumerate(scores.data.T): + + label = k if scores.labels is None else scores.labels[k] + + # initial state + start = timestamps[0] + is_active = k_scores[0] > self.onset + curr_scores = [k_scores[0]] + curr_timestamps = [start] + t = start + for t, y in zip(timestamps[1:], k_scores[1:]): + # currently active + if is_active: + curr_duration = t - start + if curr_duration > self.max_duration: + search_after = len(curr_scores) // 2 + # divide segment + min_score_div_idx = search_after + np.argmin(curr_scores[search_after:]) + min_score_t = curr_timestamps[min_score_div_idx] + region = Segment(start - self.pad_onset, min_score_t + self.pad_offset) + active[region, k] = label + start = curr_timestamps[min_score_div_idx] + curr_scores = curr_scores[min_score_div_idx + 1:] + curr_timestamps = curr_timestamps[min_score_div_idx + 1:] + # switching from active to inactive + elif y < self.offset: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + start = t + is_active = False + curr_scores = [] + curr_timestamps = [] + curr_scores.append(y) + curr_timestamps.append(t) + # currently inactive + else: + # switching from inactive to active + if y > self.onset: + start = t + is_active = True + + # if active at the end, add final region + if is_active: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + + # because of padding, some active regions might be overlapping: merge them. + # also: fill same speaker gaps shorter than min_duration_off + if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0: + if self.max_duration < float("inf"): + raise NotImplementedError(f"This would break current max_duration param") + active = active.support(collar=self.min_duration_off) + + # remove tracks shorter than min_duration_on + if self.min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < self.min_duration_on: + del active[segment, track] + + return active + + +class VoiceActivitySegmentation(VoiceActivityDetection): + def __init__( + self, + segmentation: PipelineModel = "pyannote/segmentation", + fscore: bool = False, + use_auth_token: Union[Text, None] = None, + ts_model_path: str = None, + **inference_kwargs, + ): + + super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, ts_model_path=ts_model_path, **inference_kwargs) + + def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation: + """Apply voice activity detection + + Parameters + ---------- + file : AudioFile + Processed file. + hook : callable, optional + Hook called after each major step of the pipeline with the following + signature: hook("step_name", step_artefact, file=file) + + Returns + ------- + speech : Annotation + Speech regions. + """ + + # setup hook (e.g. for debugging purposes) + hook = self.setup_hook(file, hook=hook) + + # apply segmentation model (only if needed) + # output shape is (num_chunks, num_frames, 1) + if self.training: + if self.CACHED_SEGMENTATION in file: + segmentations = file[self.CACHED_SEGMENTATION] + else: + segmentations = self._segmentation(file) + file[self.CACHED_SEGMENTATION] = segmentations + else: + segmentations: SlidingWindowFeature = self._segmentation(file) + + return segmentations + + +def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0): + + active = Annotation() + for k, vad_t in enumerate(vad_arr): + region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset) + active[region, k] = 1 + + + if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0: + active = active.support(collar=min_duration_off) + + # remove tracks shorter than min_duration_on + if min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < min_duration_on: + del active[segment, track] + + active = active.for_json() + active_segs = pd.DataFrame([x['segment'] for x in active['content']]) + return active_segs + + +def merge_chunks( + segments, + chunk_size, + onset: float = 0.5, + offset: Optional[float] = None, +): + """ + Merge operation described in paper + """ + curr_end = 0 + merged_segments = [] + seg_idxs = [] + speaker_idxs = [] + + if chunk_size <= 0: + raise ValueError("chunk_size must be greater than 0") + + binarize = Binarize(max_duration=chunk_size, onset=onset, offset=offset) + segments = binarize(segments) + segments_list = [] + for speech_turn in segments.get_timeline(): + segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN")) + + if len(segments_list) == 0: + print("No active speech found in audio") + return [] + # assert segments_list, "segments_list is empty." + # Make sur the starting point is the start of the segment. + curr_start = segments_list[0].start + + for seg in segments_list: + if seg.end - curr_start > chunk_size and curr_end - curr_start > 0: + merged_segments.append({ + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + }) + curr_start = seg.start + seg_idxs = [] + speaker_idxs = [] + curr_end = seg.end + seg_idxs.append((seg.start, seg.end)) + speaker_idxs.append(seg.speaker) + # add final + merged_segments.append({ + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + }) + return merged_segments \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/WhisperX/voice_activity_detection.patch b/MindIE/MindIE-Torch/built-in/audio/WhisperX/voice_activity_detection.patch new file mode 100644 index 0000000000000000000000000000000000000000..2fc635269b97b5503c43c9fe65b4de30f2a5b38b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/WhisperX/voice_activity_detection.patch @@ -0,0 +1,20 @@ +--- voice_activity_detection.py 2024-11-11 12:06:17.300126910 +0800 ++++ whisperX_update/voice_activity_detection.py 2024-11-13 15:31:35.374394983 +0800 +@@ -112,6 +112,7 @@ + segmentation: PipelineModel = "pyannote/segmentation", + fscore: bool = False, + use_auth_token: Union[Text, None] = None, ++ ts_model_path: str = None, + **inference_kwargs, + ): + super().__init__() +@@ -125,6 +126,9 @@ + inference_kwargs["pre_aggregation_hook"] = lambda scores: np.max( + scores, axis=-1, keepdims=True + ) ++ if ts_model_path is not None: ++ inference_kwargs["ts_model_path"] = ts_model_path ++ + self._segmentation = Inference(model, **inference_kwargs) + + if model.specifications.powerset: