diff --git a/ACL_PyTorch/built-in/audio/Paraformer/README.md b/ACL_PyTorch/built-in/audio/Paraformer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b780f40df2a25d517de5dd7903766a5fe2e1e0cd --- /dev/null +++ b/ACL_PyTorch/built-in/audio/Paraformer/README.md @@ -0,0 +1,91 @@ +# Paraformer(TorchAir)-推理指导 + +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) +- [模型推理性能&精度](#模型推理性能&精度) + +****** + +# 概述 +Paraformer是达摩院语音团队提出的一种高效的非自回归端到端语音识别框架。本项目为Paraformer中文通用语音识别模型,采用工业级数万小时的标注音频进行模型训练,保证了模型的通用识别效果。模型可以被应用于语音输入法、语音导航、智能会议纪要等场景。 + +- 版本说明: + ``` + url=https://github.com/modelscope/FunASR + commit_id=77db489a8f9d1ff0771bfaea55cbeedfc77aac77 + model_name=Chinese-Clip + ``` + +# 推理环境准备 +- 该模型需要以下插件与驱动 + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ |--------------| ------------------------------------------------------------ | + | 固件与驱动 | 24.0.0 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 8.0.0 | - | + | Python | 3.10 | - | + | PyTorch | 2.1.0 | - | + | Ascend Extension PyTorch | 2.1.0.post11 | - | + | 说明:Atlas 800I A2/Atlas 300I Pro 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + + +# 快速上手 + +## 获取源码 + +1. 安装依赖 + ```bash + pip3 install -r ../requirements.txt + ``` + +2. 获取模型仓源码 + ```bash + git clone https://github.com/modelscope/FunASR.git + cd FunASR + git reset --hard 77db489a + git apply ../adapt-torchair.patch + pip3 install -e ./ + ``` + +## 模型推理 +1. 下载模型权重 + + 本文档以[Paraformer-large-热词版模型](https://www.modelscope.cn/models/iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404)为例,请自行下载权重文件夹,或通过git下载 + ```bash + git lfs install + git clone https://www.modelscope.cn/iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404.git + ``` + +2. 执行推理命令 + + ```bash + python3 infer_air.py \ + --model_path=./speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404 \ + --data=speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav \ + --hotwords="魔搭" \ + --device=0 \ + --loop=10 + ``` + - 参数说明 + - model_path:模型权重路径 + - data:模型规模,默认为ViT-B-16 + - image:输入图片,默认为examples/pokemon.jpeg + - text:输入文本,默认为"杰尼龟" + - length:Tokenizer输出长度,默认为64 + - batch_size:模型输入batch_size,默认为16 + - device:npu芯片id,默认为0 + - loop:性能测试的循环次数,默认为10 + + 推理脚本以计算图片文本相似度为例,推理后将打屏推理结果和模型性能 + +# 模型推理性能&精度 +以为例 + +| 模型 | 硬件 | 端到端性能 | +|----------|------|-----|--------------| +| Paraformer |800I A2| 1.4 s/data | + diff --git a/ACL_PyTorch/built-in/audio/Paraformer/adapt-torchair.patch b/ACL_PyTorch/built-in/audio/Paraformer/adapt-torchair.patch new file mode 100644 index 0000000000000000000000000000000000000000..3ae2b506bfe6aaed9c1ce2094b42dd5827c98e5c --- /dev/null +++ b/ACL_PyTorch/built-in/audio/Paraformer/adapt-torchair.patch @@ -0,0 +1,199 @@ +--- + .../models/contextual_paraformer/decoder.py | 14 ++++--- + funasr/models/contextual_paraformer/model.py | 13 ++++++- + funasr/models/paraformer/model.py | 3 +- + funasr/models/sanm/attention.py | 38 ++++++++++++++++++- + funasr/models/sanm/encoder.py | 4 +- + 5 files changed, 59 insertions(+), 13 deletions(-) + +diff --git a/funasr/models/contextual_paraformer/decoder.py b/funasr/models/contextual_paraformer/decoder.py +index ba2ce9ad..be36352f 100644 +--- a/funasr/models/contextual_paraformer/decoder.py ++++ b/funasr/models/contextual_paraformer/decoder.py +@@ -254,10 +254,11 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder): + def forward( + self, + hs_pad: torch.Tensor, +- hlens: torch.Tensor, ++ memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, +- ys_in_lens: torch.Tensor, ++ tgt_mask: torch.Tensor, + contextual_info: torch.Tensor, ++ contextual_mask: torch.Tensor, + clas_scale: float = 1.0, + return_hidden: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: +@@ -279,18 +280,19 @@ class ContextualParaformerDecoder(ParaformerSANMDecoder): + olens: (batch, ) + """ + tgt = ys_in_pad +- tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] ++ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + + memory = hs_pad +- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] ++ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] + + x = tgt + x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask) + _, _, x_self_attn, x_src_attn = self.last_decoder(x, tgt_mask, memory, memory_mask) + + # contextual paraformer related +- contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0]) +- contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] ++ # contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0]) ++ # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] ++ contextual_mask = contextual_mask.unsqueeze(1).eq(0).expand(-1, -1, x_self_attn.size(1), -1) + cx, tgt_mask, _, _, _ = self.bias_decoder( + x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask + ) +diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py +index fd882202..f5a16854 100644 +--- a/funasr/models/contextual_paraformer/model.py ++++ b/funasr/models/contextual_paraformer/model.py +@@ -18,6 +18,7 @@ from distutils.version import LooseVersion + + from funasr.register import tables + from funasr.utils import postprocess_utils ++from funasr.models.scama import utils as myutils + from funasr.metrics.compute_acc import th_accuracy + from funasr.models.paraformer.model import Paraformer + from funasr.utils.datadir_writer import DatadirWriter +@@ -328,12 +329,20 @@ class ContextualParaformer(Paraformer): + _, (h_n, _) = self.bias_encoder(hw_embed) + hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) + ++ # build mask before model ++ tgt_mask = myutils.sequence_mask(ys_pad_lens, device=encoder_out.device)[:, :, None] ++ memory_mask = myutils.sequence_mask(encoder_out_lens, device=encoder_out.device)[:, None, :] ++ memory_mask = memory_mask.unsqueeze(1).eq(0).expand(-1, -1, sematic_embeds.size(1), -1) ++ contextual_length = torch.Tensor([hw_embed.shape[1]]).int().repeat(encoder_out.shape[0]) ++ contextual_mask = myutils.sequence_mask(contextual_length, device=encoder_out.device)[:, None, :] ++ + decoder_outs = self.decoder( + encoder_out, +- encoder_out_lens, ++ memory_mask, + sematic_embeds, +- ys_pad_lens, ++ tgt_mask, + contextual_info=hw_embed, ++ contextual_mask=contextual_mask, + clas_scale=clas_scale, + ) + +diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py +index 85967af3..935c6e77 100644 +--- a/funasr/models/paraformer/model.py ++++ b/funasr/models/paraformer/model.py +@@ -259,7 +259,8 @@ class Paraformer(torch.nn.Module): + speech, speech_lengths = self.normalize(speech, speech_lengths) + + # Forward encoder +- encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) ++ mask = (~make_pad_mask(speech_lengths)[:, None :]).to(speech.device) ++ encoder_out, encoder_out_lens, _ = self.encoder(speech, mask) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + +diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py +index 47d60cb6..34ae46b5 100644 +--- a/funasr/models/sanm/attention.py ++++ b/funasr/models/sanm/attention.py +@@ -10,6 +10,7 @@ import math + + import numpy + import torch ++import torch_npu + from torch import nn + from typing import Optional, Tuple + +@@ -289,7 +290,7 @@ class MultiHeadedAttentionSANM(nn.Module): + + return self.linear_out(x) # (batch, time1, d_model) + +- def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): ++ def forward_ori(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): + """Compute scaled dot product attention. + + Args: +@@ -310,6 +311,23 @@ class MultiHeadedAttentionSANM(nn.Module): + att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) + return att_outs + fsmn_memory + ++ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): ++ q_k_v = self.linear_q_k_v(x) ++ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) ++ fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) ++ attn_out = torch_npu.npu_prompt_flash_attention( ++ q.to(torch.float16), ++ k.to(torch.float16), ++ v.to(torch.float16), ++ scale_value=self.d_k ** (-0.5), ++ atten_mask=mask.unsqueeze(1).eq(0).repeat(1, 1, q.size(1), 1), ++ input_layout='BSH', ++ num_heads=self.h, ++ sparse_mode=1, ++ ).to(q.dtype) ++ att_outs = self.linear_out(attn_out) ++ return att_outs + fsmn_memory ++ + def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): + """Compute scaled dot product attention. + +@@ -697,7 +715,7 @@ class MultiHeadedAttentionCrossAtt(nn.Module): + return self.linear_out(x), attn # (batch, time1, d_model) + return self.linear_out(x) # (batch, time1, d_model) + +- def forward(self, x, memory, memory_mask, ret_attn=False): ++ def forward_ori(self, x, memory, memory_mask, ret_attn=False): + """Compute scaled dot product attention. + + Args: +@@ -716,6 +734,22 @@ class MultiHeadedAttentionCrossAtt(nn.Module): + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) + return self.forward_attention(v_h, scores, memory_mask, ret_attn=ret_attn) + ++ def forward(self, x, memory, memory_mask, ret_attn=False): ++ q = self.linear_q(x) ++ k_v = self.linear_k_v(memory) ++ k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1) ++ attn_out = torch_npu.npu_prompt_flash_attention( ++ q.to(torch.float16), ++ k.to(torch.float16), ++ v.to(torch.float16), ++ scale_value=self.d_k ** (-0.5), ++ atten_mask=memory_mask, ++ input_layout='BSH', ++ num_heads=self.h, ++ sparse_mode=1, ++ ).to(q.dtype) ++ return self.linear_out(attn_out) ++ + def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0): + """Compute scaled dot product attention. + +diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py +index 0d39ca74..b062e9d8 100644 +--- a/funasr/models/sanm/encoder.py ++++ b/funasr/models/sanm/encoder.py +@@ -361,7 +361,7 @@ class SANMEncoder(nn.Module): + def forward( + self, + xs_pad: torch.Tensor, +- ilens: torch.Tensor, ++ masks: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +@@ -374,7 +374,7 @@ class SANMEncoder(nn.Module): + Returns: + position embedded tensor and mask + """ +- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) ++ # masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + xs_pad = xs_pad * self.output_size() ** 0.5 + if self.embed is None: + xs_pad = xs_pad +-- diff --git a/ACL_PyTorch/built-in/audio/Paraformer/infer_air.py b/ACL_PyTorch/built-in/audio/Paraformer/infer_air.py new file mode 100644 index 0000000000000000000000000000000000000000..c32df398a391355f570225295a7e586c36cf93cf --- /dev/null +++ b/ACL_PyTorch/built-in/audio/Paraformer/infer_air.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import argparse +from time import time + +import torch +import torch_npu +import torchair as tng +from torchair import CompilerConfig + +from funasr import AutoModel + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Chinese Clip Inference") + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path of Pretrained Weight" + ) + parser.add_argument( + '--data', + type=str, + default="speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav", + help='Path of wav file' + ) + parser.add_argument( + '--hotwords', + type=str, + default="魔搭", + help='Hotword.' + ) + parser.add_argument('--device', type=int, default=0, help='Npu Device Id') + parser.add_argument('--loop', default=10, type=int, help='Loop Times') + args = parser.parse_args() + + # adapt torchair + device = torch.device('npu:{}'.format(args.device)) + model = AutoModel(model=args.model_path, disable_update=True) + model.model.to(device) + + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + npu_backbend = tng.get_npu_backend(compiler_config=config) + model.model.encoder = torch.compile(model.model.encoder, dynamic=True, fullgraph=True, backend=npu_backbend) + tng.use_internal_format_weight(model.model.encoder) + model.model.decoder = torch.compile(model.model.decoder, dynamic=True, fullgraph=True, backend=npu_backbend) + tng.use_internal_format_weight(model.model.decoder) + + # compile and precision test + print('start compiling...') + res = model.generate(input=args.data, hotword=args.hotwords, device=device) + print('result:', res) + + # test model performance + ct = 0 + for i in range(args.loop): + torch.npu.synchronize() + st = time() + model.generate(input=args.data, hotword=args.hotwords, device=device) + torch.npu.synchronize() + if i: + ct += time() - st + print(f'e2e performance: {(args.loop -1) / ct} data/s') + \ No newline at end of file diff --git a/ACL_PyTorch/built-in/audio/Paraformer/requirements.txt b/ACL_PyTorch/built-in/audio/Paraformer/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9487226c4c5839581a63c88bd240ef04478d2241 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/Paraformer/requirements.txt @@ -0,0 +1,4 @@ +torch==2.1.0 +torch-npu==2.1.0.post11 +numpy==1.26 +torchaudio \ No newline at end of file