From 2e553218926fd8790346de0a5f1cf81de29f1e3a Mon Sep 17 00:00:00 2001 From: brjiang Date: Wed, 14 Aug 2024 16:00:57 +0800 Subject: [PATCH 01/16] =?UTF-8?q?=E6=96=B0=E5=A2=9EParaformer=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 133 ++++++++ .../built-in/audio/Paraformer/compile.py | 99 ++++++ .../built-in/audio/Paraformer/mindie.patch | 288 ++++++++++++++++++ .../audio/Paraformer/mindie_auto_model.py | 278 +++++++++++++++++ .../built-in/audio/Paraformer/mindie_cif.py | 173 +++++++++++ .../Paraformer/mindie_encoder_decoder.py | 121 ++++++++ .../audio/Paraformer/mindie_paraformer.py | 251 +++++++++++++++ .../built-in/audio/Paraformer/mindie_punc.py | 199 ++++++++++++ 8 files changed, 1542 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie.patch create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md new file mode 100644 index 0000000000..2bacbe5e0a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -0,0 +1,133 @@ +# stable-diffusionxl-controlnet模型-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + +# 概述 + +该工程使用mindietorch部署paraformer模型 + +- 模型路径: + ```bash + https://modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch + ``` + +- 参考实现: + ```bash + https://github.com/modelscope/FunASR + ``` + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------ | ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0+cpu | - | + | torch_audio | 2.1.0+cpu | - | + | CANN | 8.0.RC2 | - | + | MindIE | 1.0.RC2.B091 | - | + +# 快速上手 +## 获取源码 + +1. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +2. 安装torch_npu + + 请根据[昇腾文档](https://www.hiascend.com/document/detail/zh/mindie/10RC2/mindietorch/Torchdev/mindie_torch0017.html)安装合适版本的torch_npu,并手动编译libtorch_npu_bridge.so + +3. 获取Funasr源码 + + ``` + git clone https://github.com/modelscope/FunASR.git + cd ./FunASR + git reset fdac68e1d09645c48adf540d6091b194bac71075 --hard + cd .. + ``` + +4. 修改Funasr的源码,先将patch文件移动至Funasr的工程路径下,而后将patch应用到代码中(若patch应用失败,则需要手动进行修改) + ``` + mv mindie.patch ./FunASR + cd ./FunASR + git apply mindie.patch --reject + cd .. + ``` + +5. 获取模型文件 + +将[Paraformer](https://modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)的模型文件下载到本地,并保存在./model文件夹下 + +将[vad](https://modelscope.cn/models/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch/files)的模型文件下载到本地,并保存在./model_vad文件夹下 + +将[punc](https://modelscope.cn/models/iic/punc_ct-transformer_cn-en-common-vocab471067-large/files)的模型文件下载到本地,并保存在./model_punc文件夹下 + +目录结构如下所示 + + ``` + Paraformer + ├── FunASR + └── model + └── model.pt + └── config.yaml + └── ... + └── model_vad + └── model.pt + └── config.yaml + └── ... + └── model_punc + └── model.pt + └── config.yaml + └── ... + ``` + +6. 安装Funasr的依赖 + ``` + sudo apt install ffmpeg + pip install jieba + ``` + +## 模型推理 + +1. 模型编译及样本测试 + 执行下述命令进行模型编译 + ```bash + python compile.py \ + --model ./model \ + --model_vad ./model_vad \ + --model_punc ./model_punc \ + --compiled_encoder ./compiled_model/compiled_encoder.pt \ + --compiled_decoder ./compiled_model/compiled_decoder.pt \ + --compiled_cif ./compiled_model/compiled_cif.pt \ + --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ + --compiled_punc ./compiled_model/compiled_punc.ts \ + --batch_size 16 \ + --sample_path ./model/example/asr_example.wav \ + --skip_compile + ``` + + 参数说明: + - --model:预训练模型路径 + - --model_vad:VAD预训练模型路径,若不使用VAD模型则设置为None + - --model_punc:PUNC预训练模型路径,若不使用PUNC模型则设置为None + - --compiled_encoder:编译后的encoder模型的路径 + - --compiled_decoder:编译后的decoder模型的路径 + - --compiled_cif:编译后的cif函数的路径 + - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 + - --compiled_punc:编译后的punc的路径 + - --batch_size:Paraformer模型所使用的batch_size + - --sample_path:用于测试模型的样本音频路径 + - --skip_compile:是否进行模型编译,若已经完成编译,后续使用可使用该参数跳过编译,若为第一次使用不能指定该参数 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py new file mode 100644 index 0000000000..5629dada48 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +import torch_npu +import mindietorch + +from mindie_auto_model import MindieAutoModel + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--skip_compile", action="store_true", + help="whether to skip compiling sub-models in Paraformer") + parser.add_argument("--model", default="./model", + help="path of pretrained model") + parser.add_argument("--model_vad", default="./model_vad", + help="path of pretrained vad model") + parser.add_argument("--model_punc", default="./model_punc", + help="path of pretrained punc model") + parser.add_argument("--compiled_encoder", default="./compiled_model/compiled_encoder.ts", + help="path to save compiled encoder") + parser.add_argument("--compiled_decoder", default="./compiled_model/compiled_decoder.ts", + help="path to save compiled decoder") + parser.add_argument("--compiled_cif", default="./compiled_model/compiled_cif.ts", + help="path to save compiled cif function") + parser.add_argument("--compiled_cif_timestamp", default="./compiled_model/compiled_cif_timestamp.ts", + help="path to save compiled cif timestamp function") + parser.add_argument("--compiled_punc", default="./compiled_model/compiled_punc.ts", + help="path to save compiled punc model") + parser.add_argument("--batch_size", default=16, type=int, + help="batch size of paraformer model") + parser.add_argument("--sample_path", default="./audio/test_1.wav", + help="path of sample audio") + args = parser.parse_args() + + mindietorch.set_device(0) + + # use mindietorch to compile sub-models in Paraformer + if not args.skip_compile: + print("Begin compiling sub-models") + MindieAutoModel.export_model(model=args.model_punc, compiled_path=args.compiled_punc, compile_type="punc") + MindieAutoModel.export_model(model=args.model, compiled_encoder=args.compiled_encoder, + compiled_decoder=args.compiled_decoder, compiled_cif=args.compiled_cif, + compiled_cif_timestamp=args.compiled_cif_timestamp, + cif_interval=200, cif_timestamp_interval=500, compile_type="paraformer") + print("Finish compiling sub-models") + else: + print("Use existing compiled model") + + # initialize auto model + # note: ncpu means the number of threads used for intraop parallelism on the CPU, which is relevant to speed of vad model + model = MindieAutoModel(model=args.model, vad_model=args.model_vad, punc_model=args.model_punc, + compiled_encoder=args.compiled_encoder, compiled_decoder=args.compiled_decoder, + compiled_cif=args.compiled_cif, compiled_cif_timestamp=args.compiled_cif_timestamp, + compiled_punc=args.compiled_punc, paraformer_batch_size=args.batch_size, + cif_interval=200, cif_timestamp_interval=500, ncpu=16) + + # warm up + print("Begin warming up") + for i in range(5): + _ = model.generate(input=args.sample_path) + + # run with sample audio + iteration_num = 100 + print("Begin runing with sample audio with {} iterations".format(iteration_num)) + + total_dict_time = {} + for i in range(iteration_num): + res, time_stats = model.generate(input=args.sample_path) + + print("Iteration {} Model output: {}".format(i, res[0]["text"])) + print("Iteration {} Time comsumption:".format(i)) + print(" ".join(f"{key}: {value:.3f}s" for key, value in time_stats.items())) + for key, value in time_stats.items(): + if key not in total_dict_time: + total_dict_time[key] = float(value) + else: + total_dict_time[key] += float(value) + + # display average time consumption + print("\nAverage time comsumption") + for key, value in total_dict_time.items(): + print(key, ": {:.3f}s".format(float(value) / iteration_num)) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie.patch b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie.patch new file mode 100644 index 0000000000..23d3207807 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie.patch @@ -0,0 +1,288 @@ +diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py +index 01e6aaf6..2081ebb1 100644 +--- a/funasr/auto/auto_model.py ++++ b/funasr/auto/auto_model.py +@@ -277,9 +277,9 @@ class AutoModel: + asr_result_list = [] + num_samples = len(data_list) + disable_pbar = self.kwargs.get("disable_pbar", False) +- pbar = ( +- tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None +- ) ++ # pbar = ( ++ # tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None ++ # ) + time_speech_total = 0.0 + time_escape_total = 0.0 + for beg_idx in range(0, num_samples, batch_size): +@@ -311,20 +311,22 @@ class AutoModel: + speed_stats["batch_size"] = f"{len(results)}" + speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}" + description = f"{speed_stats}, " +- if pbar: +- pbar.update(1) +- pbar.set_description(description) ++ # if pbar: ++ # pbar.update(1) ++ # pbar.set_description(description) + time_speech_total += batch_data_time + time_escape_total += time_escape + +- if pbar: +- # pbar.update(1) +- pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") ++ # if pbar: ++ # # pbar.update(1) ++ # pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") + torch.cuda.empty_cache() + return asr_result_list + + def inference_with_vad(self, input, input_len=None, **cfg): + kwargs = self.kwargs ++ time_stats = {"input_speech_time": 0.0, "end_to_end_time": 0.0, "vad_time" : 0.0, ++ "paraformer_time": 0.0, "punc_time": 0.0} + # step.1: compute the vad model + deep_update(self.vad_kwargs, cfg) + beg_vad = time.time() +@@ -332,6 +334,7 @@ class AutoModel: + input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg + ) + end_vad = time.time() ++ time_stats["vad_time"] = end_vad - beg_vad + + # FIX(gcf): concat the vad clips for sense vocie model for better aed + if kwargs.get("merge_vad", False): +@@ -366,7 +369,7 @@ class AutoModel: + speech_lengths = len(speech) + n = len(vadsegments) + data_with_index = [(vadsegments[i], i) for i in range(n)] +- sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) ++ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0], reverse=True) + results_sorted = [] + + if not len(sorted_data): +@@ -377,7 +380,7 @@ class AutoModel: + if len(sorted_data) > 0 and len(sorted_data[0]) > 0: + batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]) + +- beg_idx = 0 ++ batch_size = 0 + beg_asr_total = time.time() + time_speech_total_per_sample = speech_lengths / 16000 + time_speech_total_all_samples += time_speech_total_per_sample +@@ -385,28 +388,19 @@ class AutoModel: + # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True) + + all_segments = [] +- max_len_in_batch = 0 +- end_idx = 1 +- for j, _ in enumerate(range(0, n)): +- # pbar_sample.update(1) +- sample_length = sorted_data[j][0][1] - sorted_data[j][0][0] +- potential_batch_length = max(max_len_in_batch, sample_length) * (j + 1 - beg_idx) +- # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0] +- if ( +- j < n - 1 +- and sample_length < batch_size_threshold_ms +- and potential_batch_length < batch_size +- ): +- max_len_in_batch = max(max_len_in_batch, sample_length) +- end_idx += 1 +- continue +- ++ batch_segments = kwargs["paraformer_batch_size"] ++ print("Num of vadsegments: {}, Paraformer batch_size: {}".format(n, batch_segments)) ++ loop_num = n // batch_segments if n % batch_segments == 0 else n // batch_segments + 1 ++ beg_idx = 0 ++ end_idx = batch_segments ++ for j in range(loop_num): + speech_j, speech_lengths_j = slice_padding_audio_samples( + speech, speech_lengths, sorted_data[beg_idx:end_idx] + ) +- results = self.inference( ++ results_batch = self.inference_with_asr( + speech_j, input_len=None, model=model, kwargs=kwargs, **cfg + ) ++ results = results_batch[0] + if self.spk_model is not None: + # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] + for _b in range(len(speech_j)): +@@ -425,8 +419,7 @@ class AutoModel: + ) + results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"] + beg_idx = end_idx +- end_idx += 1 +- max_len_in_batch = sample_length ++ end_idx += batch_segments + if len(results) < 1: + continue + results_sorted.extend(results) +@@ -478,6 +471,10 @@ class AutoModel: + if not len(result["text"].strip()): + continue + return_raw_text = kwargs.get("return_raw_text", False) ++ ++ end_paraformer = time.time() ++ time_stats["paraformer_time"] = end_paraformer - beg_asr_total ++ + # step.3 compute punc model + raw_text = None + if self.punc_model is not None: +@@ -489,6 +486,8 @@ class AutoModel: + if return_raw_text: + result["raw_text"] = raw_text + result["text"] = punc_res[0]["text"] ++ end_punc = time.time() ++ time_stats["punc_time"] = end_punc - end_paraformer + + # speaker embedding cluster after resorted + if self.spk_model is not None and kwargs.get("return_spk_res", True): +@@ -575,12 +574,13 @@ class AutoModel: + f"time_escape: {time_escape_total_per_sample:0.3f}" + ) + +- # end_total = time.time() +- # time_escape_total_all_samples = end_total - beg_total ++ end_total = time.time() ++ time_stats["end_to_end_time"] = end_total - beg_vad ++ time_stats["input_speech_time"] = time_speech_total_all_samples + # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, " + # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, " + # f"time_escape_all: {time_escape_total_all_samples:0.3f}") +- return results_ret_list ++ return results_ret_list, time_stats + + def export(self, input=None, **cfg): + """ +diff --git a/funasr/models/bicif_paraformer/cif_predictor.py b/funasr/models/bicif_paraformer/cif_predictor.py +index ca98cdc2..4d61ab85 100644 +--- a/funasr/models/bicif_paraformer/cif_predictor.py ++++ b/funasr/models/bicif_paraformer/cif_predictor.py +@@ -238,6 +238,7 @@ class CifPredictorV3(torch.nn.Module): + elif self.tail_threshold > 0.0: + hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask) + ++ return hidden, alphas, token_num + acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + if target_length is None and self.tail_threshold > 0.0: + token_num_int = torch.max(token_num).type(torch.int32).item() +@@ -282,6 +283,7 @@ class CifPredictorV3(torch.nn.Module): + _token_num = alphas2.sum(-1) + if token_num is not None: + alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1)) ++ return alphas2 + # re-downsample + ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1) + ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4) +diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py +index c7e8a8e0..0f1862ae 100644 +--- a/funasr/models/sanm/attention.py ++++ b/funasr/models/sanm/attention.py +@@ -275,13 +275,15 @@ class MultiHeadedAttentionSANM(nn.Module): + "inf" + ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) + scores = scores.masked_fill(mask, min_value) +- self.attn = torch.softmax(scores, dim=-1).masked_fill( +- mask, 0.0 +- ) # (batch, head, time1, time2) ++ # self.attn = torch.softmax(scores, dim=-1).masked_fill( ++ # mask, 0.0 ++ # ) # (batch, head, time1, time2) ++ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + else: +- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) ++ # self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) ++ attn = torch.softmax(scores, dim=-1) + +- p_attn = self.dropout(self.attn) ++ p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) +@@ -683,18 +685,22 @@ class MultiHeadedAttentionCrossAtt(nn.Module): + # logging.info( + # "scores: {}, mask_size: {}".format(scores.size(), mask.size())) + scores = scores.masked_fill(mask, min_value) +- self.attn = torch.softmax(scores, dim=-1).masked_fill( +- mask, 0.0 +- ) # (batch, head, time1, time2) ++ # self.attn = torch.softmax(scores, dim=-1).masked_fill( ++ # mask, 0.0 ++ # ) # (batch, head, time1, time2) ++ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) +- p_attn = self.dropout(self.attn) ++ attn = torch.softmax(scores, dim=-1) ++ # p_attn = self.dropout(self.attn) ++ p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + if ret_attn: +- return self.linear_out(x), self.attn # (batch, time1, d_model) ++ # return self.linear_out(x), self.attn # (batch, time1, d_model) ++ return self.linear_out(x), attn + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, x, memory, memory_mask, ret_attn=False): +diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py +index b2a442bd..c8044cee 100644 +--- a/funasr/models/sanm/encoder.py ++++ b/funasr/models/sanm/encoder.py +@@ -15,7 +15,7 @@ import torch.nn.functional as F + + import numpy as np + from funasr.train_utils.device_funcs import to_device +-from funasr.models.transformer.utils.nets_utils import make_pad_mask ++from funasr.models.transformer.utils.nets_utils import make_pad_mask, make_pad_mask_new + from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM + from funasr.models.transformer.embedding import ( + SinusoidalPositionEncoder, +@@ -374,7 +374,8 @@ class SANMEncoder(nn.Module): + Returns: + position embedded tensor and mask + """ +- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) ++ # masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) ++ masks = (~make_pad_mask_new(ilens)[:, None, :]).to(xs_pad.device) + xs_pad = xs_pad * self.output_size() ** 0.5 + if self.embed is None: + xs_pad = xs_pad +diff --git a/funasr/models/transformer/utils/nets_utils.py b/funasr/models/transformer/utils/nets_utils.py +index 29d23ee5..19693c9e 100644 +--- a/funasr/models/transformer/utils/nets_utils.py ++++ b/funasr/models/transformer/utils/nets_utils.py +@@ -218,6 +218,15 @@ def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): + return mask + + ++def make_pad_mask_new(lengths): ++ maxlen = lengths.max() ++ row_vector = torch.arange(0, maxlen, 1).to(lengths.device) ++ matrix = torch.unsqueeze(lengths, dim=-1) ++ mask = row_vector >= matrix ++ mask = mask.detach() ++ return mask ++ ++ + def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + +diff --git a/funasr/models/transformer/utils/repeat.py b/funasr/models/transformer/utils/repeat.py +index a44c1a01..0935d854 100644 +--- a/funasr/models/transformer/utils/repeat.py ++++ b/funasr/models/transformer/utils/repeat.py +@@ -28,8 +28,9 @@ class MultiSequential(torch.nn.Sequential): + """Repeat.""" + _probs = torch.empty(len(self)).uniform_() + for idx, m in enumerate(self): +- if not self.training or (_probs[idx] >= self.layer_drop_rate): +- args = m(*args) ++ # if not self.training or (_probs[idx] >= self.layer_drop_rate): ++ # args = m(*args) ++ args = m(*args) + return args + + diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py new file mode 100644 index 0000000000..8251f7533f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py @@ -0,0 +1,278 @@ +import sys +sys.path.append("./FunASR") + +import torch +import time +import logging + +from mindie_paraformer import MindieBiCifParaformer +from mindie_encoder_decoder import MindieEncoder, MindieDecoder +from mindie_punc import MindiePunc, MindieCTTransformer +from mindie_cif import MindieCifTimestamp, MindieCif +from funasr.auto.auto_model import AutoModel, download_model, tables, deep_update, \ + load_pretrained_model, prepare_data_iterator + + +class MindieAutoModel(AutoModel): + def __init__(self, **kwargs): + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + logging.basicConfig(level=log_level) + + if not kwargs.get("disable_log", True): + tables.print() + + kwargs["compile_type"] = "paraformer" + model, kwargs = self.build_model_with_mindie(**kwargs) + + # if vad_model is not None, build vad model else None + vad_model = kwargs.get("vad_model", None) + vad_kwargs = {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {}) + if vad_model is not None: + logging.info("Building VAD model.") + vad_kwargs["model"] = vad_model + vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master") + vad_kwargs["device"] = "cpu" + vad_model, vad_kwargs = self.build_model(**vad_kwargs) + + # if punc_model is not None, build punc model else None + punc_model = kwargs.get("punc_model", None) + punc_kwargs = {} if kwargs.get("punc_kwargs", {}) is None else kwargs.get("punc_kwargs", {}) + if punc_model is not None: + logging.info("Building punc model.") + punc_kwargs["model"] = punc_model + punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master") + punc_kwargs["device"] = "npu" + punc_kwargs["compile_type"] = "punc" + punc_kwargs["compiled_punc"] = kwargs["compiled_punc"] + punc_model, punc_kwargs = self.build_model_with_mindie(**punc_kwargs) + + self.kwargs = kwargs + self.model = model + self.vad_model = vad_model + self.vad_kwargs = vad_kwargs + self.punc_model = punc_model + self.punc_kwargs = punc_kwargs + self.spk_model = None + self.spk_kwargs = {} + self.model_path = kwargs.get("model_path") + + def generate(self, input, input_len=None, **cfg): + if self.vad_model is None: + return self.inference_with_asr(input, input_len=input_len, **cfg) + + else: + return self.inference_with_vad(input, input_len=input_len, **cfg) + + @staticmethod + def export_model(**kwargs): + # load model config + assert "model" in kwargs + if "model_conf" not in kwargs: + print("download models from model hub: {}".format(kwargs.get("hub", "ms"))) + kwargs = download_model(**kwargs) + + kwargs["batch_size"] = 1 + kwargs["device"] = "cpu" + + # build tokenizer + tokenizer = kwargs.get("tokenizer", None) + if tokenizer is not None: + tokenizer_class = tables.tokenizer_classes.get(tokenizer) + tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {})) + kwargs["token_list"] = ( + tokenizer.token_list if hasattr(tokenizer, "token_list") else None + ) + kwargs["token_list"] = ( + tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"] + ) + vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 + if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): + vocab_size = tokenizer.get_vocab_size() + else: + vocab_size = -1 + kwargs["tokenizer"] = tokenizer + + # build frontend + frontend = kwargs.get("frontend", None) + kwargs["input_size"] = None + if frontend is not None: + frontend_class = tables.frontend_classes.get(frontend) + frontend = frontend_class(**kwargs.get("frontend_conf", {})) + kwargs["input_size"] = ( + frontend.output_size() if hasattr(frontend, "output_size") else None + ) + kwargs["frontend"] = frontend + + model_conf = {} + deep_update(model_conf, kwargs.get("model_conf", {})) + deep_update(model_conf, kwargs) + init_param = kwargs.get("init_param", None) + + if kwargs["compile_type"] == "punc": + model = MindiePunc(**model_conf, vocab_size=vocab_size) + model.eval() + print(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=model, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + MindiePunc.export_ts(model, kwargs["compiled_path"]) + else: + # compile encoder + encoder = MindieEncoder(**model_conf, vocab_size=vocab_size) + encoder.eval() + print(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=encoder, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + MindieEncoder.export_ts(encoder, kwargs["compiled_encoder"]) + + # compile decoder + decoder = MindieDecoder(**model_conf, vocab_size=vocab_size) + decoder.eval() + print(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=decoder, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + MindieDecoder.export_ts(decoder, kwargs["compiled_decoder"]) + + # compile cif + mindie_cif = MindieCif(decoder.predictor.threshold, kwargs["cif_interval"]) + mindie_cif.export_ts(kwargs["compiled_cif"]) + + # compile cif_timestamp + mindie_cif_timestamp = MindieCifTimestamp(decoder.predictor.threshold - 1e-4, kwargs["cif_timestamp_interval"]) + mindie_cif_timestamp.export_ts(kwargs["compiled_cif_timestamp"]) + + def build_model_with_mindie(self, **kwargs): + assert "model" in kwargs + if "model_conf" not in kwargs: + logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) + kwargs = download_model(**kwargs) + + torch.set_num_threads(kwargs.get("ncpu", 4)) + + # build tokenizer + tokenizer = kwargs.get("tokenizer", None) + if tokenizer is not None: + tokenizer_class = tables.tokenizer_classes.get(tokenizer) + tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {})) + kwargs["token_list"] = ( + tokenizer.token_list if hasattr(tokenizer, "token_list") else None + ) + kwargs["token_list"] = ( + tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"] + ) + vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 + if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): + vocab_size = tokenizer.get_vocab_size() + else: + vocab_size = -1 + kwargs["tokenizer"] = tokenizer + + # build frontend + frontend = kwargs.get("frontend", None) + kwargs["input_size"] = None + if frontend is not None: + frontend_class = tables.frontend_classes.get(frontend) + frontend = frontend_class(**kwargs.get("frontend_conf", {})) + kwargs["input_size"] = ( + frontend.output_size() if hasattr(frontend, "output_size") else None + ) + kwargs["frontend"] = frontend + + # build model + model_conf = {} + deep_update(model_conf, kwargs.get("model_conf", {})) + deep_update(model_conf, kwargs) + + if kwargs["compile_type"] == "punc": + model = MindieCTTransformer(**model_conf, vocab_size=vocab_size) + else: + model = MindieBiCifParaformer(**model_conf, vocab_size=vocab_size) + + # init_param + init_param = kwargs.get("init_param", None) + logging.info(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=model, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + + model.to("npu") + + return model, kwargs + + def inference_with_asr(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg): + kwargs = self.kwargs if kwargs is None else kwargs + deep_update(kwargs, cfg) + model = self.model if model is None else model + model.eval() + + batch_size = kwargs.get("batch_size", 1) + + key_list, data_list = prepare_data_iterator( + input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key + ) + + time_stats = {"input_speech_time": 0.0, "end_to_end_time": 0.0, "pure_infer_time": 0.0, + "load_data": 0.0, "encoder": 0.0, "predictor": 0.0, "decoder": 0.0, + "predictor_timestamp": 0.0, "post_process": 0.0} + asr_result_list = [] + num_samples = len(data_list) + + for beg_idx in range(0, num_samples, batch_size): + end_idx = min(num_samples, beg_idx + batch_size) + data_batch = data_list[beg_idx:end_idx] + key_batch = key_list[beg_idx:end_idx] + batch = {"data_in": data_batch, "key": key_batch} + + if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank + batch["data_in"] = data_batch[0] + batch["data_lengths"] = input_len + + with torch.no_grad(): + time1 = time.perf_counter() + res = model.inference_with_npu(**batch, **kwargs) + time2 = time.perf_counter() + if isinstance(res, (list, tuple)): + results = res[0] if len(res) > 0 else [{"text": ""}] + meta_data = res[1] if len(res) > 1 else {} + + asr_result_list.extend(results) + + # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() + batch_data_time = meta_data.get("batch_data_time", -1) + time_escape = time2 - time1 + + time_stats["load_data"] += meta_data.get("load_data", 0.0) + time_stats["encoder"] += meta_data.get("encoder", 0.0) + time_stats["predictor"] += meta_data.get("calc_predictor", 0.0) + time_stats["decoder"] += meta_data.get("decoder", 0.0) + time_stats["predictor_timestamp"] += meta_data.get("calc_predictor_timestamp", 0.0) + time_stats["post_process"] += meta_data.get("post_process", 0.0) + time_stats["end_to_end_time"] += time_escape + + time_stats["input_speech_time"] += batch_data_time + + time_stats["pure_infer_time"] = time_stats["end_to_end_time"] - time_stats["load_data"] + + return asr_result_list, time_stats \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py new file mode 100644 index 0000000000..71b761187d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py @@ -0,0 +1,173 @@ +import sys +sys.path.append("./FunASR") + +import torch +import mindietorch + +from mindie_paraformer import precision_eval + + +def cif(hidden, alphas, integrate, frame, threshold): + batch_size, len_time, hidden_size = hidden.size() + + # intermediate vars along time + list_fires = [] + list_frames = [] + + constant = torch.ones([batch_size], device=hidden.device) + for t in range(len_time): + alpha = alphas[:, t] + distribution_completion = constant - integrate + + integrate += alpha + list_fires.append(integrate) + + fire_place = integrate >= threshold + integrate = torch.where( + fire_place, integrate - constant, integrate + ) + cur = torch.where(fire_place, distribution_completion, alpha) + remainds = alpha - cur + + frame += cur[:, None] * hidden[:, t, :] + list_frames.append(frame) + frame = torch.where( + fire_place[:, None].repeat(1, hidden_size), remainds[:, None] * hidden[:, t, :], frame + ) + + fires = torch.stack(list_fires, 1) + frames = torch.stack(list_frames, 1) + + return fires, frames, integrate, frame + + +def cif_wo_hidden(alphas, integrate, threshold): + batch_size, len_time = alphas.size() + + list_fires = [] + + constant = torch.ones([batch_size], device=alphas.device) * threshold + + for t in range(len_time): + alpha = alphas[:, t] + + integrate += alpha + list_fires.append(integrate) + + fire_place = integrate >= threshold + integrate = torch.where( + fire_place, + integrate - constant, + integrate, + ) + + fire_list = [] + for i in range(0, len(list_fires), 500): + batch = list_fires[i:i + 500] + fire = torch.stack(batch, 1) + fire_list.append(fire) + + fires = torch.cat(fire_list, 1) + return fires, integrate + + +class MindieCif(torch.nn.Module): + def __init__(self, threshold, seq_len): + super().__init__() + self.threshold = threshold + self.seq_len = seq_len + + def forward(self, hidden, alphas, integrate, frame): + fires, frames, integrate_new, frame_new = cif(hidden, alphas, integrate, frame, self.threshold) + return fires, frames, integrate_new, frame_new + + def export_ts(self, path="./compiled_cif.ts"): + print("Begin trace cif!") + + input_shape1 = (1, self.seq_len, 512) + input_shape2 = (1, self.seq_len) + input_shape3 = (1, ) + input_shape4 = (1, 512) + + hidden = torch.randn(input_shape1, dtype=torch.float32) + alphas = torch.randn(input_shape2, dtype=torch.float32) + integrate = torch.randn(input_shape3, dtype=torch.float32) + frame = torch.randn(input_shape4, dtype=torch.float32) + compile_inputs = [mindietorch.Input(shape = input_shape1, dtype = torch.float32), + mindietorch.Input(shape = input_shape2, dtype = torch.float32), + mindietorch.Input(shape = input_shape3, dtype = torch.float32), + mindietorch.Input(shape = input_shape4, dtype = torch.float32)] + + export_model = torch.jit.trace(self, example_inputs=(hidden, alphas, integrate, frame)) + print("Finish trace cif") + + compiled_model = mindietorch.compile( + export_model, + inputs = compile_inputs, + precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, + soc_version = "Ascend310P3", + ir = "ts" + ) + print("mindietorch compile done !") + compiled_model.save(path) + # compiled_model = torch.jit.load(path) + + print("start to check the percision of cif model.") + sample_hidden = torch.randn(input_shape1, dtype=torch.float32) + sample_alphas = torch.randn(input_shape2, dtype=torch.float32) + sample_integrate = torch.randn(input_shape3, dtype=torch.float32) + sample_frame = torch.randn(input_shape4, dtype=torch.float32) + mrt_res = compiled_model(sample_hidden.to("npu"), sample_alphas.to("npu"), + sample_integrate.to("npu"), sample_frame.to("npu")) + print("mindie infer done !") + ref_res = cif(sample_hidden, sample_alphas, sample_integrate, sample_frame, self.threshold) + print("torch infer done !") + + precision_eval(mrt_res, ref_res) + + +class MindieCifTimestamp(torch.nn.Module): + def __init__(self, threshold, seq_len): + super().__init__() + self.threshold = threshold + self.seq_len = seq_len + + def forward(self, us_alphas, integrate): + us_peaks, integrate_new = cif_wo_hidden(us_alphas, integrate, self.threshold) + + return us_peaks, integrate_new + + def export_ts(self, path="./compiled_cif_timestamp.ts"): + print("Begin trace cif_timestamp!") + + input_shape1 = (1, self.seq_len) + input_shape2 = (1, ) + + us_alphas = torch.randn(input_shape1, dtype=torch.float32) + integrate = torch.randn(input_shape2, dtype=torch.float32) + compile_inputs = [mindietorch.Input(shape = input_shape1, dtype = torch.float32), + mindietorch.Input(shape = input_shape2, dtype = torch.float32)] + + export_model = torch.jit.trace(self, example_inputs=(us_alphas, integrate)) + print("Finish trace cif_timestamp") + + compiled_model = mindietorch.compile( + export_model, + inputs = compile_inputs, + precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, + soc_version = "Ascend310P3", + ir = "ts" + ) + print("mindietorch compile done !") + compiled_model.save(path) + # compiled_model = torch.jit.load(path) + + print("start to check the percision of cif_timestamp model.") + sample_input1 = torch.randn(input_shape1, dtype=torch.float32) + sample_input2 = torch.randn(input_shape2, dtype=torch.float32) + mrt_res = compiled_model(sample_input1.to("npu"), sample_input2.to("npu")) + print("mindie infer done !") + ref_res = cif_wo_hidden(sample_input1, sample_input2, self.threshold) + print("torch infer done !") + + precision_eval(mrt_res, ref_res) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py new file mode 100644 index 0000000000..55ace9f3d4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py @@ -0,0 +1,121 @@ +import sys +sys.path.append("./FunASR") + +import torch +import mindietorch + +from mindie_paraformer import precision_eval +from funasr.models.bicif_paraformer.model import Paraformer + + +class MindieEncoder(Paraformer): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward(self, speech, speech_length): + encoder_out, encoder_out_lens = self.encode(speech, speech_length) + + return encoder_out, encoder_out_lens + + @staticmethod + def export_ts(encoder, path="./compiled_encoder.ts"): + print("Begin trace encoder!") + + input_shape = (2, 50, 560) + min_shape = (-1, -1, 560) + max_shape = (-1, -1, 560) + input_speech = torch.randn(input_shape, dtype=torch.float32) + input_speech_lens = torch.tensor([50, 25], dtype=torch.int32) + compile_inputs = [mindietorch.Input(min_shape = min_shape, max_shape = max_shape, dtype = torch.float32), + mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32)] + + export_model = torch.jit.trace(encoder, example_inputs=(input_speech, input_speech_lens)) + print("Finish trace encoder") + + compiled_model = mindietorch.compile( + export_model, + inputs = compile_inputs, + precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, + soc_version = "Ascend310P3", + ir = "ts" + ) + print("mindietorch compile done !") + compiled_model.save(path) + # compiled_model = torch.jit.load(path) + + print("start to check the percision of encoder.") + sample_speech = torch.randn((4, 100, 560), dtype=torch.float32) + sample_speech_lens = torch.tensor([100, 50, 100, 25], dtype=torch.int32) + mrt_res = compiled_model(sample_speech.to("npu"), sample_speech_lens.to("npu")) + print("mindie infer done !") + ref_res = encoder(sample_speech, sample_speech_lens) + print("torch infer done !") + + precision_eval(mrt_res, ref_res) + + +class MindieDecoder(Paraformer): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward(self, encoder_out, encoder_out_lens, sematic_embeds, pre_token_length): + decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, pre_token_length) + decoder_out = decoder_outs[0] + decoder_out = torch.log_softmax(decoder_out, dim=-1) + return decoder_out + + @staticmethod + def export_ts(decoder, path="./compiled_decoder.ts"): + print("Begin trace decoder!") + + input_shape1 = (2, 939, 512) + min_shape1 = (-1, -1, 512) + max_shape1 = (-1, -1, 512) + + input_shape2 = (2, 261, 512) + min_shape2 = (-1, -1, 512) + max_shape2 = (-1, -1, 512) + + encoder_out = torch.randn(input_shape1, dtype=torch.float32) + encoder_out_lens = torch.tensor([939, 500], dtype=torch.int32) + sematic_embeds = torch.randn(input_shape2, dtype=torch.float32) + sematic_embeds_lens = torch.tensor([261, 100], dtype=torch.int32) + + compile_inputs = [mindietorch.Input(min_shape = min_shape1, max_shape = max_shape1, dtype = torch.float32), + mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32), + mindietorch.Input(min_shape = min_shape2, max_shape = max_shape2, dtype = torch.float32), + mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32)] + + export_model = torch.jit.trace(decoder, example_inputs=(encoder_out, encoder_out_lens, sematic_embeds, sematic_embeds_lens)) + print("Finish trace decoder") + + compiled_model = mindietorch.compile( + export_model, + inputs = compile_inputs, + precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, + soc_version = "Ascend310P3", + ir = "ts" + ) + print("mindietorch compile done !") + compiled_model.save(path) + # compiled_model = torch.jit.load(path) + + print("start to check the percision of decoder.") + sample_encoder = torch.randn((4, 150, 512), dtype=torch.float32) + sample_encoder_lens = torch.tensor([150, 100, 150, 50], dtype=torch.int32) + sample_sematic = torch.randn((4, 50, 512), dtype=torch.float32) + sample_sematic_lens = torch.tensor([50, 30, 50, 10], dtype=torch.int32) + mrt_res = compiled_model(sample_encoder.to("npu"), sample_encoder_lens.to("npu"), sample_sematic.to("npu"), sample_sematic_lens.to("npu")) + print("mindie infer done !") + ref_res = decoder(sample_encoder, sample_encoder_lens, sample_sematic, sample_sematic_lens) + print("torch infer done !") + + precision_eval(mrt_res, ref_res) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py new file mode 100644 index 0000000000..bc4f38335b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py @@ -0,0 +1,251 @@ +import sys +sys.path.append("./FunASR") + +import copy +import time +import torch +import torch_npu +import torch.nn.functional as F + +from funasr.models.bicif_paraformer.model import BiCifParaformer, load_audio_text_image_video, \ + extract_fbank, Hypothesis, ts_prediction_lfr6_standard, postprocess_utils +from funasr.models.transformer.utils.nets_utils import make_pad_mask_new + + +COSINE_THRESHOLD = 0.999 +def cosine_similarity(gt_tensor, pred_tensor): + gt_tensor = gt_tensor.flatten().to(torch.float32) + pred_tensor = pred_tensor.flatten().to(torch.float32) + if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0: + if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True): + return 1.0 + res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6) + res = res.cpu().detach().item() + return res + + +def precision_eval(mrt_res, ref_res): + if not isinstance(mrt_res, (list, tuple)): + mrt_res = [mrt_res, ] + if not isinstance(ref_res, (list, tuple)): + ref_res = [ref_res, ] + + com_res = True + for j, a in zip(mrt_res, ref_res): + res = cosine_similarity(j.to("cpu"), a) + print(res) + if res < COSINE_THRESHOLD: + com_res = False + + if com_res: + print("Compare success ! NPU model have the same output with CPU model !") + else: + print("Compare failed ! Outputs of NPU model are not the same with CPU model !") + + +class MindieBiCifParaformer(BiCifParaformer): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.mindie_encoder = torch.jit.load(kwargs["compiled_encoder"]) + self.mindie_decoder = torch.jit.load(kwargs["compiled_decoder"]) + self.mindie_cif = torch.jit.load(kwargs["compiled_cif"]) + self.mindie_cif_timestamp = torch.jit.load(kwargs["compiled_cif_timestamp"]) + + def inference_with_npu( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + stream = torch_npu.npu.current_stream() + # Step1: load input data + time1 = time.perf_counter() + meta_data = {} + + is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None + is_use_lm = ( + kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None + ) + if self.beam_search is None and (is_use_lm or is_use_ctc): + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + audio_sample_list = load_audio_text_image_video( + data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000) + ) + + speech, speech_lengths = extract_fbank( + audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend + ) + print("Input shape: ", speech.shape) + speech = speech.to("npu") + speech_lengths = speech_lengths.to("npu") + meta_data["batch_data_time"] = ( + speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + ) + + time2 = time.perf_counter() + meta_data["load_data"] = time2 - time1 + + # Step2: run with compiled encoder + encoder_out, encoder_out_lens = self.mindie_encoder(speech, speech_lengths) + encoder_out_lens = encoder_out_lens.to(torch.int32) + + encoder_out_mask = ( + ~make_pad_mask_new(encoder_out_lens)[:, None, :] + ).to(encoder_out.device) + hidden, alphas, pre_token_length = ( + self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id) + ) + pre_token_length = pre_token_length.round().to(torch.int32) + stream.synchronize() + time3 = time.perf_counter() + meta_data["encoder"] = time3 - time2 + + + # Step3: divide dynamic loop into multiple smaller loops for calculation + # each with a number of iterations based on kwargs["cif_interval"] + batch_size, len_time, hidden_size = hidden.size() + loop_num = len_time // kwargs["cif_interval"] + 1 + padding_len = loop_num * kwargs["cif_interval"] + padding_size = padding_len - len_time + padded_hidden = F.pad(hidden, (0, 0, 0, padding_size), "constant", 0) + padded_alphas = F.pad(alphas, (0, padding_size), "constant", 0) + + fires_batch = [] + frames_batch = [] + for b in range(batch_size): + fires_list = [] + frames_list = [] + integrate = torch.zeros([1, ], device=hidden.device) + frame = torch.zeros([1, hidden_size], device=hidden.device) + for i in range(loop_num): + cur_hidden = padded_hidden[b : b + 1, i * kwargs["cif_interval"] : (i + 1) * kwargs["cif_interval"], :] + cur_alphas = padded_alphas[b : b + 1, i * kwargs["cif_interval"] : (i + 1) * kwargs["cif_interval"]] + cur_fires, cur_frames, integrate, frame = self.mindie_cif(cur_hidden, cur_alphas, integrate, frame) + fires_list.append(cur_fires) + frames_list.append(cur_frames) + fire = torch.cat(fires_list, 1) + frame = torch.cat(frames_list, 1) + fires_batch.append(fire) + frames_batch.append(frame) + fires = torch.cat(fires_batch, 0) + frames = torch.cat(frames_batch, 0) + + list_ls = [] + len_labels = torch.round(alphas.sum(-1)).int() + max_label_len = len_labels.max() + for b in range(batch_size): + fire = fires[b, :len_time] + l = torch.index_select(frames[b, :len_time, :], 0, torch.nonzero(fire >= self.predictor.threshold).squeeze()) + pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device) + list_ls.append(torch.cat([l, pad_l], 0)) + acoustic_embeds = torch.stack(list_ls, 0) + token_num_int = torch.max(pre_token_length) + pre_acoustic_embeds = acoustic_embeds[:, :token_num_int, :] + + if torch.max(pre_token_length) < 1: + return [] + stream.synchronize() + time4 = time.perf_counter() + meta_data["calc_predictor"] = time4 - time3 + + + # Step4: run with compiled decoder + decoder_out = self.mindie_decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length) + stream.synchronize() + time5 = time.perf_counter() + meta_data["decoder"] = time5 - time4 + + + # Step5: divide dynamic loop into multiple smaller loops for calculation + # each with a number of iterations based on kwargs["cif_timestamp_interval"] + us_alphas = self.predictor.get_upsample_timestamp(encoder_out, encoder_out_mask, pre_token_length) + len_alphas = us_alphas.shape[1] + loop_num = len_alphas // kwargs["cif_timestamp_interval"] + 1 + padding_len = loop_num * kwargs["cif_timestamp_interval"] + padding_size = padding_len - len_alphas + padded_alphas = F.pad(us_alphas, (0, padding_size), "constant", 0) + + peak_batch = [] + for b in range(batch_size): + peak_list = [] + integrate_alphas = torch.zeros([1], device=alphas.device) + for i in range(loop_num): + cur_alphas = padded_alphas[b:b+1, i * kwargs["cif_timestamp_interval"] : (i + 1) * kwargs["cif_timestamp_interval"]] + peak, integrate_alphas = self.mindie_cif_timestamp(cur_alphas, integrate_alphas) + peak_list.append(peak) + us_peak = torch.cat(peak_list, 1)[:, :len_alphas] + peak_batch.append(us_peak) + us_peaks = torch.cat(peak_batch, 0) + + stream.synchronize() + time6 = time.perf_counter() + meta_data["calc_predictor_timestamp"] = time6 - time5 + + + # Step6: post process + results = [] + b, n, d = decoder_out.size() + for i in range(b): + x = encoder_out[i, : encoder_out_lens[i], :] + am_scores = decoder_out[i, : pre_token_length[i], :] + + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor([self.sos] + yseq.tolist() + [self.eos], device=yseq.device) + + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + + for nbest_idx, hyp in enumerate(nbest_hyps): + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list( + filter( + lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int + ) + ) + + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + + _, timestamp = ts_prediction_lfr6_standard( + us_alphas[i][: encoder_out_lens[i] * 3], + us_peaks[i][: encoder_out_lens[i] * 3], + copy.copy(token), + vad_offset=kwargs.get("begin_time", 0), + ) + + text_postprocessed, time_stamp_postprocessed, word_lists = ( + postprocess_utils.sentence_postprocess(token, timestamp) + ) + + result_i = { + "key": key[i], + "text": text_postprocessed, + "timestamp": time_stamp_postprocessed, + } + + results.append(result_i) + + stream.synchronize() + time7 = time.perf_counter() + meta_data["post_process"] = time7 - time6 + + return results, meta_data diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py new file mode 100644 index 0000000000..87c7635264 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py @@ -0,0 +1,199 @@ +import sys +sys.path.append("./FunASR") + +import torch +import mindietorch + +from mindie_paraformer import precision_eval +from funasr.utils.load_utils import load_audio_text_image_video +from funasr.models.ct_transformer.model import CTTransformer +from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words + + +class MindiePunc(CTTransformer): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward(self, text, text_lengths): + y, _ = self.punc_forward(text, text_lengths) + _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) + punctuations = torch.squeeze(indices, dim=1) + + return punctuations + + @staticmethod + def export_ts(punc, path="./compiled_punc.ts"): + print("Begin trace punc!") + + input_shape = (1, 20) + min_shape = (1, -1) + max_shape = (1, -1) + input_speech = torch.randint(1, 10, input_shape, dtype=torch.int32) + input_speech_lengths = torch.tensor([20, ], dtype=torch.int32) + compile_inputs = [mindietorch.Input(min_shape = min_shape, max_shape = max_shape, dtype = torch.int32), + mindietorch.Input(min_shape = (1, ), max_shape = (1, ), dtype = torch.int32)] + + export_model = torch.jit.trace(punc, example_inputs=(input_speech, input_speech_lengths)) + print("Finish trace punc") + + compiled_model = mindietorch.compile( + export_model, + inputs = compile_inputs, + precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, + soc_version = "Ascend310P3", + ir = "ts" + ) + print("mindietorch compile done !") + compiled_model.save(path) + # compiled_model = torch.jit.load(path) + + print("start to check the percision of punc model.") + sample_speech = torch.randint(1, 10, (1, 10), dtype=torch.int32) + sample_speech_lengths = torch.tensor([10, ], dtype=torch.int32) + mrt_res = compiled_model(sample_speech.to("npu"), sample_speech_lengths.to("npu")) + print("mindie infer done !") + ref_res = punc(sample_speech, sample_speech_lengths) + print("torch infer done !") + + precision_eval(mrt_res, ref_res) + + +class MindieCTTransformer(CTTransformer): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.mindie_punc = torch.jit.load(kwargs["compiled_punc"]) + + def inference( + self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + assert len(data_in) == 1 + text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0] + + split_size = kwargs.get("split_size", 20) + + tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict) + tokens_int = tokenizer.encode(tokens) + + mini_sentences = split_to_mini_sentence(tokens, split_size) + mini_sentences_id = split_to_mini_sentence(tokens_int, split_size) + assert len(mini_sentences) == len(mini_sentences_id) + + mini_sentences_id = [torch.unsqueeze(torch.tensor(id, dtype=torch.int32), 0).to("npu") for id in mini_sentences_id] + + cache_sent = [] + cache_sent_id = torch.tensor([[]], dtype=torch.int32).to("npu") + new_mini_sentence = "" + cache_pop_trigger_limit = 200 + results = [] + meta_data = {} + + for mini_sentence_i in range(len(mini_sentences)): + mini_sentence = mini_sentences[mini_sentence_i] + mini_sentence_id = mini_sentences_id[mini_sentence_i] + mini_sentence = cache_sent + mini_sentence + mini_sentence_id = torch.cat([cache_sent_id, mini_sentence_id], dim=1) + + text = mini_sentence_id + text_lengths = torch.tensor([text.shape[1], ], dtype=torch.int32, device=text.device) + punctuations = self.mindie_punc(text, text_lengths) + + assert punctuations.size()[0] == len(mini_sentence) + + # Search for the last Period/QuestionMark as cache + if mini_sentence_i < len(mini_sentences) - 1: + sentenceEnd = -1 + last_comma_index = -1 + for i in range(len(punctuations) - 2, 1, -1): + if ( + self.punc_list[punctuations[i]] == "。" + or self.punc_list[punctuations[i]] == "?" + ): + sentenceEnd = i + break + if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": + last_comma_index = i + + if ( + sentenceEnd < 0 + and len(mini_sentence) > cache_pop_trigger_limit + and last_comma_index >= 0 + ): + # The sentence it too long, cut off at a comma. + sentenceEnd = last_comma_index + punctuations[sentenceEnd] = self.sentence_end_id + cache_sent = mini_sentence[sentenceEnd + 1 :] + cache_sent_id = mini_sentence_id[:, sentenceEnd + 1 :] + mini_sentence = mini_sentence[0 : sentenceEnd + 1] + punctuations = punctuations[0 : sentenceEnd + 1] + + words_with_punc = [] + for i in range(len(mini_sentence)): + if ( + i == 0 + or self.punc_list[punctuations[i - 1]] == "。" + or self.punc_list[punctuations[i - 1]] == "?" + ) and len(mini_sentence[i][0].encode()) == 1: + mini_sentence[i] = mini_sentence[i].capitalize() + if i == 0: + if len(mini_sentence[i][0].encode()) == 1: + mini_sentence[i] = " " + mini_sentence[i] + if i > 0: + if ( + len(mini_sentence[i][0].encode()) == 1 + and len(mini_sentence[i - 1][0].encode()) == 1 + ): + mini_sentence[i] = " " + mini_sentence[i] + words_with_punc.append(mini_sentence[i]) + if self.punc_list[punctuations[i]] != "_": + punc_res = self.punc_list[punctuations[i]] + if len(mini_sentence[i][0].encode()) == 1: + if punc_res == ",": + punc_res = "," + elif punc_res == "。": + punc_res = "." + elif punc_res == "?": + punc_res = "?" + words_with_punc.append(punc_res) + new_mini_sentence += "".join(words_with_punc) + # Add Period for the end of the sentence + new_mini_sentence_out = new_mini_sentence + if mini_sentence_i == len(mini_sentences) - 1: + if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": + new_mini_sentence_out = new_mini_sentence[:-1] + "。" + elif new_mini_sentence[-1] == ",": + new_mini_sentence_out = new_mini_sentence[:-1] + "." + elif ( + new_mini_sentence[-1] != "。" + and new_mini_sentence[-1] != "?" + and len(new_mini_sentence[-1].encode()) != 1 + ): + new_mini_sentence_out = new_mini_sentence + "。" + if len(punctuations): + punctuations[-1] = 2 + elif ( + new_mini_sentence[-1] != "." + and new_mini_sentence[-1] != "?" + and len(new_mini_sentence[-1].encode()) == 1 + ): + new_mini_sentence_out = new_mini_sentence + "." + if len(punctuations): + punctuations[-1] = 2 + + result_i = {"key": key[0], "text": new_mini_sentence_out} + results.append(result_i) + return results, meta_data -- Gitee From 3db33ce4080b6eaf27fbd3e43832933d4f7ef847 Mon Sep 17 00:00:00 2001 From: brjiang Date: Mon, 19 Aug 2024 11:03:13 +0800 Subject: [PATCH 02/16] =?UTF-8?q?=E5=88=A0=E9=99=A4torch=5Fnpu?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 16 ++-- .../built-in/audio/Paraformer/compile.py | 3 +- .../audio/Paraformer/mindie_auto_model.py | 4 +- .../built-in/audio/Paraformer/mindie_cif.py | 5 +- .../Paraformer/mindie_encoder_decoder.py | 21 ++++- .../audio/Paraformer/mindie_paraformer.py | 77 ++++++++----------- .../built-in/audio/Paraformer/mindie_punc.py | 9 ++- 7 files changed, 68 insertions(+), 67 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index 2bacbe5e0a..ebf7dfae4a 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -46,11 +46,7 @@ source /usr/local/Ascend/mindie/set_env.sh ``` -2. 安装torch_npu - - 请根据[昇腾文档](https://www.hiascend.com/document/detail/zh/mindie/10RC2/mindietorch/Torchdev/mindie_torch0017.html)安装合适版本的torch_npu,并手动编译libtorch_npu_bridge.so - -3. 获取Funasr源码 +2. 获取Funasr源码 ``` git clone https://github.com/modelscope/FunASR.git @@ -59,7 +55,7 @@ cd .. ``` -4. 修改Funasr的源码,先将patch文件移动至Funasr的工程路径下,而后将patch应用到代码中(若patch应用失败,则需要手动进行修改) +3. 修改Funasr的源码,先将patch文件移动至Funasr的工程路径下,而后将patch应用到代码中(若patch应用失败,则需要手动进行修改) ``` mv mindie.patch ./FunASR cd ./FunASR @@ -67,7 +63,7 @@ cd .. ``` -5. 获取模型文件 +4. 获取模型文件 将[Paraformer](https://modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)的模型文件下载到本地,并保存在./model文件夹下 @@ -94,13 +90,17 @@ └── ... ``` -6. 安装Funasr的依赖 +5. 安装Funasr的依赖 ``` sudo apt install ffmpeg pip install jieba ``` ## 模型推理 +0. 设置mindie内存池上限为12G,执行如下命令设置环境变量 + ``` + export TORCH_AIE_NPU_CACHE_MAX_SIZE=12 + ``` 1. 模型编译及样本测试 执行下述命令进行模型编译 diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py index 5629dada48..224f4a11e1 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py @@ -17,7 +17,6 @@ import argparse import torch -import torch_npu import mindietorch from mindie_auto_model import MindieAutoModel @@ -77,7 +76,7 @@ if __name__ == "__main__": _ = model.generate(input=args.sample_path) # run with sample audio - iteration_num = 100 + iteration_num = 10 print("Begin runing with sample audio with {} iterations".format(iteration_num)) total_dict_time = {} diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py index 8251f7533f..0e33b0b2e4 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py @@ -41,7 +41,7 @@ class MindieAutoModel(AutoModel): logging.info("Building punc model.") punc_kwargs["model"] = punc_model punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master") - punc_kwargs["device"] = "npu" + punc_kwargs["device"] = "cpu" punc_kwargs["compile_type"] = "punc" punc_kwargs["compiled_punc"] = kwargs["compiled_punc"] punc_model, punc_kwargs = self.build_model_with_mindie(**punc_kwargs) @@ -217,8 +217,6 @@ class MindieAutoModel(AutoModel): excludes=kwargs.get("excludes", None), ) - model.to("npu") - return model, kwargs def inference_with_asr(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg): diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py index 71b761187d..71f18f6c03 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py @@ -79,7 +79,10 @@ class MindieCif(torch.nn.Module): def forward(self, hidden, alphas, integrate, frame): fires, frames, integrate_new, frame_new = cif(hidden, alphas, integrate, frame, self.threshold) - return fires, frames, integrate_new, frame_new + + frame = torch.index_select(frames[0, :, :], 0, torch.nonzero(fires[0, :] >= self.threshold).squeeze(1)) + + return frame, integrate_new, frame_new def export_ts(self, path="./compiled_cif.ts"): print("Begin trace cif!") diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py index 55ace9f3d4..db8e59fa41 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py @@ -6,6 +6,7 @@ import mindietorch from mindie_paraformer import precision_eval from funasr.models.bicif_paraformer.model import Paraformer +from funasr.models.transformer.utils.nets_utils import make_pad_mask_new class MindieEncoder(Paraformer): @@ -19,7 +20,16 @@ class MindieEncoder(Paraformer): def forward(self, speech, speech_length): encoder_out, encoder_out_lens = self.encode(speech, speech_length) - return encoder_out, encoder_out_lens + encoder_out_lens = encoder_out_lens.to(torch.int32) + + encoder_out_mask = ( + ~make_pad_mask_new(encoder_out_lens)[:, None, :] + ).to(encoder_out.device) + hidden, alphas, pre_token_length = ( + self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id) + ) + + return encoder_out, encoder_out_lens, hidden, alphas, pre_token_length @staticmethod def export_ts(encoder, path="./compiled_encoder.ts"): @@ -70,7 +80,14 @@ class MindieDecoder(Paraformer): decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, pre_token_length) decoder_out = decoder_outs[0] decoder_out = torch.log_softmax(decoder_out, dim=-1) - return decoder_out + + encoder_out_mask = ( + ~make_pad_mask_new(encoder_out_lens)[:, None, :] + ).to(encoder_out.device) + + us_alphas = self.predictor.get_upsample_timestamp(encoder_out, encoder_out_mask, pre_token_length) + + return decoder_out, us_alphas @staticmethod def export_ts(decoder, path="./compiled_decoder.ts"): diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py index bc4f38335b..73876484f0 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py @@ -4,7 +4,6 @@ sys.path.append("./FunASR") import copy import time import torch -import torch_npu import torch.nn.functional as F from funasr.models.bicif_paraformer.model import BiCifParaformer, load_audio_text_image_video, \ @@ -65,7 +64,6 @@ class MindieBiCifParaformer(BiCifParaformer): frontend=None, **kwargs, ): - stream = torch_npu.npu.current_stream() # Step1: load input data time1 = time.perf_counter() meta_data = {} @@ -85,27 +83,23 @@ class MindieBiCifParaformer(BiCifParaformer): audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend ) print("Input shape: ", speech.shape) - speech = speech.to("npu") - speech_lengths = speech_lengths.to("npu") meta_data["batch_data_time"] = ( speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 ) + speech = speech.to("npu") + speech_lengths = speech_lengths.to("npu") time2 = time.perf_counter() meta_data["load_data"] = time2 - time1 # Step2: run with compiled encoder - encoder_out, encoder_out_lens = self.mindie_encoder(speech, speech_lengths) - encoder_out_lens = encoder_out_lens.to(torch.int32) - - encoder_out_mask = ( - ~make_pad_mask_new(encoder_out_lens)[:, None, :] - ).to(encoder_out.device) - hidden, alphas, pre_token_length = ( - self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id) - ) + encoder_out, encoder_out_lens, hidden, alphas, pre_token_length = self.mindie_encoder(speech, speech_lengths) + + hidden = hidden.to("cpu") + alphas = alphas.to("cpu") + pre_token_length = pre_token_length.to("cpu") + pre_token_length = pre_token_length.round().to(torch.int32) - stream.synchronize() time3 = time.perf_counter() meta_data["encoder"] = time3 - time2 @@ -119,55 +113,43 @@ class MindieBiCifParaformer(BiCifParaformer): padded_hidden = F.pad(hidden, (0, 0, 0, padding_size), "constant", 0) padded_alphas = F.pad(alphas, (0, padding_size), "constant", 0) - fires_batch = [] + len_labels = torch.round(alphas.sum(-1)).int() + max_label_len = len_labels.max() + frames_batch = [] for b in range(batch_size): - fires_list = [] frames_list = [] - integrate = torch.zeros([1, ], device=hidden.device) - frame = torch.zeros([1, hidden_size], device=hidden.device) + integrate = torch.zeros([1, ]).to("npu") + frame = torch.zeros([1, hidden_size]).to("npu") for i in range(loop_num): cur_hidden = padded_hidden[b : b + 1, i * kwargs["cif_interval"] : (i + 1) * kwargs["cif_interval"], :] cur_alphas = padded_alphas[b : b + 1, i * kwargs["cif_interval"] : (i + 1) * kwargs["cif_interval"]] - cur_fires, cur_frames, integrate, frame = self.mindie_cif(cur_hidden, cur_alphas, integrate, frame) - fires_list.append(cur_fires) - frames_list.append(cur_frames) - fire = torch.cat(fires_list, 1) - frame = torch.cat(frames_list, 1) - fires_batch.append(fire) - frames_batch.append(frame) - fires = torch.cat(fires_batch, 0) - frames = torch.cat(frames_batch, 0) - - list_ls = [] - len_labels = torch.round(alphas.sum(-1)).int() - max_label_len = len_labels.max() - for b in range(batch_size): - fire = fires[b, :len_time] - l = torch.index_select(frames[b, :len_time, :], 0, torch.nonzero(fire >= self.predictor.threshold).squeeze()) - pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device) - list_ls.append(torch.cat([l, pad_l], 0)) - acoustic_embeds = torch.stack(list_ls, 0) + cur_frames, integrate, frame = self.mindie_cif(cur_hidden.to("npu"), cur_alphas.to("npu"), integrate, frame) + frames_list.append(cur_frames.to("cpu")) + frame = torch.cat(frames_list, 0) + pad_frame = torch.zeros([max_label_len - frame.size(0), hidden_size], device=hidden.device) + frames_batch.append(torch.cat([frame, pad_frame], 0)) + + acoustic_embeds = torch.stack(frames_batch, 0) token_num_int = torch.max(pre_token_length) pre_acoustic_embeds = acoustic_embeds[:, :token_num_int, :] if torch.max(pre_token_length) < 1: return [] - stream.synchronize() time4 = time.perf_counter() meta_data["calc_predictor"] = time4 - time3 # Step4: run with compiled decoder - decoder_out = self.mindie_decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length) - stream.synchronize() + decoder_out, us_alphas = self.mindie_decoder(encoder_out, encoder_out_lens, + pre_acoustic_embeds.to("npu"), pre_token_length.to("npu")) + us_alphas = us_alphas.to("cpu") time5 = time.perf_counter() meta_data["decoder"] = time5 - time4 # Step5: divide dynamic loop into multiple smaller loops for calculation # each with a number of iterations based on kwargs["cif_timestamp_interval"] - us_alphas = self.predictor.get_upsample_timestamp(encoder_out, encoder_out_mask, pre_token_length) len_alphas = us_alphas.shape[1] loop_num = len_alphas // kwargs["cif_timestamp_interval"] + 1 padding_len = loop_num * kwargs["cif_timestamp_interval"] @@ -177,25 +159,27 @@ class MindieBiCifParaformer(BiCifParaformer): peak_batch = [] for b in range(batch_size): peak_list = [] - integrate_alphas = torch.zeros([1], device=alphas.device) + integrate_alphas = torch.zeros([1]).to("npu") for i in range(loop_num): cur_alphas = padded_alphas[b:b+1, i * kwargs["cif_timestamp_interval"] : (i + 1) * kwargs["cif_timestamp_interval"]] - peak, integrate_alphas = self.mindie_cif_timestamp(cur_alphas, integrate_alphas) - peak_list.append(peak) + peak, integrate_alphas = self.mindie_cif_timestamp(cur_alphas.to("npu"), integrate_alphas) + peak_list.append(peak.to("cpu")) us_peak = torch.cat(peak_list, 1)[:, :len_alphas] peak_batch.append(us_peak) us_peaks = torch.cat(peak_batch, 0) - stream.synchronize() time6 = time.perf_counter() meta_data["calc_predictor_timestamp"] = time6 - time5 # Step6: post process + decoder_out = decoder_out.to("cpu") + us_alphas = us_alphas.to("cpu") + us_peaks = us_peaks.to("cpu") + encoder_out_lens = encoder_out_lens.to("cpu") results = [] b, n, d = decoder_out.size() for i in range(b): - x = encoder_out[i, : encoder_out_lens[i], :] am_scores = decoder_out[i, : pre_token_length[i], :] yseq = am_scores.argmax(dim=-1) @@ -244,7 +228,6 @@ class MindieBiCifParaformer(BiCifParaformer): results.append(result_i) - stream.synchronize() time7 = time.perf_counter() meta_data["post_process"] = time7 - time6 diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py index 87c7635264..4aa895f9f7 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py @@ -93,10 +93,10 @@ class MindieCTTransformer(CTTransformer): mini_sentences_id = split_to_mini_sentence(tokens_int, split_size) assert len(mini_sentences) == len(mini_sentences_id) - mini_sentences_id = [torch.unsqueeze(torch.tensor(id, dtype=torch.int32), 0).to("npu") for id in mini_sentences_id] + mini_sentences_id = [torch.unsqueeze(torch.tensor(id, dtype=torch.int32), 0) for id in mini_sentences_id] cache_sent = [] - cache_sent_id = torch.tensor([[]], dtype=torch.int32).to("npu") + cache_sent_id = torch.tensor([[]], dtype=torch.int32) new_mini_sentence = "" cache_pop_trigger_limit = 200 results = [] @@ -108,9 +108,10 @@ class MindieCTTransformer(CTTransformer): mini_sentence = cache_sent + mini_sentence mini_sentence_id = torch.cat([cache_sent_id, mini_sentence_id], dim=1) - text = mini_sentence_id - text_lengths = torch.tensor([text.shape[1], ], dtype=torch.int32, device=text.device) + text = mini_sentence_id.to("npu") + text_lengths = torch.tensor([text.shape[1], ], dtype=torch.int32).to("npu") punctuations = self.mindie_punc(text, text_lengths) + punctuations = punctuations.to("cpu") assert punctuations.size()[0] == len(mini_sentence) -- Gitee From 735c4e61be2aa5a95f8d98dfa7c5ef29ff723f5b Mon Sep 17 00:00:00 2001 From: brjiang Date: Mon, 19 Aug 2024 17:10:24 +0800 Subject: [PATCH 03/16] =?UTF-8?q?=E6=96=B0=E5=A2=9Etrace=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 22 ++- .../built-in/audio/Paraformer/compile.py | 4 +- .../audio/Paraformer/mindie_auto_model.py | 2 +- .../Paraformer/mindie_encoder_decoder.py | 27 ++-- .../audio/Paraformer/trace_decoder.py | 137 ++++++++++++++++++ 5 files changed, 177 insertions(+), 15 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/trace_decoder.py diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index ebf7dfae4a..e16b89fc5d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -97,12 +97,28 @@ ``` ## 模型推理 -0. 设置mindie内存池上限为12G,执行如下命令设置环境变量 +1. 设置mindie内存池上限为12G,执行如下命令设置环境变量 ``` export TORCH_AIE_NPU_CACHE_MAX_SIZE=12 ``` -1. 模型编译及样本测试 +2. (可选) 若CPU为aarch64的架构,则在编译decoder时会出现RuntimeError: could not create a primitive descriptor for a matmul primitive,此时需要新创建一个Python环境(推荐使用conda创建),而后使用如下命令安装torch 2.2.1 + ``` + pip install torch==2.2.1 torchvision==0.17.1 --index-url https://download.pytorch.org/whl/cpu + ``` + + 而后,执行如下脚本将decoder序列化 + ``` + python trace_decoder.py --model ./model --traced_decoder ./compiled_model/traced_decoder.pt + ``` + 参数说明: + - --model:预训练模型路径 + - --traced_decoder: 序列化后的decoder保存路径 + + 该步骤仅获得一个序列化后的decoder模型,后续模型编译仍需回到原始环境 + + +3. 模型编译及样本测试 执行下述命令进行模型编译 ```bash python compile.py \ @@ -114,6 +130,7 @@ --compiled_cif ./compiled_model/compiled_cif.pt \ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ --compiled_punc ./compiled_model/compiled_punc.ts \ + --traced_decoder ./compiled_model/traced_decoder.pt \ --batch_size 16 \ --sample_path ./model/example/asr_example.wav \ --skip_compile @@ -128,6 +145,7 @@ - --compiled_cif:编译后的cif函数的路径 - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 - --compiled_punc:编译后的punc的路径 + - --traced_decoder:预先序列化的decoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 - --batch_size:Paraformer模型所使用的batch_size - --sample_path:用于测试模型的样本音频路径 - --skip_compile:是否进行模型编译,若已经完成编译,后续使用可使用该参数跳过编译,若为第一次使用不能指定该参数 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py index 224f4a11e1..ebe823195c 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py @@ -42,6 +42,8 @@ if __name__ == "__main__": help="path to save compiled cif timestamp function") parser.add_argument("--compiled_punc", default="./compiled_model/compiled_punc.ts", help="path to save compiled punc model") + parser.add_argument("--traced_decoder", default=None, + help="path to save compiled punc model") parser.add_argument("--batch_size", default=16, type=int, help="batch size of paraformer model") parser.add_argument("--sample_path", default="./audio/test_1.wav", @@ -56,7 +58,7 @@ if __name__ == "__main__": MindieAutoModel.export_model(model=args.model_punc, compiled_path=args.compiled_punc, compile_type="punc") MindieAutoModel.export_model(model=args.model, compiled_encoder=args.compiled_encoder, compiled_decoder=args.compiled_decoder, compiled_cif=args.compiled_cif, - compiled_cif_timestamp=args.compiled_cif_timestamp, + compiled_cif_timestamp=args.compiled_cif_timestamp, traced_decoder=args.traced_decoder, cif_interval=200, cif_timestamp_interval=500, compile_type="paraformer") print("Finish compiling sub-models") else: diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py index 0e33b0b2e4..93f1a6c65d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py @@ -148,7 +148,7 @@ class MindieAutoModel(AutoModel): scope_map=kwargs.get("scope_map", []), excludes=kwargs.get("excludes", None), ) - MindieDecoder.export_ts(decoder, kwargs["compiled_decoder"]) + MindieDecoder.export_ts(decoder, kwargs["compiled_decoder"], kwargs["traced_decoder"]) # compile cif mindie_cif = MindieCif(decoder.predictor.threshold, kwargs["cif_interval"]) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py index db8e59fa41..3a9b7436c5 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py @@ -1,4 +1,5 @@ import sys +import os sys.path.append("./FunASR") import torch @@ -90,7 +91,7 @@ class MindieDecoder(Paraformer): return decoder_out, us_alphas @staticmethod - def export_ts(decoder, path="./compiled_decoder.ts"): + def export_ts(decoder, path="./compiled_decoder.ts", traced_path="./traced_decoder.ts"): print("Begin trace decoder!") input_shape1 = (2, 939, 512) @@ -101,18 +102,22 @@ class MindieDecoder(Paraformer): min_shape2 = (-1, -1, 512) max_shape2 = (-1, -1, 512) - encoder_out = torch.randn(input_shape1, dtype=torch.float32) - encoder_out_lens = torch.tensor([939, 500], dtype=torch.int32) - sematic_embeds = torch.randn(input_shape2, dtype=torch.float32) - sematic_embeds_lens = torch.tensor([261, 100], dtype=torch.int32) + if traced_path is not None and os.path.exists(traced_path): + export_model = torch.load(traced_path) + print("load existing traced_decoder") + else: + encoder_out = torch.randn(input_shape1, dtype=torch.float32) + encoder_out_lens = torch.tensor([939, 500], dtype=torch.int32) + sematic_embeds = torch.randn(input_shape2, dtype=torch.float32) + sematic_embeds_lens = torch.tensor([261, 100], dtype=torch.int32) + + export_model = torch.jit.trace(decoder, example_inputs=(encoder_out, encoder_out_lens, sematic_embeds, sematic_embeds_lens)) + print("Finish trace decoder") compile_inputs = [mindietorch.Input(min_shape = min_shape1, max_shape = max_shape1, dtype = torch.float32), - mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32), - mindietorch.Input(min_shape = min_shape2, max_shape = max_shape2, dtype = torch.float32), - mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32)] - - export_model = torch.jit.trace(decoder, example_inputs=(encoder_out, encoder_out_lens, sematic_embeds, sematic_embeds_lens)) - print("Finish trace decoder") + mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32), + mindietorch.Input(min_shape = min_shape2, max_shape = max_shape2, dtype = torch.float32), + mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32)] compiled_model = mindietorch.compile( export_model, diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/trace_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/trace_decoder.py new file mode 100644 index 0000000000..155a7912e7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/trace_decoder.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import torch + +import sys +sys.path.append("./FunASR") + +from funasr.auto.auto_model import AutoModel, download_model, tables, deep_update, \ + load_pretrained_model, prepare_data_iterator +from funasr.models.bicif_paraformer.model import Paraformer +from funasr.models.transformer.utils.nets_utils import make_pad_mask_new + + +class ParaformerDecoder(Paraformer): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward(self, encoder_out, encoder_out_lens, sematic_embeds, pre_token_length): + decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, pre_token_length) + decoder_out = decoder_outs[0] + decoder_out = torch.log_softmax(decoder_out, dim=-1) + + encoder_out_mask = ( + ~make_pad_mask_new(encoder_out_lens)[:, None, :] + ).to(encoder_out.device) + + us_alphas = self.predictor.get_upsample_timestamp(encoder_out, encoder_out_mask, pre_token_length) + + return decoder_out, us_alphas + + def trace_model(decoder, path="./traced_decoder.ts"): + print("Begin trace decoder!") + + input_shape1 = (2, 939, 512) + input_shape2 = (2, 261, 512) + + encoder_out = torch.randn(input_shape1, dtype=torch.float32) + encoder_out_lens = torch.tensor([939, 500], dtype=torch.int32) + sematic_embeds = torch.randn(input_shape2, dtype=torch.float32) + sematic_embeds_lens = torch.tensor([261, 100], dtype=torch.int32) + + trace_model = torch.jit.trace(decoder, example_inputs=(encoder_out, encoder_out_lens, sematic_embeds, sematic_embeds_lens)) + trace_model.save(path) + print("Finish trace decoder") + + +class AutoModelDecoder(AutoModel): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @staticmethod + def trace_decoder(**kwargs): + # load model config + assert "model" in kwargs + if "model_conf" not in kwargs: + print("download models from model hub: {}".format(kwargs.get("hub", "ms"))) + kwargs = download_model(**kwargs) + + kwargs["batch_size"] = 1 + kwargs["device"] = "cpu" + + # build tokenizer + tokenizer = kwargs.get("tokenizer", None) + if tokenizer is not None: + tokenizer_class = tables.tokenizer_classes.get(tokenizer) + tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {})) + kwargs["token_list"] = ( + tokenizer.token_list if hasattr(tokenizer, "token_list") else None + ) + kwargs["token_list"] = ( + tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"] + ) + vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 + if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): + vocab_size = tokenizer.get_vocab_size() + else: + vocab_size = -1 + kwargs["tokenizer"] = tokenizer + + # build frontend + frontend = kwargs.get("frontend", None) + kwargs["input_size"] = None + if frontend is not None: + frontend_class = tables.frontend_classes.get(frontend) + frontend = frontend_class(**kwargs.get("frontend_conf", {})) + kwargs["input_size"] = ( + frontend.output_size() if hasattr(frontend, "output_size") else None + ) + kwargs["frontend"] = frontend + + model_conf = {} + deep_update(model_conf, kwargs.get("model_conf", {})) + deep_update(model_conf, kwargs) + init_param = kwargs.get("init_param", None) + + decoder = ParaformerDecoder(**model_conf, vocab_size=vocab_size) + decoder.eval() + print(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=decoder, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + ParaformerDecoder.trace_model(decoder, kwargs["traced_decoder"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="./model", + help="path of pretrained model") + parser.add_argument("--traced_decoder", default="./compiled_model/traced_decoder.ts", + help="path to save compiled decoder") + args = parser.parse_args() + + AutoModelDecoder.trace_decoder(model=args.model, traced_decoder=args.traced_decoder) \ No newline at end of file -- Gitee From dcf2d3ed2e4fe182fef59a93e1db26af1e374b60 Mon Sep 17 00:00:00 2001 From: brjiang Date: Mon, 19 Aug 2024 18:03:10 +0800 Subject: [PATCH 04/16] =?UTF-8?q?=E9=80=82=E9=85=8Dencoder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 16 ++++-- .../built-in/audio/Paraformer/compile.py | 7 ++- .../audio/Paraformer/mindie_auto_model.py | 2 +- .../Paraformer/mindie_encoder_decoder.py | 20 +++++--- .../audio/Paraformer/trace_decoder.py | 51 ++++++++++++++++++- 5 files changed, 80 insertions(+), 16 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index e16b89fc5d..c4fb83e671 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -102,20 +102,24 @@ export TORCH_AIE_NPU_CACHE_MAX_SIZE=12 ``` -2. (可选) 若CPU为aarch64的架构,则在编译decoder时会出现RuntimeError: could not create a primitive descriptor for a matmul primitive,此时需要新创建一个Python环境(推荐使用conda创建),而后使用如下命令安装torch 2.2.1 +2. (可选) 若CPU为aarch64的架构,则在编译encoder和decoder时会出现RuntimeError: could not create a primitive descriptor for a matmul primitive,此时需要新创建一个Python环境(推荐使用conda创建),而后使用如下命令安装torch 2.2.1 ``` - pip install torch==2.2.1 torchvision==0.17.1 --index-url https://download.pytorch.org/whl/cpu + pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cpu ``` 而后,执行如下脚本将decoder序列化 - ``` - python trace_decoder.py --model ./model --traced_decoder ./compiled_model/traced_decoder.pt + ```bash + python trace_decoder.py \ + --model ./model \ + --traced_encoder ./compiled_model/traced_encoder.pt \ + --traced_decoder ./compiled_model/traced_decoder.pt ``` 参数说明: - --model:预训练模型路径 + - --traced_encoder: 序列化后的encoder保存路径 - --traced_decoder: 序列化后的decoder保存路径 - 该步骤仅获得一个序列化后的decoder模型,后续模型编译仍需回到原始环境 + 该步骤获得序列化后的encoder和decoder模型,后续模型编译仍需回到原始环境 3. 模型编译及样本测试 @@ -130,6 +134,7 @@ --compiled_cif ./compiled_model/compiled_cif.pt \ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ --compiled_punc ./compiled_model/compiled_punc.ts \ + --traced_encoder ./compiled_model/traced_encoder.pt \ --traced_decoder ./compiled_model/traced_decoder.pt \ --batch_size 16 \ --sample_path ./model/example/asr_example.wav \ @@ -145,6 +150,7 @@ - --compiled_cif:编译后的cif函数的路径 - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 - --compiled_punc:编译后的punc的路径 + - --traced_encoder:预先序列化的encoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 - --traced_decoder:预先序列化的decoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 - --batch_size:Paraformer模型所使用的batch_size - --sample_path:用于测试模型的样本音频路径 diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py index ebe823195c..3610458379 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py @@ -42,8 +42,10 @@ if __name__ == "__main__": help="path to save compiled cif timestamp function") parser.add_argument("--compiled_punc", default="./compiled_model/compiled_punc.ts", help="path to save compiled punc model") + parser.add_argument("--traced_encoder", default=None, + help="path to save traced encoder model") parser.add_argument("--traced_decoder", default=None, - help="path to save compiled punc model") + help="path to save traced decoder model") parser.add_argument("--batch_size", default=16, type=int, help="batch size of paraformer model") parser.add_argument("--sample_path", default="./audio/test_1.wav", @@ -58,7 +60,8 @@ if __name__ == "__main__": MindieAutoModel.export_model(model=args.model_punc, compiled_path=args.compiled_punc, compile_type="punc") MindieAutoModel.export_model(model=args.model, compiled_encoder=args.compiled_encoder, compiled_decoder=args.compiled_decoder, compiled_cif=args.compiled_cif, - compiled_cif_timestamp=args.compiled_cif_timestamp, traced_decoder=args.traced_decoder, + compiled_cif_timestamp=args.compiled_cif_timestamp, + traced_encoder=args.traced_encoder, traced_decoder=args.traced_decoder, cif_interval=200, cif_timestamp_interval=500, compile_type="paraformer") print("Finish compiling sub-models") else: diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py index 93f1a6c65d..b70627eb97 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py @@ -134,7 +134,7 @@ class MindieAutoModel(AutoModel): scope_map=kwargs.get("scope_map", []), excludes=kwargs.get("excludes", None), ) - MindieEncoder.export_ts(encoder, kwargs["compiled_encoder"]) + MindieEncoder.export_ts(encoder, kwargs["compiled_encoder"], kwargs["traced_encoder"]) # compile decoder decoder = MindieDecoder(**model_conf, vocab_size=vocab_size) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py index 3a9b7436c5..57c17f21f3 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py @@ -33,19 +33,25 @@ class MindieEncoder(Paraformer): return encoder_out, encoder_out_lens, hidden, alphas, pre_token_length @staticmethod - def export_ts(encoder, path="./compiled_encoder.ts"): + def export_ts(encoder, path="./compiled_encoder.ts", traced_path="./traced_encoder.ts"): print("Begin trace encoder!") input_shape = (2, 50, 560) min_shape = (-1, -1, 560) max_shape = (-1, -1, 560) - input_speech = torch.randn(input_shape, dtype=torch.float32) - input_speech_lens = torch.tensor([50, 25], dtype=torch.int32) - compile_inputs = [mindietorch.Input(min_shape = min_shape, max_shape = max_shape, dtype = torch.float32), - mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32)] + + + if traced_path is not None and os.path.exists(traced_path): + export_model = torch.load(traced_path) + print("load existing traced_encoder") + else: + input_speech = torch.randn(input_shape, dtype=torch.float32) + input_speech_lens = torch.tensor([50, 25], dtype=torch.int32) - export_model = torch.jit.trace(encoder, example_inputs=(input_speech, input_speech_lens)) - print("Finish trace encoder") + export_model = torch.jit.trace(encoder, example_inputs=(input_speech, input_speech_lens)) + + compile_inputs = [mindietorch.Input(min_shape = min_shape, max_shape = max_shape, dtype = torch.float32), + mindietorch.Input(min_shape = (-1, ), max_shape = (-1, ), dtype = torch.int32)] compiled_model = mindietorch.compile( export_model, diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/trace_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/trace_decoder.py index 155a7912e7..44bb0e00f8 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/trace_decoder.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/trace_decoder.py @@ -26,6 +26,40 @@ from funasr.models.bicif_paraformer.model import Paraformer from funasr.models.transformer.utils.nets_utils import make_pad_mask_new +class ParaformerEncoder(Paraformer): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + def forward(self, speech, speech_length): + encoder_out, encoder_out_lens = self.encode(speech, speech_length) + + encoder_out_lens = encoder_out_lens.to(torch.int32) + + encoder_out_mask = ( + ~make_pad_mask_new(encoder_out_lens)[:, None, :] + ).to(encoder_out.device) + hidden, alphas, pre_token_length = ( + self.predictor(encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id) + ) + + return encoder_out, encoder_out_lens, hidden, alphas, pre_token_length + + def trace_model(encoder, path="./traced_encoder.ts"): + print("Begin trace encoder!") + + input_shape = (2, 50, 560) + input_speech = torch.randn(input_shape, dtype=torch.float32) + input_speech_lens = torch.tensor([50, 25], dtype=torch.int32) + + trace_model = torch.jit.trace(encoder, example_inputs=(input_speech, input_speech_lens)) + trace_model.save(path) + print("Finish trace encoder") + + class ParaformerDecoder(Paraformer): def __init__( self, @@ -112,6 +146,19 @@ class AutoModelDecoder(AutoModel): deep_update(model_conf, kwargs) init_param = kwargs.get("init_param", None) + encoder = ParaformerEncoder(**model_conf, vocab_size=vocab_size) + encoder.eval() + print(f"Loading pretrained params from {init_param}") + load_pretrained_model( + model=encoder, + path=init_param, + ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), + oss_bucket=kwargs.get("oss_bucket", None), + scope_map=kwargs.get("scope_map", []), + excludes=kwargs.get("excludes", None), + ) + ParaformerEncoder.trace_model(encoder, kwargs["traced_encoder"]) + decoder = ParaformerDecoder(**model_conf, vocab_size=vocab_size) decoder.eval() print(f"Loading pretrained params from {init_param}") @@ -130,8 +177,10 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", default="./model", help="path of pretrained model") + parser.add_argument("--traced_encoder", default="./compiled_model/traced_encoder.ts", + help="path to save compiled decoder") parser.add_argument("--traced_decoder", default="./compiled_model/traced_decoder.ts", help="path to save compiled decoder") args = parser.parse_args() - AutoModelDecoder.trace_decoder(model=args.model, traced_decoder=args.traced_decoder) \ No newline at end of file + AutoModelDecoder.trace_decoder(model=args.model, traced_encoder=args.traced_encoder, traced_decoder=args.traced_decoder) \ No newline at end of file -- Gitee From 8507c27eb784c17b009cd64bd2f2778775b12682 Mon Sep 17 00:00:00 2001 From: brjiang Date: Tue, 20 Aug 2024 10:40:12 +0800 Subject: [PATCH 05/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=B2=BE=E5=BA=A6?= =?UTF-8?q?=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py | 4 ++-- .../built-in/audio/Paraformer/mindie_encoder_decoder.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py index 71f18f6c03..309dc4a93a 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py @@ -123,7 +123,7 @@ class MindieCif(torch.nn.Module): mrt_res = compiled_model(sample_hidden.to("npu"), sample_alphas.to("npu"), sample_integrate.to("npu"), sample_frame.to("npu")) print("mindie infer done !") - ref_res = cif(sample_hidden, sample_alphas, sample_integrate, sample_frame, self.threshold) + ref_res = self.forward(sample_hidden, sample_alphas, sample_integrate, sample_frame) print("torch infer done !") precision_eval(mrt_res, ref_res) @@ -170,7 +170,7 @@ class MindieCifTimestamp(torch.nn.Module): sample_input2 = torch.randn(input_shape2, dtype=torch.float32) mrt_res = compiled_model(sample_input1.to("npu"), sample_input2.to("npu")) print("mindie infer done !") - ref_res = cif_wo_hidden(sample_input1, sample_input2, self.threshold) + ref_res = self.forward(sample_input1, sample_input2) print("torch infer done !") precision_eval(mrt_res, ref_res) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py index 57c17f21f3..4138447068 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py @@ -69,7 +69,7 @@ class MindieEncoder(Paraformer): sample_speech_lens = torch.tensor([100, 50, 100, 25], dtype=torch.int32) mrt_res = compiled_model(sample_speech.to("npu"), sample_speech_lens.to("npu")) print("mindie infer done !") - ref_res = encoder(sample_speech, sample_speech_lens) + ref_res = export_model(sample_speech, sample_speech_lens) print("torch infer done !") precision_eval(mrt_res, ref_res) @@ -143,7 +143,7 @@ class MindieDecoder(Paraformer): sample_sematic_lens = torch.tensor([50, 30, 50, 10], dtype=torch.int32) mrt_res = compiled_model(sample_encoder.to("npu"), sample_encoder_lens.to("npu"), sample_sematic.to("npu"), sample_sematic_lens.to("npu")) print("mindie infer done !") - ref_res = decoder(sample_encoder, sample_encoder_lens, sample_sematic, sample_sematic_lens) + ref_res = export_model(sample_encoder, sample_encoder_lens, sample_sematic, sample_sematic_lens) print("torch infer done !") precision_eval(mrt_res, ref_res) \ No newline at end of file -- Gitee From 08d7d532b531eb49f5440585f3d385e807cf3615 Mon Sep 17 00:00:00 2001 From: brjiang Date: Wed, 21 Aug 2024 10:47:43 +0800 Subject: [PATCH 06/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 18 ++++++++++-------- .../audio/Paraformer/mindie_encoder_decoder.py | 12 ++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index c4fb83e671..ef40dfa10a 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -1,4 +1,4 @@ -# stable-diffusionxl-controlnet模型-推理指导 +# Paraformer模型-推理指导 - [概述](#概述) - [推理环境准备](#推理环境准备) @@ -65,13 +65,13 @@ 4. 获取模型文件 -将[Paraformer](https://modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)的模型文件下载到本地,并保存在./model文件夹下 + 将[Paraformer](https://modelscope.cn/models/iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/files)的模型文件下载到本地,并保存在./model文件夹下 -将[vad](https://modelscope.cn/models/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch/files)的模型文件下载到本地,并保存在./model_vad文件夹下 + 将[vad](https://modelscope.cn/models/iic/speech_fsmn_vad_zh-cn-16k-common-pytorch/files)的模型文件下载到本地,并保存在./model_vad文件夹下 -将[punc](https://modelscope.cn/models/iic/punc_ct-transformer_cn-en-common-vocab471067-large/files)的模型文件下载到本地,并保存在./model_punc文件夹下 + 将[punc](https://modelscope.cn/models/iic/punc_ct-transformer_cn-en-common-vocab471067-large/files)的模型文件下载到本地,并保存在./model_punc文件夹下 -目录结构如下所示 + 目录结构如下所示 ``` Paraformer @@ -93,7 +93,8 @@ 5. 安装Funasr的依赖 ``` sudo apt install ffmpeg - pip install jieba + pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip install jieba omegaconf kaldiio librosa tqdm hydra-core six attrs psutil tornado ``` ## 模型推理 @@ -102,9 +103,10 @@ export TORCH_AIE_NPU_CACHE_MAX_SIZE=12 ``` -2. (可选) 若CPU为aarch64的架构,则在编译encoder和decoder时会出现RuntimeError: could not create a primitive descriptor for a matmul primitive,此时需要新创建一个Python环境(推荐使用conda创建),而后使用如下命令安装torch 2.2.1 +2. (可选) 若CPU为aarch64的架构,则在编译encoder和decoder时会出现RuntimeError: could not create a primitive descriptor for a matmul primitive,此时需要新创建一个Python环境(推荐使用conda创建),而后使用如下命令安装torch 2.2.1及相关依赖 ``` pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cpu + pip install omegaconf kaldiio librosa tqdm hydra-core six ``` 而后,执行如下脚本将decoder序列化 @@ -123,7 +125,7 @@ 3. 模型编译及样本测试 - 执行下述命令进行模型编译 + 执行下述命令进行模型编译(如编译后的模型保存于compiled_model目录下,需要首先mkdir compiled_model) ```bash python compile.py \ --model ./model \ diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py index 4138447068..25282d55ac 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py @@ -69,10 +69,10 @@ class MindieEncoder(Paraformer): sample_speech_lens = torch.tensor([100, 50, 100, 25], dtype=torch.int32) mrt_res = compiled_model(sample_speech.to("npu"), sample_speech_lens.to("npu")) print("mindie infer done !") - ref_res = export_model(sample_speech, sample_speech_lens) - print("torch infer done !") + # ref_res = export_model(sample_speech, sample_speech_lens) + # print("torch infer done !") - precision_eval(mrt_res, ref_res) + # precision_eval(mrt_res, ref_res) class MindieDecoder(Paraformer): @@ -143,7 +143,7 @@ class MindieDecoder(Paraformer): sample_sematic_lens = torch.tensor([50, 30, 50, 10], dtype=torch.int32) mrt_res = compiled_model(sample_encoder.to("npu"), sample_encoder_lens.to("npu"), sample_sematic.to("npu"), sample_sematic_lens.to("npu")) print("mindie infer done !") - ref_res = export_model(sample_encoder, sample_encoder_lens, sample_sematic, sample_sematic_lens) - print("torch infer done !") + # ref_res = export_model(sample_encoder, sample_encoder_lens, sample_sematic, sample_sematic_lens) + # print("torch infer done !") - precision_eval(mrt_res, ref_res) \ No newline at end of file + # precision_eval(mrt_res, ref_res) \ No newline at end of file -- Gitee From 12778d9e1ff8775c05de10469ffcfeae142d7e7b Mon Sep 17 00:00:00 2001 From: brjiang Date: Tue, 27 Aug 2024 09:25:57 +0800 Subject: [PATCH 07/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=BC=96=E8=AF=91?= =?UTF-8?q?=E5=8F=8A=E6=B5=8B=E8=AF=95=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 38 +++++++-- .../built-in/audio/Paraformer/compile.py | 68 +++------------ .../built-in/audio/Paraformer/test.py | 85 +++++++++++++++++++ 3 files changed, 127 insertions(+), 64 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/test.py diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index ef40dfa10a..f5d21ad229 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -124,7 +124,7 @@ 该步骤获得序列化后的encoder和decoder模型,后续模型编译仍需回到原始环境 -3. 模型编译及样本测试 +3. 模型编译 执行下述命令进行模型编译(如编译后的模型保存于compiled_model目录下,需要首先mkdir compiled_model) ```bash python compile.py \ @@ -137,10 +137,35 @@ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ --compiled_punc ./compiled_model/compiled_punc.ts \ --traced_encoder ./compiled_model/traced_encoder.pt \ - --traced_decoder ./compiled_model/traced_decoder.pt \ + --traced_decoder ./compiled_model/traced_decoder.pt + ``` + + 参数说明: + - --model:预训练模型路径 + - --model_vad:VAD预训练模型路径,若不使用VAD模型则设置为None + - --model_punc:PUNC预训练模型路径,若不使用PUNC模型则设置为None + - --compiled_encoder:编译后的encoder模型的保存路径 + - --compiled_decoder:编译后的decoder模型的保存路径 + - --compiled_cif:编译后的cif函数的保存路径 + - --compiled_cif_timestamp:编译后的cif_timestamp函数的保存路径 + - --compiled_punc:编译后的punc的保存路径 + - --traced_encoder:预先序列化的encoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 + - --traced_decoder:预先序列化的decoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 + +4. 样本测试 + 执行下述命令进行音频样本测试 + ```bash + python test.py \ + --model ./model \ + --model_vad ./model_vad \ + --model_punc ./model_punc \ + --compiled_encoder ./compiled_model/compiled_encoder.pt \ + --compiled_decoder ./compiled_model/compiled_decoder.pt \ + --compiled_cif ./compiled_model/compiled_cif.pt \ + --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ + --compiled_punc ./compiled_model/compiled_punc.ts \ --batch_size 16 \ - --sample_path ./model/example/asr_example.wav \ - --skip_compile + --sample_path ./model/example ``` 参数说明: @@ -152,8 +177,5 @@ - --compiled_cif:编译后的cif函数的路径 - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 - --compiled_punc:编译后的punc的路径 - - --traced_encoder:预先序列化的encoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 - - --traced_decoder:预先序列化的decoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 - --batch_size:Paraformer模型所使用的batch_size - - --sample_path:用于测试模型的样本音频路径 - - --skip_compile:是否进行模型编译,若已经完成编译,后续使用可使用该参数跳过编译,若为第一次使用不能指定该参数 \ No newline at end of file + - --sample_path:测试音频的路径或所在的文件夹路径,若为文件夹路径则会遍历文件夹下的所有音频文件 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py index 3610458379..6776879ac5 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py @@ -24,21 +24,19 @@ from mindie_auto_model import MindieAutoModel if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--skip_compile", action="store_true", - help="whether to skip compiling sub-models in Paraformer") parser.add_argument("--model", default="./model", help="path of pretrained model") parser.add_argument("--model_vad", default="./model_vad", help="path of pretrained vad model") parser.add_argument("--model_punc", default="./model_punc", help="path of pretrained punc model") - parser.add_argument("--compiled_encoder", default="./compiled_model/compiled_encoder.ts", + parser.add_argument("--compiled_encoder", default="./compiled_model/compiled_encoder.pt", help="path to save compiled encoder") - parser.add_argument("--compiled_decoder", default="./compiled_model/compiled_decoder.ts", + parser.add_argument("--compiled_decoder", default="./compiled_model/compiled_decoder.pt", help="path to save compiled decoder") - parser.add_argument("--compiled_cif", default="./compiled_model/compiled_cif.ts", + parser.add_argument("--compiled_cif", default="./compiled_model/compiled_cif.pt", help="path to save compiled cif function") - parser.add_argument("--compiled_cif_timestamp", default="./compiled_model/compiled_cif_timestamp.ts", + parser.add_argument("--compiled_cif_timestamp", default="./compiled_model/compiled_cif_timestamp.pt", help="path to save compiled cif timestamp function") parser.add_argument("--compiled_punc", default="./compiled_model/compiled_punc.ts", help="path to save compiled punc model") @@ -46,58 +44,16 @@ if __name__ == "__main__": help="path to save traced encoder model") parser.add_argument("--traced_decoder", default=None, help="path to save traced decoder model") - parser.add_argument("--batch_size", default=16, type=int, - help="batch size of paraformer model") - parser.add_argument("--sample_path", default="./audio/test_1.wav", - help="path of sample audio") args = parser.parse_args() mindietorch.set_device(0) # use mindietorch to compile sub-models in Paraformer - if not args.skip_compile: - print("Begin compiling sub-models") - MindieAutoModel.export_model(model=args.model_punc, compiled_path=args.compiled_punc, compile_type="punc") - MindieAutoModel.export_model(model=args.model, compiled_encoder=args.compiled_encoder, - compiled_decoder=args.compiled_decoder, compiled_cif=args.compiled_cif, - compiled_cif_timestamp=args.compiled_cif_timestamp, - traced_encoder=args.traced_encoder, traced_decoder=args.traced_decoder, - cif_interval=200, cif_timestamp_interval=500, compile_type="paraformer") - print("Finish compiling sub-models") - else: - print("Use existing compiled model") - - # initialize auto model - # note: ncpu means the number of threads used for intraop parallelism on the CPU, which is relevant to speed of vad model - model = MindieAutoModel(model=args.model, vad_model=args.model_vad, punc_model=args.model_punc, - compiled_encoder=args.compiled_encoder, compiled_decoder=args.compiled_decoder, - compiled_cif=args.compiled_cif, compiled_cif_timestamp=args.compiled_cif_timestamp, - compiled_punc=args.compiled_punc, paraformer_batch_size=args.batch_size, - cif_interval=200, cif_timestamp_interval=500, ncpu=16) - - # warm up - print("Begin warming up") - for i in range(5): - _ = model.generate(input=args.sample_path) - - # run with sample audio - iteration_num = 10 - print("Begin runing with sample audio with {} iterations".format(iteration_num)) - - total_dict_time = {} - for i in range(iteration_num): - res, time_stats = model.generate(input=args.sample_path) - - print("Iteration {} Model output: {}".format(i, res[0]["text"])) - print("Iteration {} Time comsumption:".format(i)) - print(" ".join(f"{key}: {value:.3f}s" for key, value in time_stats.items())) - for key, value in time_stats.items(): - if key not in total_dict_time: - total_dict_time[key] = float(value) - else: - total_dict_time[key] += float(value) - - # display average time consumption - print("\nAverage time comsumption") - for key, value in total_dict_time.items(): - print(key, ": {:.3f}s".format(float(value) / iteration_num)) \ No newline at end of file + print("Begin compiling sub-models") + MindieAutoModel.export_model(model=args.model_punc, compiled_path=args.compiled_punc, compile_type="punc") + MindieAutoModel.export_model(model=args.model, compiled_encoder=args.compiled_encoder, + compiled_decoder=args.compiled_decoder, compiled_cif=args.compiled_cif, + compiled_cif_timestamp=args.compiled_cif_timestamp, + traced_encoder=args.traced_encoder, traced_decoder=args.traced_decoder, + cif_interval=200, cif_timestamp_interval=500, compile_type="paraformer") + print("Finish compiling sub-models") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test.py new file mode 100644 index 0000000000..226a2add1c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + +import mindietorch + +from mindie_auto_model import MindieAutoModel + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="./model", + help="path of pretrained model") + parser.add_argument("--model_vad", default="./model_vad", + help="path of pretrained vad model") + parser.add_argument("--model_punc", default="./model_punc", + help="path of pretrained punc model") + parser.add_argument("--compiled_encoder", default="./compiled_model/compiled_encoder.pt", + help="path to save compiled encoder") + parser.add_argument("--compiled_decoder", default="./compiled_model/compiled_decoder.pt", + help="path to save compiled decoder") + parser.add_argument("--compiled_cif", default="./compiled_model/compiled_cif.pt", + help="path to save compiled cif function") + parser.add_argument("--compiled_cif_timestamp", default="./compiled_model/compiled_cif_timestamp.pt", + help="path to save compiled cif timestamp function") + parser.add_argument("--compiled_punc", default="./compiled_model/compiled_punc.ts", + help="path to save compiled punc model") + parser.add_argument("--batch_size", default=16, type=int, + help="batch size of paraformer model") + parser.add_argument("--sample_path", default="./audio/", + help="directory or path of sample audio") + args = parser.parse_args() + + mindietorch.set_device(0) + + valid_extensions = ['.wav'] + audio_files = [] + + if os.path.isfile(args.sample_path): + if any(args.sample_path.endswith(ext) for ext in valid_extensions): + audio_files.append(args.sample_path) + elif os.path.isdir(args.sample_path): + for root, dirs, files in os.walk(args.sample_path): + for file in files: + if any(file.endswith(ext) for ext in valid_extensions): + audio_files.append(os.path.join(root, file)) + + if len(audio_files) == 0: + print("There is no valid wav file in sample_dir.") + else: + # initialize auto model + model = MindieAutoModel(model=args.model, vad_model=args.model_vad, punc_model=args.model_punc, + compiled_encoder=args.compiled_encoder, compiled_decoder=args.compiled_decoder, + compiled_cif=args.compiled_cif, compiled_cif_timestamp=args.compiled_cif_timestamp, + compiled_punc=args.compiled_punc, paraformer_batch_size=args.batch_size, + cif_interval=200, cif_timestamp_interval=500) + + # warm up + print("Begin warming up") + _ = model.generate(input=audio_files[0]) + print("Finish warming up") + + # iterate over sample_dir + for wav_file in audio_files: + print("Begin evaluating {}".format(wav_file)) + + res, time_stats = model.generate(input=wav_file) + print("Model output: {}".format(res[0]["text"])) + print("Time comsumption:") + print(" ".join(f"{key}: {value:.3f}s" for key, value in time_stats.items())) \ No newline at end of file -- Gitee From d925caa36d1f8e3ecb0ee5e40093278326cfc2c3 Mon Sep 17 00:00:00 2001 From: brjiang Date: Tue, 3 Sep 2024 10:07:14 +0800 Subject: [PATCH 08/16] =?UTF-8?q?=E6=96=B0=E5=A2=9Etest=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 63 +++++++++++++----- .../audio/Paraformer/mindie_auto_model.py | 13 +++- .../Paraformer/mindie_encoder_decoder.py | 10 ++- .../audio/Paraformer/mindie_paraformer.py | 2 - .../audio/Paraformer/test_accuracy.py | 65 +++++++++++++++++++ .../{test.py => test_performance.py} | 37 ++++++----- 6 files changed, 151 insertions(+), 39 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py rename MindIE/MindIE-Torch/built-in/audio/Paraformer/{test.py => test_performance.py} (71%) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index f5d21ad229..993740636d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -55,11 +55,10 @@ cd .. ``` -3. 修改Funasr的源码,先将patch文件移动至Funasr的工程路径下,而后将patch应用到代码中(若patch应用失败,则需要手动进行修改) +3. 修改Funasr的源码,将patch应用到代码中(若patch应用失败,则需要手动进行修改) ``` - mv mindie.patch ./FunASR cd ./FunASR - git apply mindie.patch --reject + git apply ../mindie.patch --ignore-whitespace cd .. ``` @@ -97,6 +96,9 @@ pip install jieba omegaconf kaldiio librosa tqdm hydra-core six attrs psutil tornado ``` +6. 下载测试数据集[AISHELL-1](https://www.aishelltech.com/kysjcp)并保存于任意路径 + + ## 模型推理 1. 设置mindie内存池上限为12G,执行如下命令设置环境变量 ``` @@ -125,6 +127,7 @@ 3. 模型编译 + 执行下述命令进行模型编译(如编译后的模型保存于compiled_model目录下,需要首先mkdir compiled_model) ```bash python compile.py \ @@ -152,30 +155,58 @@ - --traced_encoder:预先序列化的encoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 - --traced_decoder:预先序列化的decoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 -4. 样本测试 - 执行下述命令进行音频样本测试 - ```bash - python test.py \ +4. 性能测试 + + 执行下述命令对于Paraformer进行性能测试,同时将推理结果保存于result_path中 + ``` + python test_performance.py \ --model ./model \ - --model_vad ./model_vad \ - --model_punc ./model_punc \ --compiled_encoder ./compiled_model/compiled_encoder.pt \ --compiled_decoder ./compiled_model/compiled_decoder.pt \ --compiled_cif ./compiled_model/compiled_cif.pt \ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ - --compiled_punc ./compiled_model/compiled_punc.ts \ - --batch_size 16 \ - --sample_path ./model/example + --batch_size 64 \ + --sample_path /path/to/AISHELL-1 + --result_path ./mindie_result.txt ``` 参数说明: - --model:预训练模型路径 - - --model_vad:VAD预训练模型路径,若不使用VAD模型则设置为None - - --model_punc:PUNC预训练模型路径,若不使用PUNC模型则设置为None - --compiled_encoder:编译后的encoder模型的路径 - --compiled_decoder:编译后的decoder模型的路径 - --compiled_cif:编译后的cif函数的路径 - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 - - --compiled_punc:编译后的punc的路径 - --batch_size:Paraformer模型所使用的batch_size - - --sample_path:测试音频的路径或所在的文件夹路径,若为文件夹路径则会遍历文件夹下的所有音频文件 \ No newline at end of file + - --sample_path:AISHELL-1测试音频所在路径 + - --result_path:测试音频的推理结果的保存路径 + +5. 精度测试 + + 首先下载精度验证预训练模型[bert-base-chinese](https://huggingface.co/ckiplab/bert-base-chinese-ner/tree/main),并将预训练模型中的文件保存于任意路径 + + 而后,利用如下命令安装中文文本精度对比库bert-score + ``` + pip install bert-score + ``` + + 最后执行如下脚本进行精度验证 + ``` + python test_accuracy.py \ + --result_path ./mindie_result.txt \ + --ref_path /path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt + --bert_path /path/to/bert-base-chinese + ``` + + 参数说明: + - --result_path:测试音频的推理结果的保存路径 + - --ref_path:AISHELL-1测试音频的Ground Truth所在路径 + - --bert_path: 预训练bert-base-chinese模型的路径 + +## 模型精度及性能 + +该模型在310P3和910B4上的参考性能及精度如下所示 + +| NPU | batch_size | rtf_avg | precision | recall | f1-score | +|--------------|------------|---------|-----------|--------|----------| +| Ascend 310P3 | 16 | 217.175 | 0.9906 | 0.9907 | 0.9907 | +| Ascend 910B4 | 64 | 392.746 | 0.9906 | 0.9907 | 0.9907 | \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py index b70627eb97..ffa8fc9e86 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py @@ -4,6 +4,7 @@ sys.path.append("./FunASR") import torch import time import logging +from tqdm import tqdm from mindie_paraformer import MindieBiCifParaformer from mindie_encoder_decoder import MindieEncoder, MindieDecoder @@ -231,12 +232,14 @@ class MindieAutoModel(AutoModel): input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key ) - time_stats = {"input_speech_time": 0.0, "end_to_end_time": 0.0, "pure_infer_time": 0.0, + time_stats = {"rtf_avg": 0.0, "input_speech_time": 0.0, "end_to_end_time": 0.0, "pure_infer_time": 0.0, "load_data": 0.0, "encoder": 0.0, "predictor": 0.0, "decoder": 0.0, "predictor_timestamp": 0.0, "post_process": 0.0} asr_result_list = [] num_samples = len(data_list) + pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True) + for beg_idx in range(0, num_samples, batch_size): end_idx = min(num_samples, beg_idx + batch_size) data_batch = data_list[beg_idx:end_idx] @@ -270,7 +273,11 @@ class MindieAutoModel(AutoModel): time_stats["end_to_end_time"] += time_escape time_stats["input_speech_time"] += batch_data_time - - time_stats["pure_infer_time"] = time_stats["end_to_end_time"] - time_stats["load_data"] + + time_stats["pure_infer_time"] = time_stats["end_to_end_time"] - time_stats["load_data"] + time_stats["rtf_avg"] = time_stats["input_speech_time"] / time_stats["pure_infer_time"] + + pbar.update(batch_size) + pbar.set_description("rtf_avg:{:.3f}".format(time_stats["rtf_avg"])) return asr_result_list, time_stats \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py index 25282d55ac..91e5c53e7d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py @@ -7,7 +7,15 @@ import mindietorch from mindie_paraformer import precision_eval from funasr.models.bicif_paraformer.model import Paraformer -from funasr.models.transformer.utils.nets_utils import make_pad_mask_new + + +def make_pad_mask_new(lengths): + maxlen = lengths.max() + row_vector = torch.arange(0, maxlen, 1).to(lengths.device) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector >= matrix + mask = mask.detach() + return mask class MindieEncoder(Paraformer): diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py index 73876484f0..b47cad5b2c 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py @@ -8,7 +8,6 @@ import torch.nn.functional as F from funasr.models.bicif_paraformer.model import BiCifParaformer, load_audio_text_image_video, \ extract_fbank, Hypothesis, ts_prediction_lfr6_standard, postprocess_utils -from funasr.models.transformer.utils.nets_utils import make_pad_mask_new COSINE_THRESHOLD = 0.999 @@ -82,7 +81,6 @@ class MindieBiCifParaformer(BiCifParaformer): speech, speech_lengths = extract_fbank( audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend ) - print("Input shape: ", speech.shape) meta_data["batch_data_time"] = ( speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 ) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py new file mode 100644 index 0000000000..8db1b5ce83 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + +import torch + +from bert_score import score + + +def load_txt(file_name): + result = {} + + with open(file_name, "r") as file: + for line in file: + parts = line.strip().split(maxsplit=1) + + if len(parts) == 2: + result[parts[0]] = parts[1] + + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--result_path", default="./mindie_result.txt", type=str, + help="directory or path of sample audio") + parser.add_argument("--ref_path", default="/path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt", + type=str, help="directory or path of sample audio") + parser.add_argument("--bert_path", default="/path/to/bert-base-chinese") + args = parser.parse_args() + + infer_result = load_txt(args.result_path) + ref_result = load_txt(args.ref_path) + + infer_list = [] + refer_list = [] + for key, value in infer_result.items(): + if key in ref_result: + infer_list.append(value) + refer_list.append(ref_result[key]) + + P, R, F1 = score(infer_list, refer_list, model_type=args.bert_path, lang="zh", verbose=True, num_layers=8) + + precision = torch.mean(P) + recall = torch.mean(R) + f1_score = torch.mean(F1) + + print("Precision: ", precision.item()) + print("Recall: ", recall.item()) + print("F1 Score: ", f1_score.item()) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py similarity index 71% rename from MindIE/MindIE-Torch/built-in/audio/Paraformer/test.py rename to MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py index 226a2add1c..6da7ee9542 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py @@ -26,10 +26,6 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", default="./model", help="path of pretrained model") - parser.add_argument("--model_vad", default="./model_vad", - help="path of pretrained vad model") - parser.add_argument("--model_punc", default="./model_punc", - help="path of pretrained punc model") parser.add_argument("--compiled_encoder", default="./compiled_model/compiled_encoder.pt", help="path to save compiled encoder") parser.add_argument("--compiled_decoder", default="./compiled_model/compiled_decoder.pt", @@ -38,12 +34,12 @@ if __name__ == "__main__": help="path to save compiled cif function") parser.add_argument("--compiled_cif_timestamp", default="./compiled_model/compiled_cif_timestamp.pt", help="path to save compiled cif timestamp function") - parser.add_argument("--compiled_punc", default="./compiled_model/compiled_punc.ts", - help="path to save compiled punc model") - parser.add_argument("--batch_size", default=16, type=int, + parser.add_argument("--batch_size", default=64, type=int, help="batch size of paraformer model") - parser.add_argument("--sample_path", default="./audio/", + parser.add_argument("--sample_path", default="./audio/", type=str, help="directory or path of sample audio") + parser.add_argument("--result_path", default="./mindie_result.txt", type=str, + help="path to save infer result") args = parser.parse_args() mindietorch.set_device(0) @@ -60,26 +56,33 @@ if __name__ == "__main__": if any(file.endswith(ext) for ext in valid_extensions): audio_files.append(os.path.join(root, file)) + # filter out wav files which is smaller than 1KB + audio_files = [file for file in audio_files if os.path.getsize(file) >= 1024] + if len(audio_files) == 0: print("There is no valid wav file in sample_dir.") else: # initialize auto model - model = MindieAutoModel(model=args.model, vad_model=args.model_vad, punc_model=args.model_punc, + model = MindieAutoModel(model=args.model, compiled_encoder=args.compiled_encoder, compiled_decoder=args.compiled_decoder, compiled_cif=args.compiled_cif, compiled_cif_timestamp=args.compiled_cif_timestamp, - compiled_punc=args.compiled_punc, paraformer_batch_size=args.batch_size, + batch_size=args.batch_size, cif_interval=200, cif_timestamp_interval=500) # warm up print("Begin warming up") - _ = model.generate(input=audio_files[0]) + for i in range(3): + _ = model.inference_with_asr(input=audio_files[0]) print("Finish warming up") # iterate over sample_dir - for wav_file in audio_files: - print("Begin evaluating {}".format(wav_file)) + print("Begin evaluating") + + results, time_stats = model.inference_with_asr(input=audio_files) + print("Average RTF: {:.3f}".format(time_stats["rtf_avg"])) + print("Time comsumption:") + print(" ".join(f"{key}: {value:.3f}s" for key, value in time_stats.items() if key != "rtf_avg")) - res, time_stats = model.generate(input=wav_file) - print("Model output: {}".format(res[0]["text"])) - print("Time comsumption:") - print(" ".join(f"{key}: {value:.3f}s" for key, value in time_stats.items())) \ No newline at end of file + with open(args.result_path, "w") as f: + for res in results: + f.write("{} {}\n".format(res["key"], res["text"])) \ No newline at end of file -- Gitee From 8224dc0149adc0736928cef4279934b5d0c85285 Mon Sep 17 00:00:00 2001 From: brjiang Date: Thu, 5 Sep 2024 11:01:22 +0800 Subject: [PATCH 09/16] =?UTF-8?q?=E6=96=B0=E5=A2=9Esoc=5Fversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 4 +++- .../built-in/audio/Paraformer/compile.py | 14 +++++++++----- .../built-in/audio/Paraformer/mindie_auto_model.py | 10 +++++----- .../built-in/audio/Paraformer/mindie_cif.py | 8 ++++---- .../audio/Paraformer/mindie_encoder_decoder.py | 8 ++++---- .../built-in/audio/Paraformer/mindie_punc.py | 4 ++-- 6 files changed, 27 insertions(+), 21 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index 993740636d..72b6bf708b 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -140,7 +140,8 @@ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ --compiled_punc ./compiled_model/compiled_punc.ts \ --traced_encoder ./compiled_model/traced_encoder.pt \ - --traced_decoder ./compiled_model/traced_decoder.pt + --traced_decoder ./compiled_model/traced_decoder.pt \ + --soc_version Ascend310P3 ``` 参数说明: @@ -154,6 +155,7 @@ - --compiled_punc:编译后的punc的保存路径 - --traced_encoder:预先序列化的encoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 - --traced_decoder:预先序列化的decoder模型的路径,若并未执行第2步提前编译模型,则无需指定该参数 + - --soc_version:昇腾芯片的型号 4. 性能测试 diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py index 6776879ac5..4a6a3529e2 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/compile.py @@ -44,16 +44,20 @@ if __name__ == "__main__": help="path to save traced encoder model") parser.add_argument("--traced_decoder", default=None, help="path to save traced decoder model") + parser.add_argument("--soc_version", default="Ascend310P3", type=str, + help="soc version of Ascend") args = parser.parse_args() mindietorch.set_device(0) # use mindietorch to compile sub-models in Paraformer print("Begin compiling sub-models") - MindieAutoModel.export_model(model=args.model_punc, compiled_path=args.compiled_punc, compile_type="punc") + MindieAutoModel.export_model(model=args.model_punc, compiled_path=args.compiled_punc, + compile_type="punc", soc_version=args.soc_version) MindieAutoModel.export_model(model=args.model, compiled_encoder=args.compiled_encoder, - compiled_decoder=args.compiled_decoder, compiled_cif=args.compiled_cif, - compiled_cif_timestamp=args.compiled_cif_timestamp, - traced_encoder=args.traced_encoder, traced_decoder=args.traced_decoder, - cif_interval=200, cif_timestamp_interval=500, compile_type="paraformer") + compiled_decoder=args.compiled_decoder, compiled_cif=args.compiled_cif, + compiled_cif_timestamp=args.compiled_cif_timestamp, + traced_encoder=args.traced_encoder, traced_decoder=args.traced_decoder, + cif_interval=200, cif_timestamp_interval=500, + compile_type="paraformer", soc_version=args.soc_version) print("Finish compiling sub-models") \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py index ffa8fc9e86..fcfb347302 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_auto_model.py @@ -121,7 +121,7 @@ class MindieAutoModel(AutoModel): scope_map=kwargs.get("scope_map", []), excludes=kwargs.get("excludes", None), ) - MindiePunc.export_ts(model, kwargs["compiled_path"]) + MindiePunc.export_ts(model, kwargs["compiled_path"], kwargs["soc_version"]) else: # compile encoder encoder = MindieEncoder(**model_conf, vocab_size=vocab_size) @@ -135,7 +135,7 @@ class MindieAutoModel(AutoModel): scope_map=kwargs.get("scope_map", []), excludes=kwargs.get("excludes", None), ) - MindieEncoder.export_ts(encoder, kwargs["compiled_encoder"], kwargs["traced_encoder"]) + MindieEncoder.export_ts(encoder, kwargs["compiled_encoder"], kwargs["traced_encoder"], kwargs["soc_version"]) # compile decoder decoder = MindieDecoder(**model_conf, vocab_size=vocab_size) @@ -149,15 +149,15 @@ class MindieAutoModel(AutoModel): scope_map=kwargs.get("scope_map", []), excludes=kwargs.get("excludes", None), ) - MindieDecoder.export_ts(decoder, kwargs["compiled_decoder"], kwargs["traced_decoder"]) + MindieDecoder.export_ts(decoder, kwargs["compiled_decoder"], kwargs["traced_decoder"], kwargs["soc_version"]) # compile cif mindie_cif = MindieCif(decoder.predictor.threshold, kwargs["cif_interval"]) - mindie_cif.export_ts(kwargs["compiled_cif"]) + mindie_cif.export_ts(kwargs["compiled_cif"], kwargs["soc_version"]) # compile cif_timestamp mindie_cif_timestamp = MindieCifTimestamp(decoder.predictor.threshold - 1e-4, kwargs["cif_timestamp_interval"]) - mindie_cif_timestamp.export_ts(kwargs["compiled_cif_timestamp"]) + mindie_cif_timestamp.export_ts(kwargs["compiled_cif_timestamp"], kwargs["soc_version"]) def build_model_with_mindie(self, **kwargs): assert "model" in kwargs diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py index 309dc4a93a..39b5fb3e8b 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_cif.py @@ -84,7 +84,7 @@ class MindieCif(torch.nn.Module): return frame, integrate_new, frame_new - def export_ts(self, path="./compiled_cif.ts"): + def export_ts(self, path="./compiled_cif.ts", soc_version="Ascend310P3"): print("Begin trace cif!") input_shape1 = (1, self.seq_len, 512) @@ -108,7 +108,7 @@ class MindieCif(torch.nn.Module): export_model, inputs = compile_inputs, precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, - soc_version = "Ascend310P3", + soc_version = soc_version, ir = "ts" ) print("mindietorch compile done !") @@ -140,7 +140,7 @@ class MindieCifTimestamp(torch.nn.Module): return us_peaks, integrate_new - def export_ts(self, path="./compiled_cif_timestamp.ts"): + def export_ts(self, path="./compiled_cif_timestamp.ts", soc_version="Ascend310P3"): print("Begin trace cif_timestamp!") input_shape1 = (1, self.seq_len) @@ -158,7 +158,7 @@ class MindieCifTimestamp(torch.nn.Module): export_model, inputs = compile_inputs, precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, - soc_version = "Ascend310P3", + soc_version = soc_version, ir = "ts" ) print("mindietorch compile done !") diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py index 91e5c53e7d..0457625208 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_encoder_decoder.py @@ -41,7 +41,7 @@ class MindieEncoder(Paraformer): return encoder_out, encoder_out_lens, hidden, alphas, pre_token_length @staticmethod - def export_ts(encoder, path="./compiled_encoder.ts", traced_path="./traced_encoder.ts"): + def export_ts(encoder, path="./compiled_encoder.ts", traced_path="./traced_encoder.ts", soc_version="Ascend310P3"): print("Begin trace encoder!") input_shape = (2, 50, 560) @@ -65,7 +65,7 @@ class MindieEncoder(Paraformer): export_model, inputs = compile_inputs, precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, - soc_version = "Ascend310P3", + soc_version = soc_version, ir = "ts" ) print("mindietorch compile done !") @@ -105,7 +105,7 @@ class MindieDecoder(Paraformer): return decoder_out, us_alphas @staticmethod - def export_ts(decoder, path="./compiled_decoder.ts", traced_path="./traced_decoder.ts"): + def export_ts(decoder, path="./compiled_decoder.ts", traced_path="./traced_decoder.ts", soc_version="Ascend310P3"): print("Begin trace decoder!") input_shape1 = (2, 939, 512) @@ -137,7 +137,7 @@ class MindieDecoder(Paraformer): export_model, inputs = compile_inputs, precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, - soc_version = "Ascend310P3", + soc_version = soc_version, ir = "ts" ) print("mindietorch compile done !") diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py index 4aa895f9f7..7441bd75eb 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_punc.py @@ -26,7 +26,7 @@ class MindiePunc(CTTransformer): return punctuations @staticmethod - def export_ts(punc, path="./compiled_punc.ts"): + def export_ts(punc, path="./compiled_punc.ts", soc_version="Ascend310P3"): print("Begin trace punc!") input_shape = (1, 20) @@ -44,7 +44,7 @@ class MindiePunc(CTTransformer): export_model, inputs = compile_inputs, precision_policy = mindietorch.PrecisionPolicy.PREF_FP16, - soc_version = "Ascend310P3", + soc_version = soc_version, ir = "ts" ) print("mindietorch compile done !") -- Gitee From d287277f02b91ff85410ec90f4ea88e965e334c8 Mon Sep 17 00:00:00 2001 From: brjiang Date: Thu, 5 Sep 2024 16:37:33 +0800 Subject: [PATCH 10/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index 72b6bf708b..f99d758b54 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -111,7 +111,7 @@ pip install omegaconf kaldiio librosa tqdm hydra-core six ``` - 而后,执行如下脚本将decoder序列化 + 而后,执行如下脚本将decoder序列化(如序列化后的模型保存于compiled_model目录下,需要首先mkdir compiled_model) ```bash python trace_decoder.py \ --model ./model \ @@ -123,7 +123,7 @@ - --traced_encoder: 序列化后的encoder保存路径 - --traced_decoder: 序列化后的decoder保存路径 - 该步骤获得序列化后的encoder和decoder模型,后续模型编译仍需回到原始环境 + 该步骤获得序列化后的encoder和decoder模型,后续模型编译仍需回到pytorch2.1.0的环境中 3. 模型编译 -- Gitee From 3c22e44be02a428481269ec19795d9ad1933047a Mon Sep 17 00:00:00 2001 From: brjiang Date: Fri, 6 Sep 2024 11:19:04 +0800 Subject: [PATCH 11/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=B2=BE=E5=BA=A6?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=96=B9=E6=A1=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 38 ++++---- .../audio/Paraformer/test_accuracy.py | 87 +++++++++++++++---- .../audio/Paraformer/test_performance.py | 12 +-- 3 files changed, 95 insertions(+), 42 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index f99d758b54..c4c7e9d678 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -159,7 +159,7 @@ 4. 性能测试 - 执行下述命令对于Paraformer进行性能测试,同时将推理结果保存于result_path中 + 执行下述命令对于Paraformer进行性能测试 ``` python test_performance.py \ --model ./model \ @@ -169,7 +169,6 @@ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ --batch_size 64 \ --sample_path /path/to/AISHELL-1 - --result_path ./mindie_result.txt ``` 参数说明: @@ -179,36 +178,45 @@ - --compiled_cif:编译后的cif函数的路径 - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 - --batch_size:Paraformer模型所使用的batch_size - - --sample_path:AISHELL-1测试音频所在路径 - - --result_path:测试音频的推理结果的保存路径 + - --sample_path:待测试音频所在路径,模型会递归查找该路径下的所有音频文件 5. 精度测试 - - 首先下载精度验证预训练模型[bert-base-chinese](https://huggingface.co/ckiplab/bert-base-chinese-ner/tree/main),并将预训练模型中的文件保存于任意路径 - 而后,利用如下命令安装中文文本精度对比库bert-score + 利用如下命令安装中文文本精度对比库nltk ``` - pip install bert-score + pip install nltk ``` 最后执行如下脚本进行精度验证 ``` python test_accuracy.py \ - --result_path ./mindie_result.txt \ + --model ./model \ + --compiled_encoder ./compiled_model/compiled_encoder.pt \ + --compiled_decoder ./compiled_model/compiled_decoder.pt \ + --compiled_cif ./compiled_model/compiled_cif.pt \ + --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ + --batch_size 64 \ + --sample_path /path/to/AISHELL-1/wav/test \ + --result_path ./aishell_test_result.txt \ --ref_path /path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt - --bert_path /path/to/bert-base-chinese ``` 参数说明: + - --model:预训练模型路径 + - --compiled_encoder:编译后的encoder模型的路径 + - --compiled_decoder:编译后的decoder模型的路径 + - --compiled_cif:编译后的cif函数的路径 + - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 + - --batch_size:Paraformer模型所使用的batch_size + - --sample_path:AISHELL-1测试集音频所在路径 - --result_path:测试音频的推理结果的保存路径 - --ref_path:AISHELL-1测试音频的Ground Truth所在路径 - - --bert_path: 预训练bert-base-chinese模型的路径 ## 模型精度及性能 该模型在310P3和910B4上的参考性能及精度如下所示 -| NPU | batch_size | rtf_avg | precision | recall | f1-score | -|--------------|------------|---------|-----------|--------|----------| -| Ascend 310P3 | 16 | 217.175 | 0.9906 | 0.9907 | 0.9907 | -| Ascend 910B4 | 64 | 392.746 | 0.9906 | 0.9907 | 0.9907 | \ No newline at end of file +| NPU | batch_size | rtx_avg | cer | accuracy | +|--------------|------------|---------|---------|----------| +| Ascend 310P3 | 16 | 217.175 | 0.0198 | 0.9822 | +| Ascend 910B4 | 64 | 392.746 | 0.0198 | 0.9822 | \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py index 8db1b5ce83..a93996be58 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py @@ -16,10 +16,13 @@ import os import argparse +from tqdm import tqdm import torch +import mindietorch -from bert_score import score +from mindie_auto_model import MindieAutoModel +from nltk.metrics.distance import edit_distance def load_txt(file_name): @@ -37,29 +40,77 @@ def load_txt(file_name): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--result_path", default="./mindie_result.txt", type=str, + parser.add_argument("--model", default="./model", + help="path of pretrained model") + parser.add_argument("--compiled_encoder", default="./compiled_model/compiled_encoder.pt", + help="path to save compiled encoder") + parser.add_argument("--compiled_decoder", default="./compiled_model/compiled_decoder.pt", + help="path to save compiled decoder") + parser.add_argument("--compiled_cif", default="./compiled_model/compiled_cif.pt", + help="path to save compiled cif function") + parser.add_argument("--compiled_cif_timestamp", default="./compiled_model/compiled_cif_timestamp.pt", + help="path to save compiled cif timestamp function") + parser.add_argument("--batch_size", default=64, type=int, + help="batch size of paraformer model") + parser.add_argument("--sample_path", default="/path/to/AISHELL-1/wav/test", type=str, help="directory or path of sample audio") + parser.add_argument("--result_path", default="./aishell_test_result.txt", type=str, + help="path to save infer result") parser.add_argument("--ref_path", default="/path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt", type=str, help="directory or path of sample audio") - parser.add_argument("--bert_path", default="/path/to/bert-base-chinese") args = parser.parse_args() - infer_result = load_txt(args.result_path) - ref_result = load_txt(args.ref_path) + valid_extensions = ['.wav'] + audio_files = [] - infer_list = [] - refer_list = [] - for key, value in infer_result.items(): - if key in ref_result: - infer_list.append(value) - refer_list.append(ref_result[key]) + if os.path.isfile(args.sample_path): + if any(args.sample_path.endswith(ext) for ext in valid_extensions): + audio_files.append(args.sample_path) + elif os.path.isdir(args.sample_path): + for root, dirs, files in os.walk(args.sample_path): + for file in files: + if any(file.endswith(ext) for ext in valid_extensions): + audio_files.append(os.path.join(root, file)) - P, R, F1 = score(infer_list, refer_list, model_type=args.bert_path, lang="zh", verbose=True, num_layers=8) + # filter out wav files which is smaller than 1KB + audio_files = [file for file in audio_files if os.path.getsize(file) >= 1024] - precision = torch.mean(P) - recall = torch.mean(R) - f1_score = torch.mean(F1) + if len(audio_files) == 0: + print("There is no valid wav file in sample_dir.") + else: + # initialize auto model + model = MindieAutoModel(model=args.model, + compiled_encoder=args.compiled_encoder, compiled_decoder=args.compiled_decoder, + compiled_cif=args.compiled_cif, compiled_cif_timestamp=args.compiled_cif_timestamp, + batch_size=args.batch_size, + cif_interval=200, cif_timestamp_interval=500) - print("Precision: ", precision.item()) - print("Recall: ", recall.item()) - print("F1 Score: ", f1_score.item()) \ No newline at end of file + # iterate over sample_dir + print("Begin evaluating") + results, time_stats = model.inference_with_asr(input=audio_files) + + with open(args.result_path, "w") as f: + for res in results: + f.write("{} {}\n".format(res["key"], res["text"])) + + infer_result = load_txt(args.result_path) + ref_result = load_txt(args.ref_path) + + infer_list = [] + refer_list = [] + for key, value in infer_result.items(): + if key in ref_result: + infer_list.append(value.replace(" ", "")) + refer_list.append(ref_result[key].replace(" ", "")) + + cer_total = 0 + step = 0 + for infer, refer in tqdm(zip(infer_list, refer_list)): + infer = [i for i in infer] + refer = [r for r in refer] + cer_total += edit_distance(infer, refer) / len(refer) + step += 1 + + cer = cer_total / step + accuracy = 1 - cer + print("character-errer-rate: {:.4f}, accuracy: {:.4f}".format(cer, accuracy)) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py index 6da7ee9542..ba97dc1db8 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py @@ -36,10 +36,8 @@ if __name__ == "__main__": help="path to save compiled cif timestamp function") parser.add_argument("--batch_size", default=64, type=int, help="batch size of paraformer model") - parser.add_argument("--sample_path", default="./audio/", type=str, + parser.add_argument("--sample_path", default="/path/to/AISHELL-1", type=str, help="directory or path of sample audio") - parser.add_argument("--result_path", default="./mindie_result.txt", type=str, - help="path to save infer result") args = parser.parse_args() mindietorch.set_device(0) @@ -79,10 +77,6 @@ if __name__ == "__main__": print("Begin evaluating") results, time_stats = model.inference_with_asr(input=audio_files) - print("Average RTF: {:.3f}".format(time_stats["rtf_avg"])) + print("Average RTX: {:.3f}".format(time_stats["rtf_avg"])) print("Time comsumption:") - print(" ".join(f"{key}: {value:.3f}s" for key, value in time_stats.items() if key != "rtf_avg")) - - with open(args.result_path, "w") as f: - for res in results: - f.write("{} {}\n".format(res["key"], res["text"])) \ No newline at end of file + print(" ".join(f"{key}: {value:.3f}s" for key, value in time_stats.items() if key != "rtf_avg")) \ No newline at end of file -- Gitee From b9661347353df0000c7bc02e443969b66a5502b5 Mon Sep 17 00:00:00 2001 From: brjiang Date: Fri, 6 Sep 2024 16:23:20 +0800 Subject: [PATCH 12/16] =?UTF-8?q?=E6=96=B0=E5=A2=9Epta=E7=9A=84=E7=B2=BE?= =?UTF-8?q?=E5=BA=A6=E6=AF=94=E5=AF=B9=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 40 +++++-- .../audio/Paraformer/test_accuracy_pta.py | 111 ++++++++++++++++++ 2 files changed, 144 insertions(+), 7 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index c4c7e9d678..d6007ea81d 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -168,7 +168,7 @@ --compiled_cif ./compiled_model/compiled_cif.pt \ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ --batch_size 64 \ - --sample_path /path/to/AISHELL-1 + --sample_path /path/to/AISHELL-1/wav/test ``` 参数说明: @@ -178,7 +178,7 @@ - --compiled_cif:编译后的cif函数的路径 - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 - --batch_size:Paraformer模型所使用的batch_size - - --sample_path:待测试音频所在路径,模型会递归查找该路径下的所有音频文件 + - --sample_path:AISHELL-1测试集音频所在路径,模型会递归查找该路径下的所有音频文件 5. 精度测试 @@ -212,11 +212,37 @@ - --result_path:测试音频的推理结果的保存路径 - --ref_path:AISHELL-1测试音频的Ground Truth所在路径 +6. 基线精度测试(torch_npu) + 首先执行如下命令将FunASR的修改回退(若后续需要使用mindie进行编译推理,则需要参考获取源码的第3条重新对源码进行修改) + ``` + cd ./FunASR + git checkout -- . + cd .. + ``` + + 而后,安装合适版本的torch_npu,并使用如下命令获得torch_npu在相同数据集上的精度(建议在Atlas 800I A2环境下执行,Atlas 310I pro环境执行时间过长) + ``` + python test_accuracy_pta.py \ + --model ./model \ + --batch_size 64 \ + --sample_path /path/to/AISHELL-1/wav/test \ + --result_path ./aishell_test_result_pta.txt \ + --ref_path /path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt + ``` + + 参数说明: + - --model:预训练模型路径 + - --batch_size:Paraformer模型所使用的batch_size + - --sample_path:AISHELL-1测试集音频所在路径 + - --result_path:测试音频的推理结果的保存路径 + - --ref_path:AISHELL-1测试音频的Ground Truth所在路径 + + ## 模型精度及性能 -该模型在310P3和910B4上的参考性能及精度如下所示 +该模型在Atlas 310I pro和Atlas 800I A2上的参考性能及精度如下所示 -| NPU | batch_size | rtx_avg | cer | accuracy | -|--------------|------------|---------|---------|----------| -| Ascend 310P3 | 16 | 217.175 | 0.0198 | 0.9822 | -| Ascend 910B4 | 64 | 392.746 | 0.0198 | 0.9822 | \ No newline at end of file +| NPU | batch_size | rtx_avg | cer | cer(torch_npu) | +|----------------|------------|---------|---------|----------------| +| Atlas 310I pro | 16 | 217.175 | 0.0198 | \ | +| Atlas 800I A2 | 64 | 392.746 | 0.0198 | 0.0198 | \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py new file mode 100644 index 0000000000..61a4fba976 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("./FunASR") + +import os +import argparse +from tqdm import tqdm + +import torch +import torch_npu + +from funasr.auto.auto_model import AutoModel +from nltk.metrics.distance import edit_distance + + +def load_txt(file_name): + result = {} + + with open(file_name, "r") as file: + for line in file: + parts = line.strip().split(maxsplit=1) + + if len(parts) == 2: + result[parts[0]] = parts[1] + + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="./model", + help="path of pretrained model") + parser.add_argument("--batch_size", default=64, type=int, + help="batch size of paraformer model") + parser.add_argument("--sample_path", default="/path/to/AISHELL-1/wav/test", type=str, + help="directory or path of sample audio") + parser.add_argument("--result_path", default="./aishell_test_result_pta.txt", type=str, + help="path to save infer result") + parser.add_argument("--ref_path", default="/path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt", + type=str, help="directory or path of sample audio") + args = parser.parse_args() + + valid_extensions = ['.wav'] + audio_files = [] + + if os.path.isfile(args.sample_path): + if any(args.sample_path.endswith(ext) for ext in valid_extensions): + audio_files.append(args.sample_path) + elif os.path.isdir(args.sample_path): + for root, dirs, files in os.walk(args.sample_path): + for file in files: + if any(file.endswith(ext) for ext in valid_extensions): + audio_files.append(os.path.join(root, file)) + + # filter out wav files which is smaller than 1KB + audio_files = [file for file in audio_files if os.path.getsize(file) >= 1024] + + if len(audio_files) == 0: + print("There is no valid wav file in sample_dir.") + else: + # initialize auto model + model = AutoModel(model=args.model) + + model.model.to("npu") + model.kwargs["device"] = "npu" + model.kwargs["batch_size"] = args.batch_size + + # iterate over sample_dir + print("Begin evaluating") + results, time_stats = model.inference_with_asr(input=audio_files) + + with open(args.result_path, "w") as f: + for res in results: + f.write("{} {}\n".format(res["key"], res["text"])) + + infer_result = load_txt(args.result_path) + ref_result = load_txt(args.ref_path) + + infer_list = [] + refer_list = [] + for key, value in infer_result.items(): + if key in ref_result: + infer_list.append(value.replace(" ", "")) + refer_list.append(ref_result[key].replace(" ", "")) + + cer_total = 0 + step = 0 + for infer, refer in tqdm(zip(infer_list, refer_list)): + infer = [i for i in infer] + refer = [r for r in refer] + cer_total += edit_distance(infer, refer) / len(refer) + step += 1 + + cer = cer_total / step + accuracy = 1 - cer + print("character-errer-rate: {:.4f}, accuracy: {:.4f}".format(cer, accuracy)) \ No newline at end of file -- Gitee From 1433094d4139523c4dcc959eae9b41d09cfa2270 Mon Sep 17 00:00:00 2001 From: brjiang Date: Mon, 9 Sep 2024 15:07:20 +0800 Subject: [PATCH 13/16] =?UTF-8?q?=E5=8A=A0=E5=85=A5torch=5Fnpu?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Paraformer/README.md | 8 ++++++++ .../audio/Paraformer/mindie_paraformer.py | 20 +++++++++---------- .../audio/Paraformer/test_accuracy.py | 8 ++++++++ .../audio/Paraformer/test_performance.py | 9 +++++++++ 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index d6007ea81d..b34eb54c37 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -98,10 +98,13 @@ 6. 下载测试数据集[AISHELL-1](https://www.aishelltech.com/kysjcp)并保存于任意路径 +7. 安装合适版本的torch_npu,同时参考[昇腾文档](https://www.hiascend.com/document/detail/zh/mindie/10RC2/mindietorch/Torchdev/mindie_torch0017.html)兼容mindie和torch_npu + ## 模型推理 1. 设置mindie内存池上限为12G,执行如下命令设置环境变量 ``` + export INF_NAN_MODE_ENABLE=0 export TORCH_AIE_NPU_CACHE_MAX_SIZE=12 ``` @@ -169,6 +172,7 @@ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ --batch_size 64 \ --sample_path /path/to/AISHELL-1/wav/test + --soc_version Ascend310P3 ``` 参数说明: @@ -179,6 +183,8 @@ - --compiled_cif_timestamp:编译后的cif_timestamp函数的路径 - --batch_size:Paraformer模型所使用的batch_size - --sample_path:AISHELL-1测试集音频所在路径,模型会递归查找该路径下的所有音频文件 + - --soc_version:昇腾芯片的型号 + 5. 精度测试 @@ -199,6 +205,7 @@ --sample_path /path/to/AISHELL-1/wav/test \ --result_path ./aishell_test_result.txt \ --ref_path /path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt + --soc_version Ascend310P3 ``` 参数说明: @@ -211,6 +218,7 @@ - --sample_path:AISHELL-1测试集音频所在路径 - --result_path:测试音频的推理结果的保存路径 - --ref_path:AISHELL-1测试音频的Ground Truth所在路径 + - --soc_version:昇腾芯片的型号 6. 基线精度测试(torch_npu) 首先执行如下命令将FunASR的修改回退(若后续需要使用mindie进行编译推理,则需要参考获取源码的第3条重新对源码进行修改) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py index b47cad5b2c..c090bc5b66 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/mindie_paraformer.py @@ -93,9 +93,9 @@ class MindieBiCifParaformer(BiCifParaformer): # Step2: run with compiled encoder encoder_out, encoder_out_lens, hidden, alphas, pre_token_length = self.mindie_encoder(speech, speech_lengths) - hidden = hidden.to("cpu") - alphas = alphas.to("cpu") - pre_token_length = pre_token_length.to("cpu") + hidden = hidden.to(kwargs["mindie_device"]) + alphas = alphas.to(kwargs["mindie_device"]) + pre_token_length = pre_token_length.to(kwargs["mindie_device"]) pre_token_length = pre_token_length.round().to(torch.int32) time3 = time.perf_counter() @@ -123,7 +123,7 @@ class MindieBiCifParaformer(BiCifParaformer): cur_hidden = padded_hidden[b : b + 1, i * kwargs["cif_interval"] : (i + 1) * kwargs["cif_interval"], :] cur_alphas = padded_alphas[b : b + 1, i * kwargs["cif_interval"] : (i + 1) * kwargs["cif_interval"]] cur_frames, integrate, frame = self.mindie_cif(cur_hidden.to("npu"), cur_alphas.to("npu"), integrate, frame) - frames_list.append(cur_frames.to("cpu")) + frames_list.append(cur_frames.to(kwargs["mindie_device"])) frame = torch.cat(frames_list, 0) pad_frame = torch.zeros([max_label_len - frame.size(0), hidden_size], device=hidden.device) frames_batch.append(torch.cat([frame, pad_frame], 0)) @@ -141,7 +141,7 @@ class MindieBiCifParaformer(BiCifParaformer): # Step4: run with compiled decoder decoder_out, us_alphas = self.mindie_decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds.to("npu"), pre_token_length.to("npu")) - us_alphas = us_alphas.to("cpu") + us_alphas = us_alphas.to(kwargs["mindie_device"]) time5 = time.perf_counter() meta_data["decoder"] = time5 - time4 @@ -161,7 +161,7 @@ class MindieBiCifParaformer(BiCifParaformer): for i in range(loop_num): cur_alphas = padded_alphas[b:b+1, i * kwargs["cif_timestamp_interval"] : (i + 1) * kwargs["cif_timestamp_interval"]] peak, integrate_alphas = self.mindie_cif_timestamp(cur_alphas.to("npu"), integrate_alphas) - peak_list.append(peak.to("cpu")) + peak_list.append(peak.to(kwargs["mindie_device"])) us_peak = torch.cat(peak_list, 1)[:, :len_alphas] peak_batch.append(us_peak) us_peaks = torch.cat(peak_batch, 0) @@ -171,10 +171,10 @@ class MindieBiCifParaformer(BiCifParaformer): # Step6: post process - decoder_out = decoder_out.to("cpu") - us_alphas = us_alphas.to("cpu") - us_peaks = us_peaks.to("cpu") - encoder_out_lens = encoder_out_lens.to("cpu") + decoder_out = decoder_out.to(kwargs["mindie_device"]) + us_alphas = us_alphas.to(kwargs["mindie_device"]) + us_peaks = us_peaks.to(kwargs["mindie_device"]) + encoder_out_lens = encoder_out_lens.to(kwargs["mindie_device"]) results = [] b, n, d = decoder_out.size() for i in range(b): diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py index a93996be58..f6e616c48a 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy.py @@ -19,6 +19,7 @@ import argparse from tqdm import tqdm import torch +import torch_npu import mindietorch from mindie_auto_model import MindieAutoModel @@ -58,6 +59,8 @@ if __name__ == "__main__": help="path to save infer result") parser.add_argument("--ref_path", default="/path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt", type=str, help="directory or path of sample audio") + parser.add_argument("--soc_version", default="Ascend310P3", type=str, + help="soc version of Ascend") args = parser.parse_args() valid_extensions = ['.wav'] @@ -85,6 +88,11 @@ if __name__ == "__main__": batch_size=args.batch_size, cif_interval=200, cif_timestamp_interval=500) + if args.soc_version == "Ascend310P3": + model.kwargs["mindie_device"] = "cpu" + else: + model.kwargs["mindie_device"] = "npu" + # iterate over sample_dir print("Begin evaluating") results, time_stats = model.inference_with_asr(input=audio_files) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py index ba97dc1db8..a8ddd718ed 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_performance.py @@ -17,6 +17,8 @@ import os import argparse +import torch +import torch_npu import mindietorch from mindie_auto_model import MindieAutoModel @@ -38,6 +40,8 @@ if __name__ == "__main__": help="batch size of paraformer model") parser.add_argument("--sample_path", default="/path/to/AISHELL-1", type=str, help="directory or path of sample audio") + parser.add_argument("--soc_version", default="Ascend310P3", type=str, + help="soc version of Ascend") args = parser.parse_args() mindietorch.set_device(0) @@ -67,6 +71,11 @@ if __name__ == "__main__": batch_size=args.batch_size, cif_interval=200, cif_timestamp_interval=500) + if args.soc_version == "Ascend310P3": + model.kwargs["mindie_device"] = "cpu" + else: + model.kwargs["mindie_device"] = "npu" + # warm up print("Begin warming up") for i in range(3): -- Gitee From b764c8b578e37685f0fd3a12765ffe4d61aa6547 Mon Sep 17 00:00:00 2001 From: brjiang Date: Mon, 9 Sep 2024 15:18:17 +0800 Subject: [PATCH 14/16] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=80=A7=E8=83=BD?= =?UTF-8?q?=E6=8C=87=E6=A0=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index b34eb54c37..3d89fb12f3 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -253,4 +253,4 @@ | NPU | batch_size | rtx_avg | cer | cer(torch_npu) | |----------------|------------|---------|---------|----------------| | Atlas 310I pro | 16 | 217.175 | 0.0198 | \ | -| Atlas 800I A2 | 64 | 392.746 | 0.0198 | 0.0198 | \ No newline at end of file +| Atlas 800I A2 | 64 | 461.775 | 0.0198 | 0.0198 | \ No newline at end of file -- Gitee From 7ba7bbfa0342d2c7d0a46b6bef2df2f0d0ba8fe0 Mon Sep 17 00:00:00 2001 From: brjiang Date: Fri, 13 Sep 2024 10:14:58 +0800 Subject: [PATCH 15/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8DPTA=E5=9F=BA=E7=BA=BF?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=E5=AD=98=E5=9C=A8=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py index 61a4fba976..fde7a1fb99 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/test_accuracy_pta.py @@ -82,7 +82,7 @@ if __name__ == "__main__": # iterate over sample_dir print("Begin evaluating") - results, time_stats = model.inference_with_asr(input=audio_files) + results = model.inference(input=audio_files) with open(args.result_path, "w") as f: for res in results: -- Gitee From d53c5beee4e02f1e2d43c13c891226887003c31c Mon Sep 17 00:00:00 2001 From: brjiang Date: Sat, 12 Oct 2024 17:03:45 +0800 Subject: [PATCH 16/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md index 3d89fb12f3..bd0e21af97 100644 --- a/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md +++ b/MindIE/MindIE-Torch/built-in/audio/Paraformer/README.md @@ -31,8 +31,8 @@ | Python | 3.10.13 | - | | torch | 2.1.0+cpu | - | | torch_audio | 2.1.0+cpu | - | - | CANN | 8.0.RC2 | - | - | MindIE | 1.0.RC2.B091 | - | + | CANN | 8.0.RC3 | - | + | MindIE | 1.0.RC3 | - | # 快速上手 ## 获取源码 @@ -171,7 +171,7 @@ --compiled_cif ./compiled_model/compiled_cif.pt \ --compiled_cif_timestamp ./compiled_model/compiled_cif_timestamp.pt \ --batch_size 64 \ - --sample_path /path/to/AISHELL-1/wav/test + --sample_path /path/to/AISHELL-1/wav/test \ --soc_version Ascend310P3 ``` @@ -204,7 +204,7 @@ --batch_size 64 \ --sample_path /path/to/AISHELL-1/wav/test \ --result_path ./aishell_test_result.txt \ - --ref_path /path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt + --ref_path /path/to/AISHELL-1/transcript/aishell_transcript_v0.8.txt \ --soc_version Ascend310P3 ``` -- Gitee