From c903bf1ac6ee2163b84cc94335b58e6a2c2b66cd Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Fri, 18 Apr 2025 16:34:46 +0800 Subject: [PATCH 1/3] add_comfyui --- .../Stable-Audio-Open-1.0/.gitattributes | 35 + .../Stable-Audio-Open-1.0/README.md | 165 ++++ .../inference_stableaudio.py | 161 ++++ .../Stable-Audio-Open-1.0/prompts/prompts.txt | 6 + .../Stable-Audio-Open-1.0/requirements.txt | 6 + .../stableaudio/__init__.py | 4 + .../stableaudio/layers/__init__.py | 4 + .../stableaudio/layers/attention.py | 373 ++++++++ .../stableaudio/layers/linear.py | 99 ++ .../stableaudio/models/__init__.py | 1 + .../models/stable_audio_transformer.py | 431 +++++++++ .../stableaudio/pipeline/__init__.py | 1 + .../pipeline/pipeline_stable_audio.py | 754 +++++++++++++++ .../stableaudio/schedulers/__init__.py | 1 + .../scheduling_cosine_dpmsolver_multistep.py | 574 +++++++++++ .../schedulers/scheduling_dpmsolver_sde.py | 71 ++ .../schedulers/scheduling_utils.py | 893 ++++++++++++++++++ 17 files changed, 3579 insertions(+) create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/.gitattributes create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/README.md create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/inference_stableaudio.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/prompts/prompts.txt create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/requirements.txt create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/attention.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/linear.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/stable_audio_transformer.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/pipeline_stable_audio.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_dpmsolver_sde.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_utils.py diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/.gitattributes b/MindIE/MultiModal/Stable-Audio-Open-1.0/.gitattributes new file mode 100644 index 0000000000..a6344aac8c --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md b/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md new file mode 100644 index 0000000000..2c536a6a3e --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md @@ -0,0 +1,165 @@ +--- +pipeline_tag: text-to-audio +frameworks: + - PyTorch +license: apache-2.0 +library_name: openmind +hardwares: + - NPU +language: + - en +--- +## 一、准备运行环境 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +### 1.1 获取CANN&MindIE安装包&环境准备 +- 设备支持 +Atlas 800I A2推理设备:支持的卡数为1 +Atlas 300I Duo推理卡:支持的卡数为1 +- [Atlas 800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32) +- [Atlas 300I Duo](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=2&model=17) +- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html) + +### 1.2 CANN安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。 +chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run +chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run +# 校验软件包安装文件的一致性和完整性 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --check +./Ascend-cann-kernels-{soc}_{version}_linux.run --check +# 安装 +./Ascend-cann-toolkit_{version}_linux-{arch}.run --install +./Ascend-cann-kernels-{soc}_{version}_linux.run --install + +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh +``` + + +### 1.3 MindIE安装 +```shell +# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。 +chmod +x ./Ascend-mindie_${version}_linux-${arch}.run +./Ascend-mindie_${version}_linux-${arch}.run --check + +# 方式一:默认路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install +# 设置环境变量 +cd /usr/local/Ascend/mindie && source set_env.sh + +# 方式二:指定路径安装 +./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath} +# 设置环境变量 +cd ${AieInstallPath}/mindie && source set_env.sh +``` + +### 1.4 Torch_npu安装 +安装pytorch框架 版本2.1.0 +[安装包下载](https://download.pytorch.org/whl/cpu/torch/) + +使用pip安装 +```shell +# {version}表示软件版本号,{arch}表示CPU架构。 +pip install torch-${version}-cp310-cp310-linux_${arch}.whl +``` +下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +```shell +tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz +# 解压后,会有whl包 +pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl +``` + +### 1.5 安装mindspeed +```shell +# 下载mindspeed源码仓: +git clone https://gitee.com/ascend/MindSpeed.git +# 执行如下命令进行安装: +pip install -e MindSpeed +``` + +## 二、下载本仓库 + +### 2.1 下载到本地 +```shell +git clone https://modelers.cn/MindIE/stable_audio_open_1.0.git +``` + +### 2.2 依赖安装 +```bash +pip3 install -r requirements.txt +apt-get update +apt-get install libsndfile1 +``` + +## 三、Stable-Audio-Open-1.0 使用 + +### 3.1 权重及配置文件说明 +stable-audio-open-1.0权重链接: +```shell +https://huggingface.co/stabilityai/stable-audio-open-1.0/tree/main +``` + +### 3.2 单卡功能测试 +设置权重路径 +```shell +model_base='./stable-audio-open-1.0' +``` +执行命令: +```shell +# 不使用DiTCache策略 +python3 inference_stableaudio.py \ + --model ${model_base} \ + --prompt_file ./prompts/prompts.txt \ + --num_inference_steps 100 \ + --audio_end_in_s 10 10 47 \ + --save_dir ./results \ + --seed 1 \ + --device 0 + +# 使用DiTCache策略 +python3 inference_stableaudio.py \ + --model ${model_base} \ + --prompt_file ./prompts/prompts.txt \ + --num_inference_steps 100 \ + --audio_end_in_s 10 10 47 \ + --save_dir ./results \ + --seed 1 \ + --device 0 \ + --use_cache +``` +参数说明: +- --model:模型权重路径。 +- --prompt_file:提示词文件。 +- --num_inference_steps: 语音生成迭代次数。 +- --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。 +- --save_dir:生成语音的存放目录。 +- --seed:设置随机种子,不指定时默认使用随机种子。 +- --device:推理设备ID。 +- --use_cache: 【可选】使用DiTCache策略。 + +执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 + +### 3.2 模型推理性能 + +性能参考下列数据。 + +| 硬件形态 | 迭代次数 | 平均耗时(w/o DiTCache)| 平均耗时(with DiTCache)| +| :------: |:----:|:----:|:----:| +| Atlas 800I A2(8*32G) | 100 | 5.645s | 4.991s | + +## 五、优化指南 +本模型使用的优化手段如下: +- 等价优化:FA、RoPE、Linear +- 算法优化:DiTcache + + +## 声明 +- 本代码仓提到的数据集和模型仅作为示例,这些数据集和模型仅供您用于非商业目的,如您使用这些数据集和模型来完成示例,请您特别注意应遵守对应数据集和模型的License,如您因使用数据集或模型而产生侵权纠纷,华为不承担任何责任。 +- 如您在使用本代码仓的过程中,发现任何问题(包括但不限于功能问题、合规问题),请在本代码仓提交issue,我们将及时审视并解答。 \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/inference_stableaudio.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/inference_stableaudio.py new file mode 100644 index 0000000000..d7b49fb36b --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/inference_stableaudio.py @@ -0,0 +1,161 @@ +import time +import json +import os +import argparse + +import torch +import torch_npu +import soundfile as sf +from safetensors.torch import load_file + +from stableaudio import ( + StableAudioPipeline, + AutoencoderOobleck, +) +from mindiesd import CacheConfig, CacheAgent + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts/prompts.txt", + help="The prompts file to guide audio generation.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="", + help="The prompt or prompts to guide what to not include in audio generation.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.", + ) + parser.add_argument( + "--model", + type=str, + default="./stable-audio-open-1.0", + help="The path of stable-audio-open-1.0.", + ) + parser.add_argument( + "--audio_end_in_s", + nargs='+', + default=[10], + help="Audio end index in seconds.", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device id.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result audio files.", + ) + parser.add_argument( + "--seed", + type=int, + default=-1, + help="Random seed, default 1.", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="turn on dit cache or not.", + ) + parser = add_attentioncache_args(parser) + return parser.parse_args() + + +def add_attentioncache_args(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Attention Cache args") + group.add_argument("--use_attentioncache", action='store_true') + group.add_argument("--attentioncache_ratio", type=float, default=1.4) + group.add_argument("--attentioncache_interval", type=int, default=5) + group.add_argument("--start_step", type=int, default=60) + group.add_argument("--end_step", type=int, default=97) + + return parser + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + npu_stream = torch_npu.npu.Stream() + torch_npu.npu.set_device(args.device) + if args.seed != -1: + torch.manual_seed(args.seed) + latents = torch.randn(1, 64, 1024, dtype=torch.float16, device="cpu").to("npu") + + with open(os.path.join(args.model, "vae", "config.json")) as f: + vae_config = json.load(f) + vae = AutoencoderOobleck.from_config(vae_config) + vae.load_state_dict(load_file(os.path.join(args.model, "vae", "diffusion_pytorch_model.safetensors"))) + + pipe = StableAudioPipeline.from_pretrained(args.model, vae=vae) + pipe.to(torch.float16).to("npu") + + transformer = wan_t2v.model + if args.use_attentioncache: + config = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.transformer_blocks), + steps_count=args.num_inference_steps, + step_start=args.start_step, + step_interval=args.attentioncache_interval, + step_end=args.end_step + ) + else: + config = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.transformer_blocks), + steps_count=args.num_inference_steps + ) + cache = CacheAgent(config) + for block in transformer.transformer_blocks: + block.cache = cache + + total_time = 0 + prompts_num = 0 + average_time = 0 + skip = 2 + with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f: + for i, prompt in enumerate(f): + with torch.no_grad(): + audio_end_in_s = float(args.audio_end_in_s[i]) if (len(args.audio_end_in_s) > i) else 10.0 + npu_stream.synchronize() + begin = time.time() + audio = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + num_inference_steps=args.num_inference_steps, + latents=latents, + audio_end_in_s=audio_end_in_s, + use_cache=args.use_cache, + ).audios + npu_stream.synchronize() + end = time.time() + if i > skip - 1: + total_time += end - begin + prompts_num = i + 1 + output = audio[0].T.float().cpu().numpy() + file_path = os.path.join(args.save_dir, f"audio_by_prompt{prompts_num}.wav") + sf.write(file_path, output, pipe.vae.sampling_rate) + if prompts_num > skip: + average_time = total_time / (prompts_num - skip) + else: + raise ValueError("Infer average time skip first two prompts, ensure that prompts.txt \ + contains more than three prompts") + print(f"Infer average time: {average_time:.3f}s\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/prompts/prompts.txt b/MindIE/MultiModal/Stable-Audio-Open-1.0/prompts/prompts.txt new file mode 100644 index 0000000000..523a9bcbcf --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/prompts/prompts.txt @@ -0,0 +1,6 @@ +Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, evolving, 135BPM. LOOP. +Uplifting acoustic loop. 120 BPM. +Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM. +Warm arpeggios on an analog synthesizer with a gradually rising filter cutoff and a reverb tail. +Blackbird song, summer, dusk in the forest. +Rock beat played in a treated studio, session drumming on an acoustic kit. \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/requirements.txt b/MindIE/MultiModal/Stable-Audio-Open-1.0/requirements.txt new file mode 100644 index 0000000000..6353806d3f --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/requirements.txt @@ -0,0 +1,6 @@ +torch==2.1.0 +torchsde==0.2.6 +diffusers==0.30.0 +transformers==4.40.0 +soundfile==0.12.1 +torch_npu==2.1.0.post6 \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/__init__.py new file mode 100644 index 0000000000..f484d3fceb --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/__init__.py @@ -0,0 +1,4 @@ +from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck +from .pipeline import StableAudioPipeline, StableAudioProjectionModel +from .models import StableAudioDiTModel, ModelMixin +from .schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler, SchedulerMixin \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/__init__.py new file mode 100644 index 0000000000..f2f044203d --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/__init__.py @@ -0,0 +1,4 @@ +from .attention import ( + Attention, + StableAudioAttnProcessor2_0, +) \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/attention.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/attention.py new file mode 100644 index 0000000000..aec5aa7ca0 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/attention.py @@ -0,0 +1,373 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch_npu +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import logging +from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 +from mindiesd import attention_forward, rotary_position_embedding +from .linear import QKVLinear + +logger = logging.get_logger(__name__) + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + ): + super().__init__() + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_qkv = QKVLinear( + attention_dim=query_dim, + hidden_size=self.inner_dim, + qkv_bias=self.use_bias, + cross_attention_dim=cross_attention_dim, + cross_hidden_size=self.inner_kv_dim, + ) + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + +class StableAudioAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + query, key, value = attn.to_qkv(hidden_states, encoder_hidden_states) + + head_dim = attn.dim_head + kv_heads = attn.inner_kv_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = key.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1) + key = key.view(batch_size, sequence_length, attn.heads, head_dim) + value = value.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1) + value = value.view(batch_size, sequence_length, attn.heads, head_dim) + + # Apply RoPE if needed + if rotary_emb is not None: + cos, sin = rotary_emb + query_to_rotate, query_unrotated = torch.chunk(query, 2, 3) + query_rotated = rotary_position_embedding(query_to_rotate, cos, sin, rotated_mode="rotated_half", fused=True) + query = torch.cat((query_rotated, query_unrotated), dim=-1) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = torch.chunk(key, 2, 3) + key_rotated = rotary_position_embedding(key_to_rotate, cos, sin, rotated_mode="rotated_half", fused=True) + key = torch.cat((key_rotated, key_unrotated), dim=-1) + + hidden_states = attention_forward( + query, key, value, + opt_mode="manual", + op_type="fused_attn_score", + layout="BSND" + ) + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/linear.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/linear.py new file mode 100644 index 0000000000..3e8ee84fc8 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/linear.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch_npu + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, cross_attention_dim=None, cross_hidden_size=None, + device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + + self.cross_attention_dim = cross_attention_dim + self.cross_hidden_size = self.hidden_size if cross_hidden_size is None else cross_hidden_size + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + if cross_attention_dim is None: + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + else: + self.q_weight = nn.Parameter(torch.empty([self.attention_dim, self.hidden_size], **factory_kwargs)) + self.kv_weight = nn.Parameter( + torch.empty([self.cross_attention_dim, 2 * self.cross_hidden_size], **factory_kwargs)) + + if self.qkv_bias: + self.q_bias = nn.Parameter(torch.empty([self.hidden_size], **factory_kwargs)) + self.kv_bias = nn.Parameter(torch.empty([2 * self.cross_hidden_size], **factory_kwargs)) + + def forward(self, hidden_states, encoder_hidden_states=None): + + if self.cross_attention_dim is None: + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + + batch, seqlen, _ = hidden_states.shape + qkv_shape = (batch, seqlen, 3, -1) + qkv = qkv.view(qkv_shape) + q, k, v = qkv.unbind(2) + + else: + if not self.qkv_bias: + q = torch.matmul(hidden_states, self.q_weight) + kv = torch.matmul(encoder_hidden_states, self.kv_weight) + else: + q = torch.addmm( + self.q_bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.q_weight, + beta=1, + alpha=1 + ) + kv = torch.addmm( + self.kv_bias, + encoder_hidden_states.view( + encoder_hidden_states.size(0) * encoder_hidden_states.size(1), + encoder_hidden_states.size(2)), + self.kv_weight, + beta=1, + alpha=1 + ) + + batch, q_seqlen, _ = hidden_states.shape + q = q.view(batch, q_seqlen, -1) + + batch, kv_seqlen, _ = encoder_hidden_states.shape + kv_shape = (batch, kv_seqlen, 2, -1) + + kv = kv.view(kv_shape) + k, v = kv.unbind(2) + + return q, k, v diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/__init__.py new file mode 100644 index 0000000000..c4170f5af4 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/__init__.py @@ -0,0 +1 @@ +from .stable_audio_transformer import StableAudioDiTModel, ModelMixin \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/stable_audio_transformer.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/stable_audio_transformer.py new file mode 100644 index 0000000000..aeb06b2ade --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/stable_audio_transformer.py @@ -0,0 +1,431 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Union +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging + +from ..layers.attention import ( + Attention, + StableAudioAttnProcessor2_0, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +class StableAudioDiTBlock(nn.Module): + r""" + Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip + connection and QKNorm + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for the query states. + num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ): + super().__init__() + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = nn.LayerNorm(dim, norm_eps, True) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_key_value_attention_heads, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, True) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn="swiglu", + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.cache = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + rotary_embedding: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.cache.apply( + self.attn1, + norm_hidden_states, + attention_mask=attention_mask, + rotary_emb=rotary_embedding, + ) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class StableAudioDiTModel(ModelMixin, ConfigMixin): + """ + The Diffusion Transformer model introduced in Stable Audio. + + Reference: https://github.com/Stability-AI/stable-audio-tools + + Parameters: + sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. + in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. + num_key_value_attention_heads (`int`, *optional*, defaults to 12): + The number of heads to use for the key and value states. + out_channels (`int`, defaults to 64): Number of output channels. + cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. + time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. + global_states_input_dim ( `int`, *optional*, defaults to 1536): + Input dimension of the global hidden states projection. + cross_attention_input_dim ( `int`, *optional*, defaults to 768): + Input dimension of the cross-attention projection + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + + self.cache_block_start = 11 + self.cache_step_interval = 2 + self.cache_num_blocks = 9 + self.cache_step_start = 5 + + self.num_layers = num_layers + + self.sample_size = sample_size + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.init_dtype = self.dtype + self.time_proj = StableAudioGaussianFourierProjection( + embedding_size=time_proj_dim // 2, + flip_sin_to_cos=True, + log=False, + set_W_to_weight=False, + ) + + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for i in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) + self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + step_id, + hidden_states: torch.FloatTensor, + timestep: torch.LongTensor = None, + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, + return_dict: bool = True, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`StableAudioDiTModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): + Input `hidden_states`. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): + Global embeddings that will be prepended to the hidden states. + rotary_embedding (`torch.Tensor`): + The rotary embeddings to apply on query and key tensors during attention calculation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention + masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating + the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + global_hidden_states = self.global_proj(global_hidden_states) + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.init_dtype))) + + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.proj_in(hidden_states) + + # prepend global states to hidden states + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) + if attention_mask is not None: + prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + if not use_cache or (use_cache and step_id < self.cache_step_start): + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=0, + end_id=self.num_layers, + ) + else: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=0, + end_id=self.cache_block_start, + ) + + cache_end = np.minimum(self.cache_block_start + self.cache_num_blocks, self.num_layers) + hidden_states_pre_cache = hidden_states.clone() + if (step_id - self.cache_step_start) % self.cache_step_interval == 0: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=self.cache_block_start, + end_id=cache_end, + ) + self.delta_cache = hidden_states - hidden_states_pre_cache + else: + hidden_states = hidden_states_pre_cache + self.delta_cache + + if cache_end < self.num_layers: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=cache_end, + end_id=self.num_layers, + ) + + hidden_states = self.proj_out(hidden_states) + + # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) + # remove prepend length that has been added by global hidden states + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) + + def _transformer_blocks_forward(self, hidden_states, encoder_hidden_states, rotary_embedding, start_id, end_id): + for block in self.transformer_blocks[start_id: end_id]: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_embedding=rotary_embedding, + ) + return hidden_states + + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) ->None: + for i in range(self.num_layers): + self_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + self_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + self_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + self_qkv_weight = torch.cat([self_q_weight, self_k_weight, self_v_weight], dim=0).transpose(0, 1).contiguous() + state_dict[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = self_qkv_weight + + cross_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_q.weight", None) + cross_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_k.weight", None) + cross_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_v.weight", None) + cross_q_weight = cross_q_weight.transpose(0, 1).contiguous() + cross_kv_weight = torch.cat([cross_k_weight, cross_v_weight], dim=0).transpose(0, 1).contiguous() + state_dict[f"transformer_blocks.{i}.attn2.to_qkv.q_weight"] = cross_q_weight + state_dict[f"transformer_blocks.{i}.attn2.to_qkv.kv_weight"] = cross_kv_weight \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/__init__.py new file mode 100644 index 0000000000..a283da8ecd --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/__init__.py @@ -0,0 +1 @@ +from .pipeline_stable_audio import StableAudioPipeline, StableAudioProjectionModel \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/pipeline_stable_audio.py new file mode 100644 index 0000000000..3abd5ecc9d --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/pipeline_stable_audio.py @@ -0,0 +1,754 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +from transformers import ( + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, +) + +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck +from diffusers.utils import ( + logging, + replace_example_docstring, +) + +from ..models import StableAudioDiTModel +from ..schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> import soundfile as sf + >>> from diffusers import StableAudioPipeline + + >>> repo_id = "stabilityai/stable-audio-open-1.0" + >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_end_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> output = audio[0].T.float().cpu().numpy() + >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate) + ``` +""" + + +class StableAudioPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using StableAudio. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderOobleck`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.T5EncoderModel`]): + Frozen text-encoder. StableAudio uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant. + projection_model ([`StableAudioProjectionModel`]): + A trained model used to linearly project the hidden-states from the text encoder model and the start and + end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to + give the input to the transformer model. + tokenizer ([`~transformers.T5Tokenizer`]): + Tokenizer to tokenize text for the frozen text-encoder. + transformer ([`StableAudioDiTModel`]): + A `StableAudioDiTModel` to denoise the encoded audio latents. + scheduler ([`CosineDPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. + """ + + model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae" + + def __init__( + self, + vae: AutoencoderOobleck, + text_encoder: T5EncoderModel, + projection_model: StableAudioProjectionModel, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + transformer: StableAudioDiTModel, + scheduler: CosineDPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + projection_model=projection_model, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # 1. Tokenize text + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because {self.text_encoder.config.model_type} can " + f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if do_classifier_free_guidance and negative_prompt is not None: + uncond_tokens: List[str] + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # 1. Tokenize text + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if negative_attention_mask is not None: + # set the masked tokens to the null embed + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0 + ) + + # 3. Project prompt_embeds and negative_prompt_embeds + if do_classifier_free_guidance and negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the negative and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance, + batch_size, + ): + audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s] + audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s] + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + # Cast the inputs to floats + audio_start_in_s = [float(x) for x in audio_start_in_s] + audio_start_in_s = torch.tensor(audio_start_in_s).to(device) + + audio_end_in_s = [float(x) for x in audio_end_in_s] + audio_end_in_s = torch.tensor(audio_end_in_s).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + # For classifier free guidance, we need to do two forward passes. + # Here we repeat the audio hidden states to avoid doing two forward passes + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + initial_audio_waveforms=None, + initial_audio_sampling_rate=None, + ): + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but " + ) + + if ( + audio_start_in_s < self.projection_model.config.min_value + or audio_start_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_start_in_s}." + ) + + if ( + audio_end_in_s < self.projection_model.config.min_value + or audio_end_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_end_in_s}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds`. Cannot leave" + "`prompt` undefined without specifying `prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if initial_audio_sampling_rate is None and initial_audio_waveforms is not None: + raise ValueError( + "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`." + ) + + if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate: + raise ValueError( + f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`." + "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. " + ) + + def prepare_latents( + self, + batch_size, + num_channels_vae, + sample_size, + dtype, + device, + generator, + latents=None, + initial_audio_waveforms=None, + num_waveforms_per_prompt=None, + audio_channels=None, + ): + shape = (batch_size, num_channels_vae, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # encode the initial audio for use by the model + if initial_audio_waveforms is not None: + # check dimension + if initial_audio_waveforms.ndim == 2: + initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1) + elif initial_audio_waveforms.ndim != 3: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" + ) + + audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length + audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) + + # check num_channels + if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2: + initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1) + elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1: + initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True) + + if initial_audio_waveforms.shape[:2] != audio_shape[:2]: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`" + ) + + # crop or pad + audio_length = initial_audio_waveforms.shape[-1] + if audio_length < audio_vae_length: + logger.warning( + f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded." + ) + elif audio_length > audio_vae_length: + logger.warning( + f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped." + ) + + audio = initial_audio_waveforms.new_zeros(audio_shape) + audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] + + encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) + encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) + latents = encoded_audio + latents + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_end_in_s: Optional[float] = None, + audio_start_in_s: Optional[float] = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + initial_audio_waveforms: Optional[torch.Tensor] = None, + initial_audio_sampling_rate: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + output_type: Optional[str] = "pt", + use_cache: Optional[bool] = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_end_in_s (`float`, *optional*, defaults to 47.55): + Audio end index in seconds. + audio_start_in_s (`float`, *optional*, defaults to 0): + Audio start index in seconds. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + initial_audio_waveforms (`torch.Tensor`, *optional*): + Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape + `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` + corresponds to the number of prompts passed to the model. + initial_audio_sampling_rate (`int`, *optional*): + Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to latent length + downsample_ratio = self.vae.hop_length + + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + initial_audio_waveforms, + initial_audio_sampling_rate, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create text_audio_duration_embeds and audio_duration_embeds + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 + ) + + audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) + + # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and + # to concatenate it to the embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like( + text_audio_duration_embeds, device=text_audio_duration_embeds.device + ) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 + ) + audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) + + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + initial_audio_waveforms, + num_waveforms_per_prompt, + audio_channels=self.vae.config.audio_channels, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary positional embedding + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + + cos = rotary_embedding[0][None, :, None, :].to(latents.device).to(torch.float16) + sin = rotary_embedding[1][None, :, None, :].to(latents.device).to(torch.float16) + rotary_embedding = (cos, sin) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + i, + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + use_cache=use_cache, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + audio = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = audio[:, :, waveform_start:waveform_end] + + if output_type == "np": + audio = audio.cpu().float().numpy() + + self.maybe_free_model_hooks() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/__init__.py new file mode 100644 index 0000000000..5bad3d9b15 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/__init__.py @@ -0,0 +1 @@ +from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py new file mode 100644 index 0000000000..d4e7805d13 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -0,0 +1,574 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler + + +class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021). + This scheduler was used in Stable Audio Open [1]. + + [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.3): + Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1]. + sigma_max (`float`, *optional*, defaults to 500): + Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1]. + sigma_data (`float`, *optional*, defaults to 1.0): + The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1]. + sigma_schedule (`str`, *optional*, defaults to `exponential`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`. + prediction_type (`str`, defaults to `v_prediction`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.3, + sigma_max: float = 500, + sigma_data: float = 1.0, + sigma_schedule: str = "exponential", + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "v_prediction", + rho: float = 7.0, + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + ramp = torch.linspace(0, 1, num_train_timesteps) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + # setable values + self.num_inference_steps = None + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + return sigma.atan() / math.pi * 2 + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + self.num_inference_steps = num_inference_steps + + ramp = torch.linspace(0, 1, self.num_inference_steps) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + sigmas = sigmas.to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.config.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # if a noise sampler is used, reinitialise it + self.noise_sampler = None + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 + sigma_t = sigma + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + sigma = self.sigmas[self.step_index] + x0_pred = self.precondition_outputs(sample, model_output, sigma) + + return x0_pred + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if noise is None: + raise ValueError("noise must not be None") + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + # sde-dpmsolver++ + if noise is None: + raise ValueError("noise must not be None") + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.noise_sampler is None: + seed = None + if generator is not None: + seed = ( + [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() + ) + self.noise_sampler = BrownianTreeNoiseSampler( + model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed + ) + noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( + model_output.device + ) + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_dpmsolver_sde.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_dpmsolver_sde.py new file mode 100644 index 0000000000..8533b082f5 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_dpmsolver_sde.py @@ -0,0 +1,71 @@ +# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from .scheduling_utils import BrownianTree + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get("w0", torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + if len(seed) == x.shape[0]: + raise ValueError(f"len(seed) should equal to x.shape[0], but got{len(seed)}") + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each + with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_utils.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_utils.py new file mode 100644 index 0000000000..e4734b052f --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_utils.py @@ -0,0 +1,893 @@ + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +import trampoline + +import numpy as np +import torch + +from torchsde._brownian import brownian_base +from torchsde.settings import LEVY_AREA_APPROXIMATIONS +from torchsde.types import Scalar, Optional, Tuple, Union, Tensor + +_rsqrt3 = 1 / math.sqrt(3) +_r12 = 1 / 12 + + +def _randn(size, dtype, device, seed): + generator = torch.Generator(device).manual_seed(int(seed)) + return torch.randn(size, dtype=dtype, device=device, generator=generator) + + +def _is_scalar(x): + return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1) + + +def _assert_floating_tensor(name, tensor): + if not torch.is_tensor(tensor): + raise ValueError(f"{name}={tensor} should be a Tensor.") + if not tensor.is_floating_point(): + raise ValueError(f"{name}={tensor} should be floating point.") + + +def _check_tensor_info(*tensors, size, dtype, device): + """Check if sizes, dtypes, and devices of input tensors all match prescribed values.""" + tensors = list(filter(torch.is_tensor, tensors)) + + if dtype is None and len(tensors) == 0: + dtype = torch.get_default_dtype() + if device is None and len(tensors) == 0: + device = torch.device("cpu") + + sizes = [] if size is None else [size] + sizes += [t.shape for t in tensors] + + dtypes = [] if dtype is None else [dtype] + dtypes += [t.dtype for t in tensors] + + devices = [] if device is None else [device] + devices += [t.device for t in tensors] + + if len(sizes) == 0: + raise ValueError("Must either specify `size` or pass in `W` or `H` to implicitly define the size.") + + if not all(i == sizes[0] for i in sizes): + raise ValueError("Multiple sizes found. Make sure `size` and `W` or `H` are consistent.") + if not all(i == dtypes[0] for i in dtypes): + raise ValueError("Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.") + if not all(i == devices[0] for i in devices): + raise ValueError("Multiple devices found. Make sure `device` and `W` or `H` are consistent.") + + # Make sure size is a tuple (not a torch.Size) for neat repr-printing purposes. + return tuple(sizes[0]), dtypes[0], devices[0] + + +def _davie_foster_approximation(W, H, h, levy_area_approximation, get_noise): + if levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.none, LEVY_AREA_APPROXIMATIONS.space_time): + return None + elif W.ndimension() in (0, 1): + # If we have zero or one dimensions then treat the scalar / single dimension we have as batch, so that the + # Brownian motion is one dimensional and the Levy area is zero. + return torch.zeros_like(W) + else: + # Davie's approximation to the Levy area from space-time Levy area + A = H.unsqueeze(-1) * W.unsqueeze(-2) - W.unsqueeze(-1) * H.unsqueeze(-2) + noise = get_noise() + noise = noise - noise.transpose(-1, -2) # noise is skew symmetric of variance 2 + if levy_area_approximation == LEVY_AREA_APPROXIMATIONS.foster: + # Foster's additional correction to Davie's approximation + tenth_h = 0.1 * h + H_squared = H ** 2 + std = (tenth_h * (tenth_h + H_squared.unsqueeze(-1) + H_squared.unsqueeze(-2))).sqrt() + else: # davie approximation + std = math.sqrt(_r12 * h ** 2) + a_tilde = std * noise + A += a_tilde + return A + + +def _H_to_U(W: torch.Tensor, H: torch.Tensor, h: float) -> torch.Tensor: + return h * (.5 * W + H) + + +class _EmptyDict: + def __setitem__(self, key, value): + pass + + def __getitem__(self, item): + raise KeyError + + +class _LRUDict(dict): + def __init__(self, max_size): + super().__init__() + self._max_size = max_size + self._keys = [] + + def __setitem__(self, key, value): + if key in self: + self._keys.remove(key) + elif len(self) >= self._max_size: + del self[self._keys.pop(0)] + super().__setitem__(key, value) + self._keys.append(key) + + +class _Interval: + # Intervals correspond to some subinterval of the overall interval [t0, t1]. + # They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left + # and right subintervals, which partition the parent interval. + + __slots__ = ( + # These are the things that every interval has + '_start', + '_end', + '_parent', + '_is_left', + '_top', + # These are the things that intervals which are parents also have + '_midway', + '_spawn_key', + '_depth', + '_W_seed', + '_H_seed', + '_left_a_seed', + '_right_a_seed', + '_left_child', + '_right_child') + + def __init__(self, start, end, parent, is_left, top): + self._start = top._round(start) # the left hand edge of the interval + self._end = top._round(end) # the right hand edge of the interval + self._parent = parent # our parent interval + self._is_left = is_left # are we the left or right child of our parent + self._top = top # the top-level BrownianInterval, where we cache certain state + self._midway = None # The point at which we split between left and right subintervals + + ######################################## + # Calculate increments and levy area # + ######################################## + # + # This is a little bit convoluted, so here's an explanation. + # + # The entry point is _increment_and_levy_area, below. This immediately calls _increment_and_space_time_levy_area, + # applies the space-time to full Levy area correction, and then returns. + # + # _increment_and_space_time_levy_area in turn calls a central LRU cache, as (later on) we'll need the increment and + # space-time Levy area of the parent interval to compute our own increment and space-time Levy area, and it's likely + # that our parent exists in the cache, as if we're being queried then our parent was probably queried recently as + # well. + # (The top-level BrownianInterval overrides _increment_and_space_time_levy_area to return its own increment and + # space-time Levy area, effectively holding them permanently in the cache.) + # + # If the request isn't found in the LRU cache then it computes it from its parent. + # Now it turns out that the size of our increment and space-time Levy area is really most naturally thought of as a + # property of our parent: it depends on our parent's increment, space-time Levy area, and whether we are the left or + # right interval within our parent. So _increment_and_space_time_levy_area in turn checks if we are on the + # left or right of our parent and does most of the computation using the parent's attributes. + + def _increment_and_levy_area(self): + W, H = trampoline.trampoline(self._increment_and_space_time_levy_area()) + A = _davie_foster_approximation(W, H, self._end - self._start, self._top._levy_area_approximation, + self._randn_levy) + return W, H, A + + def _increment_and_space_time_levy_area(self): + try: + return self._top._increment_and_space_time_levy_area_cache[self] + except KeyError: + parent = self._parent + + W, H = yield parent._increment_and_space_time_levy_area() + h_reciprocal = 1 / (parent._end - parent._start) + left_diff = parent._midway - parent._start + right_diff = parent._end - parent._midway + + if self._top._have_H: + left_diff_squared = left_diff ** 2 + right_diff_squared = right_diff ** 2 + left_diff_cubed = left_diff * left_diff_squared + right_diff_cubed = right_diff * right_diff_squared + + v = 0.5 * math.sqrt(left_diff * right_diff / (left_diff_cubed + right_diff_cubed)) + + a = v * left_diff_squared * h_reciprocal + b = v * right_diff_squared * h_reciprocal + c = v * _rsqrt3 + + X1 = parent._randn(parent._W_seed) + X2 = parent._randn(parent._H_seed) + + third_coeff = 2 * (a * left_diff + b * right_diff) * h_reciprocal + + if self._is_left: + first_coeff = left_diff * h_reciprocal + second_coeff = 6 * first_coeff * right_diff * h_reciprocal + out_W = first_coeff * W + second_coeff * H + third_coeff * X1 + out_H = first_coeff ** 2 * H - a * X1 + c * right_diff * X2 + else: + first_coeff = right_diff * h_reciprocal + second_coeff = 6 * first_coeff * left_diff * h_reciprocal + out_W = first_coeff * W - second_coeff * H - third_coeff * X1 + out_H = first_coeff ** 2 * H - b * X1 - c * left_diff * X2 + else: + # Don't compute space-time Levy area unless we need to + + mean = left_diff * h_reciprocal * W + var = left_diff * right_diff * h_reciprocal + noise = parent._randn(parent._W_seed) + left_W = mean + math.sqrt(var) * noise + + if self._is_left: + out_W = left_W + else: + out_W = W - left_W + out_H = None + + self._top._increment_and_space_time_levy_area_cache[self] = (out_W, out_H) + return out_W, out_H + + def _randn(self, seed): + # We generate random noise deterministically wrt some seed; this seed is determined by the generator. + # This means that if we drop out of the cache, then we'll create the same random noise next time, as we still + # have the generator. + size = self._top._size + return _randn(size, self._top._dtype, self._top._device, seed) + + def _a_seed(self): + return self._parent._left_a_seed if self._is_left else self._parent._right_a_seed + + def _randn_levy(self): + size = (*self._top._size, *self._top._size[-1:]) + return _randn(size, self._top._dtype, self._top._device, self._a_seed()) + + ######################################## + # Locate an interval in the hierarchy # + ######################################## + # + # The other important piece of this construction is a way to locate any given interval within the binary tree + # hierarchy. (This is typically the slightly slower part, actually, so if you want to speed things up then this is + # the bit to target.) + # + # loc finds the interval [ta, tb] - and creates it in the appropriate place (as a child of some larger interval) if + # it doesn't already exist. As in principle we may request an interval that covers multiple existing intervals, then + # in fact the interval [ta, tb] is returned as an ordered list of existing subintervals. + # + # It calls _loc, which operates recursively. See _loc for more details on how the search works. + + def _loc(self, ta, tb): + out = [] + ta = self._top._round(ta) + tb = self._top._round(tb) + trampoline.trampoline(self._loc_inner(ta, tb, out)) + return out + + def _loc_inner(self, ta, tb, out): + # Expect to have ta < tb + + # First, we (this interval) only have jurisdiction over [self._start, self._end]. So if we're asked for + # something outside of that then we pass the buck up to our parent, who is strictly larger. + if ta < self._start or tb > self._end: + raise trampoline.TailCall(self._parent._loc_inner(ta, tb, out)) + + # If it's us that's being asked for, then we add ourselves on to out and return. + if ta == self._start and tb == self._end: + out.append(self) + return + + # If we've got this far then we know that it's an interval that's within our jurisdiction, and that it's not us. + # So next we check if it's up to us to figure out, or up to our children. + if self._midway is None: + # It's up to us. Create subintervals (_split) if appropriate. + if ta == self._start: + self._split(tb) + raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out)) + # implies ta > self._start + self._split(ta) + # Query our (newly created) right_child: if tb == self._end then our right child will be the result, and it + # will tell us so. But if tb < self._end then our right_child will need to make another split of its own. + raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out)) + + # If we're here then we have children: self._midway is not None + if tb <= self._midway: + # Strictly our left_child's problem + raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out)) + if ta >= self._midway: + # Strictly our right_child's problem + raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out)) + # It's a problem for both of our children: the requested interval overlaps our midpoint. Call the left_child + # first (to append to out in the correct order), then call our right child. + # (Implies ta < self._midway < tb) + yield self._left_child._loc_inner(ta, self._midway, out) + raise trampoline.TailCall(self._right_child._loc_inner(self._midway, tb, out)) + + def _set_spawn_key_and_depth(self): + self._spawn_key = 2 * self._parent._spawn_key + (0 if self._is_left else 1) + self._depth = self._parent._depth + 1 + + def _split(self, midway): + if self._top._halfway_tree: + self._split_exact(0.5 * (self._end + self._start)) + # self._midway is now the rounded halfway point. + if midway > self._midway: + self._right_child._split(midway) + elif midway < self._midway: + self._left_child._split(midway) + else: + self._split_exact(midway) + + def _split_exact(self, midway): # Create two children + self._midway = self._top._round(midway) + # Use splittable PRNGs to generate noise. + self._set_spawn_key_and_depth() + generator = np.random.SeedSequence(entropy=self._top._entropy, + spawn_key=(self._spawn_key, self._depth), + pool_size=self._top._pool_size) + self._W_seed, self._H_seed, self._left_a_seed, self._right_a_seed = generator.generate_state(4) + + self._left_child = _Interval(start=self._start, + end=midway, + parent=self, + is_left=True, + top=self._top) + self._right_child = _Interval(start=midway, + end=self._end, + parent=self, + is_left=False, + top=self._top) + + +class BrownianInterval(brownian_base.BaseBrownian, _Interval): + """Brownian interval with fixed entropy. + + Computes increments (and optionally Levy area). + + To use: + >>> bm = BrownianInterval(t0=0.0, t1=1.0, size=(4, 1), device='cuda') + >>> bm(0., 0.5) + tensor([[ 0.0733], + [-0.5692], + [ 0.1872], + [-0.3889]], device='cuda:0') + """ + + __slots__ = ( + # Inputs + '_size', + '_dtype', + '_device', + '_entropy', + '_levy_area_approximation', + '_dt', + '_tol', + '_pool_size', + '_cache_size', + '_halfway_tree', + # Quantisation + '_round', + # Caching, searching and computing values + '_increment_and_space_time_levy_area_cache', + '_last_interval', + '_have_H', + '_have_A', + '_w_h', + '_top_a_seed', + # Dependency tree creation + '_average_dt', + '_tree_dt', + '_num_evaluations' + ) + + def __init__(self, + t0: Optional[Scalar] = 0., + t1: Optional[Scalar] = 1., + size: Optional[Tuple[int, ...]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, + entropy: Optional[int] = None, + dt: Optional[Scalar] = None, + tol: Scalar = 0., + pool_size: int = 8, + cache_size: Optional[int] = 45, + halfway_tree: bool = False, + levy_area_approximation: str = LEVY_AREA_APPROXIMATIONS.none, + W: Optional[Tensor] = None, + H: Optional[Tensor] = None): + """Initialize the Brownian interval. + + Args: + t0 (float or Tensor): Initial time. + t1 (float or Tensor): Terminal time. + size (tuple of int): The shape of each Brownian sample. + If zero dimensional represents a scalar Brownian motion. + If one dimensional represents a batch of scalar Brownian motions. + If >two dimensional the last dimension represents the size of a + a multidimensional Brownian motion, and all previous dimensions + represent batch dimensions. + dtype (torch.dtype): The dtype of each Brownian sample. + Defaults to the PyTorch default. + device (str or torch.device): The device of each Brownian sample. + Defaults to the CPU. + entropy (int): Global seed, defaults to `None` for random entropy. + levy_area_approximation (str): Whether to also approximate Levy + area. Defaults to 'none'. Valid options are 'none', + 'space-time', 'davie' or 'foster', corresponding to different + approximation types. + This is needed for some higher-order SDE solvers. + dt (float or Tensor): The expected average step size of the SDE + solver. Set it if you know it (e.g. when using a fixed-step + solver); else it will be estimated from the first few queries. + This is used to set up the data structure such that it is + efficient to query at these intervals. + tol (float or Tensor): What tolerance to resolve the Brownian motion + to. Must be non-negative. Defaults to zero, i.e. floating point + resolution. Usually worth setting in conjunction with + `halfway_tree`, below. + pool_size (int): Size of the pooled entropy. If you care about + statistical randomness then increasing this will help (but will + slow things down). + cache_size (int): How big a cache of recent calculations to use. + (As new calculations depend on old calculations, this speeds + things up dramatically, rather than recomputing things.) + Set this to `None` to use an infinite cache, which will be fast + but memory inefficient. + halfway_tree (bool): Whether the dependency tree (the internal data + structure) should be the dyadic tree. Defaults to `False`. + Normally, the sample path is determined by both `entropy`, + _and_ the locations and order of the query points. Setting this + to `True` will make it deterministic with respect to just + `entropy`; however this is much slower. + W (Tensor): The increment of the Brownian motion over the interval + [t0, t1]. Will be generated randomly if not provided. + H (Tensor): The space-time Levy area of the Brownian motion over the + interval [t0, t1]. Will be generated randomly if not provided. + """ + + ##################################### + # Check and normalise inputs # + ##################################### + + if not _is_scalar(t0): + raise ValueError('Initial time t0 should be a float or 0-d torch.Tensor.') + if not _is_scalar(t1): + raise ValueError('Terminal time t1 should be a float or 0-d torch.Tensor.') + if dt is not None and not _is_scalar(dt): + raise ValueError('Expected average time step dt should be a float or 0-d torch.Tensor.') + + if t0 > t1: + raise ValueError(f'Initial time {t0} should be less than terminal time {t1}.') + t0 = float(t0) + t1 = float(t1) + if dt is not None: + dt = float(dt) + + if halfway_tree: + if tol <= 0.: + raise ValueError("`tol` should be positive.") + if dt is not None: + raise ValueError("`dt` is not used and should be set to `None` if `halfway_tree` is True.") + else: + if tol < 0.: + raise ValueError("`tol` should be non-negative.") + + size, dtype, device = _check_tensor_info(W, H, size=size, dtype=dtype, device=device) + + # Let numpy dictate randomness, so we have fewer seeds to set for reproducibility. + if entropy is None: + entropy = np.random.randint(0, 2 ** 31 - 1) + + if levy_area_approximation not in LEVY_AREA_APPROXIMATIONS: + raise ValueError(f"`levy_area_approximation` must be one of {LEVY_AREA_APPROXIMATIONS}, but got " + f"'{levy_area_approximation}'.") + + ##################################### + # Record inputs # + ##################################### + + self._size = size + self._dtype = dtype + self._device = device + self._entropy = entropy + self._levy_area_approximation = levy_area_approximation + self._dt = dt + self._tol = tol + self._pool_size = pool_size + self._cache_size = cache_size + self._halfway_tree = halfway_tree + + ##################################### + # A miscellany of other things # + ##################################### + + # We keep a cache of recent queries, and their results. This is very important for speed, so that we don't + # recurse all the way up to the top every time we have a query. + if cache_size is None: + self._increment_and_space_time_levy_area_cache = {} + elif cache_size == 0: + self._increment_and_space_time_levy_area_cache = _EmptyDict() + else: + self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size) + + # We keep track of the most recently queried interval, and start searching for the next interval from that + # element of the binary tree. This is because subsequent queries are likely to be near the most recent query. + self._last_interval = self + + # Precompute these as we don't want to spend lots of time checking strings in hot loops. + self._have_H = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.space_time, + LEVY_AREA_APPROXIMATIONS.davie, + LEVY_AREA_APPROXIMATIONS.foster) + self._have_A = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.davie, + LEVY_AREA_APPROXIMATIONS.foster) + + # If we like we can quantise what level we want to compute the Brownian motion to. + if math.isclose(tol, 0., rel_tol=1e-5): + def round_func(x): + return x + else: + ndigits = -int(math.log10(tol)) + + def round_func(x): + return round(x, ndigits) + + self._round = round_func + + # Initalise as _Interval. + # (Must come after _round but before _w_h) + super(BrownianInterval, self).__init__(start=t0, + end=t1, + parent=None, + is_left=None, + top=self) + + # Set the global increment and space-time Levy area + generator = np.random.SeedSequence(entropy=entropy, pool_size=pool_size) + initial_W_seed, initial_H_seed, top_a_seed = generator.generate_state(3) + if W is None: + W = self._randn(initial_W_seed) * math.sqrt(t1 - t0) + else: + _assert_floating_tensor('W', W) + if H is None: + H = self._randn(initial_H_seed) * math.sqrt((t1 - t0) / 12) + else: + _assert_floating_tensor('H', H) + self._w_h = (W, H) + self._top_a_seed = top_a_seed + + if not self._halfway_tree: + # We create a binary tree dependency between the points. If we don't do this then the forward pass is still + # efficient at O(N), but we end up with a dependency chain stretching along the interval [t0, t1], making + # the backward pass O(N^2). By setting up a dependency tree of depth relative to `dt` and `cache_size` we + # can instead make both directions O(N log N). + self._average_dt = 0 + self._tree_dt = t1 - t0 + self._num_evaluations = -100 # start off with a warmup period to get a decent estimate of the average + if dt is not None: + # Create the dependency tree based on the supplied hint `dt`. + self._create_dependency_tree(dt) + # If dt is None, then create the dependency tree based on observed statistics of query points. (In __call__) + + # Effectively permanently store our increment and space-time Levy area in the cache. + def _increment_and_space_time_levy_area(self): + return self._w_h + yield # make it a generator + + def _a_seed(self): + return self._top_a_seed + + def _set_spawn_key_and_depth(self): + self._spawn_key = 0 + self._depth = 0 + + def __call__(self, ta, tb=None, return_U=False, return_A=False): + if tb is None: + warnings.warn(f"{self.__class__.__name__} is optimised for interval-based queries, not point evaluation.") + ta, tb = self._start, ta + tb_name = 'ta' + else: + tb_name = 'tb' + ta = float(ta) + tb = float(tb) + if ta < self._start: + warnings.warn(f"Should have ta>=t0 but got ta={ta} and t0={self._start}.") + ta = self._start + if tb < self._start: + warnings.warn(f"Should have {tb_name}>=t0 but got {tb_name}={tb} and t0={self._start}.") + tb = self._start + if ta > self._end: + warnings.warn(f"Should have ta<=t1 but got ta={ta} and t1={self._end}.") + ta = self._end + if tb > self._end: + warnings.warn(f"Should have {tb_name}<=t1 but got {tb_name}={tb} and t1={self._end}.") + tb = self._end + if ta > tb: + raise RuntimeError(f"Query times ta={ta:.3f} and tb={tb:.3f} must respect ta <= tb.") + + if ta == tb: + W = torch.zeros(self._size, dtype=self._dtype, device=self._device) + H = None + A = None + if self._have_H: + H = torch.zeros(self._size, dtype=self._dtype, device=self._device) + if self._have_A: + size = (*self._size, *self._size[-1:]) # not self._size[-1] as that may not exist + A = torch.zeros(size, dtype=self._dtype, device=self._device) + else: + if self._dt is None and not self._halfway_tree: + self._num_evaluations += 1 + # We start off with "negative" num evaluations, to give us a small warm-up period at the start. + if self._num_evaluations > 0: + # Compute average step size so far + dt = tb - ta + self._average_dt = (dt + self._average_dt * (self._num_evaluations - 1)) / self._num_evaluations + if self._average_dt < 0.5 * self._tree_dt: + # If 'dt' wasn't specified, then check the average interval length against the size of the + # bottom of the dependency tree. If we're below halfway then refine the tree by splitting all + # the bottom pieces into two. + self._create_dependency_tree(dt) + + # Find the intervals that correspond to the query. We start our search at the last interval we accessed in + # the binary tree, as it's likely that the next query will come nearby. + intervals = self._last_interval._loc(ta, tb) + # Ideally we'd keep track of intervals[0] on the backward pass. Practically speaking len(intervals) tends to + # be 1 or 2 almost always so this isn't a huge deal. + self._last_interval = intervals[-1] + + W, H, A = intervals[0]._increment_and_levy_area() + if len(intervals) > 1: + # If we have multiple intervals then add up their increments and Levy areas. + + for interval in intervals[1:]: + Wi, Hi, Ai = interval._increment_and_levy_area() + if self._have_H: + # Aggregate H: + # Given s < u < t, then + # H_{s,t} = (term1 + term2) / (t - s) + # where + # term1 = (t - u) * (H_{u, t} + W_{s, u} / 2) + # term2 = (u - s) * (H_{s, u} - W_{u, t} / 2) + term1 = (interval._end - interval._start) * (Hi + 0.5 * W) + term2 = (interval._start - ta) * (H - 0.5 * Wi) + H = (term1 + term2) / (interval._end - ta) + if self._have_A and len(self._size) not in (0, 1): + # If len(self._size) in (0, 1) then we treat our scalar / single dimension as a batch + # dimension, so we have zero Levy area. (And these unsqueezes will result in a tensor of shape + # (batch, batch) which is wrong.) + + # Let B_{x, y} = \int_x^y W^1_{s,u} dW^2_u. + # Then + # B_{s, t} = \int_s^t W^1_{s,u} dW^2_u + # = \int_s^v W^1_{s,u} dW^2_u + \int_v^t W^1_{s,v} dW^2_u + \int_v^t W^1_{v,u} dW^2_u + # = B_{s, v} + W^1_{s, v} W^2_{v, t} + B_{v, t} + # + # A is now the antisymmetric part of B, which gives the formula below. + A = A + Ai + 0.5 * (W.unsqueeze(-1) * Wi.unsqueeze(-2) - Wi.unsqueeze(-1) * W.unsqueeze(-2)) + W = W + Wi + + U = None + if self._have_H: + U = _H_to_U(W, H, tb - ta) + + if return_U: + if return_A: + return W, U, A + else: + return W, U + else: + if return_A: + return W, A + else: + return W + + def _create_dependency_tree(self, dt): + # For safety we take a min with 100: if people take very large cache sizes then this would then break the + # logarithmic into linear, which causes RecursionErrors. + if self._cache_size is None: # cache_size=None corresponds to infinite cache. + cache_size = 100 + else: + cache_size = min(self._cache_size, 100) + + self._tree_dt = min(self._tree_dt, dt) + # Rationale: We are prepared to hold `cache_size` many things in memory, so when making steps of size `dt` + # then we can afford to have the intervals at the bottom of our binary tree be of size `dt * cache_size`. + # For safety we then make this a bit smaller by multiplying by 0.8. + piece_length = self._tree_dt * cache_size * 0.8 + + def _set_points(interval): + start = interval._start + end = interval._end + if end - start > piece_length: + midway = (end + start) / 2 + interval._loc(start, midway) + _set_points(interval._left_child) + _set_points(interval._right_child) + + _set_points(self) + + def __repr__(self): + if self._dt is None: + dt = None + else: + dt = f"{self._dt:.3f}" + return (f"{self.__class__.__name__}(" + f"t0={self._start:.3f}, " + f"t1={self._end:.3f}, " + f"size={self._size}, " + f"dtype={self._dtype}, " + f"device={repr(self._device)}, " + f"entropy={self._entropy}, " + f"dt={dt}, " + f"tol={self._tol}, " + f"pool_size={self._pool_size}, " + f"cache_size={self._cache_size}, " + f"levy_area_approximation={repr(self._levy_area_approximation)}" + f")") + + def display_binary_tree(self): + stack = [(self, 0)] + out = [] + while stack: + elem, depth = stack.pop() + out.append(" " * depth + f"({elem._start}, {elem._end})") + if elem._midway is not None: + stack.append((elem._right_child, depth + 1)) + stack.append((elem._left_child, depth + 1)) + print("\n".join(out)) + + @property + def shape(self): + return self._size + + @property + def dtype(self): + return self._dtype + + @property + def device(self): + return self._device + + @property + def entropy(self): + return self._entropy + + @property + def levy_area_approximation(self): + return self._levy_area_approximation + + @property + def dt(self): + return self._dt + + @property + def tol(self): + return self._tol + + @property + def pool_size(self): + return self._pool_size + + @property + def cache_size(self): + return self._cache_size + + @property + def halfway_tree(self): + return self._halfway_tree + + def size(self): + return self._size + + +class BrownianTree(brownian_base.BaseBrownian): + """Brownian tree with fixed entropy. + + Useful when the map from entropy -> Brownian motion shouldn't depend on the + locations and order of the query points. (As the usual BrownianInterval + does - note that BrownianTree is slower as a result though.) + + To use: + >>> bm = BrownianTree(t0=0.0, w0=torch.zeros(4, 1)) + >>> bm(0., 0.5) + tensor([[ 0.0733], + [-0.5692], + [ 0.1872], + [-0.3889]], device='cuda:0') + """ + + def __init__(self, t0: Scalar, + w0: Tensor, + t1: Optional[Scalar] = None, + w1: Optional[Tensor] = None, + entropy: Optional[int] = None, + tol: float = 1e-6, + pool_size: int = 24, + cache_depth: int = 9, + safety: Optional[float] = None): + """Initialize the Brownian tree. + + The random value generation process exploits the parallel random number paradigm and uses + `numpy.random.SeedSequence`. The default generator is PCG64 (used by `default_rng`). + + Arguments: + t0: Initial time. + w0: Initial state. + t1: Terminal time. + w1: Terminal state. + entropy: Global seed, defaults to `None` for random entropy. + tol: Error tolerance before the binary search is terminated; the search depth ~ log2(tol). + pool_size: Size of the pooled entropy. This parameter affects the query speed significantly. + cache_depth: Unused; deprecated. + safety: Unused; deprecated. + """ + + if t1 is None: + t1 = t0 + 1 + if w1 is None: + W = None + else: + W = w1 - w0 + self._w0 = w0 + self._interval = BrownianInterval(t0=t0, + t1=t1, + size=w0.shape, + dtype=w0.dtype, + device=w0.device, + entropy=entropy, + tol=tol, + pool_size=pool_size, + halfway_tree=True, + W=W) + super(BrownianTree, self).__init__() + + def __call__(self, t, tb=None, return_U=False, return_A=False): + # Deliberately called t rather than ta, for backward compatibility + out = self._interval(t, tb, return_U=return_U, return_A=return_A) + if tb is None and not return_U and not return_A: + out = out + self._w0 + return out + + def __repr__(self): + return f"{self.__class__.__name__}(interval={self._interval})" + + @property + def dtype(self): + return self._interval.dtype + + @property + def device(self): + return self._interval.device + + @property + def shape(self): + return self._interval.shape + + @property + def levy_area_approximation(self): + return self._interval.levy_area_approximation + + +def brownian_interval_like(y: Tensor, + t0: Optional[Scalar] = 0., + t1: Optional[Scalar] = 1., + size: Optional[Tuple[int, ...]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, + **kwargs): + """Returns a BrownianInterval object with the same size, device, and dtype as a given tensor.""" + size = y.shape if size is None else size + dtype = y.dtype if dtype is None else dtype + device = y.device if device is None else device + return brownian_interval.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, **kwargs) \ No newline at end of file -- Gitee From df8b961d15f0613ded5c0794083a30f85bef742c Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Fri, 18 Apr 2025 16:49:59 +0800 Subject: [PATCH 2/3] add_comfyui --- .../Stable-Audio-Open-1.0/README.md | 11 - .../StableAudio_Comfyui_Plugin/__init__.py | 3 + .../StableAudio_Comfyui_Plugin/node.py | 133 +++ .../requirements.txt | 6 + .../stableaudio/__init__.py | 4 + .../stableaudio/layers/__init__.py | 4 + .../stableaudio/layers/attention.py | 373 ++++++++ .../stableaudio/layers/linear.py | 99 ++ .../stableaudio/models/__init__.py | 1 + .../models/stable_audio_transformer.py | 431 +++++++++ .../stableaudio/pipeline/__init__.py | 1 + .../pipeline/pipeline_stable_audio.py | 754 +++++++++++++++ .../stableaudio/schedulers/__init__.py | 1 + .../scheduling_cosine_dpmsolver_multistep.py | 574 +++++++++++ .../schedulers/scheduling_dpmsolver_sde.py | 71 ++ .../schedulers/scheduling_utils.py | 893 ++++++++++++++++++ 16 files changed, 3348 insertions(+), 11 deletions(-) create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/requirements.txt create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/attention.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/linear.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/stable_audio_transformer.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/pipeline_stable_audio.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/__init__.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_dpmsolver_sde.py create mode 100644 MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_utils.py diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md b/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md index 2c536a6a3e..2c094cfee2 100644 --- a/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md @@ -1,14 +1,3 @@ ---- -pipeline_tag: text-to-audio -frameworks: - - PyTorch -license: apache-2.0 -library_name: openmind -hardwares: - - NPU -language: - - en ---- ## 一、准备运行环境 **表 1** 版本配套表 diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/__init__.py new file mode 100644 index 0000000000..98c2e997a2 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/__init__.py @@ -0,0 +1,3 @@ +from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +all = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py new file mode 100644 index 0000000000..42d5735145 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py @@ -0,0 +1,133 @@ +#Put this in the custom_nodes folder, put your mindiesd files in ComfyUI/models/mindiesd/ (you will have to create the directory) + +import torch +import torch_npu +import os +import json +from safetensors.torch import load_file + +import comfy.model_base +import comfy.model_management +import comfy.model_patcher +import comfy.supported_models +import comfy.supported_models_base +import comfy.sd +from comfy import utils +import folder_paths +from folder_paths import get_folder_paths, get_filename_list_ + +from .stableaudio import ( + StableAudioPipeline, + StableAudioDiTModel, + AutoencoderOobleck, +) +from mindiesd import CacheConfig, CacheAgent +from diffusers.models.embeddings import get_1d_rotary_pos_embed + + +MODEL_FOLDER_PATH = folder_paths.get_folder_paths("mindiesd")[0] +PIPELINE_FOLDER_PATH = os.path.join(MODEL_FOLDER_PATH, "stable-audio-open-1.0") +UNET_FOLDER_PATH = os.path.join(PIPELINE_FOLDER_PATH, "transformer") + + +class UnetAdapter(torch.nn.Module): + def __init__(self, engine_path, use_attentioncache, start_step, attentioncache_interval, end_step): + super().__init__() + + self.device = comfy.model_management.get_torch_device() + self.dtype = torch.float16 + print("os.path.dirname(engine_path)",os.path.dirname(engine_path)) + self.unet_model = StableAudioDiTModel.from_pretrained(os.path.dirname(engine_path), + torch_dtype=self.dtype) + transformer = self.unet_model + if use_attentioncache: + config = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.transformer_blocks), + steps_count=100, + step_start=start_step, + step_interval=attentioncache_interval, + step_end=end_step + ) + else: + config = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.transformer_blocks), + steps_count=100 + ) + cache = CacheAgent(config) + for block in transformer.transformer_blocks: + block.cache = cache + + # self.unet_model = self.unet_model.eval() + self.unet_model = self.unet_model.to(self.device).eval() + self.unet_model.config.addition_embed_type = None + + # freqs + rotary_embedding = get_1d_rotary_pos_embed( + 32, + 1025, + use_real=True, + repeat_interleave_real=False, + ) + cos = rotary_embedding[0][None, :, None, :].to(self.device).to(self.dtype) + sin = rotary_embedding[1][None, :, None, :].to(self.device).to(self.dtype) + self.rotary_embedding = (cos, sin) + + def __call__(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): + global_embed = kwargs.get("global_embed", None) + noise_pred = self.unet_model( + 0, + x, + timesteps, + context, + global_embed.unsqueeze(1), + rotary_embedding=self.rotary_embedding, + return_dict=False, + )[0] + return noise_pred + + +class StableAudioUnetLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": {"unet_name": (folder_paths.get_filename_list("mindiesd"), {"tooltip": "The name of the transformer (model) to load."}), + # "ckpt_path": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), + "use_attentioncache": ("BOOLEAN", {"default": True, "tooltip": "Use AGB Cache."}), + "start_step": ("INT", {"default": 60, "min": 10, "max": 10000, "tooltip": ""}), + "attentioncache_interval": ("INT", {"default": 5, "min": 1, "max": 10, "tooltip": ""}), + "end_step": ("INT", {"default": 97, "min": 1, "max": 10000, "tooltip": ""}), + }} + CATEGORY = "MindIE SD/StableAudio" + + def load_unet(self, unet_name, use_attentioncache, start_step, attentioncache_interval, end_step): + + unet_path = os.path.join(MODEL_FOLDER_PATH, unet_name) + unet = UnetAdapter(unet_path, use_attentioncache, start_step, attentioncache_interval, end_step) + conf = comfy.supported_models.StableAudio( {"audio_model": "dit1.0"}) + conf.unet_config["disable_unet_model_creation"] = True + conf.unet_config = { + "audio_model": "dit1.0", + } + + project_path = os.path.join("projection_model", "diffusion_pytorch_model.safetensors") + project_model=load_file(os.path.join(PIPELINE_FOLDER_PATH, project_path)) + seconds_start_embedder_weights = utils.state_dict_prefix_replace(project_model, {"start_number_conditioner.time_positional_": "embedder."}, filter_keys=True) + seconds_total_embedder_weights = utils.state_dict_prefix_replace(project_model, {"end_number_conditioner.time_positional_": "embedder."}, filter_keys=True) + + model = comfy.model_base.StableAudio1(conf, seconds_start_embedder_weights, seconds_total_embedder_weights) + + model.diffusion_model = unet + model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting + + return (comfy.model_patcher.ModelPatcher(model, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device()),) + +NODE_CLASS_MAPPINGS = { +"StableAudioUnetLoader": StableAudioUnetLoader, +} + +NODE_DISPLAY_NAME_MAPPINGS = { +"StableAudioUnetLoader": "MindIE SD StableAudio Unet Loader", +} diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/requirements.txt b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/requirements.txt new file mode 100644 index 0000000000..d9137f6285 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/requirements.txt @@ -0,0 +1,6 @@ +torch==2.1.0 +torchsde==0.2.6 +diffusers==0.30.0 +transformers==4.40.0 +soundfile==0.12.1 +safetensors \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/__init__.py new file mode 100644 index 0000000000..f484d3fceb --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/__init__.py @@ -0,0 +1,4 @@ +from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck +from .pipeline import StableAudioPipeline, StableAudioProjectionModel +from .models import StableAudioDiTModel, ModelMixin +from .schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler, SchedulerMixin \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/__init__.py new file mode 100644 index 0000000000..f2f044203d --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/__init__.py @@ -0,0 +1,4 @@ +from .attention import ( + Attention, + StableAudioAttnProcessor2_0, +) \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/attention.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/attention.py new file mode 100644 index 0000000000..aec5aa7ca0 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/attention.py @@ -0,0 +1,373 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch_npu +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import logging +from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 +from mindiesd import attention_forward, rotary_position_embedding +from .linear import QKVLinear + +logger = logging.get_logger(__name__) + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + pre_only=False, + ): + super().__init__() + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_qkv = QKVLinear( + attention_dim=query_dim, + hidden_size=self.inner_dim, + qkv_bias=self.use_bias, + cross_attention_dim=cross_attention_dim, + cross_hidden_size=self.inner_kv_dim, + ) + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + if qk_norm is not None and added_kv_proj_dim is not None: + if qk_norm == "fp32_layer_norm": + self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) + elif qk_norm == "rms_norm": + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + self.norm_added_q = None + self.norm_added_k = None + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + +class StableAudioAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + query, key, value = attn.to_qkv(hidden_states, encoder_hidden_states) + + head_dim = attn.dim_head + kv_heads = attn.inner_kv_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = key.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1) + key = key.view(batch_size, sequence_length, attn.heads, head_dim) + value = value.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1) + value = value.view(batch_size, sequence_length, attn.heads, head_dim) + + # Apply RoPE if needed + if rotary_emb is not None: + cos, sin = rotary_emb + query_to_rotate, query_unrotated = torch.chunk(query, 2, 3) + query_rotated = rotary_position_embedding(query_to_rotate, cos, sin, rotated_mode="rotated_half", fused=True) + query = torch.cat((query_rotated, query_unrotated), dim=-1) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = torch.chunk(key, 2, 3) + key_rotated = rotary_position_embedding(key_to_rotate, cos, sin, rotated_mode="rotated_half", fused=True) + key = torch.cat((key_rotated, key_unrotated), dim=-1) + + hidden_states = attention_forward( + query, key, value, + opt_mode="manual", + op_type="fused_attn_score", + layout="BSND" + ) + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/linear.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/linear.py new file mode 100644 index 0000000000..3e8ee84fc8 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/linear.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch_npu + + +class QKVLinear(nn.Module): + def __init__(self, attention_dim, hidden_size, qkv_bias=True, cross_attention_dim=None, cross_hidden_size=None, + device=None, dtype=None): + super(QKVLinear, self).__init__() + self.attention_dim = attention_dim + self.hidden_size = hidden_size + + self.cross_attention_dim = cross_attention_dim + self.cross_hidden_size = self.hidden_size if cross_hidden_size is None else cross_hidden_size + self.qkv_bias = qkv_bias + + factory_kwargs = {"device": device, "dtype": dtype} + + if cross_attention_dim is None: + self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs)) + if self.qkv_bias: + self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs)) + else: + self.q_weight = nn.Parameter(torch.empty([self.attention_dim, self.hidden_size], **factory_kwargs)) + self.kv_weight = nn.Parameter( + torch.empty([self.cross_attention_dim, 2 * self.cross_hidden_size], **factory_kwargs)) + + if self.qkv_bias: + self.q_bias = nn.Parameter(torch.empty([self.hidden_size], **factory_kwargs)) + self.kv_bias = nn.Parameter(torch.empty([2 * self.cross_hidden_size], **factory_kwargs)) + + def forward(self, hidden_states, encoder_hidden_states=None): + + if self.cross_attention_dim is None: + if not self.qkv_bias: + qkv = torch.matmul(hidden_states, self.weight) + else: + qkv = torch.addmm( + self.bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.weight, + beta=1, + alpha=1 + ) + + batch, seqlen, _ = hidden_states.shape + qkv_shape = (batch, seqlen, 3, -1) + qkv = qkv.view(qkv_shape) + q, k, v = qkv.unbind(2) + + else: + if not self.qkv_bias: + q = torch.matmul(hidden_states, self.q_weight) + kv = torch.matmul(encoder_hidden_states, self.kv_weight) + else: + q = torch.addmm( + self.q_bias, + hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)), + self.q_weight, + beta=1, + alpha=1 + ) + kv = torch.addmm( + self.kv_bias, + encoder_hidden_states.view( + encoder_hidden_states.size(0) * encoder_hidden_states.size(1), + encoder_hidden_states.size(2)), + self.kv_weight, + beta=1, + alpha=1 + ) + + batch, q_seqlen, _ = hidden_states.shape + q = q.view(batch, q_seqlen, -1) + + batch, kv_seqlen, _ = encoder_hidden_states.shape + kv_shape = (batch, kv_seqlen, 2, -1) + + kv = kv.view(kv_shape) + k, v = kv.unbind(2) + + return q, k, v diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/__init__.py new file mode 100644 index 0000000000..c4170f5af4 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/__init__.py @@ -0,0 +1 @@ +from .stable_audio_transformer import StableAudioDiTModel, ModelMixin \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/stable_audio_transformer.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/stable_audio_transformer.py new file mode 100644 index 0000000000..aeb06b2ade --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/stable_audio_transformer.py @@ -0,0 +1,431 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Union +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import logging + +from ..layers.attention import ( + Attention, + StableAudioAttnProcessor2_0, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +class StableAudioDiTBlock(nn.Module): + r""" + Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip + connection and QKNorm + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for the query states. + num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ): + super().__init__() + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = nn.LayerNorm(dim, norm_eps, True) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_key_value_attention_heads, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, True) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn="swiglu", + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + self.cache = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + rotary_embedding: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.cache.apply( + self.attn1, + norm_hidden_states, + attention_mask=attention_mask, + rotary_emb=rotary_embedding, + ) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class StableAudioDiTModel(ModelMixin, ConfigMixin): + """ + The Diffusion Transformer model introduced in Stable Audio. + + Reference: https://github.com/Stability-AI/stable-audio-tools + + Parameters: + sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. + in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. + num_key_value_attention_heads (`int`, *optional*, defaults to 12): + The number of heads to use for the key and value states. + out_channels (`int`, defaults to 64): Number of output channels. + cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. + time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. + global_states_input_dim ( `int`, *optional*, defaults to 1536): + Input dimension of the global hidden states projection. + cross_attention_input_dim ( `int`, *optional*, defaults to 768): + Input dimension of the cross-attention projection + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + + self.cache_block_start = 11 + self.cache_step_interval = 2 + self.cache_num_blocks = 9 + self.cache_step_start = 5 + + self.num_layers = num_layers + + self.sample_size = sample_size + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.init_dtype = self.dtype + self.time_proj = StableAudioGaussianFourierProjection( + embedding_size=time_proj_dim // 2, + flip_sin_to_cos=True, + log=False, + set_W_to_weight=False, + ) + + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for i in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) + self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) + + self.gradient_checkpointing = False + + def forward( + self, + step_id, + hidden_states: torch.FloatTensor, + timestep: torch.LongTensor = None, + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, + return_dict: bool = True, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`StableAudioDiTModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): + Input `hidden_states`. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): + Global embeddings that will be prepended to the hidden states. + rotary_embedding (`torch.Tensor`): + The rotary embeddings to apply on query and key tensors during attention calculation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention + masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating + the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + global_hidden_states = self.global_proj(global_hidden_states) + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.init_dtype))) + + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.proj_in(hidden_states) + + # prepend global states to hidden states + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) + if attention_mask is not None: + prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + if not use_cache or (use_cache and step_id < self.cache_step_start): + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=0, + end_id=self.num_layers, + ) + else: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=0, + end_id=self.cache_block_start, + ) + + cache_end = np.minimum(self.cache_block_start + self.cache_num_blocks, self.num_layers) + hidden_states_pre_cache = hidden_states.clone() + if (step_id - self.cache_step_start) % self.cache_step_interval == 0: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=self.cache_block_start, + end_id=cache_end, + ) + self.delta_cache = hidden_states - hidden_states_pre_cache + else: + hidden_states = hidden_states_pre_cache + self.delta_cache + + if cache_end < self.num_layers: + hidden_states = self._transformer_blocks_forward( + hidden_states=hidden_states, + encoder_hidden_states=cross_attention_hidden_states, + rotary_embedding=rotary_embedding, + start_id=cache_end, + end_id=self.num_layers, + ) + + hidden_states = self.proj_out(hidden_states) + + # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) + # remove prepend length that has been added by global hidden states + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) + + def _transformer_blocks_forward(self, hidden_states, encoder_hidden_states, rotary_embedding, start_id, end_id): + for block in self.transformer_blocks[start_id: end_id]: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_embedding=rotary_embedding, + ) + return hidden_states + + def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) ->None: + for i in range(self.num_layers): + self_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None) + self_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None) + self_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None) + self_qkv_weight = torch.cat([self_q_weight, self_k_weight, self_v_weight], dim=0).transpose(0, 1).contiguous() + state_dict[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = self_qkv_weight + + cross_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_q.weight", None) + cross_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_k.weight", None) + cross_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_v.weight", None) + cross_q_weight = cross_q_weight.transpose(0, 1).contiguous() + cross_kv_weight = torch.cat([cross_k_weight, cross_v_weight], dim=0).transpose(0, 1).contiguous() + state_dict[f"transformer_blocks.{i}.attn2.to_qkv.q_weight"] = cross_q_weight + state_dict[f"transformer_blocks.{i}.attn2.to_qkv.kv_weight"] = cross_kv_weight \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/__init__.py new file mode 100644 index 0000000000..a283da8ecd --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/__init__.py @@ -0,0 +1 @@ +from .pipeline_stable_audio import StableAudioPipeline, StableAudioProjectionModel \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/pipeline_stable_audio.py new file mode 100644 index 0000000000..3abd5ecc9d --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/pipeline_stable_audio.py @@ -0,0 +1,754 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +from transformers import ( + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, +) + +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck +from diffusers.utils import ( + logging, + replace_example_docstring, +) + +from ..models import StableAudioDiTModel +from ..schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> import soundfile as sf + >>> from diffusers import StableAudioPipeline + + >>> repo_id = "stabilityai/stable-audio-open-1.0" + >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_end_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> output = audio[0].T.float().cpu().numpy() + >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate) + ``` +""" + + +class StableAudioPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using StableAudio. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderOobleck`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.T5EncoderModel`]): + Frozen text-encoder. StableAudio uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant. + projection_model ([`StableAudioProjectionModel`]): + A trained model used to linearly project the hidden-states from the text encoder model and the start and + end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to + give the input to the transformer model. + tokenizer ([`~transformers.T5Tokenizer`]): + Tokenizer to tokenize text for the frozen text-encoder. + transformer ([`StableAudioDiTModel`]): + A `StableAudioDiTModel` to denoise the encoded audio latents. + scheduler ([`CosineDPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. + """ + + model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae" + + def __init__( + self, + vae: AutoencoderOobleck, + text_encoder: T5EncoderModel, + projection_model: StableAudioProjectionModel, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + transformer: StableAudioDiTModel, + scheduler: CosineDPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + projection_model=projection_model, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # 1. Tokenize text + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because {self.text_encoder.config.model_type} can " + f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if do_classifier_free_guidance and negative_prompt is not None: + uncond_tokens: List[str] + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # 1. Tokenize text + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if negative_attention_mask is not None: + # set the masked tokens to the null embed + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0 + ) + + # 3. Project prompt_embeds and negative_prompt_embeds + if do_classifier_free_guidance and negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the negative and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance, + batch_size, + ): + audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s] + audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s] + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + # Cast the inputs to floats + audio_start_in_s = [float(x) for x in audio_start_in_s] + audio_start_in_s = torch.tensor(audio_start_in_s).to(device) + + audio_end_in_s = [float(x) for x in audio_end_in_s] + audio_end_in_s = torch.tensor(audio_end_in_s).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + # For classifier free guidance, we need to do two forward passes. + # Here we repeat the audio hidden states to avoid doing two forward passes + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + initial_audio_waveforms=None, + initial_audio_sampling_rate=None, + ): + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but " + ) + + if ( + audio_start_in_s < self.projection_model.config.min_value + or audio_start_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_start_in_s}." + ) + + if ( + audio_end_in_s < self.projection_model.config.min_value + or audio_end_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_end_in_s}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds`. Cannot leave" + "`prompt` undefined without specifying `prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if initial_audio_sampling_rate is None and initial_audio_waveforms is not None: + raise ValueError( + "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`." + ) + + if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate: + raise ValueError( + f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`." + "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. " + ) + + def prepare_latents( + self, + batch_size, + num_channels_vae, + sample_size, + dtype, + device, + generator, + latents=None, + initial_audio_waveforms=None, + num_waveforms_per_prompt=None, + audio_channels=None, + ): + shape = (batch_size, num_channels_vae, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # encode the initial audio for use by the model + if initial_audio_waveforms is not None: + # check dimension + if initial_audio_waveforms.ndim == 2: + initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1) + elif initial_audio_waveforms.ndim != 3: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" + ) + + audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length + audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) + + # check num_channels + if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2: + initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1) + elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1: + initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True) + + if initial_audio_waveforms.shape[:2] != audio_shape[:2]: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`" + ) + + # crop or pad + audio_length = initial_audio_waveforms.shape[-1] + if audio_length < audio_vae_length: + logger.warning( + f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded." + ) + elif audio_length > audio_vae_length: + logger.warning( + f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped." + ) + + audio = initial_audio_waveforms.new_zeros(audio_shape) + audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] + + encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) + encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) + latents = encoded_audio + latents + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_end_in_s: Optional[float] = None, + audio_start_in_s: Optional[float] = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + initial_audio_waveforms: Optional[torch.Tensor] = None, + initial_audio_sampling_rate: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + output_type: Optional[str] = "pt", + use_cache: Optional[bool] = False, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_end_in_s (`float`, *optional*, defaults to 47.55): + Audio end index in seconds. + audio_start_in_s (`float`, *optional*, defaults to 0): + Audio start index in seconds. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + initial_audio_waveforms (`torch.Tensor`, *optional*): + Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape + `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` + corresponds to the number of prompts passed to the model. + initial_audio_sampling_rate (`int`, *optional*): + Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to latent length + downsample_ratio = self.vae.hop_length + + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + initial_audio_waveforms, + initial_audio_sampling_rate, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create text_audio_duration_embeds and audio_duration_embeds + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 + ) + + audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) + + # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and + # to concatenate it to the embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like( + text_audio_duration_embeds, device=text_audio_duration_embeds.device + ) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 + ) + audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) + + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + initial_audio_waveforms, + num_waveforms_per_prompt, + audio_channels=self.vae.config.audio_channels, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary positional embedding + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + + cos = rotary_embedding[0][None, :, None, :].to(latents.device).to(torch.float16) + sin = rotary_embedding[1][None, :, None, :].to(latents.device).to(torch.float16) + rotary_embedding = (cos, sin) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + i, + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + use_cache=use_cache, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + audio = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = audio[:, :, waveform_start:waveform_end] + + if output_type == "np": + audio = audio.cpu().float().numpy() + + self.maybe_free_model_hooks() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/__init__.py new file mode 100644 index 0000000000..5bad3d9b15 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/__init__.py @@ -0,0 +1 @@ +from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py new file mode 100644 index 0000000000..d4e7805d13 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -0,0 +1,574 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler + + +class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021). + This scheduler was used in Stable Audio Open [1]. + + [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.3): + Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1]. + sigma_max (`float`, *optional*, defaults to 500): + Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1]. + sigma_data (`float`, *optional*, defaults to 1.0): + The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1]. + sigma_schedule (`str`, *optional*, defaults to `exponential`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`. + prediction_type (`str`, defaults to `v_prediction`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.3, + sigma_max: float = 500, + sigma_data: float = 1.0, + sigma_schedule: str = "exponential", + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "v_prediction", + rho: float = 7.0, + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + ramp = torch.linspace(0, 1, num_train_timesteps) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + # setable values + self.num_inference_steps = None + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + return sigma.atan() / math.pi * 2 + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + self.num_inference_steps = num_inference_steps + + ramp = torch.linspace(0, 1, self.num_inference_steps) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + sigmas = sigmas.to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.config.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # if a noise sampler is used, reinitialise it + self.noise_sampler = None + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 + sigma_t = sigma + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + sigma = self.sigmas[self.step_index] + x0_pred = self.precondition_outputs(sample, model_output, sigma) + + return x0_pred + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if noise is None: + raise ValueError("noise must not be None") + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + # sde-dpmsolver++ + if noise is None: + raise ValueError("noise must not be None") + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.noise_sampler is None: + seed = None + if generator is not None: + seed = ( + [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() + ) + self.noise_sampler = BrownianTreeNoiseSampler( + model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed + ) + noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( + model_output.device + ) + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_dpmsolver_sde.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_dpmsolver_sde.py new file mode 100644 index 0000000000..8533b082f5 --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_dpmsolver_sde.py @@ -0,0 +1,71 @@ +# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from .scheduling_utils import BrownianTree + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get("w0", torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + if len(seed) == x.shape[0]: + raise ValueError(f"len(seed) should equal to x.shape[0], but got{len(seed)}") + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each + with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() \ No newline at end of file diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_utils.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_utils.py new file mode 100644 index 0000000000..e4734b052f --- /dev/null +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_utils.py @@ -0,0 +1,893 @@ + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +import trampoline + +import numpy as np +import torch + +from torchsde._brownian import brownian_base +from torchsde.settings import LEVY_AREA_APPROXIMATIONS +from torchsde.types import Scalar, Optional, Tuple, Union, Tensor + +_rsqrt3 = 1 / math.sqrt(3) +_r12 = 1 / 12 + + +def _randn(size, dtype, device, seed): + generator = torch.Generator(device).manual_seed(int(seed)) + return torch.randn(size, dtype=dtype, device=device, generator=generator) + + +def _is_scalar(x): + return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1) + + +def _assert_floating_tensor(name, tensor): + if not torch.is_tensor(tensor): + raise ValueError(f"{name}={tensor} should be a Tensor.") + if not tensor.is_floating_point(): + raise ValueError(f"{name}={tensor} should be floating point.") + + +def _check_tensor_info(*tensors, size, dtype, device): + """Check if sizes, dtypes, and devices of input tensors all match prescribed values.""" + tensors = list(filter(torch.is_tensor, tensors)) + + if dtype is None and len(tensors) == 0: + dtype = torch.get_default_dtype() + if device is None and len(tensors) == 0: + device = torch.device("cpu") + + sizes = [] if size is None else [size] + sizes += [t.shape for t in tensors] + + dtypes = [] if dtype is None else [dtype] + dtypes += [t.dtype for t in tensors] + + devices = [] if device is None else [device] + devices += [t.device for t in tensors] + + if len(sizes) == 0: + raise ValueError("Must either specify `size` or pass in `W` or `H` to implicitly define the size.") + + if not all(i == sizes[0] for i in sizes): + raise ValueError("Multiple sizes found. Make sure `size` and `W` or `H` are consistent.") + if not all(i == dtypes[0] for i in dtypes): + raise ValueError("Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.") + if not all(i == devices[0] for i in devices): + raise ValueError("Multiple devices found. Make sure `device` and `W` or `H` are consistent.") + + # Make sure size is a tuple (not a torch.Size) for neat repr-printing purposes. + return tuple(sizes[0]), dtypes[0], devices[0] + + +def _davie_foster_approximation(W, H, h, levy_area_approximation, get_noise): + if levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.none, LEVY_AREA_APPROXIMATIONS.space_time): + return None + elif W.ndimension() in (0, 1): + # If we have zero or one dimensions then treat the scalar / single dimension we have as batch, so that the + # Brownian motion is one dimensional and the Levy area is zero. + return torch.zeros_like(W) + else: + # Davie's approximation to the Levy area from space-time Levy area + A = H.unsqueeze(-1) * W.unsqueeze(-2) - W.unsqueeze(-1) * H.unsqueeze(-2) + noise = get_noise() + noise = noise - noise.transpose(-1, -2) # noise is skew symmetric of variance 2 + if levy_area_approximation == LEVY_AREA_APPROXIMATIONS.foster: + # Foster's additional correction to Davie's approximation + tenth_h = 0.1 * h + H_squared = H ** 2 + std = (tenth_h * (tenth_h + H_squared.unsqueeze(-1) + H_squared.unsqueeze(-2))).sqrt() + else: # davie approximation + std = math.sqrt(_r12 * h ** 2) + a_tilde = std * noise + A += a_tilde + return A + + +def _H_to_U(W: torch.Tensor, H: torch.Tensor, h: float) -> torch.Tensor: + return h * (.5 * W + H) + + +class _EmptyDict: + def __setitem__(self, key, value): + pass + + def __getitem__(self, item): + raise KeyError + + +class _LRUDict(dict): + def __init__(self, max_size): + super().__init__() + self._max_size = max_size + self._keys = [] + + def __setitem__(self, key, value): + if key in self: + self._keys.remove(key) + elif len(self) >= self._max_size: + del self[self._keys.pop(0)] + super().__setitem__(key, value) + self._keys.append(key) + + +class _Interval: + # Intervals correspond to some subinterval of the overall interval [t0, t1]. + # They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left + # and right subintervals, which partition the parent interval. + + __slots__ = ( + # These are the things that every interval has + '_start', + '_end', + '_parent', + '_is_left', + '_top', + # These are the things that intervals which are parents also have + '_midway', + '_spawn_key', + '_depth', + '_W_seed', + '_H_seed', + '_left_a_seed', + '_right_a_seed', + '_left_child', + '_right_child') + + def __init__(self, start, end, parent, is_left, top): + self._start = top._round(start) # the left hand edge of the interval + self._end = top._round(end) # the right hand edge of the interval + self._parent = parent # our parent interval + self._is_left = is_left # are we the left or right child of our parent + self._top = top # the top-level BrownianInterval, where we cache certain state + self._midway = None # The point at which we split between left and right subintervals + + ######################################## + # Calculate increments and levy area # + ######################################## + # + # This is a little bit convoluted, so here's an explanation. + # + # The entry point is _increment_and_levy_area, below. This immediately calls _increment_and_space_time_levy_area, + # applies the space-time to full Levy area correction, and then returns. + # + # _increment_and_space_time_levy_area in turn calls a central LRU cache, as (later on) we'll need the increment and + # space-time Levy area of the parent interval to compute our own increment and space-time Levy area, and it's likely + # that our parent exists in the cache, as if we're being queried then our parent was probably queried recently as + # well. + # (The top-level BrownianInterval overrides _increment_and_space_time_levy_area to return its own increment and + # space-time Levy area, effectively holding them permanently in the cache.) + # + # If the request isn't found in the LRU cache then it computes it from its parent. + # Now it turns out that the size of our increment and space-time Levy area is really most naturally thought of as a + # property of our parent: it depends on our parent's increment, space-time Levy area, and whether we are the left or + # right interval within our parent. So _increment_and_space_time_levy_area in turn checks if we are on the + # left or right of our parent and does most of the computation using the parent's attributes. + + def _increment_and_levy_area(self): + W, H = trampoline.trampoline(self._increment_and_space_time_levy_area()) + A = _davie_foster_approximation(W, H, self._end - self._start, self._top._levy_area_approximation, + self._randn_levy) + return W, H, A + + def _increment_and_space_time_levy_area(self): + try: + return self._top._increment_and_space_time_levy_area_cache[self] + except KeyError: + parent = self._parent + + W, H = yield parent._increment_and_space_time_levy_area() + h_reciprocal = 1 / (parent._end - parent._start) + left_diff = parent._midway - parent._start + right_diff = parent._end - parent._midway + + if self._top._have_H: + left_diff_squared = left_diff ** 2 + right_diff_squared = right_diff ** 2 + left_diff_cubed = left_diff * left_diff_squared + right_diff_cubed = right_diff * right_diff_squared + + v = 0.5 * math.sqrt(left_diff * right_diff / (left_diff_cubed + right_diff_cubed)) + + a = v * left_diff_squared * h_reciprocal + b = v * right_diff_squared * h_reciprocal + c = v * _rsqrt3 + + X1 = parent._randn(parent._W_seed) + X2 = parent._randn(parent._H_seed) + + third_coeff = 2 * (a * left_diff + b * right_diff) * h_reciprocal + + if self._is_left: + first_coeff = left_diff * h_reciprocal + second_coeff = 6 * first_coeff * right_diff * h_reciprocal + out_W = first_coeff * W + second_coeff * H + third_coeff * X1 + out_H = first_coeff ** 2 * H - a * X1 + c * right_diff * X2 + else: + first_coeff = right_diff * h_reciprocal + second_coeff = 6 * first_coeff * left_diff * h_reciprocal + out_W = first_coeff * W - second_coeff * H - third_coeff * X1 + out_H = first_coeff ** 2 * H - b * X1 - c * left_diff * X2 + else: + # Don't compute space-time Levy area unless we need to + + mean = left_diff * h_reciprocal * W + var = left_diff * right_diff * h_reciprocal + noise = parent._randn(parent._W_seed) + left_W = mean + math.sqrt(var) * noise + + if self._is_left: + out_W = left_W + else: + out_W = W - left_W + out_H = None + + self._top._increment_and_space_time_levy_area_cache[self] = (out_W, out_H) + return out_W, out_H + + def _randn(self, seed): + # We generate random noise deterministically wrt some seed; this seed is determined by the generator. + # This means that if we drop out of the cache, then we'll create the same random noise next time, as we still + # have the generator. + size = self._top._size + return _randn(size, self._top._dtype, self._top._device, seed) + + def _a_seed(self): + return self._parent._left_a_seed if self._is_left else self._parent._right_a_seed + + def _randn_levy(self): + size = (*self._top._size, *self._top._size[-1:]) + return _randn(size, self._top._dtype, self._top._device, self._a_seed()) + + ######################################## + # Locate an interval in the hierarchy # + ######################################## + # + # The other important piece of this construction is a way to locate any given interval within the binary tree + # hierarchy. (This is typically the slightly slower part, actually, so if you want to speed things up then this is + # the bit to target.) + # + # loc finds the interval [ta, tb] - and creates it in the appropriate place (as a child of some larger interval) if + # it doesn't already exist. As in principle we may request an interval that covers multiple existing intervals, then + # in fact the interval [ta, tb] is returned as an ordered list of existing subintervals. + # + # It calls _loc, which operates recursively. See _loc for more details on how the search works. + + def _loc(self, ta, tb): + out = [] + ta = self._top._round(ta) + tb = self._top._round(tb) + trampoline.trampoline(self._loc_inner(ta, tb, out)) + return out + + def _loc_inner(self, ta, tb, out): + # Expect to have ta < tb + + # First, we (this interval) only have jurisdiction over [self._start, self._end]. So if we're asked for + # something outside of that then we pass the buck up to our parent, who is strictly larger. + if ta < self._start or tb > self._end: + raise trampoline.TailCall(self._parent._loc_inner(ta, tb, out)) + + # If it's us that's being asked for, then we add ourselves on to out and return. + if ta == self._start and tb == self._end: + out.append(self) + return + + # If we've got this far then we know that it's an interval that's within our jurisdiction, and that it's not us. + # So next we check if it's up to us to figure out, or up to our children. + if self._midway is None: + # It's up to us. Create subintervals (_split) if appropriate. + if ta == self._start: + self._split(tb) + raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out)) + # implies ta > self._start + self._split(ta) + # Query our (newly created) right_child: if tb == self._end then our right child will be the result, and it + # will tell us so. But if tb < self._end then our right_child will need to make another split of its own. + raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out)) + + # If we're here then we have children: self._midway is not None + if tb <= self._midway: + # Strictly our left_child's problem + raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out)) + if ta >= self._midway: + # Strictly our right_child's problem + raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out)) + # It's a problem for both of our children: the requested interval overlaps our midpoint. Call the left_child + # first (to append to out in the correct order), then call our right child. + # (Implies ta < self._midway < tb) + yield self._left_child._loc_inner(ta, self._midway, out) + raise trampoline.TailCall(self._right_child._loc_inner(self._midway, tb, out)) + + def _set_spawn_key_and_depth(self): + self._spawn_key = 2 * self._parent._spawn_key + (0 if self._is_left else 1) + self._depth = self._parent._depth + 1 + + def _split(self, midway): + if self._top._halfway_tree: + self._split_exact(0.5 * (self._end + self._start)) + # self._midway is now the rounded halfway point. + if midway > self._midway: + self._right_child._split(midway) + elif midway < self._midway: + self._left_child._split(midway) + else: + self._split_exact(midway) + + def _split_exact(self, midway): # Create two children + self._midway = self._top._round(midway) + # Use splittable PRNGs to generate noise. + self._set_spawn_key_and_depth() + generator = np.random.SeedSequence(entropy=self._top._entropy, + spawn_key=(self._spawn_key, self._depth), + pool_size=self._top._pool_size) + self._W_seed, self._H_seed, self._left_a_seed, self._right_a_seed = generator.generate_state(4) + + self._left_child = _Interval(start=self._start, + end=midway, + parent=self, + is_left=True, + top=self._top) + self._right_child = _Interval(start=midway, + end=self._end, + parent=self, + is_left=False, + top=self._top) + + +class BrownianInterval(brownian_base.BaseBrownian, _Interval): + """Brownian interval with fixed entropy. + + Computes increments (and optionally Levy area). + + To use: + >>> bm = BrownianInterval(t0=0.0, t1=1.0, size=(4, 1), device='cuda') + >>> bm(0., 0.5) + tensor([[ 0.0733], + [-0.5692], + [ 0.1872], + [-0.3889]], device='cuda:0') + """ + + __slots__ = ( + # Inputs + '_size', + '_dtype', + '_device', + '_entropy', + '_levy_area_approximation', + '_dt', + '_tol', + '_pool_size', + '_cache_size', + '_halfway_tree', + # Quantisation + '_round', + # Caching, searching and computing values + '_increment_and_space_time_levy_area_cache', + '_last_interval', + '_have_H', + '_have_A', + '_w_h', + '_top_a_seed', + # Dependency tree creation + '_average_dt', + '_tree_dt', + '_num_evaluations' + ) + + def __init__(self, + t0: Optional[Scalar] = 0., + t1: Optional[Scalar] = 1., + size: Optional[Tuple[int, ...]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, + entropy: Optional[int] = None, + dt: Optional[Scalar] = None, + tol: Scalar = 0., + pool_size: int = 8, + cache_size: Optional[int] = 45, + halfway_tree: bool = False, + levy_area_approximation: str = LEVY_AREA_APPROXIMATIONS.none, + W: Optional[Tensor] = None, + H: Optional[Tensor] = None): + """Initialize the Brownian interval. + + Args: + t0 (float or Tensor): Initial time. + t1 (float or Tensor): Terminal time. + size (tuple of int): The shape of each Brownian sample. + If zero dimensional represents a scalar Brownian motion. + If one dimensional represents a batch of scalar Brownian motions. + If >two dimensional the last dimension represents the size of a + a multidimensional Brownian motion, and all previous dimensions + represent batch dimensions. + dtype (torch.dtype): The dtype of each Brownian sample. + Defaults to the PyTorch default. + device (str or torch.device): The device of each Brownian sample. + Defaults to the CPU. + entropy (int): Global seed, defaults to `None` for random entropy. + levy_area_approximation (str): Whether to also approximate Levy + area. Defaults to 'none'. Valid options are 'none', + 'space-time', 'davie' or 'foster', corresponding to different + approximation types. + This is needed for some higher-order SDE solvers. + dt (float or Tensor): The expected average step size of the SDE + solver. Set it if you know it (e.g. when using a fixed-step + solver); else it will be estimated from the first few queries. + This is used to set up the data structure such that it is + efficient to query at these intervals. + tol (float or Tensor): What tolerance to resolve the Brownian motion + to. Must be non-negative. Defaults to zero, i.e. floating point + resolution. Usually worth setting in conjunction with + `halfway_tree`, below. + pool_size (int): Size of the pooled entropy. If you care about + statistical randomness then increasing this will help (but will + slow things down). + cache_size (int): How big a cache of recent calculations to use. + (As new calculations depend on old calculations, this speeds + things up dramatically, rather than recomputing things.) + Set this to `None` to use an infinite cache, which will be fast + but memory inefficient. + halfway_tree (bool): Whether the dependency tree (the internal data + structure) should be the dyadic tree. Defaults to `False`. + Normally, the sample path is determined by both `entropy`, + _and_ the locations and order of the query points. Setting this + to `True` will make it deterministic with respect to just + `entropy`; however this is much slower. + W (Tensor): The increment of the Brownian motion over the interval + [t0, t1]. Will be generated randomly if not provided. + H (Tensor): The space-time Levy area of the Brownian motion over the + interval [t0, t1]. Will be generated randomly if not provided. + """ + + ##################################### + # Check and normalise inputs # + ##################################### + + if not _is_scalar(t0): + raise ValueError('Initial time t0 should be a float or 0-d torch.Tensor.') + if not _is_scalar(t1): + raise ValueError('Terminal time t1 should be a float or 0-d torch.Tensor.') + if dt is not None and not _is_scalar(dt): + raise ValueError('Expected average time step dt should be a float or 0-d torch.Tensor.') + + if t0 > t1: + raise ValueError(f'Initial time {t0} should be less than terminal time {t1}.') + t0 = float(t0) + t1 = float(t1) + if dt is not None: + dt = float(dt) + + if halfway_tree: + if tol <= 0.: + raise ValueError("`tol` should be positive.") + if dt is not None: + raise ValueError("`dt` is not used and should be set to `None` if `halfway_tree` is True.") + else: + if tol < 0.: + raise ValueError("`tol` should be non-negative.") + + size, dtype, device = _check_tensor_info(W, H, size=size, dtype=dtype, device=device) + + # Let numpy dictate randomness, so we have fewer seeds to set for reproducibility. + if entropy is None: + entropy = np.random.randint(0, 2 ** 31 - 1) + + if levy_area_approximation not in LEVY_AREA_APPROXIMATIONS: + raise ValueError(f"`levy_area_approximation` must be one of {LEVY_AREA_APPROXIMATIONS}, but got " + f"'{levy_area_approximation}'.") + + ##################################### + # Record inputs # + ##################################### + + self._size = size + self._dtype = dtype + self._device = device + self._entropy = entropy + self._levy_area_approximation = levy_area_approximation + self._dt = dt + self._tol = tol + self._pool_size = pool_size + self._cache_size = cache_size + self._halfway_tree = halfway_tree + + ##################################### + # A miscellany of other things # + ##################################### + + # We keep a cache of recent queries, and their results. This is very important for speed, so that we don't + # recurse all the way up to the top every time we have a query. + if cache_size is None: + self._increment_and_space_time_levy_area_cache = {} + elif cache_size == 0: + self._increment_and_space_time_levy_area_cache = _EmptyDict() + else: + self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size) + + # We keep track of the most recently queried interval, and start searching for the next interval from that + # element of the binary tree. This is because subsequent queries are likely to be near the most recent query. + self._last_interval = self + + # Precompute these as we don't want to spend lots of time checking strings in hot loops. + self._have_H = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.space_time, + LEVY_AREA_APPROXIMATIONS.davie, + LEVY_AREA_APPROXIMATIONS.foster) + self._have_A = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.davie, + LEVY_AREA_APPROXIMATIONS.foster) + + # If we like we can quantise what level we want to compute the Brownian motion to. + if math.isclose(tol, 0., rel_tol=1e-5): + def round_func(x): + return x + else: + ndigits = -int(math.log10(tol)) + + def round_func(x): + return round(x, ndigits) + + self._round = round_func + + # Initalise as _Interval. + # (Must come after _round but before _w_h) + super(BrownianInterval, self).__init__(start=t0, + end=t1, + parent=None, + is_left=None, + top=self) + + # Set the global increment and space-time Levy area + generator = np.random.SeedSequence(entropy=entropy, pool_size=pool_size) + initial_W_seed, initial_H_seed, top_a_seed = generator.generate_state(3) + if W is None: + W = self._randn(initial_W_seed) * math.sqrt(t1 - t0) + else: + _assert_floating_tensor('W', W) + if H is None: + H = self._randn(initial_H_seed) * math.sqrt((t1 - t0) / 12) + else: + _assert_floating_tensor('H', H) + self._w_h = (W, H) + self._top_a_seed = top_a_seed + + if not self._halfway_tree: + # We create a binary tree dependency between the points. If we don't do this then the forward pass is still + # efficient at O(N), but we end up with a dependency chain stretching along the interval [t0, t1], making + # the backward pass O(N^2). By setting up a dependency tree of depth relative to `dt` and `cache_size` we + # can instead make both directions O(N log N). + self._average_dt = 0 + self._tree_dt = t1 - t0 + self._num_evaluations = -100 # start off with a warmup period to get a decent estimate of the average + if dt is not None: + # Create the dependency tree based on the supplied hint `dt`. + self._create_dependency_tree(dt) + # If dt is None, then create the dependency tree based on observed statistics of query points. (In __call__) + + # Effectively permanently store our increment and space-time Levy area in the cache. + def _increment_and_space_time_levy_area(self): + return self._w_h + yield # make it a generator + + def _a_seed(self): + return self._top_a_seed + + def _set_spawn_key_and_depth(self): + self._spawn_key = 0 + self._depth = 0 + + def __call__(self, ta, tb=None, return_U=False, return_A=False): + if tb is None: + warnings.warn(f"{self.__class__.__name__} is optimised for interval-based queries, not point evaluation.") + ta, tb = self._start, ta + tb_name = 'ta' + else: + tb_name = 'tb' + ta = float(ta) + tb = float(tb) + if ta < self._start: + warnings.warn(f"Should have ta>=t0 but got ta={ta} and t0={self._start}.") + ta = self._start + if tb < self._start: + warnings.warn(f"Should have {tb_name}>=t0 but got {tb_name}={tb} and t0={self._start}.") + tb = self._start + if ta > self._end: + warnings.warn(f"Should have ta<=t1 but got ta={ta} and t1={self._end}.") + ta = self._end + if tb > self._end: + warnings.warn(f"Should have {tb_name}<=t1 but got {tb_name}={tb} and t1={self._end}.") + tb = self._end + if ta > tb: + raise RuntimeError(f"Query times ta={ta:.3f} and tb={tb:.3f} must respect ta <= tb.") + + if ta == tb: + W = torch.zeros(self._size, dtype=self._dtype, device=self._device) + H = None + A = None + if self._have_H: + H = torch.zeros(self._size, dtype=self._dtype, device=self._device) + if self._have_A: + size = (*self._size, *self._size[-1:]) # not self._size[-1] as that may not exist + A = torch.zeros(size, dtype=self._dtype, device=self._device) + else: + if self._dt is None and not self._halfway_tree: + self._num_evaluations += 1 + # We start off with "negative" num evaluations, to give us a small warm-up period at the start. + if self._num_evaluations > 0: + # Compute average step size so far + dt = tb - ta + self._average_dt = (dt + self._average_dt * (self._num_evaluations - 1)) / self._num_evaluations + if self._average_dt < 0.5 * self._tree_dt: + # If 'dt' wasn't specified, then check the average interval length against the size of the + # bottom of the dependency tree. If we're below halfway then refine the tree by splitting all + # the bottom pieces into two. + self._create_dependency_tree(dt) + + # Find the intervals that correspond to the query. We start our search at the last interval we accessed in + # the binary tree, as it's likely that the next query will come nearby. + intervals = self._last_interval._loc(ta, tb) + # Ideally we'd keep track of intervals[0] on the backward pass. Practically speaking len(intervals) tends to + # be 1 or 2 almost always so this isn't a huge deal. + self._last_interval = intervals[-1] + + W, H, A = intervals[0]._increment_and_levy_area() + if len(intervals) > 1: + # If we have multiple intervals then add up their increments and Levy areas. + + for interval in intervals[1:]: + Wi, Hi, Ai = interval._increment_and_levy_area() + if self._have_H: + # Aggregate H: + # Given s < u < t, then + # H_{s,t} = (term1 + term2) / (t - s) + # where + # term1 = (t - u) * (H_{u, t} + W_{s, u} / 2) + # term2 = (u - s) * (H_{s, u} - W_{u, t} / 2) + term1 = (interval._end - interval._start) * (Hi + 0.5 * W) + term2 = (interval._start - ta) * (H - 0.5 * Wi) + H = (term1 + term2) / (interval._end - ta) + if self._have_A and len(self._size) not in (0, 1): + # If len(self._size) in (0, 1) then we treat our scalar / single dimension as a batch + # dimension, so we have zero Levy area. (And these unsqueezes will result in a tensor of shape + # (batch, batch) which is wrong.) + + # Let B_{x, y} = \int_x^y W^1_{s,u} dW^2_u. + # Then + # B_{s, t} = \int_s^t W^1_{s,u} dW^2_u + # = \int_s^v W^1_{s,u} dW^2_u + \int_v^t W^1_{s,v} dW^2_u + \int_v^t W^1_{v,u} dW^2_u + # = B_{s, v} + W^1_{s, v} W^2_{v, t} + B_{v, t} + # + # A is now the antisymmetric part of B, which gives the formula below. + A = A + Ai + 0.5 * (W.unsqueeze(-1) * Wi.unsqueeze(-2) - Wi.unsqueeze(-1) * W.unsqueeze(-2)) + W = W + Wi + + U = None + if self._have_H: + U = _H_to_U(W, H, tb - ta) + + if return_U: + if return_A: + return W, U, A + else: + return W, U + else: + if return_A: + return W, A + else: + return W + + def _create_dependency_tree(self, dt): + # For safety we take a min with 100: if people take very large cache sizes then this would then break the + # logarithmic into linear, which causes RecursionErrors. + if self._cache_size is None: # cache_size=None corresponds to infinite cache. + cache_size = 100 + else: + cache_size = min(self._cache_size, 100) + + self._tree_dt = min(self._tree_dt, dt) + # Rationale: We are prepared to hold `cache_size` many things in memory, so when making steps of size `dt` + # then we can afford to have the intervals at the bottom of our binary tree be of size `dt * cache_size`. + # For safety we then make this a bit smaller by multiplying by 0.8. + piece_length = self._tree_dt * cache_size * 0.8 + + def _set_points(interval): + start = interval._start + end = interval._end + if end - start > piece_length: + midway = (end + start) / 2 + interval._loc(start, midway) + _set_points(interval._left_child) + _set_points(interval._right_child) + + _set_points(self) + + def __repr__(self): + if self._dt is None: + dt = None + else: + dt = f"{self._dt:.3f}" + return (f"{self.__class__.__name__}(" + f"t0={self._start:.3f}, " + f"t1={self._end:.3f}, " + f"size={self._size}, " + f"dtype={self._dtype}, " + f"device={repr(self._device)}, " + f"entropy={self._entropy}, " + f"dt={dt}, " + f"tol={self._tol}, " + f"pool_size={self._pool_size}, " + f"cache_size={self._cache_size}, " + f"levy_area_approximation={repr(self._levy_area_approximation)}" + f")") + + def display_binary_tree(self): + stack = [(self, 0)] + out = [] + while stack: + elem, depth = stack.pop() + out.append(" " * depth + f"({elem._start}, {elem._end})") + if elem._midway is not None: + stack.append((elem._right_child, depth + 1)) + stack.append((elem._left_child, depth + 1)) + print("\n".join(out)) + + @property + def shape(self): + return self._size + + @property + def dtype(self): + return self._dtype + + @property + def device(self): + return self._device + + @property + def entropy(self): + return self._entropy + + @property + def levy_area_approximation(self): + return self._levy_area_approximation + + @property + def dt(self): + return self._dt + + @property + def tol(self): + return self._tol + + @property + def pool_size(self): + return self._pool_size + + @property + def cache_size(self): + return self._cache_size + + @property + def halfway_tree(self): + return self._halfway_tree + + def size(self): + return self._size + + +class BrownianTree(brownian_base.BaseBrownian): + """Brownian tree with fixed entropy. + + Useful when the map from entropy -> Brownian motion shouldn't depend on the + locations and order of the query points. (As the usual BrownianInterval + does - note that BrownianTree is slower as a result though.) + + To use: + >>> bm = BrownianTree(t0=0.0, w0=torch.zeros(4, 1)) + >>> bm(0., 0.5) + tensor([[ 0.0733], + [-0.5692], + [ 0.1872], + [-0.3889]], device='cuda:0') + """ + + def __init__(self, t0: Scalar, + w0: Tensor, + t1: Optional[Scalar] = None, + w1: Optional[Tensor] = None, + entropy: Optional[int] = None, + tol: float = 1e-6, + pool_size: int = 24, + cache_depth: int = 9, + safety: Optional[float] = None): + """Initialize the Brownian tree. + + The random value generation process exploits the parallel random number paradigm and uses + `numpy.random.SeedSequence`. The default generator is PCG64 (used by `default_rng`). + + Arguments: + t0: Initial time. + w0: Initial state. + t1: Terminal time. + w1: Terminal state. + entropy: Global seed, defaults to `None` for random entropy. + tol: Error tolerance before the binary search is terminated; the search depth ~ log2(tol). + pool_size: Size of the pooled entropy. This parameter affects the query speed significantly. + cache_depth: Unused; deprecated. + safety: Unused; deprecated. + """ + + if t1 is None: + t1 = t0 + 1 + if w1 is None: + W = None + else: + W = w1 - w0 + self._w0 = w0 + self._interval = BrownianInterval(t0=t0, + t1=t1, + size=w0.shape, + dtype=w0.dtype, + device=w0.device, + entropy=entropy, + tol=tol, + pool_size=pool_size, + halfway_tree=True, + W=W) + super(BrownianTree, self).__init__() + + def __call__(self, t, tb=None, return_U=False, return_A=False): + # Deliberately called t rather than ta, for backward compatibility + out = self._interval(t, tb, return_U=return_U, return_A=return_A) + if tb is None and not return_U and not return_A: + out = out + self._w0 + return out + + def __repr__(self): + return f"{self.__class__.__name__}(interval={self._interval})" + + @property + def dtype(self): + return self._interval.dtype + + @property + def device(self): + return self._interval.device + + @property + def shape(self): + return self._interval.shape + + @property + def levy_area_approximation(self): + return self._interval.levy_area_approximation + + +def brownian_interval_like(y: Tensor, + t0: Optional[Scalar] = 0., + t1: Optional[Scalar] = 1., + size: Optional[Tuple[int, ...]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[str, torch.device]] = None, + **kwargs): + """Returns a BrownianInterval object with the same size, device, and dtype as a given tensor.""" + size = y.shape if size is None else size + dtype = y.dtype if dtype is None else dtype + device = y.device if device is None else device + return brownian_interval.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, **kwargs) \ No newline at end of file -- Gitee From 99e000a7c23f5bdc51c479a0fbd29fbd0442df7a Mon Sep 17 00:00:00 2001 From: zhoufan2956 Date: Fri, 18 Apr 2025 16:53:42 +0800 Subject: [PATCH 3/3] add_comfyui --- .../Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py index 42d5735145..91aa35eac8 100644 --- a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py +++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py @@ -36,7 +36,6 @@ class UnetAdapter(torch.nn.Module): self.device = comfy.model_management.get_torch_device() self.dtype = torch.float16 - print("os.path.dirname(engine_path)",os.path.dirname(engine_path)) self.unet_model = StableAudioDiTModel.from_pretrained(os.path.dirname(engine_path), torch_dtype=self.dtype) transformer = self.unet_model @@ -92,7 +91,6 @@ class StableAudioUnetLoader: @classmethod def INPUT_TYPES(s): return {"required": {"unet_name": (folder_paths.get_filename_list("mindiesd"), {"tooltip": "The name of the transformer (model) to load."}), - # "ckpt_path": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), "use_attentioncache": ("BOOLEAN", {"default": True, "tooltip": "Use AGB Cache."}), "start_step": ("INT", {"default": 60, "min": 10, "max": 10000, "tooltip": ""}), "attentioncache_interval": ("INT", {"default": 5, "min": 1, "max": 10, "tooltip": ""}), -- Gitee