From 5f38ccaa1833bf4284ed28645b37f4360b324aec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=B1=9F=E6=B1=9F?= Date: Fri, 30 May 2025 11:46:45 +0800 Subject: [PATCH 1/2] feat: cosyvoice2 support stream input inference --- .../audio/CosyVoice/CosyVoice2/README.md | 88 +++-- .../CosyVoice/CosyVoice2/diff_CosyVoice.patch | 370 +++++++++++++++--- .../audio/CosyVoice/CosyVoice2/infer.py | 72 +++- .../CosyVoice/CosyVoice2/modeling_qwen2.py | 116 +++++- 4 files changed, 538 insertions(+), 108 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 31740a8f02..8ac1f82660 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -13,7 +13,9 @@ ****** # 概述 -  ‌Co‌syVoice是一款基于语音量化编码的语音生成大模型,能够深度融合文本理解和语音生成,实现自然流畅的语音体验。它通过离散化编码和依托大模型技术,能够精准解析并诠释各类文本内容,将其转化为宛如真人般的自然语音‌。CosyVoice2在原始1的基础上,把QWEN2模型接入CosyVoice的LLM部分,实现了推理加速 +‌Co‌syVoice是一款基于语音量化编码的语音生成大模型,能够深度融合文本理解和语音生成,实现自然流畅的语音体验。它通过离散化编码和依托大模型技术,能够精准解析并诠释各类文本内容,将其转化为宛如真人般的自然语音‌。CosyVoice2在原始1的基础上,把QWEN2模型接入CosyVoice的LLM部分,实现了推理加速 + +还可以通过参数 `--in_stream`设置模型为 **流式输入模式**,支持逐字符或分块输入文本,显著降低长文本生成延迟,提升实时交互体验。 - 版本说明: ``` @@ -57,6 +59,20 @@ mv ../modeling_qwen2.py ./transformers/src/transformers/models/qwen2 ``` + 文件目录结构大致如下: + ```text + 📁 CosyVoice/ + ├── 📁 CosyVoice2/ + | |── 📄 diff_CosyVoice.patch + | |── 📄 modeling_qwen2.py + | |── 📁 CosyVoice + | |── 📁 cosyVoice源码文件 # cosyVoice的源码文件,此处不一一列举 + │ ├── 📁 CosyVoice-0.5B/ # 权重文件 + │ ├── 📁 transformers/ # transformers文件,里面有修改过的modeling_qwen2.py文件 + │ ├── 📄 infer.py # 推理脚本 + │ └── 📄 modify_onnx.py # 模型转换脚本 + ``` + 2. 安装依赖 ``` pip3 install -r ../requirements.txt @@ -64,37 +80,42 @@ ``` 注:如果遇到无法安装WeTextProcessing的场景,可以参考以下方法手动安装编译 ```bash - # 下载安装包并解压 - wget https://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.8.3.tar.gz - # 进入目录后编译安装 - ./configure --enable-far --enable-mpdt --enable-pdt - make install - pip3 install WeTextProcessing==1.0.4.1 + # 下载安装包并解压 + wget https://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.8.3.tar.gz + # 进入目录后编译安装 + ./configure --enable-far --enable-mpdt --enable-pdt + make -j$(nproc) + make install + # 确认动态库文件存在: + ls /usr/local/lib/libfstmpdtscript.so.26 + # 配置动态库路径 + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + sudo ldconfig + # 安装WeTextProcessing + pip3 install WeTextProcessing==1.0.4.1 ``` 3. 安装msit工具 - 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgen组件。 - + 参考[msit](https://gitee.com/ascend/msit)安装工具中的benchmark和surgen组件。(未安装会提示 ais_bench 导入失败报错) + 4. 获取权重数据 本案例以CosyVoice2-0.5B为例,其他权重请自行适配。将下载下来的权重**放在CosyVoice目录下**。 - 因cosyvoice2在2025年4月底更新过一次代码权重,因此需要使用`snapshot_download`下载指定commit id的权重。 - ```python - from modelscope.hub.snapshot_download import snapshot_download - - model_id = "模型名称" # 例如 "iic/CosyVoice2-0.5B" - commit_id = "9bd5b08fc085bd93d3f8edb16b67295606290350" - - model_dir = snapshot_download( - model_id, - revision=commit_id, # 关键,传入commit_id - cache_dir="./my_model" - ) - print(f"模型下载至:{model_dir}") - ``` + 因cosyvoice2在2025年4月底更新过一次代码权重,因此下载时需要指定commit id,下载之前的权重。 + ```git + # 1. 克隆 + git clone https://www.modelscope.cn/iic/CosyVoice2-0.5B.git + cd CosyVoice2-0.5B + + # 2. 切换到目标 commit + git checkout 9bd5b08fc085bd93d3f8edb16b67295606290350 + + # 3. 拉取 LFS 大文件(如模型权重) + git lfs pull + ``` 本用例采用sft预训练音色推理,请额外下载spk权重放到权重目录下 ``` @@ -105,7 +126,7 @@ ### 1 模型转换 -   模型权重中提供了 flow.decoder.estimator.fp32.onnx和speech_tokenizer_v2.onnx两个onnx模型,对其进行结构修改后使用`ATC`工具将`.onnx`文件转为离线推理模型`.om`文件。 +模型权重中提供了 `flow.decoder.estimator.fp32.onnx`和`speech_tokenizer_v2.onnx`两个onnx模型,对其进行结构修改后使用`ATC`工具将`.onnx`文件转为离线推理模型`.om`文件。 1. 修改onnx模型结构 @@ -142,22 +163,27 @@ 2. 设置环境变量,执行推理命令 ``` - # 指定使用NPU ID,默认为0 + # 1. 指定使用NPU ID,默认为0 export ASCEND_RT_VISIBLE_DEVICES=0 + # 2. 设置环境变量 export PYTHONPATH=third_party/Matcha-TTS:$PYTHONPATH export PYTHONPATH=transformers/src:$PYTHONPATH - python3 infer.py --model_path=${CosyVoice2-0.5B} --stream + # 3. 执行推理脚本 + python3 infer.py --model_path=${CosyVoice2-0.5B} --out_stream ``` - --model_path: 权重路径 - --warm_up_times:warm up次数,默认为2 - --infer_count:循环推理次数,默认为20 - - --stream:是否执行流式推理 + - --stream_in:是否执行流式输入推理 + - --stream_out:是否执行流式输出推理 - 在推理开始后,首先会默认执行warm_up,目的是执行首次编译,首次编译时间较长,首次编译结束后,会在当前目录下生成.torchair_cache文件,后续推理无需重复编译,在warm_up结束后,会执行推理操作,并将推理结果保存在'sft_i.wav'中,并打屏性能数据:实时率(rtf),指的是平均1s时长的音频需要多少时间处理。 + 在推理开始后,首先会默认执行warm_up,目的是执行首次编译,首次编译时间较长,首次编译结束后,会在当前目录下生成.torchair_cache文件,后续推理无需重复编译,在warm_up结束后,会执行推理操作: + * 非流式输入:将推理结果保存在`sft_i.wav`中,并打屏性能数据:实时率(rtf),指的是平均1s时长的音频需要多少时间处理。 + * 流式输入:将推理结果保存在`stream_input_out_i.wav`文件中,并打屏性能数据:实时率(rtf) ### 3 性能数据 - |模型|芯片|rtf(实时率)| - |------|------|------| - |cosyvoice|800I A2|0.28s| + | 模型 |芯片|rtf(实时率)| + |-----------|------|------| + | cosyvoice |800I A2|0.28s| diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch index d9f3153693..caf3594b29 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch @@ -1,12 +1,13 @@ diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py -index e2d62e2..99f8463 100644 +index e2d62e2..dccea41 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py -@@ -13,11 +13,13 @@ +@@ -13,11 +13,14 @@ # limitations under the License. import os import time +import platform ++import datetime from typing import Generator from tqdm import tqdm from hyperpyyaml import load_hyperpyyaml @@ -16,16 +17,16 @@ index e2d62e2..99f8463 100644 from cosyvoice.cli.frontend import CosyVoiceFrontEnd from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model from cosyvoice.utils.file_utils import logging -@@ -126,7 +128,7 @@ class CosyVoice: - +@@ -126,7 +129,7 @@ class CosyVoice: + class CosyVoice2(CosyVoice): - + - def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False): + def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, load_om=False): self.instruct = True if '-Instruct' in model_dir else False self.model_dir = model_dir self.fp16 = fp16 -@@ -155,6 +157,16 @@ class CosyVoice2(CosyVoice): +@@ -155,6 +158,16 @@ class CosyVoice2(CosyVoice): self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'), '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir), self.fp16) @@ -40,10 +41,30 @@ index e2d62e2..99f8463 100644 + self.model.flow.decoder.flow_om_static = flow_om_static + self.model.flow.decoder.flow_om = flow_om del configs - + def inference_instruct(self, *args, **kwargs): +@@ -171,3 +184,19 @@ class CosyVoice2(CosyVoice): + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() ++ ++ def inference_sft_streaming_input(self, tts_text, char_idx, spk_id, user_id, input_end, stream=False, speed=1.0, text_frontend=True): ++ for i in [tts_text]: ++ model_input = self.frontend.frontend_sft(i, spk_id) ++ model_input["user_id"] = user_id ++ model_input["input_end"] = input_end ++ model_input['char_idx'] = char_idx ++ ++ start_time = time.time() ++ # print('synthesis text {}'.format(i)) ++ for model_output in self.model.tts_streaming_input(**model_input, stream=stream, speed=speed): ++ speech_len = model_output['tts_speech'].shape[1] / self.sample_rate ++ print("finish 1 chunk inference ", datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')) ++ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) ++ yield model_output ++ start_time = time.time() diff --git a/cosyvoice/cli/frontend.py b/cosyvoice/cli/frontend.py -index 6e10f00..0eb45a8 100644 +index 6e10f00..25ad767 100644 --- a/cosyvoice/cli/frontend.py +++ b/cosyvoice/cli/frontend.py @@ -71,6 +71,8 @@ class CosyVoiceFrontEnd: @@ -52,7 +73,7 @@ index 6e10f00..0eb45a8 100644 self.inflect_parser = inflect.engine() + self.speech_om = None + self.flow_om = None - + def _extract_text_token(self, text): if isinstance(text, Generator): @@ -92,11 +94,16 @@ class CosyVoiceFrontEnd: @@ -78,10 +99,25 @@ index 6e10f00..0eb45a8 100644 speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) return speech_token, speech_token_len diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py -index 9ebf8cb..af31dc9 100644 +index 9ebf8cb..064439c 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py -@@ -362,13 +362,18 @@ class CosyVoice2Model(CosyVoiceModel): +@@ -314,6 +314,14 @@ class CosyVoice2Model(CosyVoiceModel): + self.llm_end_dict = {} + self.hift_cache_dict = {} + ++ # add for support streaming input ++ self.first_chunk_size = 20 ++ self.token_offset_dict = {} ++ self.prompt_text_dict = {} ++ self.prompt_speech_token_dict = {} ++ self.speech_feat_dict = {} ++ self.embedding_dict = {} ++ + def load_jit(self, flow_encoder_model): + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder +@@ -362,12 +370,17 @@ class CosyVoice2Model(CosyVoiceModel): with self.lock: self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False self.hift_cache_dict[this_uuid] = None @@ -91,7 +127,6 @@ index 9ebf8cb..af31dc9 100644 token_offset = 0 - while True: - time.sleep(0.1) -- if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: + # 删除线程操作,串行执行推理,加速首包时延 + for i in self.llm.inference(text=text.to(self.device), + text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device), @@ -101,11 +136,10 @@ index 9ebf8cb..af31dc9 100644 + prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), + embedding=llm_embedding.to(self.device)): + self.tts_speech_token_dict[this_uuid].append(i) -+ if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: + if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len: this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, -@@ -379,10 +384,6 @@ class CosyVoice2Model(CosyVoiceModel): +@@ -379,10 +392,6 @@ class CosyVoice2Model(CosyVoiceModel): finalize=False) token_offset += self.token_hop_len yield {'tts_speech': this_tts_speech.cpu()} @@ -116,7 +150,7 @@ index 9ebf8cb..af31dc9 100644 this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) this_tts_speech = self.token2wav(token=this_tts_speech_token, prompt_token=flow_prompt_speech_token, -@@ -393,6 +394,8 @@ class CosyVoice2Model(CosyVoiceModel): +@@ -393,6 +402,8 @@ class CosyVoice2Model(CosyVoiceModel): finalize=True) yield {'tts_speech': this_tts_speech.cpu()} else: @@ -125,6 +159,91 @@ index 9ebf8cb..af31dc9 100644 # deal with all tokens p.join() this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) +@@ -409,3 +420,83 @@ class CosyVoice2Model(CosyVoiceModel): + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + torch.cuda.empty_cache() ++ ++ def tts_streaming_input(self, text, char_idx, flow_embedding, llm_embedding=torch.zeros(0, 192), ++ prompt_text=torch.zeros(1, 0, dtype=torch.int32), ++ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), ++ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), ++ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): ++ this_uuid = kwargs["user_id"] ++ if this_uuid not in self.tts_speech_token_dict: ++ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False ++ self.hift_cache_dict[this_uuid] = None ++ self.token_offset_dict[this_uuid] = 0 ++ ++ self.prompt_text_dict[this_uuid] = prompt_text ++ self.prompt_speech_token_dict[this_uuid] = flow_prompt_speech_token ++ self.speech_feat_dict[this_uuid] = prompt_speech_feat ++ self.embedding_dict[this_uuid] = flow_embedding ++ else: ++ prompt_text = self.prompt_text_dict[this_uuid] ++ llm_prompt_speech_token = self.prompt_speech_token_dict[this_uuid] ++ flow_prompt_speech_token = self.prompt_speech_token_dict[this_uuid] ++ flow_embedding = self.embedding_dict[this_uuid] ++ llm_embedding = self.embedding_dict[this_uuid] ++ prompt_speech_feat = self.speech_feat_dict[this_uuid] ++ ++ for i in self.llm.inference_bistream_streaming_input(text=text, ++ char_idx=torch.tensor([char_idx]).to(self.device), ++ prompt_text=prompt_text.to(self.device), ++ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device), ++ prompt_speech_token=llm_prompt_speech_token.to(self.device), ++ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device), ++ embedding=llm_embedding.to(self.device), ++ uuid=this_uuid, input_end=kwargs['input_end']): ++ self.tts_speech_token_dict[this_uuid].append(i) ++ ++ assert stream is True, "output must be streaming" ++ ++ while True: ++ is_first_chunk_ready = (self.token_offset_dict[this_uuid] == 0 and len(self.tts_speech_token_dict[this_uuid]) >= self.first_chunk_size + self.flow.pre_lookahead_len) ++ is_next_chunk_ready = (self.token_offset_dict[this_uuid] > 0 and len(self.tts_speech_token_dict[this_uuid]) - self.token_offset_dict[this_uuid] >= self.token_hop_len + self.flow.pre_lookahead_len) ++ if is_first_chunk_ready or is_next_chunk_ready: ++ if self.token_offset_dict[this_uuid] == 0: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.first_chunk_size + self.flow.pre_lookahead_len]).unsqueeze(dim=0) ++ else: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:self.token_offset_dict[this_uuid] + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0) # 0-53, 0-103, 0-153... ++ this_tts_speech = self.token2wav(token=this_tts_speech_token, ++ prompt_token=flow_prompt_speech_token, ++ prompt_feat=prompt_speech_feat, ++ embedding=flow_embedding, ++ uuid=this_uuid, ++ token_offset=self.token_offset_dict[this_uuid], ++ finalize=False) ++ if self.token_offset_dict[this_uuid] == 0: ++ self.token_offset_dict[this_uuid] += self.first_chunk_size ++ else: ++ self.token_offset_dict[this_uuid] += self.token_hop_len ++ yield {'tts_speech': this_tts_speech.cpu()} ++ # 是否需要退出循环(token 不够下一次推理) ++ if len(self.tts_speech_token_dict[this_uuid]) - self.token_offset_dict[this_uuid] < self.token_hop_len + self.flow.pre_lookahead_len: ++ break ++ ++ if kwargs['input_end'] is True: ++ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0) ++ this_tts_speech = self.token2wav(token=this_tts_speech_token, ++ prompt_token=flow_prompt_speech_token, ++ prompt_feat=prompt_speech_feat, ++ embedding=flow_embedding, ++ uuid=this_uuid, ++ token_offset=self.token_offset_dict[this_uuid], ++ finalize=True) ++ yield {'tts_speech': this_tts_speech.cpu()} ++ ++ self.tts_speech_token_dict.pop(this_uuid) ++ self.llm_end_dict.pop(this_uuid) ++ self.hift_cache_dict.pop(this_uuid) ++ ++ self.token_offset_dict.pop(this_uuid) ++ self.prompt_text_dict.pop(this_uuid) ++ self.prompt_speech_token_dict.pop(this_uuid) ++ self.speech_feat_dict.pop(this_uuid) ++ self.embedding_dict.pop(this_uuid) +\ No newline at end of file diff --git a/cosyvoice/flow/flow_matching.py b/cosyvoice/flow/flow_matching.py index 6a60f6d..fbe7545 100644 --- a/cosyvoice/flow/flow_matching.py @@ -135,15 +254,15 @@ index 6a60f6d..fbe7545 100644 import torch.nn.functional as F +import numpy as np from matcha.models.components.flow_matching import BASECFM - - + + @@ -32,6 +33,8 @@ class ConditionalCFM(BASECFM): # Just change the architecture of the estimator here self.estimator = estimator self.lock = threading.Lock() + self.flow_om = None + self.flow_om_static = None - + @torch.inference_mode() def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): @@ -105,12 +108,26 @@ class ConditionalCFM(BASECFM): @@ -180,7 +299,7 @@ index 6a60f6d..fbe7545 100644 dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) x = x + dt * dphi_dt diff --git a/cosyvoice/hifigan/generator.py b/cosyvoice/hifigan/generator.py -index c47bf05..a7c8b37 100644 +index c47bf05..7f3e4ae 100644 --- a/cosyvoice/hifigan/generator.py +++ b/cosyvoice/hifigan/generator.py @@ -23,6 +23,7 @@ import torch.nn.functional as F @@ -190,17 +309,17 @@ index c47bf05..a7c8b37 100644 +from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrizations import weight_norm from torch.distributions.uniform import Uniform - + @@ -99,8 +100,8 @@ class ResBlock(torch.nn.Module): - + def remove_weight_norm(self): for idx in range(len(self.convs1)): - remove_weight_norm(self.convs1[idx]) - remove_weight_norm(self.convs2[idx]) + remove_parametrizations(self.convs1[idx], "weight") + remove_parametrizations(self.convs2[idx], "weight") - - + + class SineGen(torch.nn.Module): @@ -319,14 +320,11 @@ class HiFTGenerator(nn.Module): def remove_weight_norm(self): @@ -219,41 +338,41 @@ index c47bf05..a7c8b37 100644 + remove_parametrizations(self.conv_post, 'weight') for l in self.source_resblocks: l.remove_weight_norm() - + @@ -346,9 +344,7 @@ class HiFTGenerator(nn.Module): self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) return inverse_transform - + - def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: - s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) - s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) + def decode(self, x: torch.Tensor, s_stft: torch.Tensor, index: torch.int) -> torch.Tensor: - + x = self.conv_pre(x) for i in range(self.num_upsamples): @@ -356,7 +352,7 @@ class HiFTGenerator(nn.Module): x = self.ups[i](x) - + if i == self.num_upsamples - 1: - x = self.reflection_pad(x) + x = torch.cat((x, x[:,:,-2:-1]), -1) - + # fusion si = self.source_downs[i](s_stft) @@ -373,12 +369,10 @@ class HiFTGenerator(nn.Module): - + x = F.leaky_relu(x) x = self.conv_post(x) - magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) - phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy + magnitude = torch.exp(x[:, :index, :]) + phase = torch.sin(x[:, index:, :]) # actually, sin is redundancy - + - x = self._istft(magnitude, phase) - x = torch.clamp(x, -self.audio_limit, self.audio_limit) - return x + return magnitude, phase - + def forward( self, @@ -407,5 +401,12 @@ class HiFTGenerator(nn.Module): @@ -271,13 +390,13 @@ index c47bf05..a7c8b37 100644 + generated_speech = torch.clamp(x, -self.audio_limit, self.audio_limit) return generated_speech, s diff --git a/cosyvoice/llm/llm.py b/cosyvoice/llm/llm.py -index bbd3305..ce82157 100644 +index bbd3305..7eb32ad 100644 --- a/cosyvoice/llm/llm.py +++ b/cosyvoice/llm/llm.py @@ -229,16 +229,17 @@ class Qwen2Encoder(torch.nn.Module): super().__init__() self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path) - + - def forward_one_step(self, xs, masks, cache=None): - input_masks = masks[:, -1, :] - outs = self.model( @@ -302,7 +421,23 @@ index bbd3305..ce82157 100644 xs = outs.hidden_states[-1] new_cache = outs.past_key_values return xs, new_cache -@@ -318,10 +319,17 @@ class Qwen2LM(TransformerLM): +@@ -283,6 +284,15 @@ class Qwen2LM(TransformerLM): + self.sampling = sampling + self.mix_ratio = mix_ratio + ++ # 5. added for support streaming input ++ self.prompt_speech_token_emb_dict = {} ++ self.lm_input_dict = {} ++ self.out_tokens_dict = {} ++ self.cache_dict = {} ++ self.text_cache_dict = {} ++ self.next_fill_index = {} ++ self.prompt_length = {} ++ + @torch.inference_mode() + def inference( + self, +@@ -318,9 +328,16 @@ class Qwen2LM(TransformerLM): # 5. step by step decode out_tokens = [] cache = None @@ -315,28 +450,172 @@ index bbd3305..ce82157 100644 + masks = None y_pred, cache = self.llm.forward_one_step(lm_input, - masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool), -- cache=cache) + masks=masks, + prompt_length=prompt_length, -+ cache=cache) + cache=cache) logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item() - if top_ids == self.speech_token_size: -@@ -331,7 +339,7 @@ class Qwen2LM(TransformerLM): +@@ -331,7 +348,7 @@ class Qwen2LM(TransformerLM): # in stream mode, yield token one by one yield top_ids out_tokens.append(top_ids) - lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() - + @torch.inference_mode() def inference_bistream( +@@ -432,3 +449,144 @@ class Qwen2LM(TransformerLM): + # in stream mode, yield token one by one + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) ++ ++ @torch.inference_mode() ++ def inference_bistream_streaming_input( ++ self, ++ text: torch.Tensor, ++ char_idx: torch.Tensor, ++ prompt_text: torch.Tensor, ++ prompt_text_len: torch.Tensor, ++ prompt_speech_token: torch.Tensor, ++ prompt_speech_token_len: torch.Tensor, ++ embedding: torch.Tensor, ++ uuid: str, ++ input_end: bool, ++ sampling: int = 25, ++ max_token_text_ratio: float = 20, ++ min_token_text_ratio: float = 2, ++ ) -> Generator[torch.Tensor, None, None]: ++ ++ def build_causal_mask(query_len, key_len, devices): ++ num_past = key_len - query_len ++ assert num_past >= 0 ++ causal_mask = torch.triu(torch.ones((query_len, query_len), device=devices), diagonal=1).to(torch.bool) ++ left_padding = torch.zeros((query_len, num_past), dtype=torch.bool, device=devices) ++ full_masks = torch.cat([left_padding, causal_mask], dim=-1) ++ return full_masks.unsqueeze(0) ++ ++ device = prompt_text.device ++ ++ if uuid not in self.cache_dict: ++ sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) ++ if prompt_speech_token_len != 0: ++ self.prompt_speech_token_emb_dict[uuid] = self.speech_embedding(prompt_speech_token) ++ else: ++ self.prompt_speech_token_emb_dict[uuid] = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device) ++ ++ self.lm_input_dict[uuid] = torch.concat([sos_eos_emb], dim=1) # [1,1,896] ++ ++ self.out_tokens_dict[uuid] = [] ++ self.cache_dict[uuid] = None ++ ++ self.text_cache_dict[uuid] = self.llm.model.model.embed_tokens(prompt_text) # [1, prompt_text, 896] ++ self.next_fill_index[uuid] = -1 ++ self.prompt_length[uuid] = 0 ++ ++ text_emb = self.llm.model.model.embed_tokens(text) ++ ++ for i in range(text_emb.size(1)): ++ self.text_cache_dict[uuid] = torch.concat([self.text_cache_dict[uuid], text_emb[:, i].unsqueeze(1)], dim=1) ++ index = 0 ++ while self.prompt_speech_token_emb_dict[uuid].size(1) != 0: ++ if self.text_cache_dict[uuid].size(1) >= self.mix_ratio[0]: ++ lm_input_text, lm_input_speech = self.text_cache_dict[uuid][:, :self.mix_ratio[0]], self.prompt_speech_token_emb_dict[uuid][:, :self.mix_ratio[1]] ++ index += 1 ++ logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1))) ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], lm_input_text, lm_input_speech], dim=1) ++ self.text_cache_dict[uuid], self.prompt_speech_token_emb_dict[uuid] = self.text_cache_dict[uuid][:, self.mix_ratio[0]:], self.prompt_speech_token_emb_dict[uuid][:, self.mix_ratio[1]:] ++ else: ++ break ++ ++ if self.prompt_speech_token_emb_dict[uuid].size(1) == 0: # 文本token数量多于音频token,混合完以后,剩余文本token,开始解码 ++ # 若上一次解码的 token 是 fill_token,说明 LLM 想要更多 text token ++ # 或者首次预测时,还没开始解码,out_tokens_dict 为空 ++ if ((len(self.out_tokens_dict[uuid]) != 0 and self.out_tokens_dict[uuid][-1] == self.speech_token_size + 2) ++ or (len(self.out_tokens_dict[uuid]) == 0 and self.lm_input_dict[uuid].size(1) == 1)): ++ # token数量够了 ++ if self.text_cache_dict[uuid].size(1) >= self.mix_ratio[0]: ++ lm_input_text = self.text_cache_dict[uuid][:, :self.mix_ratio[0]] # 抽出5个token ++ if len(self.out_tokens_dict[uuid]) != 0 and self.out_tokens_dict[uuid][-1] == self.speech_token_size + 2: # 预测出filling token,前面cache已经缓存,当前直接输入即可 ++ self.lm_input_dict[uuid] = lm_input_text ++ else: # sft刚开始预测,需要和sos token拼接在一起 ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], lm_input_text], dim=1) ++ self.text_cache_dict[uuid] = self.text_cache_dict[uuid][:, self.mix_ratio[0]:] ++ else: ++ continue ++ ++ while True: ++ self.prompt_length[uuid] += self.lm_input_dict[uuid].shape[1] ++ seq_len = self.prompt_length[uuid] ++ if self.lm_input_dict[uuid].shape[1] > 1: ++ masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, ++ self.lm_input_dict[uuid].device) ++ else: ++ masks = None ++ y_pred, self.cache_dict[uuid] = self.llm.forward_one_step(self.lm_input_dict[uuid], ++ masks=masks, ++ prompt_length=seq_len, ++ cache=self.cache_dict[uuid]) ++ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) ++ # 判断是否生成 filling_token: ++ if self.next_fill_index[uuid] != -1 and len(self.out_tokens_dict[uuid]) == self.next_fill_index[uuid]: ++ top_ids = self.speech_token_size + 2 # 该预测filling token了 ++ self.next_fill_index[uuid] += (self.mix_ratio[1] + 1) # 找到下一个filling token的位置 ++ else: ++ top_ids = self.sampling_ids(logp.squeeze(dim=0), self.out_tokens_dict[uuid], sampling, ignore_eos=True).item() ++ # 特殊 token 处理, fill_token → 中断预测、等待新文本 token。 ++ if top_ids == self.speech_token_size + 2: ++ self.next_fill_index[uuid] = len(self.out_tokens_dict[uuid]) + self.mix_ratio[1] + 1 # -1 > 30 ++ self.out_tokens_dict[uuid].append(top_ids) ++ if top_ids >= self.speech_token_size: ++ if top_ids == self.speech_token_size + 2: # 预测到了filling token, break掉迎接新的文本token ++ break ++ else: ++ raise ValueError('should not get token {}'.format(top_ids)) ++ yield top_ids ++ self.lm_input_dict[uuid] = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() ++ ++ if input_end: ++ # 3. final decode 文本全部送完,进行最后的解码。 ++ task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) ++ self.lm_input_dict[uuid] = torch.concat([self.lm_input_dict[uuid], self.text_cache_dict[uuid], task_id_emb, self.prompt_speech_token_emb_dict[uuid]], dim=1) ++ logging.info('no more text token, decode until met eos') ++ while True: ++ self.prompt_length[uuid] += self.lm_input_dict[uuid].shape[1] ++ seq_len = self.prompt_length[uuid] ++ if self.lm_input_dict[uuid].shape[1] > 1: ++ masks = build_causal_mask(self.lm_input_dict[uuid].shape[1], seq_len, self.lm_input_dict[uuid].device) ++ else: ++ masks = None ++ y_pred, self.cache_dict[uuid] = self.llm.forward_one_step(self.lm_input_dict[uuid], ++ masks=masks, ++ prompt_length=seq_len, ++ cache=self.cache_dict[uuid]) ++ logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) ++ top_ids = self.sampling_ids(logp.squeeze(dim=0), self.out_tokens_dict[uuid], sampling, ignore_eos=False).item() ++ self.out_tokens_dict[uuid].append(top_ids) ++ if top_ids >= self.speech_token_size: ++ if top_ids == self.speech_token_size: ++ break ++ else: ++ raise ValueError('should not get token {}'.format(top_ids)) ++ # in stream mode, yield token one by one ++ yield top_ids ++ self.lm_input_dict[uuid] = self.speech_embedding.weight[top_ids].reshape(1, 1, -1).detach().clone() ++ ++ # this user is done ++ self.prompt_speech_token_emb_dict.pop(uuid) ++ self.lm_input_dict.pop(uuid) ++ self.out_tokens_dict.pop(uuid) ++ self.cache_dict.pop(uuid) ++ self.text_cache_dict.pop(uuid) ++ self.next_fill_index.pop(uuid) +\ No newline at end of file diff --git a/cosyvoice/utils/common.py b/cosyvoice/utils/common.py -index 3e61a8c..f6c4346 100644 +index 3e61a8c..d316b92 100644 --- a/cosyvoice/utils/common.py +++ b/cosyvoice/utils/common.py @@ -107,12 +107,33 @@ def init_weights(m, mean=0.0, std=0.01): - + # Repetition Aware Sampling in VALL-E 2 def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): - top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) @@ -345,7 +624,7 @@ index 3e61a8c..f6c4346 100644 if rep_num >= win_size * tau_r: top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) return top_ids - + +def dst_sampling(weighted_scores, top_p=0.8, top_k=25): + + sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) @@ -367,9 +646,6 @@ index 3e61a8c..f6c4346 100644 + top_ids = selected_indices[selected_prob.multinomial(1, replacement=True)] + + return top_ids - + def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): prob, indices = [], [] --- -2.37.1 - diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py index 8743a684fe..012ade6500 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py @@ -9,6 +9,7 @@ # See the Mulan PSL v2 for more details. import argparse +from tqdm import tqdm import torch import torchaudio import torch_npu @@ -19,14 +20,61 @@ from cosyvoice.cli.cosyvoice import CosyVoice2 from cosyvoice.utils.file_utils import load_wav +def no_stream_input_inference(args, cosyvoice, prompt_txt): + with torch.no_grad(): + print('warm up start') + for _ in range(args.warm_up_times): + for _ in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): + pass + print('warm up end') + for _ in range(args.infer_count): + for i, j in enumerate(cosyvoice.inference_sft(prompt_txt[0], '中文女', stream=args.stream_out)): + torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) + + +def stream_input_inference(args, cosyvoice, prompt_txt): + infer_res = [torch.tensor([]) for _ in range(args.infer_count)] + + with torch.no_grad(): + print("warm up start") + for w_step in range(args.warm_up_times): + print(f"第{w_step+1}/{args.warm_up_times}轮预热:↓↓↓") + print(f"curr prompt text:{prompt_txt[w_step % len(prompt_txt)]}") + for char_idx, char in enumerate(prompt_txt[w_step % len(prompt_txt)]): + if char_idx == len(prompt_txt[w_step % len(prompt_txt)]) - 1: + for _ in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=True, stream=args.stream_out)): + pass + else: + for _ in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=False, stream=args.stream_out)): + pass + print("warm up end") + + print("inference start") + for i_step in range(args.infer_count): + print(f"第{i_step}/{args.infer_count}轮推理:↓↓↓") + print(f"curr prompt text:{prompt_txt[w_step % len(prompt_txt)]}") + for char_idx, char in enumerate(prompt_txt[i_step % len(prompt_txt)]): + if char_idx == len(prompt_txt[i_step % len(prompt_txt)]) - 1: + for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=True, stream=args.stream_out)): + infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + + else: + for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=False, stream=args.stream_out)): + infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + + print(f"save out wav file ...") + for i_step in tqdm(range(args.infer_count)): + torchaudio.save(f"stream_input_out_{i_step+1}.wav", infer_res[i_step], 24000) + if __name__ == '__main__': torch_npu.npu.set_compile_mode(jit_compile=False) - parser = argparse.ArgumentParser(description="CosyVoice infer") + parser = argparse.ArgumentParser(description="CosyVoice2 infer") parser.add_argument("--model_path", type=str, help="model path") parser.add_argument('--warm_up_times', default=2, type=int, help='warm up times') parser.add_argument('--infer_count', default=20, type=int, help='infer loop count') - parser.add_argument('--stream', action="store_true", help='stream infer') + parser.add_argument('--stream_in', action="store_true", help='stream input infer') + parser.add_argument('--stream_out', action="store_true", help='stream output infer') args = parser.parse_args() cosyvoice = CosyVoice2(args.model_path, load_om=True, fp16=True) @@ -41,15 +89,15 @@ if __name__ == '__main__': npu_backend = tng.get_npu_backend(compiler_config=config) cosyvoice.model.hift.decode = torch.compile(cosyvoice.model.hift.decode, dynamic=True, fullgraph=True, backend=npu_backend) - # 输入数据加载 - prompt_txt = '收到好友从远方寄来的生日礼物,那份意外的惊喜和深深的祝福,让我心中充满了甜蜜的快乐,笑容如花儿般绽放。' + prompt_txt = [ + '收到好友从远方寄来的生日礼物,那份意外的惊喜和深深的祝福,让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', + '全球每年有超过一百三十五万人,因吸烟而死亡' + ] - with torch.no_grad(): - print('warm up start') - for _ in range(args.warm_up_times): - next(cosyvoice.inference_sft(prompt_txt, '中文女', stream=args.stream)) - print('warm up end') - for _ in range(args.infer_count): - for i, j in enumerate(cosyvoice.inference_sft(prompt_txt, '中文女', stream=args.stream)): - torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate) \ No newline at end of file + # 普通输入(非流式输入) + if not args.stream_in: + no_stream_input_inference(args, cosyvoice, prompt_txt) + # 流式输入 + else: + stream_input_inference(args, cosyvoice, prompt_txt) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py index 0bdfdfcb06..75cc4dab84 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py @@ -225,24 +225,37 @@ class Qwen2SdpaAttention(Qwen2Attention): tmp_ids = updated_kv_positions.reshape(-1) torch_npu.scatter_update_(past_key_value.key_cache[self.layer_idx], tmp_ids, key_states, 1) torch_npu.scatter_update_(past_key_value.value_cache[self.layer_idx], tmp_ids, value_states, 1) - kv_states = past_key_value[self.layer_idx] if q_len == 1 else (key_states, value_states) - key_states = kv_states[0] - value_states = kv_states[1] - + # 流式输入场景,decode阶段 + if q_len == 1: + key_states = past_key_value[self.layer_idx][0] + value_states = past_key_value[self.layer_idx][1] + # 流式输入场景,首次prefill阶段 + elif q_len == actual_seq_len[0]: + key_states = key_states + value_states = value_states + # 流式输入场景,后续prefill阶段 + elif q_len < actual_seq_len[0]: + key_states = past_key_value.key_cache[self.layer_idx][:, :actual_seq_len[0]] + value_states = past_key_value.value_cache[self.layer_idx][:, :actual_seq_len[0]] if q_len > 1: # prefill阶段利用PFA自定义算子执行计算,因为bs为1,mask固定为下三角全为0上三角全为负无穷的倒三角mask矩阵 - attn_output = torch_npu.npu_prompt_flash_attention(query_states, key_states.contiguous(), - value_states.contiguous(), num_heads=self.num_heads, + attn_output = torch_npu.npu_prompt_flash_attention(query_states, + key_states.contiguous(), + value_states.contiguous(), + num_heads=self.num_heads, input_layout="BSND", scale_value=1 / math.sqrt(self.head_dim), pre_tokens=65535, next_tokens=0, atten_mask=attention_mask, + sparse_mode=1, num_key_value_heads=self.num_key_value_heads) else: # decode阶段利用IFA自定义算子执行计算,qkv的sequence都为1,该算子采用tiling下沉,视为静态算子,支持整图下发 - attn_output = torch_npu.npu_incre_flash_attention(query_states, key_states.contiguous(), - value_states.contiguous(), num_heads=self.num_heads, + attn_output = torch_npu.npu_incre_flash_attention(query_states, + key_states.contiguous(), + value_states.contiguous(), + num_heads=self.num_heads, input_layout="BSND", scale_value=1 / math.sqrt(self.head_dim), atten_mask=None, @@ -416,9 +429,11 @@ class Qwen2Model(Qwen2PreTrainedModel): config.experimental_config.frozen_parameter = True # tiling下沉,主要针对IFA算子,使其算子tiling操作在AICPU上执行 config.experimental_config.tiling_schedule_optimize = True + # torchair的cache编译,保证模型编译cache文件,避免重复推理 self.cached_decode = tng.inference.cache_compile(self.decode, config=config) - self.cached_prefill = tng.inference.cache_compile(self.prefill, config=config) + self.cached_prefill_1 = tng.inference.cache_compile(self.prefill_1, config=config) # 用于首次prefill,无kv_cache + self.cached_prefill_2 = tng.inference.cache_compile(self.prefill_2, config=config) # 用于后续prefill,有kv_cache def get_input_embeddings(self): return self.embed_tokens @@ -451,9 +466,12 @@ class Qwen2Model(Qwen2PreTrainedModel): return_dict: Optional[bool] = None, lm_head: Optional[object] = None ) -> Union[Tuple, BaseModelOutputWithPast]: - # prefill和decode需要编译为两个不同的模型 - if inputs_embeds.size(1) > 1: - return self.cached_prefill( + # prefill_1, prefill_2和decode需要编译为3个不同的模型 + seq_len = inputs_embeds.size(1) + + # 流式输入场景,首次prefill + if seq_len > 1 and seq_len == actual_seq_len[0]: + return self.cached_prefill_1( input_ids, attention_mask, position_ids, @@ -468,6 +486,26 @@ class Qwen2Model(Qwen2PreTrainedModel): return_dict, lm_head ) + + # 流式输入场景,后续prefill + elif 1 < seq_len < actual_seq_len[0]: + return self.cached_prefill_2( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + kv_padding_size, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + + # 流式输入场景,decode else: return self.cached_decode( input_ids, @@ -517,7 +555,7 @@ class Qwen2Model(Qwen2PreTrainedModel): lm_head ) - def prefill( + def prefill_1( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -549,6 +587,38 @@ class Qwen2Model(Qwen2PreTrainedModel): lm_head ) + def prefill_2( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + updated_kv_positions: Optional[torch.LongTensor] = None, + kv_padding_size: Optional[torch.LongTensor] = None, + actual_seq_len: Optional[list] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + lm_head: Optional[object] = None + ): + return self._forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + updated_kv_positions, + kv_padding_size, + actual_seq_len, + inputs_embeds, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + lm_head + ) + def _forward( self, @@ -593,7 +663,7 @@ class Qwen2Model(Qwen2PreTrainedModel): kv_shape = ( batch_size, self.config.max_position_embeddings, self.config.num_key_value_heads, - self.config.hidden_size // self.config.num_attention_heads) + self.config.hidden_size // self.config.num_attention_heads) # (1, 32768, 2, 64) past_key_values = () for _ in range(self.config.num_hidden_layers): k_cache = torch.zeros(kv_shape, dtype=inputs_embeds.dtype, device=inputs_embeds.device) @@ -601,14 +671,14 @@ class Qwen2Model(Qwen2PreTrainedModel): past_key_values += ((k_cache, v_cache),) past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = self.max_position_embeddings if seq_length == 1 else 0 + past_key_values_length = self.max_position_embeddings if actual_seq_len[0] > inputs_embeds.shape[1] else 0 if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) # tensor([0, 1, 2, 3, 4, 5], device='npu:0') else: position_ids = position_ids.view(-1, seq_length).long() @@ -818,12 +888,22 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): seq_length = inputs_embeds.shape[1] if past_key_values: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if seq_length > 1: + + # 流式输入场景,首次prefill + if seq_length > 1 and prompt_length == inputs_embeds.shape[1]: updated_kv_positions = torch.zeros(bsz, dtype=torch.long, device=inputs_embeds.device) position_ids = None + # 流式输入场景,后续prefill + elif seq_length > 1 and prompt_length > inputs_embeds.shape[1]: + # updated_kv_positions,kv_cache需要更新的起始位置 + updated_kv_positions = torch.ones(bsz, dtype=torch.long, device=inputs_embeds.device) * (prompt_length - inputs_embeds.shape[1]) + tmp_head = prompt_length - inputs_embeds.shape[1] + tmp_tail = prompt_length + position_ids = torch.arange(tmp_head, tmp_tail, dtype=torch.long, device=inputs_embeds.device) + # 流式输入场景,decode else: updated_kv_positions = torch.ones(bsz, dtype=torch.long, device=inputs_embeds.device) * (prompt_length - 1) - position_ids = torch.tensor([prompt_length], device=inputs_embeds.device) + position_ids = torch.tensor([prompt_length - 1], device=inputs_embeds.device) # ifa Computational optimization inputs kv_padding_size = torch.tensor(self.config.max_position_embeddings - prompt_length, device=inputs_embeds.device) -- Gitee From f384272eb4cc5cfbb98df8313b8a303f9563f9ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=B1=9F=E6=B1=9F?= Date: Wed, 4 Jun 2025 09:25:13 +0800 Subject: [PATCH 2/2] =?UTF-8?q?fix:=E6=A0=B9=E6=8D=AE=E6=A3=80=E8=A7=86?= =?UTF-8?q?=E6=84=8F=E8=A7=81=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../audio/CosyVoice/CosyVoice2/README.md | 8 +++- .../CosyVoice/CosyVoice2/diff_CosyVoice.patch | 4 +- .../audio/CosyVoice/CosyVoice2/infer.py | 41 ++++++++++--------- .../CosyVoice/CosyVoice2/modeling_qwen2.py | 14 ++++--- .../CosyVoice/CosyVoice2/requirements.txt | 2 +- 5 files changed, 40 insertions(+), 29 deletions(-) diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md index 8ac1f82660..48d2b3a600 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/README.md @@ -40,6 +40,12 @@ # 快速上手 +## 获取本仓源码 +``` +git clone https://gitee.com/ascend/ModelZoo-PyTorch.git +cd ModelZoo-PyTorch/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2 +``` + ## 获取源码 1. 获取`PyTorch`源码 @@ -169,7 +175,7 @@ export PYTHONPATH=third_party/Matcha-TTS:$PYTHONPATH export PYTHONPATH=transformers/src:$PYTHONPATH # 3. 执行推理脚本 - python3 infer.py --model_path=${CosyVoice2-0.5B} --out_stream + python3 infer.py --model_path=${CosyVoice2-0.5B} --stream_out ``` - --model_path: 权重路径 - --warm_up_times:warm up次数,默认为2 diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch index caf3594b29..f5f1b9016c 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/diff_CosyVoice.patch @@ -99,7 +99,7 @@ index 6e10f00..25ad767 100644 speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device) return speech_token, speech_token_len diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py -index 9ebf8cb..064439c 100644 +index 9ebf8cb..dc8b4d2 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -314,6 +314,14 @@ class CosyVoice2Model(CosyVoiceModel): @@ -169,7 +169,7 @@ index 9ebf8cb..064439c 100644 + llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), + flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), + prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs): -+ this_uuid = kwargs["user_id"] ++ this_uuid = kwargs.get("user_id", "AscendDefaultUser") + if this_uuid not in self.tts_speech_token_dict: + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False + self.hift_cache_dict[this_uuid] = None diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py index 012ade6500..972cb18522 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/infer.py @@ -33,34 +33,37 @@ def no_stream_input_inference(args, cosyvoice, prompt_txt): def stream_input_inference(args, cosyvoice, prompt_txt): + + def inference_step(step, mode): + times = args.warm_up_times if mode == "warmup" else args.infer_count + print(f"第{step + 1}/{times}轮 {mode}:↓↓↓") + print(f"curr prompt text:{prompt_txt[step % len(prompt_txt)]}") + for char_idx, char in enumerate(prompt_txt[step % len(prompt_txt)]): + if char_idx == len(prompt_txt[step % len(prompt_txt)]) - 1: + for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=True, stream=args.stream_out)): + if mode == "warmup": + pass + else: + infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + else: + for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=False, stream=args.stream_out)): + if mode == "warmup": + pass + else: + infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + infer_res = [torch.tensor([]) for _ in range(args.infer_count)] with torch.no_grad(): print("warm up start") for w_step in range(args.warm_up_times): - print(f"第{w_step+1}/{args.warm_up_times}轮预热:↓↓↓") - print(f"curr prompt text:{prompt_txt[w_step % len(prompt_txt)]}") - for char_idx, char in enumerate(prompt_txt[w_step % len(prompt_txt)]): - if char_idx == len(prompt_txt[w_step % len(prompt_txt)]) - 1: - for _ in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=True, stream=args.stream_out)): - pass - else: - for _ in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=False, stream=args.stream_out)): - pass + inference_step(w_step, mode="warmup") print("warm up end") print("inference start") for i_step in range(args.infer_count): - print(f"第{i_step}/{args.infer_count}轮推理:↓↓↓") - print(f"curr prompt text:{prompt_txt[w_step % len(prompt_txt)]}") - for char_idx, char in enumerate(prompt_txt[i_step % len(prompt_txt)]): - if char_idx == len(prompt_txt[i_step % len(prompt_txt)]) - 1: - for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=True, stream=args.stream_out)): - infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) - - else: - for _, j in enumerate(cosyvoice.inference_sft_streaming_input(char, char_idx, "中文女", user_id="AscendUser", input_end=False, stream=args.stream_out)): - infer_res[i_step] = torch.cat((infer_res[i_step], j['tts_speech']), dim=1) + inference_step(i_step, mode="inference") + print("inference end") print(f"save out wav file ...") for i_step in tqdm(range(args.infer_count)): diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py index 75cc4dab84..fd33b1f674 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/modeling_qwen2.py @@ -237,6 +237,8 @@ class Qwen2SdpaAttention(Qwen2Attention): elif q_len < actual_seq_len[0]: key_states = past_key_value.key_cache[self.layer_idx][:, :actual_seq_len[0]] value_states = past_key_value.value_cache[self.layer_idx][:, :actual_seq_len[0]] + else: + raise ValueError(f"Unexpected q_len: {q_len}, actual_seq_len[0]: {actual_seq_len[0]}") if q_len > 1: # prefill阶段利用PFA自定义算子执行计算,因为bs为1,mask固定为下三角全为0上三角全为负无穷的倒三角mask矩阵 @@ -432,8 +434,8 @@ class Qwen2Model(Qwen2PreTrainedModel): # torchair的cache编译,保证模型编译cache文件,避免重复推理 self.cached_decode = tng.inference.cache_compile(self.decode, config=config) - self.cached_prefill_1 = tng.inference.cache_compile(self.prefill_1, config=config) # 用于首次prefill,无kv_cache - self.cached_prefill_2 = tng.inference.cache_compile(self.prefill_2, config=config) # 用于后续prefill,有kv_cache + self.cached_first_prefill = tng.inference.cache_compile(self.first_prefill, config=config) # 用于首次prefill,无kv_cache + self.cached_next_prefill = tng.inference.cache_compile(self.next_prefill, config=config) # 用于后续prefill,有kv_cache def get_input_embeddings(self): return self.embed_tokens @@ -471,7 +473,7 @@ class Qwen2Model(Qwen2PreTrainedModel): # 流式输入场景,首次prefill if seq_len > 1 and seq_len == actual_seq_len[0]: - return self.cached_prefill_1( + return self.cached_first_prefill( input_ids, attention_mask, position_ids, @@ -489,7 +491,7 @@ class Qwen2Model(Qwen2PreTrainedModel): # 流式输入场景,后续prefill elif 1 < seq_len < actual_seq_len[0]: - return self.cached_prefill_2( + return self.cached_next_prefill( input_ids, attention_mask, position_ids, @@ -555,7 +557,7 @@ class Qwen2Model(Qwen2PreTrainedModel): lm_head ) - def prefill_1( + def first_prefill( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -587,7 +589,7 @@ class Qwen2Model(Qwen2PreTrainedModel): lm_head ) - def prefill_2( + def next_prefill( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, diff --git a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt index b2794ae2b9..c3cdf647ef 100755 --- a/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt +++ b/ACL_PyTorch/built-in/audio/CosyVoice/CosyVoice2/requirements.txt @@ -30,4 +30,4 @@ uvicorn==0.30.0 wget==3.2 fastapi==0.111.0 fastapi-cli==0.0.4 -WeTextProcessing==1.0.3 \ No newline at end of file +WeTextProcessing==1.0.4.1 \ No newline at end of file -- Gitee