From ef42b5ba9730c2becb1fc0c47036a6f55a91b98c Mon Sep 17 00:00:00 2001 From: shikang Date: Thu, 20 Mar 2025 11:16:14 +0800 Subject: [PATCH] add CosyVoice2 model file --- .../audio/CosyVoice/CosyVoice2/README.md | 147 ++++++ .../CosyVoice/CosyVoice2/diff_CosyVoice.patch | 218 ++++++++ .../CosyVoice2/diff_transformers.patch | 477 ++++++++++++++++++ .../audio/CosyVoice/CosyVoice2/infer.py | 54 ++ .../audio/CosyVoice/CosyVoice2/modify_onnx.py | 32 ++ .../CosyVoice/CosyVoice2/requirements.txt | 33 ++ 6 files changed, 961 insertions(+) create mode 100755 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md create mode 100755 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch create mode 100755 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_transformers.patch create mode 100755 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py create mode 100755 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modify_onnx.py create mode 100755 ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md new file mode 100755 index 0000000000..479196012a --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -0,0 +1,147 @@ +# CosyVoice-推理指导 + +- [CosyVoice-推理指导](#cosyvoice-推理指导) +- [概述](#概述) +- [推理环境准备](#推理环境准备) +- [快速上手](#快速上手) + - [获取源码](#获取源码) + - [模型推理](#模型推理) + - [1 模型转换](#1-模型转换) + - [2 开始推理验证](#2-开始推理验证) + - [3 性能](#3-性能) + +****** + +# 概述 +  ‌Co‌syVoice是一款基于语音量化编码的语音生成大模型,能够深度融合文本理解和语音生成,实现自然流畅的语音体验。它通过离散化编码和依托大模型技术,能够精准解析并诠释各类文本内容,将其转化为宛如真人般的自然语音‌。CosyVoice2在原始1的基础上,把QWEN2模型接入CosyVoice的LLM部分,实现了推理加速 + +- 版本说明: + ``` + url=https://github.com/FunAudioLLM/CosyVoice + commit_id=fd45708 + model_name=Cosyvoice + ``` + +# 推理环境准备 +- 该模型需要以下插件与驱动 + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ | + | 固件与驱动 | 24.0.RC3 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 8.0.RC3 | 包含kernels包和toolkit包 | + | Python | 3.8 | - | + | PyTorch | 2.4.0 | - | + | Ascend Extension PyTorch | 2.4.0.post2 | - | + | 说明:Atlas 800I A2 推理卡和Atlas 300I DUO 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + + +# 快速上手 + +## 获取源码 + +1. 获取`PyTorch`源码 + ``` + # 获取CosyVoice源码 + git clone https://github.com/FunAudioLLM/CosyVoice + cd CosyVoice + git reset --hard fd45708 + git submodule update --init --recursive + git apply ../diff_CosyVoice.patch + # 获取Transformer源码 + git clone https://github.com/huggingface/transformers.git + cd transformers + git checkout v4.37.0 + git apply ../../diff_transformers.patch + ``` + +2. 安装依赖 + ``` + pip3 install -r ../requirements.txt + apt-get install sox # centos版本 yum install sox + ``` + 注:如果遇到无法安装WeTextProcessing的场景,可以参考以下方法手动安装编译 + ```bash + # 下载安装包并解压 + wget https://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.8.3.tar.gz + # 进入目录后编译安装 + ./configure --enable-far --enable-mpdt --enable-pdt + make install + pip3 install WeTextProcessing==1.0.4.1 + ``` + +3. 安装msit工具 + + 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgen组件。 + + +4. 获取权重数据 + + 本案例以CosyVoice2-0.5B为例,其他权重请自行适配 + + 获取 https://www.modelscope.cn/iic/CosyVoice2-0.5B 权重文件夹,放在CosyVoice目录下 + + 或者通过git方式获取 + ``` + # git模型下载,请确保已安装git lfs + git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git CosyVoice/CosyVoice2-0.5B + ``` + +## 模型推理 + +### 1 模型转换 + +   模型权重中提供了 flow.decoder.estimator.fp32.onnx和speech_tokenizer_v2.onnx两个onnx模型,对其进行结构修改后使用`ATC`工具将`.onnx`文件转为离线推理模型`.om`文件。 + +1. 修改onnx模型结构 + + ``` + python3 modify_onnx.py ${CosyVoice2-0.5B} + ``` + + CosyVoice-300M是onnx模型所在权重文件夹,其他权重请自行更改权重文件名。执行该命令后会在CosyVoice-300M目录下生成修改后的onnx文件speech_token_md.onnx + +2. 使用`ATC`工具将`ONNX`模型转为`OM`模型 + + 配置环境变量 + + ``` + source /usr/local/Ascend/ascend-toolkit/set_env.sh + ``` + + 执行ATC命令,将利用npu-smi info命令获取的芯片型号填入${soc_version}中 + + ``` + atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/speech_token_md.onnx --output ./${CosyVoice2-0.5B}/speech --input_shape="feats:1,128,-1;feats_length:1" + atc --framework=5 --soc_version=${soc_version} --model ./${CosyVoice2-0.5B}/flow.decoder.estimator.fp32.onnx --output ./${CosyVoice2-0.5B}/flow --input_shape="x:2,80,-1;mask:2,1,-1;mu:2,80,-1;t:2;spks:2,80;cond:2,80,-1" + ``` + 在权重目录CosyVoice2-0.5B下会生成两个om模型, 分别为 speech_{arch}.om和flow_{arch}.om + + 注:模型{arch}后缀为当前使用的CPU操作系统。 + +### 2 开始推理验证 + + 1. 首先移动infer.py文件到CosyVoice目录下 + + + 2. 设置环境变量,执行推理命令 + + ``` + # 指定使用NPU ID,默认为0 + export ASCEND_RT_VISIBLE_DEVICES=0 + export PYTHONPATH=third_party/Matcha-TTS:$PYTHONPATH + export PYTHONPATH=transformers/src:$PYTHONPATH + python3 infer.py --model_path=${CosyVoice2-0.5B} + ``` + - --model_path: 权重路径 + - --warm_up_times:warm up次数,默认为2 + - --infer_count:循环推理次数,默认为10 + + 在推理开始后,首先会默认执行warm_up,目的是执行首次编译,首次编译时间较长,在warm_up结束后,会执行推理操作,并将推理结果保存在'zero_shot_result.wav'中,并打屏性能数据:实时率(rtf),指的是平均1s时长的音频需要多少时间处理。 + +### 3 性能 + + |模型|芯片|rtf(实时率)| + |------|------|------| + |cosyvoice|800I A2|0.3s| + diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch new file mode 100755 index 0000000000..ddf08819d1 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch @@ -0,0 +1,218 @@ +From b14a21d1c50e45162c607782608311fe0fd65c35 Mon Sep 17 00:00:00 2001 +From: shikang +Date: Thu, 20 Mar 2025 10:38:13 +0800 +Subject: [PATCH] add new patch + +--- + cosyvoice/cli/cosyvoice.py | 12 ++++++++++- + cosyvoice/cli/frontend.py | 17 ++++++++++----- + cosyvoice/cli/model.py | 37 ++++++++++++++++----------------- + cosyvoice/flow/flow_matching.py | 20 ++++++++++++------ + cosyvoice/llm/llm.py | 8 +++++-- + 5 files changed, 61 insertions(+), 33 deletions(-) + +diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py +index e2d62e2..d3ccbfc 100644 +--- a/cosyvoice/cli/cosyvoice.py ++++ b/cosyvoice/cli/cosyvoice.py +@@ -13,11 +13,13 @@ + # limitations under the License. + import os + import time ++import platform + from typing import Generator + from tqdm import tqdm + from hyperpyyaml import load_hyperpyyaml + from modelscope import snapshot_download + import torch ++from ais_bench.infer.interface import InferSession + from cosyvoice.cli.frontend import CosyVoiceFrontEnd + from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model + from cosyvoice.utils.file_utils import logging +@@ -26,7 +28,7 @@ from cosyvoice.utils.class_utils import get_model_type + + class CosyVoice: + +- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): ++ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False): + self.instruct = True if '-Instruct' in model_dir else False + self.model_dir = model_dir + self.fp16 = fp16 +@@ -57,6 +59,14 @@ class CosyVoice: + self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), + '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), + self.fp16) ++ if load_om: ++ arch = platform.machine() ++ system = platform.system().lower() ++ flow_om = InferSession(0, '{}/flow_{}_{}.om'.format(model_dir, system ,arch)) ++ speech_om = InferSession(0, '{}/speech_{}_{}.om'.format(model_dir, system ,arch)) ++ self.frontend.speech_om = speech_om ++ self.frontend.flow_om = flow_om ++ self.model.flow.decoder.flow_om = flow_om + del configs + + def list_available_spks(self): +diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py +index 6e10f00..0eb45a8 100644 +--- a/cosyvoice/cli/frontend.py ++++ b/cosyvoice/cli/frontend.py +@@ -71,6 +71,8 @@ class CosyVoiceFrontEnd: + self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True) + self.en_tn_model = EnNormalizer() + self.inflect_parser = inflect.engine() ++ self.speech_om = None ++ self.flow_om = None + + def _extract_text_token(self, text): + if isinstance(text, Generator): +@@ -92,11 +94,16 @@ class CosyVoiceFrontEnd: + def _extract_speech_token(self, speech): + assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s' + feat = whisper.log_mel_spectrogram(speech, n_mels=128) +- speech_token = self.speech_tokenizer_session.run(None, +- {self.speech_tokenizer_session.get_inputs()[0].name: +- feat.detach().cpu().numpy(), +- self.speech_tokenizer_session.get_inputs()[1].name: +- np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() ++ if torch.npu.is_available() and self.speech_om: ++ feed = [feat.detach().cpu().numpy(), np.array([feat.shape[2]], dtype=np.int32)] ++ speech_token = self.speech_om.infer(feed, mode='dymshape', custom_sizes=[100000000])[0].flatten().tolist() ++ self.flow_om.set_context() ++ else: ++ speech_token = self.speech_tokenizer_session.run(None, ++ {self.speech_tokenizer_session.get_inputs()[0].name: ++ feat.detach().cpu().numpy(), ++ self.speech_tokenizer_session.get_inputs()[1].name: ++ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() + speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device) + speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) + return speech_token, speech_token_len +diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py +index 9ebf8cb..530a44d 100644 +--- a/cosyvoice/cli/model.py ++++ b/cosyvoice/cli/model.py +@@ -99,25 +99,24 @@ class CosyVoiceModel: + self.flow.decoder.estimator = self.flow.decoder.estimator_engine.create_execution_context() + + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): +- with self.llm_context: +- if isinstance(text, Generator): +- assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' +- for i in self.llm.inference_bistream(text=text, +- prompt_text=prompt_text.to(self.device), +- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), +- prompt_speech_token=llm_prompt_speech_token.to(self.device), +- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), +- embedding=llm_embedding.to(self.device)): +- self.tts_speech_token_dict[uuid].append(i) +- else: +- for i in self.llm.inference(text=text.to(self.device), +- text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), +- prompt_text=prompt_text.to(self.device), +- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), +- prompt_speech_token=llm_prompt_speech_token.to(self.device), +- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), +- embedding=llm_embedding.to(self.device)): +- self.tts_speech_token_dict[uuid].append(i) ++ if isinstance(text, Generator): ++ assert isinstance(self, CosyVoice2Model), 'streaming input text is only implemented for CosyVoice2!' ++ for i in self.llm.inference_bistream(text=text, ++ prompt_text=prompt_text.to(self.device), ++ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), ++ prompt_speech_token=llm_prompt_speech_token.to(self.device), ++ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), ++ embedding=llm_embedding.to(self.device)): ++ self.tts_speech_token_dict[uuid].append(i) ++ else: ++ for i in self.llm.inference(text=text.to(self.device), ++ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), ++ prompt_text=prompt_text.to(self.device), ++ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), ++ prompt_speech_token=llm_prompt_speech_token.to(self.device), ++ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), ++ embedding=llm_embedding.to(self.device)): ++ self.tts_speech_token_dict[uuid].append(i) + self.llm_end_dict[uuid] = True + + def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0): +diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py +index 6a60f6d..457864c 100644 +--- a/cosyvoice/flow/flow_matching.py ++++ b/cosyvoice/flow/flow_matching.py +@@ -14,6 +14,7 @@ + import threading + import torch + import torch.nn.functional as F ++import numpy as np + from matcha.models.components.flow_matching import BASECFM + + +@@ -32,6 +33,7 @@ class ConditionalCFM(BASECFM): + # Just change the architecture of the estimator here + self.estimator = estimator + self.lock = threading.Lock() ++ self.flow_om = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): +@@ -105,12 +107,18 @@ class ConditionalCFM(BASECFM): + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond +- dphi_dt = self.forward_estimator( +- x_in, mask_in, +- mu_in, t_in, +- spks_in, +- cond_in +- ) ++ if torch.npu.is_available() and self.flow_om: ++ feed_list = [x_in, mask_in, mu_in, t_in, spks_in, cond_in] ++ feed = [i.cpu().detach().numpy().astype(np.float32) for i in feed_list] ++ dphi_dt = self.flow_om.infer(feed, mode="dymshape", custom_sizes=10000000) ++ dphi_dt = torch.from_numpy(dphi_dt[0]).npu() ++ else: ++ dphi_dt = self.forward_estimator( ++ x_in, mask_in, ++ mu_in, t_in, ++ spks_in, ++ cond_in ++ ) + dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) + dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) + x = x + dt * dphi_dt +diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py +index bbd3305..c169e0f 100644 +--- a/cosyvoice/llm/llm.py ++++ b/cosyvoice/llm/llm.py +@@ -229,11 +229,12 @@ class Qwen2Encoder(torch.nn.Module): + super().__init__() + self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) + +- def forward_one_step(self, xs, masks, cache=None): ++ def forward_one_step(self, xs, masks, prompt_length, cache=None): + input_masks = masks[:, -1, :] + outs = self.model( +- inputs_embeds=xs, ++ inputs_embeds=xs.to(torch.float16), #模型支持fp16推理 + attention_mask=input_masks, ++ prompt_length=prompt_length, + output_hidden_states=True, + return_dict=True, + use_cache=True, +@@ -318,9 +319,12 @@ class Qwen2LM(TransformerLM): + # 5. step by step decode + out_tokens = [] + cache = None ++ input_length = lm_input.shape[1] + for i in range(max_len): ++ prompt_length = input_length + i + y_pred, cache = self.llm.forward_one_step(lm_input, + masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), ++ prompt_length=prompt_length, + cache=cache) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() +-- +2.21.0 + diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_transformers.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_transformers.patch new file mode 100755 index 0000000000..7e66d4cf6c --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_transformers.patch @@ -0,0 +1,477 @@ +From 61c406570556f621df2507690e54beb10a87a3cd Mon Sep 17 00:00:00 2001 +From: shikang +Date: Wed, 19 Mar 2025 17:03:08 +0800 +Subject: [PATCH] add patch + +--- + .../models/qwen2/modeling_qwen2.py | 253 ++++++++++++------ + 1 file changed, 173 insertions(+), 80 deletions(-) + +diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py +index f8290928a..2ad507294 100644 +--- a/src/transformers/models/qwen2/modeling_qwen2.py ++++ b/src/transformers/models/qwen2/modeling_qwen2.py +@@ -24,6 +24,7 @@ import warnings + from typing import List, Optional, Tuple, Union + + import torch ++import torch_npu + import torch.nn.functional as F + import torch.utils.checkpoint + from torch import nn +@@ -77,7 +78,7 @@ def _get_unpad_data(attention_mask): + ) + + +-# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 ++# Add/Norm融合算子 + class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ +@@ -87,12 +88,14 @@ class Qwen2RMSNorm(nn.Module): + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + +- def forward(self, hidden_states): +- input_dtype = hidden_states.dtype +- hidden_states = hidden_states.to(torch.float32) +- variance = hidden_states.pow(2).mean(-1, keepdim=True) +- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) +- return self.weight * hidden_states.to(input_dtype) ++ def forward(self, ++ hidden_states, ++ residual: Optional[torch.Tensor] = None): ++ if residual is None: ++ return torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0], hidden_states ++ else: ++ y, _, x = torch_npu.npu_add_rms_norm(residual, hidden_states, self.weight, self.variance_epsilon) ++ return y, x + + + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 +@@ -121,14 +124,14 @@ class Qwen2RotaryEmbedding(nn.Module): + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + +- def forward(self, x, seq_len=None): ++ def forward(self, x=None, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] +- if seq_len > self.max_seq_len_cached: +- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) ++ if x is None and seq_len is None: ++ return self.cos_cached, self.sin_cached + + return ( +- self.cos_cached[:seq_len].to(dtype=x.dtype), +- self.sin_cached[:seq_len].to(dtype=x.dtype), ++ self.cos_cached.to(dtype=x.dtype), ++ self.sin_cached.to(dtype=x.dtype), + ) + + +@@ -140,8 +143,8 @@ def rotate_half(x): + return torch.cat((-x2, x1), dim=-1) + + +-# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): ++# cos/sin计算优化 ++def apply_rotary_pos_emb(q, k, cos, sin): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: +@@ -162,8 +165,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ +- cos = cos[position_ids].unsqueeze(unsqueeze_dim) +- sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed +@@ -633,13 +634,18 @@ class Qwen2SdpaAttention(Qwen2Attention): + SDPA API. + """ + +- # Adapted from Qwen2Attention.forward ++ # 优化Attention部分逻辑,替换torch_npu算子 + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, ++ updated_kv_positions: Optional[torch.LongTensor] = None, ++ kv_padding_size: Optional[torch.LongTensor] = None, ++ actual_seq_len: Optional[list] = None, ++ rotary_emb_cos: Optional[torch.Tensor] = None, ++ rotary_emb_sin: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +@@ -664,48 +670,43 @@ class Qwen2SdpaAttention(Qwen2Attention): + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + +- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) +- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) +- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) +- +- kv_seq_len = key_states.shape[-2] +- if past_key_value is not None: +- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) +- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) +- +- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) +- +- if past_key_value is not None: +- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models +- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) +- +- key_states = repeat_kv(key_states, self.num_key_value_groups) +- value_states = repeat_kv(value_states, self.num_key_value_groups) +- +- if attention_mask is not None: +- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): +- raise ValueError( +- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" +- ) +- +- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, +- # Reference: https://github.com/pytorch/pytorch/issues/112577. +- if query_states.device.type == "cuda" and attention_mask is not None: +- query_states = query_states.contiguous() +- key_states = key_states.contiguous() +- value_states = value_states.contiguous() +- +- attn_output = torch.nn.functional.scaled_dot_product_attention( +- query_states, +- key_states, +- value_states, +- attn_mask=attention_mask, +- dropout_p=self.attention_dropout if self.training else 0.0, +- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. +- is_causal=self.is_causal and attention_mask is None and q_len > 1, +- ) ++ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) ++ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) ++ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) ++ ++ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, ++ rotary_emb_cos.to(value_states.dtype), ++ rotary_emb_sin.to(value_states.dtype)) ++ ++ if use_cache and past_key_value is not None: ++ tmp_ids = updated_kv_positions.reshape(-1) ++ # format BSND, 1 means seq_len dim index ++ torch_npu.scatter_update_(past_key_value.key_cache[self.layer_idx], tmp_ids, key_states, 1) ++ torch_npu.scatter_update_(past_key_value.value_cache[self.layer_idx], tmp_ids, value_states, 1) ++ kv_states = past_key_value[self.layer_idx] if q_len == 1 else (key_states, value_states) ++ key_states = kv_states[0] ++ value_states = kv_states[1] ++ ++ ++ if q_len > 1: ++ attention_mask = torch.tril(torch.ones((bsz, q_len, q_len), device=query_states.device)).logical_not().to(torch.bool) ++ attn_output = torch_npu.npu_prompt_flash_attention(query_states, key_states.contiguous(), ++ value_states.contiguous(), num_heads=self.num_heads, ++ input_layout="BSND", ++ scale_value=1 / math.sqrt(self.num_heads), ++ pre_tokens=65535, next_tokens=0, ++ atten_mask=attention_mask, ++ num_key_value_heads=self.num_key_value_heads) ++ else: ++ attn_output = torch_npu.npu_incre_flash_attention(query_states, key_states.contiguous(), ++ value_states.contiguous(), num_heads=self.num_heads, ++ input_layout="BSND", ++ scale_value=1 / math.sqrt(self.num_heads), ++ atten_mask=None, ++ actual_seq_lengths=actual_seq_len, ++ kv_padding_size=kv_padding_size, ++ num_key_value_heads=self.num_key_value_heads) + +- attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) +@@ -739,9 +740,15 @@ class Qwen2DecoderLayer(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, ++ past_residual: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, ++ updated_kv_positions: Optional[torch.LongTensor] = None, ++ kv_padding_size: Optional[torch.LongTensor] = None, ++ actual_seq_len: Optional[list] = None, ++ rotary_emb_cos: Optional[torch.Tensor] = None, ++ rotary_emb_sin: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, +@@ -765,9 +772,8 @@ class Qwen2DecoderLayer(nn.Module): + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + +- residual = hidden_states + +- hidden_states = self.input_layernorm(hidden_states) ++ hidden_states, residual = self.input_layernorm(hidden_states, past_residual) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( +@@ -776,17 +782,18 @@ class Qwen2DecoderLayer(nn.Module): + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, ++ updated_kv_positions=updated_kv_positions, ++ kv_padding_size=kv_padding_size, ++ actual_seq_len=actual_seq_len, ++ rotary_emb_cos=rotary_emb_cos, ++ rotary_emb_sin=rotary_emb_sin, + use_cache=use_cache, + ) +- hidden_states = residual + hidden_states + +- # Fully Connected +- residual = hidden_states +- hidden_states = self.post_attention_layernorm(hidden_states) ++ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) +- hidden_states = residual + hidden_states + +- outputs = (hidden_states,) ++ outputs = (residual, hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) +@@ -926,6 +933,12 @@ class Qwen2Model(Qwen2PreTrainedModel): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size ++ self.max_position_embeddings = config.max_position_embeddings ++ ++ self.hidden_size = config.hidden_size ++ self.num_heads = config.num_attention_heads ++ self.head_dim = self.hidden_size // self.num_heads ++ self.rope_theta = config.rope_theta + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( +@@ -933,6 +946,11 @@ class Qwen2Model(Qwen2PreTrainedModel): + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ self.rotary_emb = Qwen2RotaryEmbedding( ++ self.head_dim, ++ max_position_embeddings=self.max_position_embeddings, ++ base=self.rope_theta, ++ ) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing +@@ -944,6 +962,15 @@ class Qwen2Model(Qwen2PreTrainedModel): + def set_input_embeddings(self, value): + self.embed_tokens = value + ++ def _prepare_decoder_rotary_cos_sin(self, position_ids): ++ cos, sin = self.rotary_emb() ++ f_position_ids = position_ids.flatten() ++ cos = torch.index_select(cos, 0, f_position_ids) ++ sin = torch.index_select(sin, 0, f_position_ids) ++ cos = cos.reshape(position_ids.size(0), position_ids.size(1), -1).unsqueeze(2) ++ sin = sin.reshape(position_ids.size(0), position_ids.size(1), -1).unsqueeze(2) ++ return cos, sin ++ + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, +@@ -951,6 +978,9 @@ class Qwen2Model(Qwen2PreTrainedModel): + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, ++ updated_kv_positions: Optional[torch.LongTensor] = None, ++ kv_padding_size: Optional[torch.LongTensor] = None, ++ actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, +@@ -982,13 +1012,21 @@ class Qwen2Model(Qwen2PreTrainedModel): + ) + use_cache = False + +- past_key_values_length = 0 +- + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: ++ kv_shape = ( ++ batch_size, self.config.max_position_embeddings, ++ self.config.num_key_value_heads // self.config.world_size, ++ self.config.hidden_size // self.config.num_attention_heads) ++ past_key_values = () ++ for layer_idx in range(self.config.num_hidden_layers): ++ k_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) ++ v_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) ++ past_key_values += ((k_cache, v_cache),) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) +- past_key_values_length = past_key_values.get_usable_length(seq_length) ++ ++ past_key_values_length = self.max_position_embeddings if seq_length == 1 else 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device +@@ -1017,12 +1055,7 @@ class Qwen2Model(Qwen2PreTrainedModel): + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. +- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( +- attention_mask, +- (batch_size, seq_length), +- inputs_embeds, +- past_key_values_length, +- ) ++ attention_mask = None # cosyvoice中没有多bs场景,不需要额外prepare mask + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( +@@ -1035,11 +1068,15 @@ class Qwen2Model(Qwen2PreTrainedModel): + + hidden_states = inputs_embeds + ++ # 此处处理cos, sin是为了不在每层里面重复计算 ++ rotary_emb_cos, rotary_emb_sin = self._prepare_decoder_rotary_cos_sin(position_ids) ++ + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + ++ residual = None + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) +@@ -1058,21 +1095,28 @@ class Qwen2Model(Qwen2PreTrainedModel): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, ++ past_residual=residual, + position_ids=position_ids, + past_key_value=past_key_values, ++ updated_kv_positions=updated_kv_positions, ++ kv_padding_size=kv_padding_size, ++ actual_seq_len=actual_seq_len, ++ rotary_emb_cos=rotary_emb_cos, ++ rotary_emb_sin=rotary_emb_sin, + output_attentions=output_attentions, + use_cache=use_cache, + ) + +- hidden_states = layer_outputs[0] ++ residual = layer_outputs[0] ++ hidden_states = layer_outputs[1] + + if use_cache: +- next_decoder_cache = layer_outputs[2 if output_attentions else 1] ++ next_decoder_cache = layer_outputs[3 if output_attentions else 2] + + if output_attentions: +- all_self_attns += (layer_outputs[1],) ++ all_self_attns += (layer_outputs[2],) + +- hidden_states = self.norm(hidden_states) ++ hidden_states, _ = self.norm(hidden_states, residual) + + # add hidden states from the last decoder layer + if output_hidden_states: +@@ -1130,8 +1174,8 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, ++ prompt_length: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, +- labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, +@@ -1163,6 +1207,19 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + ++ updated_kv_positions, past_key_values, position_ids, kv_padding_size, actual_seq_len = self.prepare_data(inputs_embeds, past_key_values, prompt_length) ++ ++ model_inputs = { ++ "inputs_embeds": inputs_embeds, ++ "past_key_values": past_key_values, ++ "position_ids": position_ids, ++ "kv_padding_size": kv_padding_size, ++ "actual_seq_len": actual_seq_len, ++ "attention_mask": attention_mask, ++ } ++ ++ self._mark_model_inputs_static(model_inputs) ++ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states +@@ -1175,6 +1232,9 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, ++ updated_kv_positions=updated_kv_positions, ++ kv_padding_size=kv_padding_size, ++ actual_seq_len=actual_seq_len, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, +@@ -1183,6 +1243,13 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): + ) + + hidden_states = outputs[0] ++ # 由于logits最后也只取[:,-1,:],相当于只取最新seq位置上的数据,l ++ # 所以在全量的最后线性层计算可以只对最新的seq位置做计算,降低计算量 ++ bs, seq, hidden = hidden_states.size() ++ if seq > 1: ++ gather_index = torch.ones(bs, dtype=torch.int64, device=hidden_states.device) * (seq - 1) ++ gather_index = gather_index.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, 1, hidden) ++ hidden_states = torch.gather(hidden_states, 1, gather_index) + logits = self.lm_head(hidden_states) + logits = logits.float() + +@@ -1210,6 +1277,32 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) ++ ++ def prepare_data(self, inputs_embeds, past_key_values, prompt_length): ++ bsz = inputs_embeds.shape[0] ++ seq_length = inputs_embeds.shape[1] ++ if past_key_values: ++ past_key_values = DynamicCache.from_legacy_cache(past_key_values) ++ if seq_length > 1: ++ updated_kv_positions = torch.zeros(bsz, dtype=torch.long, device=inputs_embeds.device) ++ else: ++ updated_kv_positions = torch.ones(bsz, dtype=torch.long, device=inputs_embeds.device) * (prompt_length - 1) ++ position_ids = torch.tensor([prompt_length], device=inputs_embeds.device) ++ ++ # ifa Computational optimization inputs ++ kv_padding_size = torch.tensor(self.config.max_position_embeddings - prompt_length, device=inputs_embeds.device) ++ actual_seq_len = ([prompt_length]).cpu().tolist() ++ ++ return updated_kv_positions, past_key_values, position_ids, kv_padding_size, actual_seq_len ++ ++ def _mark_model_inputs_static(self, model_inputs): ++ for key, value in model_inputs.items(): ++ if key == "past_key_values" and value is not None: ++ for i in range(self.config.num_hidden_layers): ++ torch._dynamo.mark_static(value[i][0]) ++ torch._dynamo.mark_static(value[i][1]) ++ elif isinstance(value, torch.Tensor): ++ torch._dynamo.mark_static(value) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs +-- +2.21.0 + diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py new file mode 100755 index 0000000000..54fcfbfc94 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import argparse +import torch +import torchaudio +import torch_npu +from torch_npu.contrib import transfer_to_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig +from cosyvoice.cli.cosyvoice import CosyVoice2 +from cosyvoice.utils.file_utils import load_wav + + +if __name__ == '__main__': + torch_npu.npu.set_compile_mode(jit_compile=False) + + parser = argparse.ArgumentParser(description="CosyVoice infer") + parser.add_argument("--model_path", type=str, help="model path") + parser.add_argument('--warm_up_times', default=2, type=int, help='warm up times') + parser.add_argument('--infer_count', default=10, type=int, help='infer loop count') + args = parser.parse_args() + + cosyvoice = CosyVoice2(args.model_path, load_om=True) + + # 设置torchair初始化参数 + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + npu_backend = tng.get_npu_backend(compiler_config=config) + config.experimental_config.tiling_schedule_optimize = True + cosyvoice.model.llm.llm.model.model.eval() + cosyvoice.model.llm.llm.model.model.half() + cosyvoice.model.llm.llm.model.model = torch.compile(cosyvoice.model.llm.llm.model.model, dynamic=True, fullgraph=True, backend=npu_backend) + + # 输入数据加载 + prompt_txt = '收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。' + out_txt = '希望你以后能够做的比我还好嘞。' + + prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) # 16000为当前语音用例的resample rate参数 + with torch.no_grad(): + print('warm up start') + for _ in range(args.warm_up_times): + next(cosyvoice.inference_zero_shot(prompt_txt, out_txt, prompt_speech_16k, stream=False)) + print('warm up end') + for _ in range(args.infer_count): + for j in cosyvoice.inference_zero_shot(prompt_txt, out_txt, prompt_speech_16k, stream=False): + torchaudio.save('zero_shot_result.wav', j['tts_speech'], cosyvoice.sample_rate) \ No newline at end of file diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modify_onnx.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modify_onnx.py new file mode 100755 index 0000000000..d2a312fd33 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modify_onnx.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import os +import sys +from auto_optimizer import OnnxGraph + + +def modify_speech_token(speech_token): + ReduceMax_list = speech_token.get_nodes("ReduceMax") + for node in ReduceMax_list: + if node.attrs['keepdims'] == 0: + out_nodes = speech_token.get_next_nodes(node.outputs[0]) + for out_node in out_nodes: + if out_node.op_type == "Unsqueeze": + node.attrs['keepdims'] = 1 + speech_token.remove(out_node.name) + return speech_token + +if __name__ == '__main__': + input_path = sys.argv[1] + speech_token = OnnxGraph.parse(os.path.join(input_path, "speech_tokenizer_v2.onnx")) + # om图模式不支持无shape输出,reducemax需要修改keepdim + modify_onnx = modify_speech_token(speech_token) + modify_onnx.save(os.path.join(input_path, "speech_token_md.onnx")) \ No newline at end of file diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt new file mode 100755 index 0000000000..b2794ae2b9 --- /dev/null +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt @@ -0,0 +1,33 @@ +conformer==0.3.2 +deepspeed==0.14.2 +diffusers==0.27.2 +gdown==5.1.0 +gradio==4.32.2 +grpcio==1.57.0 +grpcio-tools==1.57.0 +huggingface-hub==0.23.5 +hydra-core==1.3.2 +HyperPyYAML==1.2.2 +inflect==7.3.1 +librosa==0.10.2 +lightning==2.2.4 +matplotlib==3.7.5 +modelscope==1.15.0 +networkx==3.1 +omegaconf==2.3.0 +onnx==1.16.0 +onnxruntime==1.16.0 +openai-whisper==20231117 +protobuf==4.25 +pydantic==2.7.0 +rich==13.7.1 +soundfile==0.12.1 +tensorboard==2.14.0 +torch==2.4.0 +torch_npu==2.4.0.post2 +torchaudio==2.4.0 +uvicorn==0.30.0 +wget==3.2 +fastapi==0.111.0 +fastapi-cli==0.0.4 +WeTextProcessing==1.0.3 \ No newline at end of file -- Gitee