diff --git a/ACL_PyTorch/contrib/audio/WeNet/.keep b/ACL_PyTorch/contrib/audio/WeNet/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ACL_PyTorch/contrib/audio/WeNet/README.md b/ACL_PyTorch/contrib/audio/WeNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..65868f7b0201f5800287f5728658ac962352e941 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/README.md @@ -0,0 +1,84 @@ +# Wenet模型PyTorch离线推理指导 + +## 1 环境准备 + +1. 安装必要的依赖,测试环境可能已经安装其中的一些不同版本的库了,故手动测试时不推荐使用该命令安装 + +``` +pip3 install -r requirements.txt +``` + +2. 获取,修改与安装开源模型代码 + +``` +git clone https://github.com/wenet-e2e/wenet.git +cd wenet +git reset 9c4e305bcc24a06932f6a65c8147429d8406cc63 --hard +``` + +3. 下载网络权重文件并导出onnx + +下载链接:http://mobvoi-speech-public.ufile.ucloud.cn/public/wenet/aishell/20210601_u2pp_conformer_exp.tar.gz下载压缩文件,将文件解压,将文件夹内的文件放置到wenet/examples/aishell/s0/exp/conformer_u2文件夹下,若没有该文件夹,则创建该文件夹 + +将提供的diff文件放到wenet根目录下 +patch -p1 < wenet_onnx.diff文件适配导出onnx的代码 +将提供的export_onnx.py、static.py文件放到wenet/wenet/bin/目录下 +将提供的slice_helper.py, acl_net.py文件放到wenet/wenet/transformer文件夹下 +将提供的sh脚本文件(除了static_encoder.sh和static_decoder.sh)放到wenet/examples/aishell/s0/目录下 +运行bash export_onnx.sh exp/conformer_u2/train.yaml exp/conformer_u2/final.pt导出onnx文件在当前目录下的onnx文件夹下 + +4. 修改onnx + +首先使用改图工具om_gener改图,该工具链接为https://gitee.com/liurf_hw/om_gener,安装之后使用以下命令修改脚本, + +python3 adaptdecoder.py生成decoder_final.onnx + +python3 adaptnoflashencoder.py生成no_flash_encoder_revise.onnx + +5. 数据集下载: + + 在wenet/examples/aishell/s0/文件夹下运行bash run.sh --stage -1 –stop_stage -1下载数据集 + + 运行bash run.sh --stage 0 --stop_stage 0处理数据集 + + 运行bash run.sh --stage 1 --stop_stage 1处理数据集 + + 运行bash run.sh --stage 2 --stop_stage 2处理数据集 + + 运行bash run.sh --stage 3 --stop_stage 3处理数据集 + +## 2 离线推理 + +onnx转om: +将static_encoder.sh和static_decoder.sh放到s0/onnx文件下 + +``` +bash static_encoder.sh +bash static_decoder.sh +``` + +精度测试: + +首先export ASCEND_GLOBAL_LOG_LEVEL=3,指定acc.diff中self.encoder_ascend, self.decoder_ascend加载的文件为静态转出的encoder,decoder模型路径 + +``` +git checkout . +patch -p1 < acc.diff +cd examples/aishell/s0/ +bash static.sh +``` + +性能:在wenet/examples/aishell/s0/exp/conformer/test_attention_rescoring/text文件最后一行有fps性能数据 + + +gpu性能测试: + +``` +git checkout . +patch -p1 < gpu_fps.diff +cd examples/aishell/s0/ + +// 运行之前修改run.sh中average_checkpoint为false, decode_modes修改为attention_rescoring 修改stage5中ctc_weight=0.3 reverse_weight=0.3 +bash run.sh --stage 5 --stop_stage 5 +``` +在wenet/examples/aishell/s0/exp/conformer/test_attention_rescoring/text文件最后一行有fps性能数据 \ No newline at end of file diff --git a/ACL_PyTorch/contrib/audio/WeNet/acc.diff b/ACL_PyTorch/contrib/audio/WeNet/acc.diff new file mode 100644 index 0000000000000000000000000000000000000000..b4d2d5ea3c3acf37998d2a0c8bd940ac21416ccf --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/acc.diff @@ -0,0 +1,358 @@ +diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py +index 4f0ff39..4ce97a4 100644 +--- a/wenet/dataset/dataset.py ++++ b/wenet/dataset/dataset.py +@@ -27,7 +27,7 @@ import torchaudio.sox_effects as sox_effects + import yaml + from PIL import Image + from PIL.Image import BICUBIC +-from torch.nn.utils.rnn import pad_sequence ++#from torch.nn.utils.rnn import pad_sequence + from torch.utils.data import Dataset, DataLoader + + import wenet.dataset.kaldi_io as kaldi_io +@@ -36,7 +36,69 @@ from wenet.utils.common import IGNORE_ID + + torchaudio.set_audio_backend("sox_io") + ++def _pad_sequence(sequences, batch_first=False, padding_value=0, mul_shape = None): ++ r"""Pad a list of variable length Tensors with ``padding_value`` ++ ++ ``pad_sequence`` stacks a list of Tensors along a new dimension, ++ and pads them to equal length. For example, if the input is list of ++ sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` ++ otherwise. ++ ++ `B` is batch size. It is equal to the number of elements in ``sequences``. ++ `T` is length of the longest sequence. ++ `L` is length of the sequence. ++ `*` is any number of trailing dimensions, including none. ++ ++ Example: ++ >>> from torch.nn.utils.rnn import pad_sequence ++ >>> a = torch.ones(25, 300) ++ >>> b = torch.ones(22, 300) ++ >>> c = torch.ones(15, 300) ++ >>> pad_sequence([a, b, c]).size() ++ torch.Size([25, 3, 300]) ++ ++ Note: ++ This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` ++ where `T` is the length of the longest sequence. This function assumes ++ trailing dimensions and type of all the Tensors in sequences are same. ++ ++ Arguments: ++ sequences (list[Tensor]): list of variable length sequences. ++ batch_first (bool, optional): output will be in ``B x T x *`` if True, or in ++ ``T x B x *`` otherwise ++ padding_value (float, optional): value for padded elements. Default: 0. + ++ Returns: ++ Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. ++ Tensor of size ``B x T x *`` otherwise ++ """ ++ ++ # assuming trailing dimensions and type of all the Tensors ++ # in sequences are same and fetching those from sequences[0] ++ ++ max_size = sequences[0].size() ++ trailing_dims = max_size[1:] ++ ++ max_len = max([s.size(0) for s in sequences]) ++ if mul_shape is not None: ++ for in_shape in mul_shape: ++ if max_len < in_shape: ++ max_len = in_shape ++ break ++ if batch_first: ++ out_dims = (len(sequences), max_len) + trailing_dims ++ else: ++ out_dims = (max_len, len(sequences)) + trailing_dims ++ ++ out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) ++ for i, tensor in enumerate(sequences): ++ length = tensor.size(0) ++ # use index notation to prevent duplicate references to the tensor ++ if batch_first: ++ out_tensor[i, :length, ...] = tensor ++ else: ++ out_tensor[:length, i, ...] = tensor ++ return out_tensor + def _spec_augmentation(x, + warp_for_time=False, + num_t_mask=2, +@@ -187,6 +249,7 @@ def _extract_feature(batch, speed_perturb, wav_distortion_conf, + Returns: + (keys, feats, labels) + """ ++ + keys = [] + feats = [] + lengths = [] +@@ -331,13 +394,14 @@ class CollateFunc(object): + self.spec_sub = spec_sub + self.spec_sub_conf = spec_sub_conf + ++ ++ + def __call__(self, batch): + assert (len(batch) == 1) + if self.raw_wav: + keys, xs, ys = _extract_feature(batch[0], self.speed_perturb, + self.wav_distortion_conf, + self.feature_extraction_conf) +- + else: + keys, xs, ys = _load_feature(batch[0]) + +@@ -359,27 +423,31 @@ class CollateFunc(object): + if self.spec_aug: + xs = [_spec_augmentation(x, **self.spec_aug_conf) for x in xs] + +- # padding +- xs_lengths = torch.from_numpy( +- np.array([x.shape[0] for x in xs], dtype=np.int32)) ++ + + # pad_sequence will FAIL in case xs is empty ++ mul_shape = [262, 326, 390, 454, 518, 582, 646, 710, 774, 838, 902, 966, 1028, 1284, 1478] + if len(xs) > 0: +- xs_pad = pad_sequence([torch.from_numpy(x).float() for x in xs], +- True, 0) ++ xs_pad = _pad_sequence([torch.from_numpy(x).float() for x in xs], ++ True, 0, mul_shape) + else: + xs_pad = torch.Tensor(xs) ++ # padding ++ xs_lengths = torch.from_numpy( ++ np.array([x.shape[0] for x in xs_pad], dtype=np.int32)) ++ + if train_flag: + ys_lengths = torch.from_numpy( + np.array([y.shape[0] for y in ys], dtype=np.int32)) + if len(ys) > 0: +- ys_pad = pad_sequence([torch.from_numpy(y).int() for y in ys], ++ ys_pad = _pad_sequence([torch.from_numpy(y).int() for y in ys], + True, IGNORE_ID) + else: + ys_pad = torch.Tensor(ys) + else: + ys_pad = None + ys_lengths = None ++ + return keys, xs_pad, ys_pad, xs_lengths, ys_lengths + + +@@ -430,7 +498,6 @@ class AudioDataset(Dataset): + """ + assert batch_type in ['static', 'dynamic'] + data = [] +- + # Open in utf8 mode since meet encoding problem + with codecs.open(data_file, 'r', encoding='utf-8') as f: + for line in f: +diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py +index 73990fa..50358ca 100644 +--- a/wenet/transformer/asr_model.py ++++ b/wenet/transformer/asr_model.py +@@ -32,8 +32,74 @@ from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add, + reverse_pad_list) + from wenet.utils.mask import (make_pad_mask, mask_finished_preds, + mask_finished_scores, subsequent_mask) ++from wenet.transformer.acl_net import Net ++import time ++import acl ++ ++def _pad_sequence(sequences, batch_first=False, padding_value=0, mul_shape = None): ++ r"""Pad a list of variable length Tensors with ``padding_value`` ++ ++ ``pad_sequence`` stacks a list of Tensors along a new dimension, ++ and pads them to equal length. For example, if the input is list of ++ sequences with size ``L x *`` and if batch_first is False, and ``T x B x *`` ++ otherwise. ++ ++ `B` is batch size. It is equal to the number of elements in ``sequences``. ++ `T` is length of the longest sequence. ++ `L` is length of the sequence. ++ `*` is any number of trailing dimensions, including none. ++ ++ Example: ++ >>> from torch.nn.utils.rnn import pad_sequence ++ >>> a = torch.ones(25, 300) ++ >>> b = torch.ones(22, 300) ++ >>> c = torch.ones(15, 300) ++ >>> pad_sequence([a, b, c]).size() ++ torch.Size([25, 3, 300]) ++ ++ Note: ++ This function returns a Tensor of size ``T x B x *`` or ``B x T x *`` ++ where `T` is the length of the longest sequence. This function assumes ++ trailing dimensions and type of all the Tensors in sequences are same. ++ ++ Arguments: ++ sequences (list[Tensor]): list of variable length sequences. ++ batch_first (bool, optional): output will be in ``B x T x *`` if True, or in ++ ``T x B x *`` otherwise ++ padding_value (float, optional): value for padded elements. Default: 0. ++ ++ Returns: ++ Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``. ++ Tensor of size ``B x T x *`` otherwise ++ """ ++ ++ # assuming trailing dimensions and type of all the Tensors ++ # in sequences are same and fetching those from sequences[0] ++ ++ max_size = sequences[0].size() ++ trailing_dims = max_size[1:] ++ ++ max_len = max([s.size(0) for s in sequences]) ++ if mul_shape is not None: ++ for in_shape in mul_shape: ++ if max_len < in_shape: ++ max_len = in_shape ++ break + +- ++ if batch_first: ++ out_dims = (len(sequences), max_len) + trailing_dims ++ else: ++ out_dims = (max_len, len(sequences)) + trailing_dims ++ ++ out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) ++ for i, tensor in enumerate(sequences): ++ length = tensor.size(0) ++ # use index notation to prevent duplicate references to the tensor ++ if batch_first: ++ out_tensor[i, :length, ...] = tensor ++ else: ++ out_tensor[:length, i, ...] = tensor ++ return out_tensor + class ASRModel(torch.nn.Module): + """CTC-attention hybrid Encoder-Decoder model""" + def __init__( +@@ -60,6 +126,13 @@ class ASRModel(torch.nn.Module): + self.reverse_weight = reverse_weight + + self.encoder = encoder ++ self.device_id = 0 ++ ret = acl.init() ++ ret = acl.rt.set_device(self.device_id) ++ context, ret = acl.rt.create_context(self.device_id) ++ self.encoder_ascend = Net(model_path="/home/zry2/wenet/examples/aishell/s0/onnx/encoder_fendang_262_1478_static.om", device_id=self.device_id) ++ self.decoder_ascend = Net(model_path="/home/zry2/wenet/examples/aishell/s0/onnx/decoder_fendang.om", device_id=self.device_id) ++ self.encoder_out_shape = [] + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( +@@ -168,13 +241,21 @@ class ASRModel(torch.nn.Module): + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: +- encoder_out, encoder_mask = self.encoder( +- speech, +- speech_lengths, +- decoding_chunk_size=decoding_chunk_size, +- num_decoding_left_chunks=num_decoding_left_chunks +- ) # (B, maxlen, encoder_dim) +- return encoder_out, encoder_mask ++ st = time.time() ++ ++ # encoder_out, encoder_mask = self.encoder( ++ # speech, ++ # speech_lengths, ++ # decoding_chunk_size=decoding_chunk_size, ++ # num_decoding_left_chunks=num_decoding_left_chunks ++ # ) # (B, maxlen, encoder_dim) ++ speech = speech.numpy() ++ speech_lengths = speech_lengths.numpy().astype("int32") ++ dims1 = {'dimCount': 4, 'name': '', 'dims': [1, speech.shape[1], 80, 1]} ++ y, exe_time = self.encoder_ascend([speech, speech_lengths], dims = dims1) ++ encoder_out = torch.from_numpy(y[0]) ++ encoder_mask = torch.from_numpy(y[1]) ++ return encoder_out, encoder_mask, exe_time + + def recognize( + self, +@@ -361,13 +442,17 @@ class ASRModel(torch.nn.Module): + assert batch_size == 1 + # Let's assume B = batch_size and N = beam_size + # 1. Encoder forward and get CTC score +- encoder_out, encoder_mask = self._forward_encoder( ++ encoder_out, encoder_mask, encoder_t = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) ++ mul_shape = [96, 144, 384] ++ ++ encoder_out = _pad_sequence(encoder_out, True, 0, mul_shape) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) ++ + ctc_probs = ctc_probs.squeeze(0) + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + cur_hyps = [(tuple(), (0.0, -float('inf')))] +@@ -409,7 +494,7 @@ class ASRModel(torch.nn.Module): + reverse=True) + cur_hyps = next_hyps[:beam_size] + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] +- return hyps, encoder_out ++ return hyps, encoder_out, encoder_t + + def ctc_prefix_beam_search( + self, +@@ -485,7 +570,7 @@ class ASRModel(torch.nn.Module): + # For attention rescoring we only support batch_size=1 + assert batch_size == 1 + # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size +- hyps, encoder_out = self._ctc_prefix_beam_search( ++ hyps, encoder_out, encoder_t = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + +@@ -510,9 +595,19 @@ class ASRModel(torch.nn.Module): + r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) + r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, + self.ignore_id) +- decoder_out, r_decoder_out, _ = self.decoder( +- encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, +- reverse_weight) # (beam_size, max_hyps_len, vocab_size) ++ ++ encoder_out = encoder_out.numpy() ++ encoder_mask = encoder_mask.numpy() ++ hyps_pad = hyps_pad.numpy() ++ hyps_lens = hyps_lens.numpy().astype("int32") ++ r_hyps_pad = r_hyps_pad.numpy() ++ dims2 = {'dimCount': 11, 'name': '', 'dims': [10, encoder_out.shape[1], 256, 10, 1, encoder_out.shape[1], 10, r_hyps_pad.shape[1], 10, 10, r_hyps_pad.shape[1]]} ++ ++ y, exe_time = self.decoder_ascend([encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad], dims=dims2) ++ batch_t = encoder_t + exe_time ++ decoder_out = torch.from_numpy(y[0]) ++ r_decoder_out = torch.from_numpy(y[1]) ++ + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = decoder_out.cpu().numpy() + # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a +@@ -539,7 +634,7 @@ class ASRModel(torch.nn.Module): + if score > best_score: + best_score = score + best_index = i +- return hyps[best_index][0] ++ return hyps[best_index][0], batch_t + + @torch.jit.export + def subsampling_rate(self) -> int: +diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py +index e342ed4..c8e18d5 100644 +--- a/wenet/transformer/encoder.py ++++ b/wenet/transformer/encoder.py +@@ -157,6 +157,8 @@ class BaseEncoder(torch.nn.Module): + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) ++ ++ + for layer in self.encoders: + xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: diff --git a/ACL_PyTorch/contrib/audio/WeNet/acl_net.py b/ACL_PyTorch/contrib/audio/WeNet/acl_net.py new file mode 100644 index 0000000000000000000000000000000000000000..df83b075a48af642bd682a0c1c1228350d95cb55 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/acl_net.py @@ -0,0 +1,392 @@ +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + + +import numpy as np +import acl +import functools +import time +import os + +# error code +ACL_ERROR_NONE = 0 + +# memory malloc code +ACL_MEM_MALLOC_HUGE_FIRST = 0 +ACL_MEM_MALLOC_HUGE_ONLY = 1 +ACL_MEM_MALLOC_NORMAL_ONLY = 2 + +# memory copy code +ACL_MEMCPY_HOST_TO_HOST = 0 +ACL_MEMCPY_HOST_TO_DEVICE = 1 +ACL_MEMCPY_DEVICE_TO_HOST = 2 +ACL_MEMCPY_DEVICE_TO_DEVICE = 3 + +# format +ACL_FORMAT_NCHW = 0 +ACL_DTYPE = { + 0: 'float32', + 1: 'float16', + 2: 'int8', + 3: 'int32', + 4: 'uint8', + 6: 'int16', + 7: 'uint16', + 8: 'uint32', + 9: 'int64', + 10: 'uint64', + 11: 'float64', + 12: 'bool', +} + +ACL_DTYPE_INDEX = { + 'float32': 0, + 'float16': 1, + 'int8': 2, + 'int32': 3, + 'uint8': 4, + 'int16': 6, + 'uint16': 7, + 'uint32': 8, + 'int64': 9, + 'uint64': 10, + 'float64': 11, + 'bool': 12, +} + + +def check_ret(message, ret): + if ret != ACL_ERROR_NONE: + raise Exception("{} failed ret = {}".format(message, ret)) + + +def check_input_type(input_type, model_input_type): + for i in range(len(input_type)): + if ACL_DTYPE_INDEX.get(input_type[i]) != model_input_type[i]: + raise Exception("real input {} input_type:{} model_input_type:{} not same".format(i, input_type[i], \ + ACL_DTYPE.get(model_input_type[i]))) + + +class Net(object): + def __init__(self, model_path, device_id, check_input=False, output_data_shape=None): + self.check_input = check_input + self.dynamic = False + self.device_id = device_id + self.model_path = os.getcwd()+'/onnx/'+model_path.split('/')[-1] + self.model_id = None + # if self.ascend_mbatch_shape_data = True, the model is static with multi input shape + self.ascend_mbatch_shape_data = False + self.input_data_type = [] + self.model_input_data_type = [] + self.model_input_data_format = [] + self.model_output_data_type = [] + self.output_data_shape = output_data_shape + self.output_shape = [] + self.buffer_method = { + "in": acl.mdl.get_input_size_by_index, + "out": acl.mdl.get_output_size_by_index, + "outhost": acl.mdl.get_output_size_by_index + } + + self.input_data = [] + self.output_data = [] + self.output_data_host = [] + self.model_desc = None + self.load_input_dataset = None + self.load_output_dataset = None + self.input_size = None + self.output_size = None + self.exe_t = 0 + self._init_resource() + + def __call__(self, ori_data, dims=None): + return self.forward(ori_data, dims) + + def __del__(self): + ret = acl.mdl.unload(self.model_id) + check_ret("acl.mdl.unload", ret) + if self.model_desc: + acl.mdl.destroy_desc(self.model_desc) + self.model_desc = None + if not self.dynamic: + self._release_data_buffer() + + def _release_data_buffer(self): + while self.input_data: + item = self.input_data.pop() + ret = acl.rt.free(item["buffer"]) + check_ret("acl.rt.free", ret) + + while self.output_data: + item = self.output_data.pop() + ret = acl.rt.free(item["buffer"]) + check_ret("acl.rt.free", ret) + + while self.output_data_host: + item = self.output_data_host.pop() + ret = acl.rt.free_host(item["buffer"]) + check_ret("acl.rt.free_host", ret) + + def _init_resource(self): + # load_model + self.model_id, ret = acl.mdl.load_from_file(self.model_path) + check_ret("acl.mdl.load_from_file", ret) + + self.model_desc = acl.mdl.create_desc() + self._get_model_info() + + def _get_model_info(self): + ret = acl.mdl.get_desc(self.model_desc, self.model_id) + check_ret("acl.mdl.get_desc", ret) + self.input_size = acl.mdl.get_num_inputs(self.model_desc) + # get the input format, data_type and get the model static or not + for i in range(self.input_size): + data_type = acl.mdl.get_input_data_type(self.model_desc, i) + self.model_input_data_type.append(data_type) + data_format = acl.mdl.get_input_format(self.model_desc, i) + self.model_input_data_format.append(data_format) + dims_input, ret = acl.mdl.get_input_dims(self.model_desc, i) + # check if the model has ascend_mbatch_shape_data + if i == self.input_size - 1 and dims_input["name"] == "ascend_mbatch_shape_data": + self.dynamic = False + self.ascend_mbatch_shape_data = True + elif -1 in dims_input["dims"]: + self.dynamic = True + self.output_size = acl.mdl.get_num_outputs(self.model_desc) + for j in range(self.output_size): + data_type = acl.mdl.get_output_data_type(self.model_desc, j) + self.model_output_data_type.append(data_type) + dims_output, ret = acl.mdl.get_output_dims(self.model_desc, j) + if -1 in dims_output["dims"]: + self.dynamic = True + if self.output_data_shape is None and self.dynamic: + self.output_data_shape = 500000000 + if not self.dynamic: + self._prepare_data_buffer_in() + self._prepare_data_buffer_out() + self._prepare_data_buffer_host() + + def _gen_data_buffer(self, size, des, data=None): + func = self.buffer_method[des] + for i in range(size): + if not self.dynamic: + temp_buffer_size = func(self.model_desc, i) + else: + if des == "in": + input_size = np.prod(np.array(data[i]).shape) + temp_buffer_size = Net.gen_data_size(input_size, dtype=ACL_DTYPE.get(self.model_input_data_type[i])) + elif des == "out": + temp_buffer_size = Net.gen_data_size(data, dtype=ACL_DTYPE.get(self.model_output_data_type[i])) + + temp_buffer, ret = acl.rt.malloc(temp_buffer_size, ACL_MEM_MALLOC_HUGE_FIRST) + check_ret("acl.rt.malloc", ret) + acl.rt.memset(temp_buffer, temp_buffer_size, 0, temp_buffer_size) + if des == "in": + self.input_data.append({"buffer": temp_buffer, + "size": temp_buffer_size}) + elif des == "out": + self.output_data.append({"buffer": temp_buffer, + "size": temp_buffer_size}) + + def _gen_dataset_output_host(self, size, des, data=None): + func = self.buffer_method[des] + for i in range(size): + if not self.dynamic: + temp_buffer_size = func(self.model_desc, i) + else: + temp_buffer_size = Net.gen_data_size(data, ACL_DTYPE.get(self.model_output_data_type[i])) + temp_buffer, ret = acl.rt.malloc_host(temp_buffer_size) + check_ret("acl.rt.malloc_host", ret) + + self.output_data_host.append({"buffer": temp_buffer, + "size": temp_buffer_size}) + + def _data_interaction(self, dataset, policy=ACL_MEMCPY_HOST_TO_DEVICE): + + temp_data_buffer = self.input_data \ + if policy == ACL_MEMCPY_HOST_TO_DEVICE \ + else self.output_data + if len(dataset) == 0 and policy == ACL_MEMCPY_DEVICE_TO_HOST: + dataset = self.output_data_host + for i in range(len(dataset)): + if policy == ACL_MEMCPY_HOST_TO_DEVICE: + ptr = acl.util.numpy_to_ptr(dataset[i]) + if self.ascend_mbatch_shape_data: + malloc_size = dataset[i].size * dataset[i].itemsize + else: + malloc_size = temp_data_buffer[i]["size"] + ret = acl.rt.memcpy(temp_data_buffer[i]["buffer"], malloc_size, ptr, malloc_size, policy) + check_ret("acl.rt.memcpy", ret) + + else: + ptr = dataset[i]["buffer"] + ret = acl.rt.memcpy(ptr, temp_data_buffer[i]["size"], temp_data_buffer[i]["buffer"], + temp_data_buffer[i]["size"], policy) + check_ret("acl.rt.memcpy", ret) + + def _gen_dataset(self, type_str="input", input_shapes=None): + dataset = acl.mdl.create_dataset() + temp_dataset = None + if type_str == "in": + self.load_input_dataset = dataset + temp_dataset = self.input_data + + else: + self.load_output_dataset = dataset + temp_dataset = self.output_data + + for i, item in enumerate(temp_dataset): + data = acl.create_data_buffer(item["buffer"], item["size"]) + if data is None: + ret = acl.destroy_data_buffer(dataset) + check_ret("acl.destroy_data_buffer", ret) + + _, ret = acl.mdl.add_dataset_buffer(dataset, data) + if ret != ACL_ERROR_NONE: + ret = acl.destroy_data_buffer(dataset) + check_ret("acl.destroy_data_buffer", ret) + + if type_str == "in" and not self.ascend_mbatch_shape_data: + # set dynamic dataset tensor desc + input_shape = input_shapes[i] + input_desc = acl.create_tensor_desc(self.model_input_data_type[i], input_shape, + self.model_input_data_format[i]) + dataset, ret = acl.mdl.set_dataset_tensor_desc(dataset, input_desc, i) + if ret != ACL_ERROR_NONE: + ret = acl.destroy_data_buffer(dataset) + check_ret("acl.destroy_data_buffer", ret) + + def _data_from_host_to_device(self, images): + self._data_interaction(images, ACL_MEMCPY_HOST_TO_DEVICE) + input_shapes = [list(data.shape) for data in images] + self._gen_dataset("in", input_shapes) + self._gen_dataset("out") + + def _data_from_device_to_host(self, input_data, output_shape): + res = [] + self._data_interaction(res, ACL_MEMCPY_DEVICE_TO_HOST) + output = self.get_result(self.output_data_host, input_data, output_shape) + return output + + def _get_output_shape(self): + output_shape = [] + num = acl.mdl.get_dataset_num_buffers(self.load_output_dataset) + for output_index in range(num): + if self.dynamic: + outpu_desc = acl.mdl.get_dataset_tensor_desc(self.load_output_dataset, output_index) + temp_output_shape = [] + dim_nums = acl.get_tensor_desc_num_dims(outpu_desc) + for i in range(dim_nums): + dim, ret = acl.get_tensor_desc_dim_v2(outpu_desc, i) + temp_output_shape.append(dim) + output_shape.append(temp_output_shape) + else: + dims, ret = acl.mdl.get_cur_output_dims(self.model_desc, output_index) + data_shape = dims.get("dims") + output_shape.append(data_shape) + + return output_shape + + def _destroy_databuffer(self): + for dataset in [self.load_input_dataset, self.load_output_dataset]: + if not dataset: + continue + + num = acl.mdl.get_dataset_num_buffers(dataset) + for i in range(num): + data_buf = acl.mdl.get_dataset_buffer(dataset, i) + if data_buf: + ret = acl.destroy_data_buffer(data_buf) + check_ret("acl.destroy_data_buffer", ret) + ret = acl.mdl.destroy_dataset(dataset) + check_ret("acl.mdl.destroy_dataset", ret) + + def _prepare_data_buffer_in(self, input_data=None): + self._gen_data_buffer(self.input_size, des="in", data=input_data) + + def _prepare_data_buffer_out(self, input_data=None): + self._gen_data_buffer(self.output_size, des="out", data=input_data) + + def _prepare_data_buffer_host(self, input_data=None): + self._gen_dataset_output_host(self.output_size, des="outhost", data=input_data) + + def forward(self, input_data, dims=None): + if not isinstance(input_data, (list, tuple)): + input_data = [input_data] + if self.check_input: + self.input_data_type = [] + for data in input_data: + self.input_data_type.append(str(data.dtype)) + check_input_type(self.input_data_type, self.model_input_data_type) + if self.dynamic: + self._prepare_data_buffer_in(input_data) + self._prepare_data_buffer_out(self.output_data_shape) + self._prepare_data_buffer_host(self.output_data_shape) + self._data_from_host_to_device(input_data) + + if self.ascend_mbatch_shape_data: + if dims is None: + raise Exception("the model is static multi shape model, dims can not be None") + index, ret = acl.mdl.get_input_index_by_name(self.model_desc, 'ascend_mbatch_shape_data') + ret = acl.mdl.set_input_dynamic_dims(self.model_id, self.load_input_dataset, index, dims) + check_ret("acl.mdl.set_input_dynamic_dims", ret) + st = time.time() + ret = acl.mdl.execute(self.model_id, self.load_input_dataset, self.load_output_dataset) + self.exe_t = time.time() - st + check_ret("acl.mdl.execute", ret) + # get output shape + output_shape = self._get_output_shape() + self._destroy_databuffer() + result = self._data_from_device_to_host(input_data=input_data, output_shape=output_shape) + if self.dynamic: + self._release_data_buffer() + return result + + def get_result(self, output_data, data, output_shape): + dataset = [] + for i in range(len(output_data)): + # fix dynamic batch size + data_type = acl.mdl.get_output_data_type(self.model_desc, i) + data_len = functools.reduce(lambda x, y: x * y, output_shape[i]) + ftype = np.dtype(ACL_DTYPE.get(data_type)) + size = output_data[i]["size"] + ptr = output_data[i]["buffer"] + data = acl.util.ptr_to_numpy(ptr, (size,), 1) + np_array = np.frombuffer(bytearray(data[:data_len * ftype.itemsize]), dtype=ftype, count=data_len) + np_array = np_array.reshape(output_shape[i]) + dataset.append(np_array) + return dataset, self.exe_t * 1000 + + @staticmethod + def gen_data_size(size, dtype): + dtype = np.dtype(dtype) + return int(size * dtype.itemsize) diff --git a/ACL_PyTorch/contrib/audio/WeNet/adaptdecoder.py b/ACL_PyTorch/contrib/audio/WeNet/adaptdecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff39c7bbb54fa0cb981bf35eaa97ee9083519ea --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/adaptdecoder.py @@ -0,0 +1,91 @@ +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ +from gener_core.mod_modify.onnx_graph import OXGraph +from gener_core.mod_modify.onnx_node import OXNode +from gener_core.mod_modify.interface import AttrType as AT + +mod = OXGraph("decoder.onnx") +Expand_lists = mod.get_nodes_by_optype("Expand") +for i in range(len(Expand_lists)): + now_expand = mod.get_node(Expand_lists[i]) + cast_node = mod.add_new_node(now_expand.name + "_cast", "Cast", + {"to": (AT.INT, 6) + }) + Expand_first_input_now = mod.get_node(now_expand.input_name[0]) + now_expand.set_input_node(0, [cast_node]) + cast_node.set_input_node(0, [Expand_first_input_now]) + +Less_lists = mod.get_nodes_by_optype("Less") +for i in range(len(Less_lists)): + now_expand = mod.get_node(Less_lists[i]) + cast_node = mod.add_new_node(now_expand.name + "_cast", "Cast", + {"to": (AT.INT, 6) + }) + Expand_second_input_now = mod.get_node(now_expand.input_name[1]) + now_expand.set_input_node(1, [cast_node]) + cast_node.set_input_node(0, [Expand_second_input_now]) + +Greater_lists = mod.get_nodes_by_optype("Greater") +for greater_node in Greater_lists: + now_expand = mod.get_node(greater_node) + cast_node = mod.add_new_node(now_expand.name + "_cast", "Cast", + {"to": (AT.INT, 6) + }) + Expand_second_input_now = mod.get_node(now_expand.input_name[1]) + now_expand.set_input_node(1, [cast_node]) + cast_node.set_input_node(0, [Expand_second_input_now]) + +not_change_cast = [] +Range_lists = mod.get_nodes_by_optype("Range") +for range_node in Range_lists: + now_expand = mod.get_node(range_node) + Expand_first_input_now = mod.get_node(now_expand.input_name[1]) + not_change_cast.append(Expand_first_input_now.name) + +to = 6 +Cast = mod.get_nodes_by_optype("Cast") +for cast_node in Cast: + now_Cast = mod.get_node(cast_node) + if now_Cast.get_attr("to", AT.INT) == 7 and now_Cast.name not in not_change_cast: + now_Cast.set_attr({"to": (AT.INT, to)}) + +Equal = mod.get_nodes_by_optype("Equal") +for equal_node in Equal: + now_equal = mod.get_node(equal_node) + now_ends = mod.get_node(now_equal.input_name[1]) + if now_ends.op_type in ("Initializer", "Constant") and now_ends.const_value.dtype == "int64": + print("now_ends.dtype:", now_ends.const_value.dtype) + val = now_ends.const_value.astype("int32") + now_ends.set_const_value(val) + +mod.save_new_model("decoder_final.onnx") + diff --git a/ACL_PyTorch/contrib/audio/WeNet/adaptnoflashencoder.py b/ACL_PyTorch/contrib/audio/WeNet/adaptnoflashencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1675b1d6d6ce9327a4066297817b1914f47610 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/adaptnoflashencoder.py @@ -0,0 +1,81 @@ +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + +from gener_core.mod_modify.onnx_graph import OXGraph +from gener_core.mod_modify.onnx_node import OXNode +from gener_core.mod_modify.interface import AttrType as AT +import numpy as np + +mod = OXGraph("no_flash_encoder.onnx") +Expand_lists = mod.get_nodes_by_optype("Less") +for i in range(len(Expand_lists)): + now_expand = mod.get_node(Expand_lists[i]) + cast_node = mod.add_new_node(now_expand.name + "_cast", "Cast", + {"to": (AT.INT, 6) + }) + Expand_first_input_now = mod.get_node(now_expand.input_name[1]) + now_expand.set_input_node(1, [cast_node]) + cast_node.set_input_node(0, [Expand_first_input_now]) + +Equal = mod.get_nodes_by_optype("Equal") +for equal_node in Equal: + now_equal = mod.get_node(equal_node) + now_ends = mod.get_node(now_equal.input_name[1]) + if now_ends.op_type in ("Initializer", "Constant") and now_ends.const_value.dtype == "int64": + print("now_ends.dtype:", now_ends.const_value.dtype) + val = now_ends.const_value.astype("int32") + now_ends.set_const_value(val) + +Expand_lists = ["Expand_20"] +for expand_node in Expand_lists: + now_expand = mod.get_node(expand_node) + cast_node = mod.add_new_node(now_expand.name + "_cast", "Cast", + {"to": (AT.INT, 6) + }) + Expand_first_input_now = mod.get_node(now_expand.input_name[0]) + now_expand.set_input_node(0, [cast_node]) + cast_node.set_input_node(0, [Expand_first_input_now]) + +not_change_cast = [] +Range_lists = mod.get_nodes_by_optype("Range") +for range_node in Range_lists: + now_expand = mod.get_node(range_node) + Expand_first_input_now = mod.get_node(now_expand.input_name[1]) + not_change_cast.append(Expand_first_input_now.name) + +to = 6 +Cast = mod.get_nodes_by_optype("Cast") +for i in range(len(Cast)): + now_Cast = mod.get_node(Cast[i]) + if now_Cast.get_attr("to", AT.INT) == 7 and now_Cast.name not in not_change_cast: + now_Cast.set_attr({"to": (AT.INT, to)}) +mod.save_new_model("no_flash_encoder_revise.onnx") diff --git a/ACL_PyTorch/contrib/audio/WeNet/env.sh b/ACL_PyTorch/contrib/audio/WeNet/env.sh new file mode 100644 index 0000000000000000000000000000000000000000..33243beff29297566d7412a08452ad6657a8d6fa --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/env.sh @@ -0,0 +1,9 @@ +export install_path=/usr/local/Ascend/ascend-toolkit/latest +export PATH=${install_path}/atc/bin:${install_path}/bin:${install_path}/atc/ccec_compiler/bin:$PATH +export LD_LIBRARY_PATH=${install_path}/lib64:${install_path}/atc/lib64:${install_path}/acllib/lib64:${install_path}/compiler/lib64/plugin/opskernel:${install_path}/compiler/lib64/plugin/nnengine:$LD_LIBRARY_PATH +export PYTHONPATH=${install_path}/latest/python/site-packages:${install_path}/opp/op_impl/built-in/ai_core/tbe:${install_path}/atc/python/site-packages:${install_path}/pyACL/python/site-packages/acl:$PYTHONPATH +export ASCEND_AICPU_PATH=${install_path} +export ASCEND_OPP_PATH=${install_path}/opp +export TOOLCHAIN_HOME=${install_path}/toolkit +export ASCEND_AUTOML_PATH=${install_path}/tools +export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/:${LD_LIBRARY_PATH} diff --git a/ACL_PyTorch/contrib/audio/WeNet/export_onnx.py b/ACL_PyTorch/contrib/audio/WeNet/export_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..b4cff965da7a4ed61a199a10ebfa256a56a6909a --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/export_onnx.py @@ -0,0 +1,139 @@ +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + + +from __future__ import print_function + +import argparse +import os + +import torch +import onnx, onnxruntime +import yaml +import numpy as np + +from wenet.transformer.asr_model import init_asr_model +from wenet.transformer.decoder import TransformerDecoder, BiTransformerDecoder +from wenet.utils.checkpoint import load_checkpoint + + +def to_numpy(xx): + return xx.detach().numpy() + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='export your script model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--output_onnx_file', required=True, help='output onnx file') + args = parser.parse_args() + # No need gpu for model export + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + model = init_asr_model(configs) + print(model) + + load_checkpoint(model, args.checkpoint) + # Export jit torch script model + + model.eval() + + #export the none flash model + encoder = model.encoder + xs = torch.randn(1, 131, 80, requires_grad=False) + xs_lens = torch.tensor([131], dtype=torch.int32) + onnx_encoder_path = os.path.join(args.output_onnx_file, 'no_flash_encoder.onnx') + torch.onnx.export(encoder, + (xs, xs_lens), + onnx_encoder_path, + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names=['xs_input', 'xs_input_lens'], + output_names=['xs_output', 'masks_output'], + dynamic_axes={'xs_input': [1], 'xs_input_lens': [0], + 'xs_output': [1], 'masks_output': [2]}, + verbose=True + ) + onnx_model = onnx.load(onnx_encoder_path) + onnx.checker.check_model(onnx_model) + print("encoder onnx_model check pass!") + + ort_session = onnxruntime.InferenceSession(onnx_encoder_path) + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(xs), + ort_session.get_inputs()[1].name: to_numpy(xs_lens), + } + ort_outs = ort_session.run(None, ort_inputs) + y1, y2 = encoder(xs, xs_lens) + print("Exported no flash encoder model has been tested with ONNXRuntime, and the result looks good!") + + + #export decoder onnx + + decoder = model.decoder + decoder.set_onnx_mode(True) + onnx_decoder_path = os.path.join(args.output_onnx_file, 'decoder.onnx') + memory = torch.randn(10, 131, 256) + memory_mask = torch.ones(10, 1, 131).bool() + ys_in_pad = torch.randint(0, 4232, (10, 50)).long() + ys_in_lens = torch.tensor([13, 13, 13, 13, 13, 13, 13, 13, 50, 13], dtype=torch.int32) + r_ys_in_pad = torch.randint(0, 4232, (10, 50)).long() + + if isinstance(decoder, TransformerDecoder): + torch.onnx.export(decoder, + (memory, memory_mask, ys_in_pad, ys_in_lens), + onnx_decoder_path, + export_params=True, + opset_version=12, + do_constant_folding=True, + input_names=['memory', 'memory_mask', 'ys_in_pad', 'ys_in_lens'], + output_names=['l_x', 'r_x'], + dynamic_axes={'memory': [1], 'memory_mask':[2], 'ys_in_pad':[1], + 'ys_in_lens': [0]}, + verbose=True + ) + elif isinstance(decoder, BiTransformerDecoder): + print("BI mode") + torch.onnx.export(decoder, + (memory, memory_mask, ys_in_pad, ys_in_lens, r_ys_in_pad), + onnx_decoder_path, + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names=['memory', 'memory_mask', 'ys_in_pad', 'ys_in_lens', 'r_ys_in_pad'], + output_names=['l_x', 'r_x', 'olens'], + dynamic_axes={'memory': [1], 'memory_mask':[2], 'ys_in_pad':[1], + 'ys_in_lens': [0], 'r_ys_in_pad':[1]}, + verbose=True + ) + diff --git a/ACL_PyTorch/contrib/audio/WeNet/export_onnx.sh b/ACL_PyTorch/contrib/audio/WeNet/export_onnx.sh new file mode 100644 index 0000000000000000000000000000000000000000..bebe2a07ef04b7fb044fdb57dcac8bc1337e3894 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/export_onnx.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +. ./path.sh || exit 1; + +yaml_path=$1 +decode_checkpoint=$2 + +mkdir onnx +python3 wenet/bin/export_onnx.py \ + --config $yaml_path \ + --checkpoint $decode_checkpoint \ + --output_onnx_file onnx diff --git a/ACL_PyTorch/contrib/audio/WeNet/gpu_fps.diff b/ACL_PyTorch/contrib/audio/WeNet/gpu_fps.diff new file mode 100644 index 0000000000000000000000000000000000000000..be0582d4d4769d21dafe511e7690ccd72b4eea39 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/gpu_fps.diff @@ -0,0 +1,163 @@ +diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py +--- a/wenet/transformer/asr_model.py ++++ b/wenet/transformer/asr_model.py +@@ -168,13 +168,13 @@ + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: +- encoder_out, encoder_mask = self.encoder( ++ encoder_out, encoder_mask, encoder_t = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) +- return encoder_out, encoder_mask ++ return encoder_out, encoder_mask, encoder_t + + def recognize( + self, +@@ -361,7 +361,7 @@ + assert batch_size == 1 + # Let's assume B = batch_size and N = beam_size + # 1. Encoder forward and get CTC score +- encoder_out, encoder_mask = self._forward_encoder( ++ encoder_out, encoder_mask, encoder_t = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) +@@ -409,7 +409,7 @@ + reverse=True) + cur_hyps = next_hyps[:beam_size] + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] +- return hyps, encoder_out ++ return hyps, encoder_out, encoder_t + + def ctc_prefix_beam_search( + self, +@@ -485,7 +485,7 @@ + # For attention rescoring we only support batch_size=1 + assert batch_size == 1 + # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size +- hyps, encoder_out = self._ctc_prefix_beam_search( ++ hyps, encoder_out, encoder_t = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + +@@ -510,7 +510,7 @@ + r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) + r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, + self.ignore_id) +- decoder_out, r_decoder_out, _ = self.decoder( ++ decoder_out, r_decoder_out, _, decoder_t = self.decoder( + encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, + reverse_weight) # (beam_size, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) +@@ -539,7 +539,7 @@ + if score > best_score: + best_score = score + best_index = i +- return hyps[best_index][0] ++ return hyps[best_index][0], encoder_t+decoder_t + + @torch.jit.export + def subsampling_rate(self) -> int: + +diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py +--- a/wenet/transformer/decoder.py ++++ b/wenet/transformer/decoder.py +@@ -13,6 +13,7 @@ + from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward + from wenet.utils.mask import (subsequent_mask, make_pad_mask) + ++import time + + class TransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. +@@ -252,13 +253,14 @@ + if use_output_layer is True, + olens: (batch, ) + """ ++ st = time.time() + l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, + ys_in_lens) + r_x = torch.tensor(0.0) + if reverse_weight > 0.0: + r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad, + ys_in_lens) +- return l_x, r_x, olens ++ return l_x, r_x, olens, time.time()-st + + def forward_one_step( + self, + +diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py +--- a/wenet/transformer/encoder.py ++++ b/wenet/transformer/encoder.py +@@ -26,6 +26,7 @@ + from wenet.utils.mask import make_pad_mask + from wenet.utils.mask import add_optional_chunk_mask + ++import time + + class BaseEncoder(torch.nn.Module): + def __init__( +@@ -146,6 +147,7 @@ + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + """ ++ st = time.time() + masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) +@@ -164,7 +166,7 @@ + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later +- return xs, masks ++ return xs, masks, time.time()-st + + def forward_chunk( + self, + +diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py +--- a/wenet/bin/recognize.py ++++ b/wenet/bin/recognize.py +@@ -139,8 +139,11 @@ + model = model.to(device) + + model.eval() ++ total_t = 0 ++ total_batch = 0 + with torch.no_grad(), open(args.result_file, 'w') as fout: + for batch_idx, batch in enumerate(test_data_loader): ++ total_batch += 1 + keys, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + target = target.to(device) +@@ -177,7 +180,7 @@ + hyps = [hyp] + elif args.mode == 'attention_rescoring': + assert (feats.size(0) == 1) +- hyp = model.attention_rescoring( ++ hyp, exe_t = model.attention_rescoring( + feats, + feats_lengths, + args.beam_size, +@@ -187,6 +190,8 @@ + simulate_streaming=args.simulate_streaming, + reverse_weight=args.reverse_weight) + hyps = [hyp] ++ total_t += exe_t ++ print(exe_t) + for i, key in enumerate(keys): + content = '' + for w in hyps[i]: +@@ -195,3 +200,7 @@ + content += char_dict[w] + logging.info('{} {}'.format(key, content)) + fout.write('{} {}\n'.format(key, content)) ++ print("mean_fps: ", 1/(total_t/total_batch)) ++ print("mean_time: ", total_t/total_batch) ++ fout.write("mean_time: "+str(total_t/total_batch)) ++ fout.write("mean_fps: "+str(1/(total_t/total_batch))) diff --git a/ACL_PyTorch/contrib/audio/WeNet/requirements.txt b/ACL_PyTorch/contrib/audio/WeNet/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a4e8c3570cf18cb61706d5dcf65e1c0e918168bc --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/requirements.txt @@ -0,0 +1,9 @@ +torch==1.9.0 +onnx==1.10.0 +onnxruntime==1.8.1 +torchaudio==0.9.0 +sympy +pyyaml +decorator +typeguard +pillow \ No newline at end of file diff --git a/ACL_PyTorch/contrib/audio/WeNet/run_static.sh b/ACL_PyTorch/contrib/audio/WeNet/run_static.sh new file mode 100644 index 0000000000000000000000000000000000000000..b1360409285e69454c8d4fd8b7051ccc8dbba5b0 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/run_static.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +. ./path.sh || exit 1; + +# Use this to control how many gpu you use, It's 1-gpu training if you specify +# just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +# The NCCL_SOCKET_IFNAME variable specifies which IP interface to use for nccl +# communication. More details can be found in +# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html +# export NCCL_SOCKET_IFNAME=ens4f1 +export NCCL_DEBUG=INFO +stage=5 # start from 0 if you need to start from data preparation +stop_stage=5 +# The num of nodes or machines used for multi-machine training +# Default 1 for single machine/node +# NFS will be needed if you want run multi-machine training +num_nodes=1 +# The rank of each node or machine, range from 0 to num_nodes -1 +# The first node/machine sets node_rank 0, the second one sets node_rank 1 +# the third one set node_rank 2, and so on. Default 0 +node_rank=0 +# data +data=/export/data/asr-data/OpenSLR/33/ +data_url=www.openslr.org/resources/33 + +nj=16 +feat_dir=raw_wav +dict=data/dict/lang_char.txt + +train_set=train +# Optional train_config +# 1. conf/train_transformer.yaml: Standard transformer +# 2. conf/train_conformer.yaml: Standard conformer +# 3. conf/train_unified_conformer.yaml: Unified dynamic chunk causal conformer +# 4. conf/train_unified_transformer.yaml: Unified dynamic chunk transformer +# 5. conf/train_conformer_no_pos.yaml: Conformer without relative positional encoding +# 6. conf/train_u2++_conformer.yaml: U2++ conformer +# 7. conf/train_u2++_transformer.yaml: U2++ transformer +train_config=conf/train_conformer.yaml +cmvn=true +dir=exp/conformer +checkpoint= + +# use average_checkpoint will get better result +average_checkpoint=false +decode_checkpoint=$dir/final.pt +average_num=30 +decode_modes="attention_rescoring" + +. tools/parse_options.sh || exit 1; + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # Test model, please specify the model you want to test by --checkpoint + if [ ${average_checkpoint} == true ]; then + decode_checkpoint=$dir/avg_${average_num}.pt + echo "do model average and final checkpoint is $decode_checkpoint" + python3 wenet/bin/average_model.py \ + --dst_model $decode_checkpoint \ + --src_path $dir \ + --num ${average_num} \ + --val_best + fi + # Specify decoding_chunk_size if it's a unified dynamic chunk trained model + # -1 for full chunk + decoding_chunk_size= + ctc_weight=0.3 + reverse_weight=0.3 + for mode in ${decode_modes}; do + { + test_dir=$dir/test_${mode} + mkdir -p $test_dir + python3 wenet/bin/static.py --gpu -1 \ + --mode $mode \ + --config $dir/train.yaml \ + --test_data $feat_dir/test/format.data \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 1 \ + --penalty 0.0 \ + --dict $dict \ + --ctc_weight $ctc_weight \ + --reverse_weight $reverse_weight \ + --result_file $test_dir/text \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + python3 tools/compute-wer.py --char=1 --v=1 \ + $feat_dir/test/text $test_dir/text > $test_dir/wer + } & + done + wait + +fi + diff --git a/ACL_PyTorch/contrib/audio/WeNet/slice_helper.py b/ACL_PyTorch/contrib/audio/WeNet/slice_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..737a110713d50b86ba3fc4db0a7e7527714bdb79 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/slice_helper.py @@ -0,0 +1,67 @@ +# BSD 3-Clause License +# +# Copyright (c) 2017 xxxx +# All rights reserved. +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + +import torch + +@torch.jit.script +def slice_helper(x, offset): + return x[:, -offset: , : ] + +@torch.jit.script +def slice_helper2(x: torch.Tensor, start: torch.Tensor, end: torch.Tensor): + start = start.long() + end = end.long() + return x[:, start:end] + +@torch.jit.script +def slice_helper3(x, start): + return x[:, start:] + +@torch.jit.script +def get_item(x): + item = x.detach().item() + output = torch.tensor(item) + return output + +@torch.jit.script +def get_next_cache_start(required_cache_size: torch.Tensor, xs: torch.Tensor): + next_cache_start = 0 + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = xs.size(1) + else: + if xs.size(1) - required_cache_size < 0: + next_cache_start = 0 + else: + next_cache_start = xs.size(1) - required_cache_size + return torch.tensor(next_cache_start, dtype=torch.int64) diff --git a/ACL_PyTorch/contrib/audio/WeNet/static.py b/ACL_PyTorch/contrib/audio/WeNet/static.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1836ba8901bd3a77c00efbab4bddaa246fad5b --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/static.py @@ -0,0 +1,203 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader + +from wenet.dataset.dataset import AudioDataset, CollateFunc +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument('--beam_size', + type=int, + default=10, + help='beam size for search') + parser.add_argument('--penalty', + type=float, + default=0.0, + help='length penalty') + parser.add_argument('--result_file', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=16, + help='asr result file') + parser.add_argument('--mode', + choices=[ + 'attention', 'ctc_greedy_search', + 'ctc_prefix_beam_search', 'attention_rescoring' + ], + default='attention', + help='decoding mode') + parser.add_argument('--ctc_weight', + type=float, + default=0.0, + help='ctc weight for attention rescoring decode mode') + parser.add_argument('--decoding_chunk_size', + type=int, + default=-1, + help='''decoding chunk size, + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here''') + parser.add_argument('--num_decoding_left_chunks', + type=int, + default=-1, + help='number of left chunks for decoding') + parser.add_argument('--simulate_streaming', + action='store_true', + help='simulate streaming inference') + parser.add_argument('--reverse_weight', + type=float, + default=0.0, + help='''right to left weight for attention rescoring + decode mode''') + args = parser.parse_args() + print(args) + total_t = 0 + total_batch = 0 + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring' + ] and args.batch_size > 1: + logging.fatal( + 'decoding mode {} must be running with batch_size == 1'.format( + args.mode)) + sys.exit(1) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + raw_wav = configs['raw_wav'] + # Init dataset and data loader + # Init dataset and data loader + test_collate_conf = copy.deepcopy(configs['collate_conf']) + test_collate_conf['spec_aug'] = False + test_collate_conf['spec_sub'] = False + test_collate_conf['feature_dither'] = False + test_collate_conf['speed_perturb'] = False + if raw_wav: + test_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 + test_collate_func = CollateFunc(**test_collate_conf, raw_wav=raw_wav) + dataset_conf = configs.get('dataset_conf', {}) + dataset_conf['batch_size'] = args.batch_size + dataset_conf['batch_type'] = 'static' + dataset_conf['sort'] = False + test_dataset = AudioDataset(args.test_data, + **dataset_conf, + raw_wav=raw_wav) + test_data_loader = DataLoader(test_dataset, + collate_fn=test_collate_func, + shuffle=False, + batch_size=1, + num_workers=0) + + # Init asr model from configs + model = init_asr_model(configs) + + # Load dict + char_dict = {} + with open(args.dict, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + model.eval() + with torch.no_grad(), open(args.result_file, 'w') as fout: + for batch_idx, batch in enumerate(test_data_loader): + total_batch += 1 + print("batch_idx:", batch_idx) + keys, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + target_lengths = target_lengths.to(device) + if args.mode == 'attention': + hyps = model.recognize( + feats, + feats_lengths, + beam_size=args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + hyps = [hyp.tolist() for hyp in hyps] + elif args.mode == 'ctc_greedy_search': + hyps = model.ctc_greedy_search( + feats, + feats_lengths, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + # ctc_prefix_beam_search and attention_rescoring only return one + # result in List[int], change it to List[List[int]] for compatible + # with other batch decoding mode + elif args.mode == 'ctc_prefix_beam_search': + assert (feats.size(0) == 1) + hyp = model.ctc_prefix_beam_search( + feats, + feats_lengths, + args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + hyps = [hyp] + elif args.mode == 'attention_rescoring': + assert (feats.size(0) == 1) + hyp = model.attention_rescoring( + feats, + feats_lengths, + args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + ctc_weight=args.ctc_weight, + simulate_streaming=args.simulate_streaming, + reverse_weight=args.reverse_weight) + total_t += hyp[1] + hyps = [hyp] + for i, key in enumerate(keys): + content = '' + for w in hyps[i][0]: + if w == eos: + break + content += char_dict[w] + logging.info('{} {}'.format(key, content)) + fout.write('{} {}\n'.format(key, content)) + fout.write('FPS:{}\n'.format(1000/(total_t/total_batch))) diff --git a/ACL_PyTorch/contrib/audio/WeNet/static_decoder.sh b/ACL_PyTorch/contrib/audio/WeNet/static_decoder.sh new file mode 100644 index 0000000000000000000000000000000000000000..94d3c789519898a6f6f397a10b42973d4007752c --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/static_decoder.sh @@ -0,0 +1,13 @@ +atc --model=decoder_final.onnx --framework=5 --output=decoder_fendang --input_format=ND \ +--input_shape="memory:10,-1,256;memory_mask:10,1,-1;ys_in_pad:10,-1;ys_in_lens:10;r_ys_in_pad:10,-1" --log=error \ +--dynamic_dims="96,96,3,3;96,96,4,4;96,96,5,5;96,96,6,6;96,96,7,7;96,96,8,8;96,96,9,9;96,96,10,10;96,96,11,11;\ +96,96,12,12;96,96,13,13;96,96,14,14;96,96,15,15;96,96,16,16;96,96,17,17;96,96,18,18;96,96,19,19;96,96,20,20;\ +96,96,21,21;96,96,22,22;96,96,23,23;144,144,6,6;144,144,7,7;144,144,8,8;144,144,9,9;144,144,10,10;144,144,11,11;\ +144,144,12,12;144,144,13,13;144,144,14,14;144,144,15,15;144,144,16,16;144,144,17,17;144,144,18,18;144,144,19,19;\ +144,144,20,20;144,144,21,21;144,144,22,22;144,144,23,23;144,144,24,24;144,144,25,25;144,144,26,26;144,144,27,27;\ +144,144,28,28;384,384,9,9;384,384,10,10;384,384,11,11;384,384,12,12;384,384,13,13;384,384,14,14;384,384,15,15;\ +384,384,16,16;384,384,17,17;384,384,18,18;384,384,19,19;384,384,20,20;384,384,21,21;384,384,22,22;384,384,23,23;\ +384,384,24,24;384,384,25,25;384,384,26,26;384,384,27,27;384,384,28,28;384,384,29,29;384,384,30,30;384,384,31,31;\ +384,384,32,32;384,384,33,33;384,384,34,34;384,384,35,35;384,384,36,36;384,384,37,37;384,384,38,38;384,384,39,39;384,384,40,40;384,384,41,41;" \ +--soc_version=Ascend310 + diff --git a/ACL_PyTorch/contrib/audio/WeNet/static_encoder.sh b/ACL_PyTorch/contrib/audio/WeNet/static_encoder.sh new file mode 100644 index 0000000000000000000000000000000000000000..e1b635c0d946a6bdc9a5acdb43ada130cbac3d04 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/static_encoder.sh @@ -0,0 +1,4 @@ +atc --model=no_flash_encoder_revise.onnx --framework=5 --output=encoder_fendang_262_1478_static --input_format=ND \ +--input_shape="xs_input:1,-1,80;xs_input_lens:1" --log=error \ +--dynamic_dims="262;326;390;454;518;582;646;710;774;838;902;966;1028;1284;1478" \ +--soc_version=Ascend310 diff --git a/ACL_PyTorch/contrib/audio/WeNet/wenet_onnx.diff b/ACL_PyTorch/contrib/audio/WeNet/wenet_onnx.diff new file mode 100644 index 0000000000000000000000000000000000000000..9d954caedcedd69e9b6034db160c8fbbb8a53737 --- /dev/null +++ b/ACL_PyTorch/contrib/audio/WeNet/wenet_onnx.diff @@ -0,0 +1,794 @@ +diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py +index 73990fa..68c8299 100644 +--- a/wenet/transformer/asr_model.py ++++ b/wenet/transformer/asr_model.py +@@ -245,7 +245,7 @@ class ASRModel(torch.nn.Module): + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) +- # 2.3 Second beam prune: select topk score with history ++ # 2.3 Seconde beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) +@@ -570,13 +570,12 @@ class ASRModel(torch.nn.Module): + def forward_encoder_chunk( + self, + xs: torch.Tensor, +- offset: int, +- required_cache_size: int, ++ offset: torch.Tensor, ++ required_cache_size: torch.Tensor, + subsampling_cache: Optional[torch.Tensor] = None, +- elayers_output_cache: Optional[List[torch.Tensor]] = None, +- conformer_cnn_cache: Optional[List[torch.Tensor]] = None, +- ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], +- List[torch.Tensor]]: ++ elayers_output_cache: Optional[torch.Tensor] = None, ++ conformer_cnn_cache: Optional[torch.Tensor] = None, ++ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + +@@ -675,6 +674,10 @@ class ASRModel(torch.nn.Module): + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + return decoder_out, r_decoder_out + ++ @torch.jit.export ++ def test(self,) -> str: ++ return "test" ++ + + def init_asr_model(configs): + if configs['cmvn_file'] is not None: +diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py +index f41f7e4..40c1a57 100644 +--- a/wenet/transformer/decoder.py ++++ b/wenet/transformer/decoder.py +@@ -57,8 +57,7 @@ class TransformerDecoder(torch.nn.Module): + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), +- PositionalEncoding(attention_dim, positional_dropout_rate), +- ) ++ PositionalEncoding(attention_dim, positional_dropout_rate)) + else: + raise ValueError(f"only 'embed' is supported: {input_layer}") + +@@ -81,6 +80,10 @@ class TransformerDecoder(torch.nn.Module): + concat_after, + ) for _ in range(self.num_blocks) + ]) ++ self.onnx_mode = False ++ ++ def set_onnx_mode(self, onnx_mode=False): ++ self.onnx_mode = onnx_mode + + def forward( + self, +@@ -111,13 +114,15 @@ class TransformerDecoder(torch.nn.Module): + tgt = ys_in_pad + + # tgt_mask: (B, 1, L) +- tgt_mask = (~make_pad_mask(ys_in_lens).unsqueeze(1)).to(tgt.device) ++ tgt_mask = (~make_pad_mask(ys_in_lens, ys_in_pad).unsqueeze(1)).to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), + device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) +- tgt_mask = tgt_mask & m +- x, _ = self.embed(tgt) ++ # tgt_mask = tgt_mask & m ++ tgt_mask = torch.mul(tgt_mask, m) ++ x = self.embed[0](tgt) ++ x, _ = self.embed[1](x, onnx_mode=self.onnx_mode) + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, + memory_mask) +@@ -225,6 +230,13 @@ class BiTransformerDecoder(torch.nn.Module): + self_attention_dropout_rate, src_attention_dropout_rate, + input_layer, use_output_layer, normalize_before, concat_after) + ++ self.onnx_mode = False ++ ++ def set_onnx_mode(self, onnx_mode=False): ++ self.onnx_mode = onnx_mode ++ self.left_decoder.set_onnx_mode(onnx_mode) ++ self.right_decoder.set_onnx_mode(onnx_mode) ++ + def forward( + self, + memory: torch.Tensor, +@@ -252,6 +264,7 @@ class BiTransformerDecoder(torch.nn.Module): + if use_output_layer is True, + olens: (batch, ) + """ ++ reverse_weight = 0.3 + l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, + ys_in_lens) + r_x = torch.tensor(0.0) +diff --git a/wenet/transformer/decoder_layer.py b/wenet/transformer/decoder_layer.py +index 25bb281..59dd174 100644 +--- a/wenet/transformer/decoder_layer.py ++++ b/wenet/transformer/decoder_layer.py +@@ -17,7 +17,7 @@ class DecoderLayer(nn.Module): + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. +- src_attn (torch.nn.Module): Inter-attention module instance. ++ src_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. +@@ -61,7 +61,8 @@ class DecoderLayer(nn.Module): + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor, +- cache: Optional[torch.Tensor] = None ++ cache: Optional[torch.Tensor] = None, ++ onnx_mode: Optional[bool] = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute decoded features. + +diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py +index a47afd9..0a6794c 100644 +--- a/wenet/transformer/embedding.py ++++ b/wenet/transformer/embedding.py +@@ -9,6 +9,7 @@ import math + from typing import Tuple + + import torch ++from wenet.transformer.slice_helper import slice_helper2 + + + class PositionalEncoding(torch.nn.Module): +@@ -45,7 +46,8 @@ class PositionalEncoding(torch.nn.Module): + + def forward(self, + x: torch.Tensor, +- offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: ++ offset: torch.Tensor = torch.tensor(0), ++ onnx_mode: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: +@@ -56,13 +58,21 @@ class PositionalEncoding(torch.nn.Module): + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ +- assert offset + x.size(1) < self.max_len ++ # assert offset + x.size(1) < self.max_len + self.pe = self.pe.to(x.device) +- pos_emb = self.pe[:, offset:offset + x.size(1)] ++ # pos_emb = self.pe[:, offset:offset + x.size(1)] ++ if onnx_mode: ++ pos_emb = slice_helper2(self.pe, offset, offset + x.size(1)) ++ else: ++ pos_emb = self.pe[:, offset:offset + x.size(1)] + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + +- def position_encoding(self, offset: int, size: int) -> torch.Tensor: ++ def position_encoding(self, ++ offset: torch.Tensor, ++ size: torch.Tensor, ++ onnx_mode: bool = False, ++ ) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! +@@ -79,7 +89,12 @@ class PositionalEncoding(torch.nn.Module): + torch.Tensor: Corresponding encoding + """ + assert offset + size < self.max_len +- return self.dropout(self.pe[:, offset:offset + size]) ++ if onnx_mode: ++ # pe = torch.cat([self.pe[:, [0]], slice_helper2(self.pe, offset, offset + size - 1)], dim=1) ++ pe = slice_helper2(self.pe, offset, offset + size) ++ else: ++ pe = self.pe[:, offset:offset + size] ++ return self.dropout(pe) + + + class RelPositionalEncoding(PositionalEncoding): +@@ -96,7 +111,8 @@ class RelPositionalEncoding(PositionalEncoding): + + def forward(self, + x: torch.Tensor, +- offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: ++ offset: torch.Tensor, ++ onnx_mode: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). +@@ -104,10 +120,16 @@ class RelPositionalEncoding(PositionalEncoding): + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ +- assert offset + x.size(1) < self.max_len ++ # assert offset + x.size(1) < self.max_len + self.pe = self.pe.to(x.device) + x = x * self.xscale +- pos_emb = self.pe[:, offset:offset + x.size(1)] ++ if onnx_mode: ++ # end = offset.item() + x.size(1) ++ # pos_emb = torch.index_select(self.pe, 1, torch.tensor(range(x.size(1)))) ++ pos_emb = slice_helper2(self.pe, offset, offset + x.size(1)) ++ # pos_emb = slice_helper3(pos_emb, x.size(1)) ++ else: ++ pos_emb = self.pe[:, offset:offset + x.size(1)] + return self.dropout(x), self.dropout(pos_emb) + + +diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py +index e342ed4..9b4f968 100644 +--- a/wenet/transformer/encoder.py ++++ b/wenet/transformer/encoder.py +@@ -6,6 +6,8 @@ + """Encoder definition.""" + from typing import Tuple, List, Optional + ++import numpy as np ++import onnxruntime + import torch + from typeguard import check_argument_types + +@@ -18,6 +20,7 @@ from wenet.transformer.embedding import NoPositionalEncoding + from wenet.transformer.encoder_layer import TransformerEncoderLayer + from wenet.transformer.encoder_layer import ConformerEncoderLayer + from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward ++from wenet.transformer.slice_helper import slice_helper3, get_next_cache_start + from wenet.transformer.subsampling import Conv2dSubsampling4 + from wenet.transformer.subsampling import Conv2dSubsampling6 + from wenet.transformer.subsampling import Conv2dSubsampling8 +@@ -26,6 +29,8 @@ from wenet.utils.common import get_activation + from wenet.utils.mask import make_pad_mask + from wenet.utils.mask import add_optional_chunk_mask + ++def to_numpy(x): ++ return x.detach().numpy() + + class BaseEncoder(torch.nn.Module): + def __init__( +@@ -116,10 +121,14 @@ class BaseEncoder(torch.nn.Module): + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk ++ self.onnx_mode = False + + def output_size(self) -> int: + return self._output_size + ++ def set_onnx_mode(self, onnx_mode=False): ++ self.onnx_mode = onnx_mode ++ + def forward( + self, + xs: torch.Tensor, +@@ -130,7 +139,7 @@ class BaseEncoder(torch.nn.Module): + """Embed positions in tensor. + + Args: +- xs: padded input tensor (B, T, D) ++ xs: padded input tensor (B, L, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. +@@ -141,16 +150,18 @@ class BaseEncoder(torch.nn.Module): + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: +- encoder output tensor xs, and subsampled masks +- xs: padded output tensor (B, T' ~= T/subsample_rate, D) +- masks: torch.Tensor batch padding mask after subsample +- (B, 1, T' ~= T/subsample_rate) ++ encoder output tensor, lens and mask + """ +- masks = ~make_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T) ++ decoding_chunk_size = 1 ++ num_decoding_left_chunks = 1 ++ self.use_dynamic_chunk = False ++ self.use_dynamic_left_chunk = False ++ self.static_chunk_size = 0 ++ masks = ~make_pad_mask(xs_lens, xs).unsqueeze(1) # (B, 1, L) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) +- mask_pad = masks # (B, 1, T/subsample_rate) ++ mask_pad = masks + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, +@@ -169,13 +180,12 @@ class BaseEncoder(torch.nn.Module): + def forward_chunk( + self, + xs: torch.Tensor, +- offset: int, +- required_cache_size: int, ++ offset_tensor: torch.Tensor = torch.tensor(0), ++ required_cache_size_tensor: torch.Tensor = torch.tensor(0), + subsampling_cache: Optional[torch.Tensor] = None, +- elayers_output_cache: Optional[List[torch.Tensor]] = None, +- conformer_cnn_cache: Optional[List[torch.Tensor]] = None, +- ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], +- List[torch.Tensor]]: ++ elayers_output_cache: Optional[torch.Tensor] = None, ++ conformer_cnn_cache: Optional[torch.Tensor] = None, ++ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: +@@ -199,6 +209,7 @@ class BaseEncoder(torch.nn.Module): + List[torch.Tensor]: conformer cnn cache + + """ ++ required_cache_size_tensor = torch.tensor(-1) + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, +@@ -208,30 +219,53 @@ class BaseEncoder(torch.nn.Module): + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) +- xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) ++ # if self.onnx_mode: ++ # offset_tensor = offset_tensor - torch.tensor(1) ++ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset_tensor, self.onnx_mode) + if subsampling_cache is not None: + cache_size = subsampling_cache.size(1) + xs = torch.cat((subsampling_cache, xs), dim=1) + else: + cache_size = 0 +- pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1)) +- if required_cache_size < 0: +- next_cache_start = 0 +- elif required_cache_size == 0: +- next_cache_start = xs.size(1) ++ # if self.onnx_mode: ++ # cache_size = cache_size - 1 ++ # if self.onnx_mode: ++ # # subsampling_cache append dummy var, remove it here ++ # xs = xs[:, 1:, :] ++ # cache_size = cache_size - 1 ++ if isinstance(xs.size(1), int): ++ xs_size_1 = torch.tensor(xs.size(1)) + else: +- next_cache_start = max(xs.size(1) - required_cache_size, 0) +- r_subsampling_cache = xs[:, next_cache_start:, :] ++ xs_size_1 = xs.size(1).clone().detach() ++ pos_emb = self.embed.position_encoding(offset_tensor - cache_size, ++ xs_size_1, ++ self.onnx_mode) ++ next_cache_start = get_next_cache_start(required_cache_size_tensor, xs) ++ r_subsampling_cache = slice_helper3(xs, next_cache_start) ++ # if self.onnx_mode: ++ # next_cache_start_1 = get_next_cache_start(required_cache_size_tensor, xs) ++ # r_subsampling_cache = slice_helper3(xs, next_cache_start_1) ++ # else: ++ # required_cache_size = required_cache_size_tensor.detach().item() ++ # if required_cache_size < 0: ++ # next_cache_start = 0 ++ # elif required_cache_size == 0: ++ # next_cache_start = xs.size(1) ++ # else: ++ # next_cache_start = max(xs.size(1) - required_cache_size, 0) ++ # r_subsampling_cache = xs[:, next_cache_start:, :] + # Real mask for transformer/conformer layers + masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) + masks = masks.unsqueeze(1) +- r_elayers_output_cache = [] +- r_conformer_cnn_cache = [] ++ r_elayers_output_cache = None ++ r_conformer_cnn_cache = None + for i, layer in enumerate(self.encoders): + if elayers_output_cache is None: + attn_cache = None + else: + attn_cache = elayers_output_cache[i] ++ # if self.onnx_mode and attn_cache is not None: ++ # attn_cache = attn_cache[:, 1:, :] + if conformer_cnn_cache is None: + cnn_cache = None + else: +@@ -240,13 +274,32 @@ class BaseEncoder(torch.nn.Module): + masks, + pos_emb, + output_cache=attn_cache, +- cnn_cache=cnn_cache) +- r_elayers_output_cache.append(xs[:, next_cache_start:, :]) +- r_conformer_cnn_cache.append(new_cnn_cache) ++ cnn_cache=cnn_cache, ++ onnx_mode=self.onnx_mode) ++ if self.onnx_mode: ++ layer_output_cache = slice_helper3(xs, next_cache_start) ++ else: ++ layer_output_cache = xs[:, next_cache_start:, :] ++ if i == 0: ++ r_elayers_output_cache = layer_output_cache.unsqueeze(0) ++ r_conformer_cnn_cache = new_cnn_cache.unsqueeze(0) ++ else: ++ # r_elayers_output_cache.append(xs[:, next_cache_start:, :]) ++ r_elayers_output_cache = torch.cat((r_elayers_output_cache, layer_output_cache.unsqueeze(0)), 0) ++ # r_conformer_cnn_cache.append(new_cnn_cache) ++ r_conformer_cnn_cache = torch.cat((r_conformer_cnn_cache, new_cnn_cache.unsqueeze(0)), 0) + if self.normalize_before: + xs = self.after_norm(xs) +- +- return (xs[:, cache_size:, :], r_subsampling_cache, ++ if self.onnx_mode: ++ cache_size = cache_size - 1 ++ if isinstance(cache_size, int): ++ cache_size_1 = torch.tensor(cache_size) ++ else: ++ cache_size_1 = cache_size.clone().detach() ++ output = slice_helper3(xs, cache_size_1) ++ else: ++ output = xs[:, cache_size:, :] ++ return (output, r_subsampling_cache, + r_elayers_output_cache, r_conformer_cnn_cache) + + def forward_chunk_by_chunk( +@@ -290,24 +343,54 @@ class BaseEncoder(torch.nn.Module): + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + subsampling_cache: Optional[torch.Tensor] = None +- elayers_output_cache: Optional[List[torch.Tensor]] = None +- conformer_cnn_cache: Optional[List[torch.Tensor]] = None ++ elayers_output_cache: Optional[torch.Tensor] = None ++ conformer_cnn_cache: Optional[torch.Tensor] = None + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks ++ print("required_cache_size:", required_cache_size) ++ encoder_session = onnxruntime.InferenceSession("onnx/encoder.onnx") ++ ++ subsampling_cache_onnx = torch.zeros(1, 1, 256, requires_grad=False) ++ elayers_output_cache_onnx = torch.zeros(12, 1, 1, 256, requires_grad=False) ++ conformer_cnn_cache_onnx = torch.zeros(12, 1, 256, 7, requires_grad=False) + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] ++ ++ if offset > 0: ++ offset = offset - 1 + (y, subsampling_cache, elayers_output_cache, +- conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset, +- required_cache_size, ++ conformer_cnn_cache) = self.forward_chunk(chunk_xs, torch.tensor(offset), ++ torch.tensor(required_cache_size), + subsampling_cache, + elayers_output_cache, + conformer_cnn_cache) +- outputs.append(y) ++ ++ offset = offset + 1 ++ encoder_inputs = { ++ encoder_session.get_inputs()[0].name: chunk_xs.numpy(), ++ encoder_session.get_inputs()[1].name: np.array(offset), ++ encoder_session.get_inputs()[2].name: subsampling_cache_onnx.numpy(), ++ encoder_session.get_inputs()[3].name: elayers_output_cache_onnx.numpy(), ++ encoder_session.get_inputs()[4].name: conformer_cnn_cache_onnx.numpy(), ++ } ++ ort_outs = encoder_session.run(None, encoder_inputs) ++ y_onnx, subsampling_cache_onnx, elayers_output_cache_onnx, conformer_cnn_cache_onnx = \ ++ torch.from_numpy(ort_outs[0][:, 1:, :]), torch.from_numpy(ort_outs[1]), \ ++ torch.from_numpy(ort_outs[2]), torch.from_numpy(ort_outs[3]) ++ ++ np.testing.assert_allclose(to_numpy(y), ort_outs[0][:, 1:, :], rtol=1e-03, atol=1e-03) ++ np.testing.assert_allclose(to_numpy(subsampling_cache), ort_outs[1][:, 1:, :], rtol=1e-03, atol=1e-03) ++ np.testing.assert_allclose(to_numpy(elayers_output_cache), ort_outs[2][:, :, 1:, :], rtol=1e-03, atol=1e-03) ++ np.testing.assert_allclose(to_numpy(conformer_cnn_cache), ort_outs[3], rtol=1e-03, atol=1e-03) ++ ++ outputs.append(y_onnx) ++ # outputs.append(y) + offset += y.size(1) ++ # break + ys = torch.cat(outputs, 1) + masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool) + masks = masks.unsqueeze(1) +diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py +index db8696d..0be079c 100644 +--- a/wenet/transformer/encoder_layer.py ++++ b/wenet/transformer/encoder_layer.py +@@ -9,6 +9,7 @@ from typing import Optional, Tuple + + import torch + from torch import nn ++from wenet.transformer.slice_helper import slice_helper + + + class TransformerEncoderLayer(nn.Module): +@@ -53,6 +54,9 @@ class TransformerEncoderLayer(nn.Module): + # concat_linear may be not used in forward fuction, + # but will be saved in the *.pt + self.concat_linear = nn.Linear(size + size, size) ++ ++ def set_onnx_mode(self, onnx_mode=False): ++ self.onnx_mode = onnx_mode + + def forward( + self, +@@ -92,9 +96,14 @@ class TransformerEncoderLayer(nn.Module): + assert output_cache.size(2) == self.size + assert output_cache.size(1) < x.size(1) + chunk = x.size(1) - output_cache.size(1) +- x_q = x[:, -chunk:, :] +- residual = residual[:, -chunk:, :] +- mask = mask[:, -chunk:, :] ++ if self.onnx_mode: ++ x_q = slice_helper(x, chunk) ++ residual = slice_helper(residual, chunk) ++ mask = slice_helper(mask, chunk) ++ else: ++ x_q = x[:, -chunk:, :] ++ residual = residual[:, -chunk:, :] ++ mask = mask[:, -chunk:, :] + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) +@@ -184,6 +193,7 @@ class ConformerEncoderLayer(nn.Module): + mask_pad: Optional[torch.Tensor] = None, + output_cache: Optional[torch.Tensor] = None, + cnn_cache: Optional[torch.Tensor] = None, ++ onnx_mode: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + +@@ -193,7 +203,6 @@ class ConformerEncoderLayer(nn.Module): + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. +- (#batch, 1,time) + output_cache (torch.Tensor): Cache tensor of the output + (#batch, time2, size), time2 < time in x. + cnn_cache (torch.Tensor): Convolution cache in conformer layer +@@ -202,6 +211,14 @@ class ConformerEncoderLayer(nn.Module): + torch.Tensor: Mask tensor (#batch, time). + """ + ++ if onnx_mode: ++ x = x[:, 1:, :] ++ mask = mask[:, :, 1:] ++ # pos_emb_ = pos_emb[:, 1:, :] ++ pos_emb_ = pos_emb[:, :-1, :] ++ else: ++ pos_emb_ = pos_emb ++ + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x +@@ -223,12 +240,26 @@ class ConformerEncoderLayer(nn.Module): + assert output_cache.size(0) == x.size(0) + assert output_cache.size(2) == self.size + assert output_cache.size(1) < x.size(1) +- chunk = x.size(1) - output_cache.size(1) +- x_q = x[:, -chunk:, :] +- residual = residual[:, -chunk:, :] +- mask = mask[:, -chunk:, :] + +- x_att = self.self_attn(x_q, x, x, mask, pos_emb) ++ # chunk = x.size(1) - output_cache.size(1) ++ if onnx_mode: ++ chunk = x.size(1) - output_cache.size(1) + 1 ++ if isinstance(chunk, int): ++ chunk_1 = torch.tensor(chunk) ++ else: ++ chunk_1 = chunk.clone().detach() ++ # chunk = torch.tensor(chunk) ++ # print(type(chunk)) ++ x_q = slice_helper(x, chunk_1) ++ residual = slice_helper(residual, chunk_1) ++ mask = slice_helper(mask, chunk_1) ++ else: ++ chunk = x.size(1) - output_cache.size(1) ++ x_q = x[:, -chunk:, :] ++ residual = residual[:, -chunk:, :] ++ mask = mask[:, -chunk:, :] ++ ++ x_att = self.self_attn(x_q, x, x, mask, pos_emb_) + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) +diff --git a/wenet/transformer/subsampling.py b/wenet/transformer/subsampling.py +index b890f70..a978424 100644 +--- a/wenet/transformer/subsampling.py ++++ b/wenet/transformer/subsampling.py +@@ -16,8 +16,11 @@ class BaseSubsampling(torch.nn.Module): + self.right_context = 0 + self.subsampling_rate = 1 + +- def position_encoding(self, offset: int, size: int) -> torch.Tensor: +- return self.pos_enc.position_encoding(offset, size) ++ def position_encoding(self, ++ offset: torch.Tensor, ++ size: torch.Tensor, ++ onnx_mode: bool = False) -> torch.Tensor: ++ return self.pos_enc.position_encoding(offset, size, onnx_mode) + + + class LinearNoSubsampling(BaseSubsampling): +@@ -89,16 +92,17 @@ class Conv2dSubsampling4(BaseSubsampling): + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: +- # (kernel_size - 1) * frame_rate_of_this_layer ++ # (kernel_size - 1) / 2 * stride * frame_rate_of_this_layer + self.subsampling_rate = 4 +- # 6 = (3 - 1) * 1 + (3 - 1) * 2 ++ # 6 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, +- offset: int = 0 ++ offset: torch.Tensor = torch.tensor(0), ++ onnx_mode: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + +@@ -118,7 +122,7 @@ class Conv2dSubsampling4(BaseSubsampling): + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) +- x, pos_emb = self.pos_enc(x, offset) ++ x, pos_emb = self.pos_enc(x, offset, onnx_mode) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + + +@@ -143,9 +147,9 @@ class Conv2dSubsampling6(BaseSubsampling): + self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), + odim) + self.pos_enc = pos_enc_class +- # 10 = (3 - 1) * 1 + (5 - 1) * 2 ++ # 14 = (3 - 1) / 2 * 2 * 1 + (5 - 1) / 2 * 3 * 2 + self.subsampling_rate = 6 +- self.right_context = 10 ++ self.right_context = 14 + + def forward( + self, +@@ -198,7 +202,7 @@ class Conv2dSubsampling8(BaseSubsampling): + odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) + self.pos_enc = pos_enc_class + self.subsampling_rate = 8 +- # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 ++ # 14 = (3 - 1) / 2 * 2 * 1 + (3 - 1) / 2 * 2 * 2 + (3 - 1) / 2 * 2 * 4 + self.right_context = 14 + + def forward( +diff --git a/wenet/utils/mask.py b/wenet/utils/mask.py +index c2bb50f..d23bd95 100644 +--- a/wenet/utils/mask.py ++++ b/wenet/utils/mask.py +@@ -5,6 +5,15 @@ + + import torch + ++def tril_onnx(x, diagonal: torch.Tensor = torch.tensor(0)): ++ m,n = x.shape[0], x.shape[1] ++ arange = torch.arange(n, device = x.device) ++ mask = arange.expand(m, n) ++ mask_maker = torch.arange(m, device = x.device).unsqueeze(-1) ++ if diagonal: ++ mask_maker = mask_maker + diagonal ++ mask = mask <= mask_maker ++ return mask * x + + def subsequent_mask( + size: int, +@@ -35,13 +44,17 @@ def subsequent_mask( + [1, 1, 0], + [1, 1, 1]] + """ +- ret = torch.ones(size, size, device=device, dtype=torch.bool) +- return torch.tril(ret, out=ret) ++ # ret = torch.ones(size, size, device=device, dtype=torch.bool) ++ # return torch.tril(ret, out=ret) ++ # to export onnx, we change the code as follows ++ ret = torch.ones(size, size, device=device) ++ #return torch.tril(ret, out=ret) ++ return tril_onnx(ret) + + + def subsequent_chunk_mask( +- size: int, +- chunk_size: int, ++ size: torch.tensor(0), ++ chunk_size: torch.tensor(0), + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), + ) -> torch.Tensor: +@@ -67,6 +80,18 @@ def subsequent_chunk_mask( + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) ++ row_index = torch.arange(size, device = device) ++ index = row_index.expand(size, size) ++ expand_size = torch.ones((size), device = device)*size ++ #expand_size = expand_size.long() ++ if num_left_chunks < 0: ++ start1 = torch.tensor(0) ++ else: ++ start1 = torch.max((torch.floor_divide(row_index, chunk_size)-num_left_chunks).float()*chunk_size, torch.tensor(0.0)).long().view(size,1) ++ ending = torch.min((torch.floor_divide(row_index, chunk_size)+1).float()*chunk_size, expand_size.float()).long().view(size,1) ++ ret[torch.where(index < ending)] = True ++ ret[torch.where(index < start1)] = False ++ ''' + for i in range(size): + if num_left_chunks < 0: + start = 0 +@@ -74,6 +99,8 @@ def subsequent_chunk_mask( + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True ++ print("ret:", ret) ++ ''' + return ret + + +@@ -107,18 +134,18 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: +- max_len = xs.size(1) ++ max_len = xs.shape[1] + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: +- chunk_size = decoding_chunk_size ++ chunk_size = torch.tensor(decoding_chunk_size) + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. +- chunk_size = torch.randint(1, max_len, (1, )).item() ++ chunk_size = torch.randint(1, max_len, (1, )) + num_left_chunks = -1 + if chunk_size > max_len // 2: + chunk_size = max_len +@@ -128,14 +155,14 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = torch.randint(0, max_left_chunks, + (1, )).item() +- chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, ++ chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks +- chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, ++ chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) +@@ -145,7 +172,7 @@ def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, + return chunk_masks + + +-def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: ++def make_pad_mask(lengths: torch.Tensor, xs: torch.Tensor) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. +@@ -162,8 +189,11 @@ def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ +- batch_size = int(lengths.size(0)) +- max_len = int(lengths.max().item()) ++ # batch_size = int(lengths.size(0)) ++ # max_len = int(lengths.max().item()) ++ # to export the decoder onnx and avoid the constant fold ++ batch_size = xs.shape[0] ++ max_len = xs.shape[1] + seq_range = torch.arange(0, + max_len, + dtype=torch.int64,