diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/.gitattributes b/MindIE/MultiModal/Stable-Audio-Open-1.0/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/.gitattributes
@@ -0,0 +1,35 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md b/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2c094cfee21e3a0f9adc579d78ae900ec9d0c58f
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/README.md
@@ -0,0 +1,154 @@
+## 一、准备运行环境
+
+ **表 1** 版本配套表
+
+ | 配套 | 版本 | 环境准备指导 |
+ | ----- | ----- |-----|
+ | Python | 3.10.2 | - |
+ | torch | 2.1.0 | - |
+
+### 1.1 获取CANN&MindIE安装包&环境准备
+- 设备支持
+Atlas 800I A2推理设备:支持的卡数为1
+Atlas 300I Duo推理卡:支持的卡数为1
+- [Atlas 800I A2](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=4&model=32)
+- [Atlas 300I Duo](https://www.hiascend.com/developer/download/community/result?module=pt+ie+cann&product=2&model=17)
+- [环境准备指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/softwareinst/instg/instg_0001.html)
+
+### 1.2 CANN安装
+```shell
+# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构,{soc}表示昇腾AI处理器的版本。
+chmod +x ./Ascend-cann-toolkit_{version}_linux-{arch}.run
+chmod +x ./Ascend-cann-kernels-{soc}_{version}_linux.run
+# 校验软件包安装文件的一致性和完整性
+./Ascend-cann-toolkit_{version}_linux-{arch}.run --check
+./Ascend-cann-kernels-{soc}_{version}_linux.run --check
+# 安装
+./Ascend-cann-toolkit_{version}_linux-{arch}.run --install
+./Ascend-cann-kernels-{soc}_{version}_linux.run --install
+
+# 设置环境变量
+source /usr/local/Ascend/ascend-toolkit/set_env.sh
+```
+
+
+### 1.3 MindIE安装
+```shell
+# 增加软件包可执行权限,{version}表示软件版本号,{arch}表示CPU架构。
+chmod +x ./Ascend-mindie_${version}_linux-${arch}.run
+./Ascend-mindie_${version}_linux-${arch}.run --check
+
+# 方式一:默认路径安装
+./Ascend-mindie_${version}_linux-${arch}.run --install
+# 设置环境变量
+cd /usr/local/Ascend/mindie && source set_env.sh
+
+# 方式二:指定路径安装
+./Ascend-mindie_${version}_linux-${arch}.run --install-path=${AieInstallPath}
+# 设置环境变量
+cd ${AieInstallPath}/mindie && source set_env.sh
+```
+
+### 1.4 Torch_npu安装
+安装pytorch框架 版本2.1.0
+[安装包下载](https://download.pytorch.org/whl/cpu/torch/)
+
+使用pip安装
+```shell
+# {version}表示软件版本号,{arch}表示CPU架构。
+pip install torch-${version}-cp310-cp310-linux_${arch}.whl
+```
+下载 pytorch_v{pytorchversion}_py{pythonversion}.tar.gz
+```shell
+tar -xzvf pytorch_v{pytorchversion}_py{pythonversion}.tar.gz
+# 解压后,会有whl包
+pip install torch_npu-{pytorchversion}.xxxx.{arch}.whl
+```
+
+### 1.5 安装mindspeed
+```shell
+# 下载mindspeed源码仓:
+git clone https://gitee.com/ascend/MindSpeed.git
+# 执行如下命令进行安装:
+pip install -e MindSpeed
+```
+
+## 二、下载本仓库
+
+### 2.1 下载到本地
+```shell
+git clone https://modelers.cn/MindIE/stable_audio_open_1.0.git
+```
+
+### 2.2 依赖安装
+```bash
+pip3 install -r requirements.txt
+apt-get update
+apt-get install libsndfile1
+```
+
+## 三、Stable-Audio-Open-1.0 使用
+
+### 3.1 权重及配置文件说明
+stable-audio-open-1.0权重链接:
+```shell
+https://huggingface.co/stabilityai/stable-audio-open-1.0/tree/main
+```
+
+### 3.2 单卡功能测试
+设置权重路径
+```shell
+model_base='./stable-audio-open-1.0'
+```
+执行命令:
+```shell
+# 不使用DiTCache策略
+python3 inference_stableaudio.py \
+ --model ${model_base} \
+ --prompt_file ./prompts/prompts.txt \
+ --num_inference_steps 100 \
+ --audio_end_in_s 10 10 47 \
+ --save_dir ./results \
+ --seed 1 \
+ --device 0
+
+# 使用DiTCache策略
+python3 inference_stableaudio.py \
+ --model ${model_base} \
+ --prompt_file ./prompts/prompts.txt \
+ --num_inference_steps 100 \
+ --audio_end_in_s 10 10 47 \
+ --save_dir ./results \
+ --seed 1 \
+ --device 0 \
+ --use_cache
+```
+参数说明:
+- --model:模型权重路径。
+- --prompt_file:提示词文件。
+- --num_inference_steps: 语音生成迭代次数。
+- --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。
+- --save_dir:生成语音的存放目录。
+- --seed:设置随机种子,不指定时默认使用随机种子。
+- --device:推理设备ID。
+- --use_cache: 【可选】使用DiTCache策略。
+
+执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。
+
+### 3.2 模型推理性能
+
+性能参考下列数据。
+
+| 硬件形态 | 迭代次数 | 平均耗时(w/o DiTCache)| 平均耗时(with DiTCache)|
+| :------: |:----:|:----:|:----:|
+| Atlas 800I A2(8*32G) | 100 | 5.645s | 4.991s |
+
+## 五、优化指南
+本模型使用的优化手段如下:
+- 等价优化:FA、RoPE、Linear
+- 算法优化:DiTcache
+
+
+## 声明
+- 本代码仓提到的数据集和模型仅作为示例,这些数据集和模型仅供您用于非商业目的,如您使用这些数据集和模型来完成示例,请您特别注意应遵守对应数据集和模型的License,如您因使用数据集或模型而产生侵权纠纷,华为不承担任何责任。
+- 如您在使用本代码仓的过程中,发现任何问题(包括但不限于功能问题、合规问题),请在本代码仓提交issue,我们将及时审视并解答。
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98c2e997a24c2774b358e37e928f9e0bf707fe2a
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/__init__.py
@@ -0,0 +1,3 @@
+from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
+
+all = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py
new file mode 100644
index 0000000000000000000000000000000000000000..91aa35eac853594a7752879b637fbedf05db831d
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/node.py
@@ -0,0 +1,131 @@
+#Put this in the custom_nodes folder, put your mindiesd files in ComfyUI/models/mindiesd/ (you will have to create the directory)
+
+import torch
+import torch_npu
+import os
+import json
+from safetensors.torch import load_file
+
+import comfy.model_base
+import comfy.model_management
+import comfy.model_patcher
+import comfy.supported_models
+import comfy.supported_models_base
+import comfy.sd
+from comfy import utils
+import folder_paths
+from folder_paths import get_folder_paths, get_filename_list_
+
+from .stableaudio import (
+ StableAudioPipeline,
+ StableAudioDiTModel,
+ AutoencoderOobleck,
+)
+from mindiesd import CacheConfig, CacheAgent
+from diffusers.models.embeddings import get_1d_rotary_pos_embed
+
+
+MODEL_FOLDER_PATH = folder_paths.get_folder_paths("mindiesd")[0]
+PIPELINE_FOLDER_PATH = os.path.join(MODEL_FOLDER_PATH, "stable-audio-open-1.0")
+UNET_FOLDER_PATH = os.path.join(PIPELINE_FOLDER_PATH, "transformer")
+
+
+class UnetAdapter(torch.nn.Module):
+ def __init__(self, engine_path, use_attentioncache, start_step, attentioncache_interval, end_step):
+ super().__init__()
+
+ self.device = comfy.model_management.get_torch_device()
+ self.dtype = torch.float16
+ self.unet_model = StableAudioDiTModel.from_pretrained(os.path.dirname(engine_path),
+ torch_dtype=self.dtype)
+ transformer = self.unet_model
+ if use_attentioncache:
+ config = CacheConfig(
+ method="attention_cache",
+ blocks_count=len(transformer.transformer_blocks),
+ steps_count=100,
+ step_start=start_step,
+ step_interval=attentioncache_interval,
+ step_end=end_step
+ )
+ else:
+ config = CacheConfig(
+ method="attention_cache",
+ blocks_count=len(transformer.transformer_blocks),
+ steps_count=100
+ )
+ cache = CacheAgent(config)
+ for block in transformer.transformer_blocks:
+ block.cache = cache
+
+ # self.unet_model = self.unet_model.eval()
+ self.unet_model = self.unet_model.to(self.device).eval()
+ self.unet_model.config.addition_embed_type = None
+
+ # freqs
+ rotary_embedding = get_1d_rotary_pos_embed(
+ 32,
+ 1025,
+ use_real=True,
+ repeat_interleave_real=False,
+ )
+ cos = rotary_embedding[0][None, :, None, :].to(self.device).to(self.dtype)
+ sin = rotary_embedding[1][None, :, None, :].to(self.device).to(self.dtype)
+ self.rotary_embedding = (cos, sin)
+
+ def __call__(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
+ global_embed = kwargs.get("global_embed", None)
+ noise_pred = self.unet_model(
+ 0,
+ x,
+ timesteps,
+ context,
+ global_embed.unsqueeze(1),
+ rotary_embedding=self.rotary_embedding,
+ return_dict=False,
+ )[0]
+ return noise_pred
+
+
+class StableAudioUnetLoader:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"unet_name": (folder_paths.get_filename_list("mindiesd"), {"tooltip": "The name of the transformer (model) to load."}),
+ "use_attentioncache": ("BOOLEAN", {"default": True, "tooltip": "Use AGB Cache."}),
+ "start_step": ("INT", {"default": 60, "min": 10, "max": 10000, "tooltip": ""}),
+ "attentioncache_interval": ("INT", {"default": 5, "min": 1, "max": 10, "tooltip": ""}),
+ "end_step": ("INT", {"default": 97, "min": 1, "max": 10000, "tooltip": ""}),
+ }}
+ CATEGORY = "MindIE SD/StableAudio"
+
+ def load_unet(self, unet_name, use_attentioncache, start_step, attentioncache_interval, end_step):
+
+ unet_path = os.path.join(MODEL_FOLDER_PATH, unet_name)
+ unet = UnetAdapter(unet_path, use_attentioncache, start_step, attentioncache_interval, end_step)
+ conf = comfy.supported_models.StableAudio( {"audio_model": "dit1.0"})
+ conf.unet_config["disable_unet_model_creation"] = True
+ conf.unet_config = {
+ "audio_model": "dit1.0",
+ }
+
+ project_path = os.path.join("projection_model", "diffusion_pytorch_model.safetensors")
+ project_model=load_file(os.path.join(PIPELINE_FOLDER_PATH, project_path))
+ seconds_start_embedder_weights = utils.state_dict_prefix_replace(project_model, {"start_number_conditioner.time_positional_": "embedder."}, filter_keys=True)
+ seconds_total_embedder_weights = utils.state_dict_prefix_replace(project_model, {"end_number_conditioner.time_positional_": "embedder."}, filter_keys=True)
+
+ model = comfy.model_base.StableAudio1(conf, seconds_start_embedder_weights, seconds_total_embedder_weights)
+
+ model.diffusion_model = unet
+ model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
+
+ return (comfy.model_patcher.ModelPatcher(model,
+ load_device=comfy.model_management.get_torch_device(),
+ offload_device=comfy.model_management.unet_offload_device()),)
+
+NODE_CLASS_MAPPINGS = {
+"StableAudioUnetLoader": StableAudioUnetLoader,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+"StableAudioUnetLoader": "MindIE SD StableAudio Unet Loader",
+}
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/requirements.txt b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d9137f6285f389460aaf1aaad271b09e45747618
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/requirements.txt
@@ -0,0 +1,6 @@
+torch==2.1.0
+torchsde==0.2.6
+diffusers==0.30.0
+transformers==4.40.0
+soundfile==0.12.1
+safetensors
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f484d3fcebf98ab394508963b30f278134d7c236
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/__init__.py
@@ -0,0 +1,4 @@
+from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck
+from .pipeline import StableAudioPipeline, StableAudioProjectionModel
+from .models import StableAudioDiTModel, ModelMixin
+from .schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler, SchedulerMixin
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2f044203d9a52527a303066545740aaea80cc49
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/__init__.py
@@ -0,0 +1,4 @@
+from .attention import (
+ Attention,
+ StableAudioAttnProcessor2_0,
+)
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/attention.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..aec5aa7ca009691d8e2b5c15d46ab0afa5fab4d5
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/attention.py
@@ -0,0 +1,373 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch_npu
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import logging
+from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
+from mindiesd import attention_forward, rotary_position_embedding
+from .linear import QKVLinear
+
+logger = logging.get_logger(__name__)
+
+
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ kv_heads (`int`, *optional*, defaults to `None`):
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
+ Query Attention (MQA) otherwise GQA is used.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ out_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ ):
+ super().__init__()
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.dim_head = dim_head
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ self.spatial_norm = None
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
+ else:
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_qkv = QKVLinear(
+ attention_dim=query_dim,
+ hidden_size=self.inner_dim,
+ qkv_bias=self.use_bias,
+ cross_attention_dim=cross_attention_dim,
+ cross_hidden_size=self.inner_kv_dim,
+ )
+
+ self.added_proj_bias = added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+
+ if qk_norm is not None and added_kv_proj_dim is not None:
+ if qk_norm == "fp32_layer_norm":
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor") -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+
+class StableAudioAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ query, key, value = attn.to_qkv(hidden_states, encoder_hidden_states)
+
+ head_dim = attn.dim_head
+ kv_heads = attn.inner_kv_dim // head_dim
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, kv_heads, head_dim)
+ value = value.view(batch_size, -1, kv_heads, head_dim)
+
+ if kv_heads != attn.heads:
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
+ heads_per_kv_head = attn.heads // kv_heads
+ key = key.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1)
+ key = key.view(batch_size, sequence_length, attn.heads, head_dim)
+ value = value.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1)
+ value = value.view(batch_size, sequence_length, attn.heads, head_dim)
+
+ # Apply RoPE if needed
+ if rotary_emb is not None:
+ cos, sin = rotary_emb
+ query_to_rotate, query_unrotated = torch.chunk(query, 2, 3)
+ query_rotated = rotary_position_embedding(query_to_rotate, cos, sin, rotated_mode="rotated_half", fused=True)
+ query = torch.cat((query_rotated, query_unrotated), dim=-1)
+
+ if not attn.is_cross_attention:
+ key_to_rotate, key_unrotated = torch.chunk(key, 2, 3)
+ key_rotated = rotary_position_embedding(key_to_rotate, cos, sin, rotated_mode="rotated_half", fused=True)
+ key = torch.cat((key_rotated, key_unrotated), dim=-1)
+
+ hidden_states = attention_forward(
+ query, key, value,
+ opt_mode="manual",
+ op_type="fused_attn_score",
+ layout="BSND"
+ )
+
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/linear.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e8ee84fc878be54988022ef49e5e8649b3f498c
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/layers/linear.py
@@ -0,0 +1,99 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+import torch_npu
+
+
+class QKVLinear(nn.Module):
+ def __init__(self, attention_dim, hidden_size, qkv_bias=True, cross_attention_dim=None, cross_hidden_size=None,
+ device=None, dtype=None):
+ super(QKVLinear, self).__init__()
+ self.attention_dim = attention_dim
+ self.hidden_size = hidden_size
+
+ self.cross_attention_dim = cross_attention_dim
+ self.cross_hidden_size = self.hidden_size if cross_hidden_size is None else cross_hidden_size
+ self.qkv_bias = qkv_bias
+
+ factory_kwargs = {"device": device, "dtype": dtype}
+
+ if cross_attention_dim is None:
+ self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs))
+ if self.qkv_bias:
+ self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs))
+ else:
+ self.q_weight = nn.Parameter(torch.empty([self.attention_dim, self.hidden_size], **factory_kwargs))
+ self.kv_weight = nn.Parameter(
+ torch.empty([self.cross_attention_dim, 2 * self.cross_hidden_size], **factory_kwargs))
+
+ if self.qkv_bias:
+ self.q_bias = nn.Parameter(torch.empty([self.hidden_size], **factory_kwargs))
+ self.kv_bias = nn.Parameter(torch.empty([2 * self.cross_hidden_size], **factory_kwargs))
+
+ def forward(self, hidden_states, encoder_hidden_states=None):
+
+ if self.cross_attention_dim is None:
+ if not self.qkv_bias:
+ qkv = torch.matmul(hidden_states, self.weight)
+ else:
+ qkv = torch.addmm(
+ self.bias,
+ hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)),
+ self.weight,
+ beta=1,
+ alpha=1
+ )
+
+ batch, seqlen, _ = hidden_states.shape
+ qkv_shape = (batch, seqlen, 3, -1)
+ qkv = qkv.view(qkv_shape)
+ q, k, v = qkv.unbind(2)
+
+ else:
+ if not self.qkv_bias:
+ q = torch.matmul(hidden_states, self.q_weight)
+ kv = torch.matmul(encoder_hidden_states, self.kv_weight)
+ else:
+ q = torch.addmm(
+ self.q_bias,
+ hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)),
+ self.q_weight,
+ beta=1,
+ alpha=1
+ )
+ kv = torch.addmm(
+ self.kv_bias,
+ encoder_hidden_states.view(
+ encoder_hidden_states.size(0) * encoder_hidden_states.size(1),
+ encoder_hidden_states.size(2)),
+ self.kv_weight,
+ beta=1,
+ alpha=1
+ )
+
+ batch, q_seqlen, _ = hidden_states.shape
+ q = q.view(batch, q_seqlen, -1)
+
+ batch, kv_seqlen, _ = encoder_hidden_states.shape
+ kv_shape = (batch, kv_seqlen, 2, -1)
+
+ kv = kv.view(kv_shape)
+ k, v = kv.unbind(2)
+
+ return q, k, v
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4170f5af4c06c05f2ac5f6059a165c3e497f7eb
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/__init__.py
@@ -0,0 +1 @@
+from .stable_audio_transformer import StableAudioDiTModel, ModelMixin
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/stable_audio_transformer.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/stable_audio_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb06b2adecc79d8118ad4ba3a51bfc01b551b63
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/models/stable_audio_transformer.py
@@ -0,0 +1,431 @@
+# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Union
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention import FeedForward
+from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import logging
+
+from ..layers.attention import (
+ Attention,
+ StableAudioAttnProcessor2_0,
+)
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableAudioGaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__
+ def __init__(
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
+ ):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.log = log
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ if set_W_to_weight:
+ # to delete later
+ del self.weight
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.weight = self.W
+ del self.W
+
+ def forward(self, x):
+ if self.log:
+ x = torch.log(x)
+
+ x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :]
+
+ if self.flip_sin_to_cos:
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+ else:
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
+
+
+class StableAudioDiTBlock(nn.Module):
+ r"""
+ Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip
+ connection and QKNorm
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for the query states.
+ num_key_value_attention_heads (`int`): The number of heads to use for the key and value states.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ num_key_value_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ upcast_attention: bool = False,
+ norm_eps: float = 1e-5,
+ ff_inner_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=False,
+ upcast_attention=upcast_attention,
+ out_bias=False,
+ processor=StableAudioAttnProcessor2_0(),
+ )
+
+ # 2. Cross-Attn
+ self.norm2 = nn.LayerNorm(dim, norm_eps, True)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ kv_heads=num_key_value_attention_heads,
+ dropout=dropout,
+ bias=False,
+ upcast_attention=upcast_attention,
+ out_bias=False,
+ processor=StableAudioAttnProcessor2_0(),
+ ) # is self-attn if encoder_hidden_states is none
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, norm_eps, True)
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn="swiglu",
+ final_dropout=False,
+ inner_dim=ff_inner_dim,
+ bias=True,
+ )
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ self.cache = None
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ rotary_embedding: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.cache.apply(
+ self.attn1,
+ norm_hidden_states,
+ attention_mask=attention_mask,
+ rotary_emb=rotary_embedding,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention
+ norm_hidden_states = self.norm2(hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+class StableAudioDiTModel(ModelMixin, ConfigMixin):
+ """
+ The Diffusion Transformer model introduced in Stable Audio.
+
+ Reference: https://github.com/Stability-AI/stable-audio-tools
+
+ Parameters:
+ sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 64): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states.
+ num_key_value_attention_heads (`int`, *optional*, defaults to 12):
+ The number of heads to use for the key and value states.
+ out_channels (`int`, defaults to 64): Number of output channels.
+ cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection.
+ time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection.
+ global_states_input_dim ( `int`, *optional*, defaults to 1536):
+ Input dimension of the global hidden states projection.
+ cross_attention_input_dim ( `int`, *optional*, defaults to 768):
+ Input dimension of the cross-attention projection
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 1024,
+ in_channels: int = 64,
+ num_layers: int = 24,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 24,
+ num_key_value_attention_heads: int = 12,
+ out_channels: int = 64,
+ cross_attention_dim: int = 768,
+ time_proj_dim: int = 256,
+ global_states_input_dim: int = 1536,
+ cross_attention_input_dim: int = 768,
+ ):
+ super().__init__()
+
+ self.cache_block_start = 11
+ self.cache_step_interval = 2
+ self.cache_num_blocks = 9
+ self.cache_step_start = 5
+
+ self.num_layers = num_layers
+
+ self.sample_size = sample_size
+ self.out_channels = out_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.init_dtype = self.dtype
+ self.time_proj = StableAudioGaussianFourierProjection(
+ embedding_size=time_proj_dim // 2,
+ flip_sin_to_cos=True,
+ log=False,
+ set_W_to_weight=False,
+ )
+
+ self.timestep_proj = nn.Sequential(
+ nn.Linear(time_proj_dim, self.inner_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(self.inner_dim, self.inner_dim, bias=True),
+ )
+
+ self.global_proj = nn.Sequential(
+ nn.Linear(global_states_input_dim, self.inner_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(self.inner_dim, self.inner_dim, bias=False),
+ )
+
+ self.cross_attention_proj = nn.Sequential(
+ nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(cross_attention_dim, cross_attention_dim, bias=False),
+ )
+
+ self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False)
+ self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ StableAudioDiTBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ num_key_value_attention_heads=num_key_value_attention_heads,
+ attention_head_dim=attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False)
+ self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ step_id,
+ hidden_states: torch.FloatTensor,
+ timestep: torch.LongTensor = None,
+ encoder_hidden_states: torch.FloatTensor = None,
+ global_hidden_states: torch.FloatTensor = None,
+ rotary_embedding: torch.FloatTensor = None,
+ return_dict: bool = True,
+ attention_mask: Optional[torch.LongTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = False,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`StableAudioDiTModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`):
+ Input `hidden_states`.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`):
+ Global embeddings that will be prepended to the hidden states.
+ rotary_embedding (`torch.Tensor`):
+ The rotary embeddings to apply on query and key tensors during attention calculation.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
+ Mask to avoid performing attention on padding token indices, formed by concatenating the attention
+ masks
+ for the two text encoders together. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
+ Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating
+ the attention masks
+ for the two text encoders together. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states)
+ global_hidden_states = self.global_proj(global_hidden_states)
+ time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.init_dtype)))
+
+ global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1)
+
+ hidden_states = self.preprocess_conv(hidden_states) + hidden_states
+ # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = self.proj_in(hidden_states)
+
+ # prepend global states to hidden states
+ hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2)
+ if attention_mask is not None:
+ prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool)
+ attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
+
+ if not use_cache or (use_cache and step_id < self.cache_step_start):
+ hidden_states = self._transformer_blocks_forward(
+ hidden_states=hidden_states,
+ encoder_hidden_states=cross_attention_hidden_states,
+ rotary_embedding=rotary_embedding,
+ start_id=0,
+ end_id=self.num_layers,
+ )
+ else:
+ hidden_states = self._transformer_blocks_forward(
+ hidden_states=hidden_states,
+ encoder_hidden_states=cross_attention_hidden_states,
+ rotary_embedding=rotary_embedding,
+ start_id=0,
+ end_id=self.cache_block_start,
+ )
+
+ cache_end = np.minimum(self.cache_block_start + self.cache_num_blocks, self.num_layers)
+ hidden_states_pre_cache = hidden_states.clone()
+ if (step_id - self.cache_step_start) % self.cache_step_interval == 0:
+ hidden_states = self._transformer_blocks_forward(
+ hidden_states=hidden_states,
+ encoder_hidden_states=cross_attention_hidden_states,
+ rotary_embedding=rotary_embedding,
+ start_id=self.cache_block_start,
+ end_id=cache_end,
+ )
+ self.delta_cache = hidden_states - hidden_states_pre_cache
+ else:
+ hidden_states = hidden_states_pre_cache + self.delta_cache
+
+ if cache_end < self.num_layers:
+ hidden_states = self._transformer_blocks_forward(
+ hidden_states=hidden_states,
+ encoder_hidden_states=cross_attention_hidden_states,
+ rotary_embedding=rotary_embedding,
+ start_id=cache_end,
+ end_id=self.num_layers,
+ )
+
+ hidden_states = self.proj_out(hidden_states)
+
+ # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length)
+ # remove prepend length that has been added by global hidden states
+ hidden_states = hidden_states.transpose(1, 2)[:, :, 1:]
+ hidden_states = self.postprocess_conv(hidden_states) + hidden_states
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(sample=hidden_states)
+
+ def _transformer_blocks_forward(self, hidden_states, encoder_hidden_states, rotary_embedding, start_id, end_id):
+ for block in self.transformer_blocks[start_id: end_id]:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ rotary_embedding=rotary_embedding,
+ )
+ return hidden_states
+
+ def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) ->None:
+ for i in range(self.num_layers):
+ self_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None)
+ self_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None)
+ self_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None)
+ self_qkv_weight = torch.cat([self_q_weight, self_k_weight, self_v_weight], dim=0).transpose(0, 1).contiguous()
+ state_dict[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = self_qkv_weight
+
+ cross_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_q.weight", None)
+ cross_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_k.weight", None)
+ cross_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_v.weight", None)
+ cross_q_weight = cross_q_weight.transpose(0, 1).contiguous()
+ cross_kv_weight = torch.cat([cross_k_weight, cross_v_weight], dim=0).transpose(0, 1).contiguous()
+ state_dict[f"transformer_blocks.{i}.attn2.to_qkv.q_weight"] = cross_q_weight
+ state_dict[f"transformer_blocks.{i}.attn2.to_qkv.kv_weight"] = cross_kv_weight
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a283da8ecdf566d75e679e29b3a8e849445b4e6e
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/__init__.py
@@ -0,0 +1 @@
+from .pipeline_stable_audio import StableAudioPipeline, StableAudioProjectionModel
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/pipeline_stable_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..3abd5ecc9d47ea40f1462b4ff7a488120193da98
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/pipeline/pipeline_stable_audio.py
@@ -0,0 +1,754 @@
+# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import (
+ T5EncoderModel,
+ T5Tokenizer,
+ T5TokenizerFast,
+)
+
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel
+from diffusers.models.embeddings import get_1d_rotary_pos_embed
+from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck
+from diffusers.utils import (
+ logging,
+ replace_example_docstring,
+)
+
+from ..models import StableAudioDiTModel
+from ..schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import scipy
+ >>> import torch
+ >>> import soundfile as sf
+ >>> from diffusers import StableAudioPipeline
+
+ >>> repo_id = "stabilityai/stable-audio-open-1.0"
+ >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> # define the prompts
+ >>> prompt = "The sound of a hammer hitting a wooden surface."
+ >>> negative_prompt = "Low quality."
+
+ >>> # set the seed for generator
+ >>> generator = torch.Generator("cuda").manual_seed(0)
+
+ >>> # run the generation
+ >>> audio = pipe(
+ ... prompt,
+ ... negative_prompt=negative_prompt,
+ ... num_inference_steps=200,
+ ... audio_end_in_s=10.0,
+ ... num_waveforms_per_prompt=3,
+ ... generator=generator,
+ ... ).audios
+
+ >>> output = audio[0].T.float().cpu().numpy()
+ >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)
+ ```
+"""
+
+
+class StableAudioPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-audio generation using StableAudio.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderOobleck`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.T5EncoderModel`]):
+ Frozen text-encoder. StableAudio uses the encoder of
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant.
+ projection_model ([`StableAudioProjectionModel`]):
+ A trained model used to linearly project the hidden-states from the text encoder model and the start and
+ end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to
+ give the input to the transformer model.
+ tokenizer ([`~transformers.T5Tokenizer`]):
+ Tokenizer to tokenize text for the frozen text-encoder.
+ transformer ([`StableAudioDiTModel`]):
+ A `StableAudioDiTModel` to denoise the encoded audio latents.
+ scheduler ([`CosineDPMSolverMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded audio latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"
+
+ def __init__(
+ self,
+ vae: AutoencoderOobleck,
+ text_encoder: T5EncoderModel,
+ projection_model: StableAudioProjectionModel,
+ tokenizer: Union[T5Tokenizer, T5TokenizerFast],
+ transformer: StableAudioDiTModel,
+ scheduler: CosineDPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ projection_model=projection_model,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2
+
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ negative_attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # 1. Tokenize text
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ f"The following part of your input was truncated because {self.text_encoder.config.model_type} can "
+ f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_input_ids = text_input_ids.to(device)
+ attention_mask = attention_mask.to(device)
+
+ # 2. Text encoder forward
+ self.text_encoder.eval()
+ prompt_embeds = self.text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if do_classifier_free_guidance and negative_prompt is not None:
+ uncond_tokens: List[str]
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # 1. Tokenize text
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ uncond_input_ids = uncond_input.input_ids.to(device)
+ negative_attention_mask = uncond_input.attention_mask.to(device)
+
+ # 2. Text encoder forward
+ self.text_encoder.eval()
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input_ids,
+ attention_mask=negative_attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if negative_attention_mask is not None:
+ # set the masked tokens to the null embed
+ negative_prompt_embeds = torch.where(
+ negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0
+ )
+
+ # 3. Project prompt_embeds and negative_prompt_embeds
+ if do_classifier_free_guidance and negative_prompt_embeds is not None:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the negative and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ if attention_mask is not None and negative_attention_mask is None:
+ negative_attention_mask = torch.ones_like(attention_mask)
+ elif attention_mask is None and negative_attention_mask is not None:
+ attention_mask = torch.ones_like(negative_attention_mask)
+
+ if attention_mask is not None:
+ attention_mask = torch.cat([negative_attention_mask, attention_mask])
+
+ prompt_embeds = self.projection_model(
+ text_hidden_states=prompt_embeds,
+ ).text_hidden_states
+ if attention_mask is not None:
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
+
+ return prompt_embeds
+
+ def encode_duration(
+ self,
+ audio_start_in_s,
+ audio_end_in_s,
+ device,
+ do_classifier_free_guidance,
+ batch_size,
+ ):
+ audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
+ audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
+
+ if len(audio_start_in_s) == 1:
+ audio_start_in_s = audio_start_in_s * batch_size
+ if len(audio_end_in_s) == 1:
+ audio_end_in_s = audio_end_in_s * batch_size
+
+ # Cast the inputs to floats
+ audio_start_in_s = [float(x) for x in audio_start_in_s]
+ audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
+
+ audio_end_in_s = [float(x) for x in audio_end_in_s]
+ audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
+
+ projection_output = self.projection_model(
+ start_seconds=audio_start_in_s,
+ end_seconds=audio_end_in_s,
+ )
+ seconds_start_hidden_states = projection_output.seconds_start_hidden_states
+ seconds_end_hidden_states = projection_output.seconds_end_hidden_states
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we repeat the audio hidden states to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
+ seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
+
+ return seconds_start_hidden_states, seconds_end_hidden_states
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ audio_start_in_s,
+ audio_end_in_s,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ attention_mask=None,
+ negative_attention_mask=None,
+ initial_audio_waveforms=None,
+ initial_audio_sampling_rate=None,
+ ):
+ if audio_end_in_s < audio_start_in_s:
+ raise ValueError(
+ f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but "
+ )
+
+ if (
+ audio_start_in_s < self.projection_model.config.min_value
+ or audio_start_in_s > self.projection_model.config.max_value
+ ):
+ raise ValueError(
+ f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
+ f"is {audio_start_in_s}."
+ )
+
+ if (
+ audio_end_in_s < self.projection_model.config.min_value
+ or audio_end_in_s > self.projection_model.config.max_value
+ ):
+ raise ValueError(
+ f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
+ f"is {audio_end_in_s}."
+ )
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and (prompt_embeds is None):
+ raise ValueError(
+ "Provide either `prompt`, or `prompt_embeds`. Cannot leave"
+ "`prompt` undefined without specifying `prompt_embeds`."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
+ raise ValueError(
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
+ )
+
+ if initial_audio_sampling_rate is None and initial_audio_waveforms is not None:
+ raise ValueError(
+ "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`."
+ )
+
+ if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate:
+ raise ValueError(
+ f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`."
+ "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. "
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_vae,
+ sample_size,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ initial_audio_waveforms=None,
+ num_waveforms_per_prompt=None,
+ audio_channels=None,
+ ):
+ shape = (batch_size, num_channels_vae, sample_size)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # encode the initial audio for use by the model
+ if initial_audio_waveforms is not None:
+ # check dimension
+ if initial_audio_waveforms.ndim == 2:
+ initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1)
+ elif initial_audio_waveforms.ndim != 3:
+ raise ValueError(
+ f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
+ )
+
+ audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
+ audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
+
+ # check num_channels
+ if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2:
+ initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1)
+ elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1:
+ initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True)
+
+ if initial_audio_waveforms.shape[:2] != audio_shape[:2]:
+ raise ValueError(
+ f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`"
+ )
+
+ # crop or pad
+ audio_length = initial_audio_waveforms.shape[-1]
+ if audio_length < audio_vae_length:
+ logger.warning(
+ f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded."
+ )
+ elif audio_length > audio_vae_length:
+ logger.warning(
+ f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped."
+ )
+
+ audio = initial_audio_waveforms.new_zeros(audio_shape)
+ audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length]
+
+ encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
+ encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
+ latents = encoded_audio + latents
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ audio_end_in_s: Optional[float] = None,
+ audio_start_in_s: Optional[float] = 0.0,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_waveforms_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ initial_audio_waveforms: Optional[torch.Tensor] = None,
+ initial_audio_sampling_rate: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ negative_attention_mask: Optional[torch.LongTensor] = None,
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ output_type: Optional[str] = "pt",
+ use_cache: Optional[bool] = False,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
+ audio_end_in_s (`float`, *optional*, defaults to 47.55):
+ Audio end index in seconds.
+ audio_start_in_s (`float`, *optional*, defaults to 0):
+ Audio start index in seconds.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
+ The number of waveforms to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ initial_audio_waveforms (`torch.Tensor`, *optional*):
+ Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
+ `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
+ corresponds to the number of prompts passed to the model.
+ initial_audio_sampling_rate (`int`, *optional*):
+ Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
+ *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
+ argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
+ `negative_prompt` input argument.
+ attention_mask (`torch.LongTensor`, *optional*):
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
+ be computed from `prompt` input argument.
+ negative_attention_mask (`torch.LongTensor`, *optional*):
+ Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ output_type (`str`, *optional*, defaults to `"pt"`):
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
+ model (LDM) output.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated audio.
+ """
+ # 0. Convert audio input length from seconds to latent length
+ downsample_ratio = self.vae.hop_length
+
+ max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
+ if audio_end_in_s is None:
+ audio_end_in_s = max_audio_length_in_s
+
+ if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
+ raise ValueError(
+ f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
+ )
+
+ waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
+ waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
+ waveform_length = int(self.transformer.config.sample_size)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ audio_start_in_s,
+ audio_end_in_s,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ attention_mask,
+ negative_attention_mask,
+ initial_audio_waveforms,
+ initial_audio_sampling_rate,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ attention_mask,
+ negative_attention_mask,
+ )
+
+ # Encode duration
+ seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
+ audio_start_in_s,
+ audio_end_in_s,
+ device,
+ do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
+ batch_size,
+ )
+
+ # Create text_audio_duration_embeds and audio_duration_embeds
+ text_audio_duration_embeds = torch.cat(
+ [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
+ )
+
+ audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
+
+ # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
+ # to concatenate it to the embeddings
+ if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
+ negative_text_audio_duration_embeds = torch.zeros_like(
+ text_audio_duration_embeds, device=text_audio_duration_embeds.device
+ )
+ text_audio_duration_embeds = torch.cat(
+ [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
+ )
+ audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
+
+ bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
+ # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
+ text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
+ text_audio_duration_embeds = text_audio_duration_embeds.view(
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
+ )
+
+ audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
+ audio_duration_embeds = audio_duration_embeds.view(
+ bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_vae = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_waveforms_per_prompt,
+ num_channels_vae,
+ waveform_length,
+ text_audio_duration_embeds.dtype,
+ device,
+ generator,
+ latents,
+ initial_audio_waveforms,
+ num_waveforms_per_prompt,
+ audio_channels=self.vae.config.audio_channels,
+ )
+
+ # 6. Prepare extra step kwargs
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare rotary positional embedding
+ rotary_embedding = get_1d_rotary_pos_embed(
+ self.rotary_embed_dim,
+ latents.shape[2] + audio_duration_embeds.shape[1],
+ use_real=True,
+ repeat_interleave_real=False,
+ )
+
+ cos = rotary_embedding[0][None, :, None, :].to(latents.device).to(torch.float16)
+ sin = rotary_embedding[1][None, :, None, :].to(latents.device).to(torch.float16)
+ rotary_embedding = (cos, sin)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.transformer(
+ i,
+ latent_model_input,
+ t.unsqueeze(0),
+ encoder_hidden_states=text_audio_duration_embeds,
+ global_hidden_states=audio_duration_embeds,
+ rotary_embedding=rotary_embedding,
+ return_dict=False,
+ use_cache=use_cache,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # 9. Post-processing
+ if not output_type == "latent":
+ audio = self.vae.decode(latents).sample
+ else:
+ return AudioPipelineOutput(audios=latents)
+
+ audio = audio[:, :, waveform_start:waveform_end]
+
+ if output_type == "np":
+ audio = audio.cpu().float().numpy()
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (audio,)
+
+ return AudioPipelineOutput(audios=audio)
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bad3d9b150d1826ed727fb3fe27c446b443d2aa
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/__init__.py
@@ -0,0 +1 @@
+from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4e7805d1346cd3c6c7a9883127c037942170f27
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py
@@ -0,0 +1,574 @@
+# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
+from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler
+
+
+class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021).
+ This scheduler was used in Stable Audio Open [1].
+
+ [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ sigma_min (`float`, *optional*, defaults to 0.3):
+ Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
+ sigma_max (`float`, *optional*, defaults to 500):
+ Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
+ sigma_data (`float`, *optional*, defaults to 1.0):
+ The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
+ sigma_schedule (`str`, *optional*, defaults to `exponential`):
+ Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
+ (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
+ incorporated in this model: https://huggingface.co/stabilityai/cosxl.
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ solver_order (`int`, defaults to 2):
+ The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
+ prediction_type (`str`, defaults to `v_prediction`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ solver_type (`str`, defaults to `midpoint`):
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
+ lower_order_final (`bool`, defaults to `True`):
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
+ euler_at_final (`bool`, defaults to `False`):
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
+ steps, but sometimes may result in blurring.
+ final_sigmas_type (`str`, defaults to `"zero"`):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.3,
+ sigma_max: float = 500,
+ sigma_data: float = 1.0,
+ sigma_schedule: str = "exponential",
+ num_train_timesteps: int = 1000,
+ solver_order: int = 2,
+ prediction_type: str = "v_prediction",
+ rho: float = 7.0,
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ euler_at_final: bool = False,
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ ):
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
+
+ ramp = torch.linspace(0, 1, num_train_timesteps)
+ if sigma_schedule == "karras":
+ sigmas = self._compute_karras_sigmas(ramp)
+ elif sigma_schedule == "exponential":
+ sigmas = self._compute_exponential_sigmas(ramp)
+
+ self.timesteps = self.precondition_noise(sigmas)
+
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+
+ # setable values
+ self.num_inference_steps = None
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ return (self.config.sigma_max**2 + 1) ** 0.5
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
+ def precondition_inputs(self, sample, sigma):
+ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+ scaled_sample = sample * c_in
+ return scaled_sample
+
+ def precondition_noise(self, sigma):
+ if not isinstance(sigma, torch.Tensor):
+ sigma = torch.tensor([sigma])
+
+ return sigma.atan() / math.pi * 2
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
+ def precondition_outputs(self, sample, model_output, sigma):
+ sigma_data = self.config.sigma_data
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
+
+ if self.config.prediction_type == "epsilon":
+ c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
+ elif self.config.prediction_type == "v_prediction":
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
+ else:
+ raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
+
+ denoised = c_skip * sample + c_out * model_output
+
+ return denoised
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ sigma = self.sigmas[self.step_index]
+ sample = self.precondition_inputs(sample, sigma)
+
+ self.is_scale_input_called = True
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ self.num_inference_steps = num_inference_steps
+
+ ramp = torch.linspace(0, 1, self.num_inference_steps)
+ if self.config.sigma_schedule == "karras":
+ sigmas = self._compute_karras_sigmas(ramp)
+ elif self.config.sigma_schedule == "exponential":
+ sigmas = self._compute_exponential_sigmas(ramp)
+
+ sigmas = sigmas.to(dtype=torch.float32, device=device)
+ self.timesteps = self.precondition_noise(sigmas)
+
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = self.config.sigma_min
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)])
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ # add an index counter for schedulers that allow duplicated timesteps
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ # if a noise sampler is used, reinitialise it
+ self.noise_sampler = None
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
+ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+ sigma_min = sigma_min or self.config.sigma_min
+ sigma_max = sigma_max or self.config.sigma_max
+
+ rho = self.config.rho
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
+ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
+ """Implementation closely follows k-diffusion.
+
+ https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
+ """
+ sigma_min = sigma_min or self.config.sigma_min
+ sigma_max = sigma_max or self.config.sigma_max
+ sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
+ sigma_t = sigma
+
+ return alpha_t, sigma_t
+
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ sample: torch.Tensor = None,
+ ) -> torch.Tensor:
+ """
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
+ integral of the data prediction model.
+
+
+
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
+ prediction and data prediction models.
+
+
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The converted model output.
+ """
+ sigma = self.sigmas[self.step_index]
+ x0_pred = self.precondition_outputs(sample, model_output, sigma)
+
+ return x0_pred
+
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.Tensor,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ One step for the first-order DPMSolver (equivalent to DDIM).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
+
+ h = lambda_t - lambda_s
+ if noise is None:
+ raise ValueError("noise must not be None")
+ x_t = (
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+
+ return x_t
+
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.Tensor],
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ One step for the second-order multistep DPMSolver.
+
+ Args:
+ model_output_list (`List[torch.Tensor]`):
+ The direct outputs from learned diffusion model at current and latter timesteps.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ sigma_t, sigma_s0, sigma_s1 = (
+ self.sigmas[self.step_index + 1],
+ self.sigmas[self.step_index],
+ self.sigmas[self.step_index - 1],
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
+
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+
+ # sde-dpmsolver++
+ if noise is None:
+ raise ValueError("noise must not be None")
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+
+ return x_t
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ index_candidates = (schedule_timesteps == timestep).nonzero()
+
+ if len(index_candidates) == 0:
+ step_index = len(self.timesteps) - 1
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ elif len(index_candidates) > 1:
+ step_index = index_candidates[1].item()
+ else:
+ step_index = index_candidates[0].item()
+
+ return step_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
+ def _init_step_index(self, timestep):
+ """
+ Initialize the step_index counter for the scheduler.
+ """
+
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
+ the multistep DPMSolver.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Improve numerical stability for small number of steps
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
+ self.config.euler_at_final
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
+ or self.config.final_sigmas_type == "zero"
+ )
+ lower_order_second = (
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+
+ model_output = self.convert_model_output(model_output, sample=sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ if self.noise_sampler is None:
+ seed = None
+ if generator is not None:
+ seed = (
+ [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
+ )
+ self.noise_sampler = BrownianTreeNoiseSampler(
+ model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
+ )
+ noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
+ model_output.device
+ )
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.Tensor,
+ ) -> torch.Tensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timesteps.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timesteps.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_dpmsolver_sde.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_dpmsolver_sde.py
new file mode 100644
index 0000000000000000000000000000000000000000..8533b082f5c1dd5c07b52d896ec14f4b53f853e8
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_dpmsolver_sde.py
@@ -0,0 +1,71 @@
+# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from .scheduling_utils import BrownianTree
+
+
+class BatchedBrownianTree:
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
+
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
+ t0, t1, self.sign = self.sort(t0, t1)
+ w0 = kwargs.get("w0", torch.zeros_like(x))
+ if seed is None:
+ seed = torch.randint(0, 2**63 - 1, []).item()
+ self.batched = True
+ try:
+ if len(seed) == x.shape[0]:
+ raise ValueError(f"len(seed) should equal to x.shape[0], but got{len(seed)}")
+ w0 = w0[0]
+ except TypeError:
+ seed = [seed]
+ self.batched = False
+ self.trees = [BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+
+ @staticmethod
+ def sort(a, b):
+ return (a, b, 1) if a < b else (b, a, -1)
+
+ def __call__(self, t0, t1):
+ t0, t1, sign = self.sort(t0, t1)
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
+ return w if self.batched else w[0]
+
+
+class BrownianTreeNoiseSampler:
+ """A noise sampler backed by a torchsde.BrownianTree.
+
+ Args:
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
+ random samples.
+ sigma_min (float): The low end of the valid interval.
+ sigma_max (float): The high end of the valid interval.
+ seed (int or List[int]): The random seed. If a list of seeds is
+ supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
+ with its own seed.
+ transform (callable): A function that maps sigma to the sampler's
+ internal timestep.
+ """
+
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
+ self.transform = transform
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
+
+ def __call__(self, sigma, sigma_next):
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_utils.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4734b052f51c53c2399d20b95589e3cd122e8c9
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/StableAudio_Comfyui_Plugin/stableaudio/schedulers/scheduling_utils.py
@@ -0,0 +1,893 @@
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import warnings
+import trampoline
+
+import numpy as np
+import torch
+
+from torchsde._brownian import brownian_base
+from torchsde.settings import LEVY_AREA_APPROXIMATIONS
+from torchsde.types import Scalar, Optional, Tuple, Union, Tensor
+
+_rsqrt3 = 1 / math.sqrt(3)
+_r12 = 1 / 12
+
+
+def _randn(size, dtype, device, seed):
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+def _is_scalar(x):
+ return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1)
+
+
+def _assert_floating_tensor(name, tensor):
+ if not torch.is_tensor(tensor):
+ raise ValueError(f"{name}={tensor} should be a Tensor.")
+ if not tensor.is_floating_point():
+ raise ValueError(f"{name}={tensor} should be floating point.")
+
+
+def _check_tensor_info(*tensors, size, dtype, device):
+ """Check if sizes, dtypes, and devices of input tensors all match prescribed values."""
+ tensors = list(filter(torch.is_tensor, tensors))
+
+ if dtype is None and len(tensors) == 0:
+ dtype = torch.get_default_dtype()
+ if device is None and len(tensors) == 0:
+ device = torch.device("cpu")
+
+ sizes = [] if size is None else [size]
+ sizes += [t.shape for t in tensors]
+
+ dtypes = [] if dtype is None else [dtype]
+ dtypes += [t.dtype for t in tensors]
+
+ devices = [] if device is None else [device]
+ devices += [t.device for t in tensors]
+
+ if len(sizes) == 0:
+ raise ValueError("Must either specify `size` or pass in `W` or `H` to implicitly define the size.")
+
+ if not all(i == sizes[0] for i in sizes):
+ raise ValueError("Multiple sizes found. Make sure `size` and `W` or `H` are consistent.")
+ if not all(i == dtypes[0] for i in dtypes):
+ raise ValueError("Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.")
+ if not all(i == devices[0] for i in devices):
+ raise ValueError("Multiple devices found. Make sure `device` and `W` or `H` are consistent.")
+
+ # Make sure size is a tuple (not a torch.Size) for neat repr-printing purposes.
+ return tuple(sizes[0]), dtypes[0], devices[0]
+
+
+def _davie_foster_approximation(W, H, h, levy_area_approximation, get_noise):
+ if levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.none, LEVY_AREA_APPROXIMATIONS.space_time):
+ return None
+ elif W.ndimension() in (0, 1):
+ # If we have zero or one dimensions then treat the scalar / single dimension we have as batch, so that the
+ # Brownian motion is one dimensional and the Levy area is zero.
+ return torch.zeros_like(W)
+ else:
+ # Davie's approximation to the Levy area from space-time Levy area
+ A = H.unsqueeze(-1) * W.unsqueeze(-2) - W.unsqueeze(-1) * H.unsqueeze(-2)
+ noise = get_noise()
+ noise = noise - noise.transpose(-1, -2) # noise is skew symmetric of variance 2
+ if levy_area_approximation == LEVY_AREA_APPROXIMATIONS.foster:
+ # Foster's additional correction to Davie's approximation
+ tenth_h = 0.1 * h
+ H_squared = H ** 2
+ std = (tenth_h * (tenth_h + H_squared.unsqueeze(-1) + H_squared.unsqueeze(-2))).sqrt()
+ else: # davie approximation
+ std = math.sqrt(_r12 * h ** 2)
+ a_tilde = std * noise
+ A += a_tilde
+ return A
+
+
+def _H_to_U(W: torch.Tensor, H: torch.Tensor, h: float) -> torch.Tensor:
+ return h * (.5 * W + H)
+
+
+class _EmptyDict:
+ def __setitem__(self, key, value):
+ pass
+
+ def __getitem__(self, item):
+ raise KeyError
+
+
+class _LRUDict(dict):
+ def __init__(self, max_size):
+ super().__init__()
+ self._max_size = max_size
+ self._keys = []
+
+ def __setitem__(self, key, value):
+ if key in self:
+ self._keys.remove(key)
+ elif len(self) >= self._max_size:
+ del self[self._keys.pop(0)]
+ super().__setitem__(key, value)
+ self._keys.append(key)
+
+
+class _Interval:
+ # Intervals correspond to some subinterval of the overall interval [t0, t1].
+ # They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left
+ # and right subintervals, which partition the parent interval.
+
+ __slots__ = (
+ # These are the things that every interval has
+ '_start',
+ '_end',
+ '_parent',
+ '_is_left',
+ '_top',
+ # These are the things that intervals which are parents also have
+ '_midway',
+ '_spawn_key',
+ '_depth',
+ '_W_seed',
+ '_H_seed',
+ '_left_a_seed',
+ '_right_a_seed',
+ '_left_child',
+ '_right_child')
+
+ def __init__(self, start, end, parent, is_left, top):
+ self._start = top._round(start) # the left hand edge of the interval
+ self._end = top._round(end) # the right hand edge of the interval
+ self._parent = parent # our parent interval
+ self._is_left = is_left # are we the left or right child of our parent
+ self._top = top # the top-level BrownianInterval, where we cache certain state
+ self._midway = None # The point at which we split between left and right subintervals
+
+ ########################################
+ # Calculate increments and levy area #
+ ########################################
+ #
+ # This is a little bit convoluted, so here's an explanation.
+ #
+ # The entry point is _increment_and_levy_area, below. This immediately calls _increment_and_space_time_levy_area,
+ # applies the space-time to full Levy area correction, and then returns.
+ #
+ # _increment_and_space_time_levy_area in turn calls a central LRU cache, as (later on) we'll need the increment and
+ # space-time Levy area of the parent interval to compute our own increment and space-time Levy area, and it's likely
+ # that our parent exists in the cache, as if we're being queried then our parent was probably queried recently as
+ # well.
+ # (The top-level BrownianInterval overrides _increment_and_space_time_levy_area to return its own increment and
+ # space-time Levy area, effectively holding them permanently in the cache.)
+ #
+ # If the request isn't found in the LRU cache then it computes it from its parent.
+ # Now it turns out that the size of our increment and space-time Levy area is really most naturally thought of as a
+ # property of our parent: it depends on our parent's increment, space-time Levy area, and whether we are the left or
+ # right interval within our parent. So _increment_and_space_time_levy_area in turn checks if we are on the
+ # left or right of our parent and does most of the computation using the parent's attributes.
+
+ def _increment_and_levy_area(self):
+ W, H = trampoline.trampoline(self._increment_and_space_time_levy_area())
+ A = _davie_foster_approximation(W, H, self._end - self._start, self._top._levy_area_approximation,
+ self._randn_levy)
+ return W, H, A
+
+ def _increment_and_space_time_levy_area(self):
+ try:
+ return self._top._increment_and_space_time_levy_area_cache[self]
+ except KeyError:
+ parent = self._parent
+
+ W, H = yield parent._increment_and_space_time_levy_area()
+ h_reciprocal = 1 / (parent._end - parent._start)
+ left_diff = parent._midway - parent._start
+ right_diff = parent._end - parent._midway
+
+ if self._top._have_H:
+ left_diff_squared = left_diff ** 2
+ right_diff_squared = right_diff ** 2
+ left_diff_cubed = left_diff * left_diff_squared
+ right_diff_cubed = right_diff * right_diff_squared
+
+ v = 0.5 * math.sqrt(left_diff * right_diff / (left_diff_cubed + right_diff_cubed))
+
+ a = v * left_diff_squared * h_reciprocal
+ b = v * right_diff_squared * h_reciprocal
+ c = v * _rsqrt3
+
+ X1 = parent._randn(parent._W_seed)
+ X2 = parent._randn(parent._H_seed)
+
+ third_coeff = 2 * (a * left_diff + b * right_diff) * h_reciprocal
+
+ if self._is_left:
+ first_coeff = left_diff * h_reciprocal
+ second_coeff = 6 * first_coeff * right_diff * h_reciprocal
+ out_W = first_coeff * W + second_coeff * H + third_coeff * X1
+ out_H = first_coeff ** 2 * H - a * X1 + c * right_diff * X2
+ else:
+ first_coeff = right_diff * h_reciprocal
+ second_coeff = 6 * first_coeff * left_diff * h_reciprocal
+ out_W = first_coeff * W - second_coeff * H - third_coeff * X1
+ out_H = first_coeff ** 2 * H - b * X1 - c * left_diff * X2
+ else:
+ # Don't compute space-time Levy area unless we need to
+
+ mean = left_diff * h_reciprocal * W
+ var = left_diff * right_diff * h_reciprocal
+ noise = parent._randn(parent._W_seed)
+ left_W = mean + math.sqrt(var) * noise
+
+ if self._is_left:
+ out_W = left_W
+ else:
+ out_W = W - left_W
+ out_H = None
+
+ self._top._increment_and_space_time_levy_area_cache[self] = (out_W, out_H)
+ return out_W, out_H
+
+ def _randn(self, seed):
+ # We generate random noise deterministically wrt some seed; this seed is determined by the generator.
+ # This means that if we drop out of the cache, then we'll create the same random noise next time, as we still
+ # have the generator.
+ size = self._top._size
+ return _randn(size, self._top._dtype, self._top._device, seed)
+
+ def _a_seed(self):
+ return self._parent._left_a_seed if self._is_left else self._parent._right_a_seed
+
+ def _randn_levy(self):
+ size = (*self._top._size, *self._top._size[-1:])
+ return _randn(size, self._top._dtype, self._top._device, self._a_seed())
+
+ ########################################
+ # Locate an interval in the hierarchy #
+ ########################################
+ #
+ # The other important piece of this construction is a way to locate any given interval within the binary tree
+ # hierarchy. (This is typically the slightly slower part, actually, so if you want to speed things up then this is
+ # the bit to target.)
+ #
+ # loc finds the interval [ta, tb] - and creates it in the appropriate place (as a child of some larger interval) if
+ # it doesn't already exist. As in principle we may request an interval that covers multiple existing intervals, then
+ # in fact the interval [ta, tb] is returned as an ordered list of existing subintervals.
+ #
+ # It calls _loc, which operates recursively. See _loc for more details on how the search works.
+
+ def _loc(self, ta, tb):
+ out = []
+ ta = self._top._round(ta)
+ tb = self._top._round(tb)
+ trampoline.trampoline(self._loc_inner(ta, tb, out))
+ return out
+
+ def _loc_inner(self, ta, tb, out):
+ # Expect to have ta < tb
+
+ # First, we (this interval) only have jurisdiction over [self._start, self._end]. So if we're asked for
+ # something outside of that then we pass the buck up to our parent, who is strictly larger.
+ if ta < self._start or tb > self._end:
+ raise trampoline.TailCall(self._parent._loc_inner(ta, tb, out))
+
+ # If it's us that's being asked for, then we add ourselves on to out and return.
+ if ta == self._start and tb == self._end:
+ out.append(self)
+ return
+
+ # If we've got this far then we know that it's an interval that's within our jurisdiction, and that it's not us.
+ # So next we check if it's up to us to figure out, or up to our children.
+ if self._midway is None:
+ # It's up to us. Create subintervals (_split) if appropriate.
+ if ta == self._start:
+ self._split(tb)
+ raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out))
+ # implies ta > self._start
+ self._split(ta)
+ # Query our (newly created) right_child: if tb == self._end then our right child will be the result, and it
+ # will tell us so. But if tb < self._end then our right_child will need to make another split of its own.
+ raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out))
+
+ # If we're here then we have children: self._midway is not None
+ if tb <= self._midway:
+ # Strictly our left_child's problem
+ raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out))
+ if ta >= self._midway:
+ # Strictly our right_child's problem
+ raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out))
+ # It's a problem for both of our children: the requested interval overlaps our midpoint. Call the left_child
+ # first (to append to out in the correct order), then call our right child.
+ # (Implies ta < self._midway < tb)
+ yield self._left_child._loc_inner(ta, self._midway, out)
+ raise trampoline.TailCall(self._right_child._loc_inner(self._midway, tb, out))
+
+ def _set_spawn_key_and_depth(self):
+ self._spawn_key = 2 * self._parent._spawn_key + (0 if self._is_left else 1)
+ self._depth = self._parent._depth + 1
+
+ def _split(self, midway):
+ if self._top._halfway_tree:
+ self._split_exact(0.5 * (self._end + self._start))
+ # self._midway is now the rounded halfway point.
+ if midway > self._midway:
+ self._right_child._split(midway)
+ elif midway < self._midway:
+ self._left_child._split(midway)
+ else:
+ self._split_exact(midway)
+
+ def _split_exact(self, midway): # Create two children
+ self._midway = self._top._round(midway)
+ # Use splittable PRNGs to generate noise.
+ self._set_spawn_key_and_depth()
+ generator = np.random.SeedSequence(entropy=self._top._entropy,
+ spawn_key=(self._spawn_key, self._depth),
+ pool_size=self._top._pool_size)
+ self._W_seed, self._H_seed, self._left_a_seed, self._right_a_seed = generator.generate_state(4)
+
+ self._left_child = _Interval(start=self._start,
+ end=midway,
+ parent=self,
+ is_left=True,
+ top=self._top)
+ self._right_child = _Interval(start=midway,
+ end=self._end,
+ parent=self,
+ is_left=False,
+ top=self._top)
+
+
+class BrownianInterval(brownian_base.BaseBrownian, _Interval):
+ """Brownian interval with fixed entropy.
+
+ Computes increments (and optionally Levy area).
+
+ To use:
+ >>> bm = BrownianInterval(t0=0.0, t1=1.0, size=(4, 1), device='cuda')
+ >>> bm(0., 0.5)
+ tensor([[ 0.0733],
+ [-0.5692],
+ [ 0.1872],
+ [-0.3889]], device='cuda:0')
+ """
+
+ __slots__ = (
+ # Inputs
+ '_size',
+ '_dtype',
+ '_device',
+ '_entropy',
+ '_levy_area_approximation',
+ '_dt',
+ '_tol',
+ '_pool_size',
+ '_cache_size',
+ '_halfway_tree',
+ # Quantisation
+ '_round',
+ # Caching, searching and computing values
+ '_increment_and_space_time_levy_area_cache',
+ '_last_interval',
+ '_have_H',
+ '_have_A',
+ '_w_h',
+ '_top_a_seed',
+ # Dependency tree creation
+ '_average_dt',
+ '_tree_dt',
+ '_num_evaluations'
+ )
+
+ def __init__(self,
+ t0: Optional[Scalar] = 0.,
+ t1: Optional[Scalar] = 1.,
+ size: Optional[Tuple[int, ...]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ entropy: Optional[int] = None,
+ dt: Optional[Scalar] = None,
+ tol: Scalar = 0.,
+ pool_size: int = 8,
+ cache_size: Optional[int] = 45,
+ halfway_tree: bool = False,
+ levy_area_approximation: str = LEVY_AREA_APPROXIMATIONS.none,
+ W: Optional[Tensor] = None,
+ H: Optional[Tensor] = None):
+ """Initialize the Brownian interval.
+
+ Args:
+ t0 (float or Tensor): Initial time.
+ t1 (float or Tensor): Terminal time.
+ size (tuple of int): The shape of each Brownian sample.
+ If zero dimensional represents a scalar Brownian motion.
+ If one dimensional represents a batch of scalar Brownian motions.
+ If >two dimensional the last dimension represents the size of a
+ a multidimensional Brownian motion, and all previous dimensions
+ represent batch dimensions.
+ dtype (torch.dtype): The dtype of each Brownian sample.
+ Defaults to the PyTorch default.
+ device (str or torch.device): The device of each Brownian sample.
+ Defaults to the CPU.
+ entropy (int): Global seed, defaults to `None` for random entropy.
+ levy_area_approximation (str): Whether to also approximate Levy
+ area. Defaults to 'none'. Valid options are 'none',
+ 'space-time', 'davie' or 'foster', corresponding to different
+ approximation types.
+ This is needed for some higher-order SDE solvers.
+ dt (float or Tensor): The expected average step size of the SDE
+ solver. Set it if you know it (e.g. when using a fixed-step
+ solver); else it will be estimated from the first few queries.
+ This is used to set up the data structure such that it is
+ efficient to query at these intervals.
+ tol (float or Tensor): What tolerance to resolve the Brownian motion
+ to. Must be non-negative. Defaults to zero, i.e. floating point
+ resolution. Usually worth setting in conjunction with
+ `halfway_tree`, below.
+ pool_size (int): Size of the pooled entropy. If you care about
+ statistical randomness then increasing this will help (but will
+ slow things down).
+ cache_size (int): How big a cache of recent calculations to use.
+ (As new calculations depend on old calculations, this speeds
+ things up dramatically, rather than recomputing things.)
+ Set this to `None` to use an infinite cache, which will be fast
+ but memory inefficient.
+ halfway_tree (bool): Whether the dependency tree (the internal data
+ structure) should be the dyadic tree. Defaults to `False`.
+ Normally, the sample path is determined by both `entropy`,
+ _and_ the locations and order of the query points. Setting this
+ to `True` will make it deterministic with respect to just
+ `entropy`; however this is much slower.
+ W (Tensor): The increment of the Brownian motion over the interval
+ [t0, t1]. Will be generated randomly if not provided.
+ H (Tensor): The space-time Levy area of the Brownian motion over the
+ interval [t0, t1]. Will be generated randomly if not provided.
+ """
+
+ #####################################
+ # Check and normalise inputs #
+ #####################################
+
+ if not _is_scalar(t0):
+ raise ValueError('Initial time t0 should be a float or 0-d torch.Tensor.')
+ if not _is_scalar(t1):
+ raise ValueError('Terminal time t1 should be a float or 0-d torch.Tensor.')
+ if dt is not None and not _is_scalar(dt):
+ raise ValueError('Expected average time step dt should be a float or 0-d torch.Tensor.')
+
+ if t0 > t1:
+ raise ValueError(f'Initial time {t0} should be less than terminal time {t1}.')
+ t0 = float(t0)
+ t1 = float(t1)
+ if dt is not None:
+ dt = float(dt)
+
+ if halfway_tree:
+ if tol <= 0.:
+ raise ValueError("`tol` should be positive.")
+ if dt is not None:
+ raise ValueError("`dt` is not used and should be set to `None` if `halfway_tree` is True.")
+ else:
+ if tol < 0.:
+ raise ValueError("`tol` should be non-negative.")
+
+ size, dtype, device = _check_tensor_info(W, H, size=size, dtype=dtype, device=device)
+
+ # Let numpy dictate randomness, so we have fewer seeds to set for reproducibility.
+ if entropy is None:
+ entropy = np.random.randint(0, 2 ** 31 - 1)
+
+ if levy_area_approximation not in LEVY_AREA_APPROXIMATIONS:
+ raise ValueError(f"`levy_area_approximation` must be one of {LEVY_AREA_APPROXIMATIONS}, but got "
+ f"'{levy_area_approximation}'.")
+
+ #####################################
+ # Record inputs #
+ #####################################
+
+ self._size = size
+ self._dtype = dtype
+ self._device = device
+ self._entropy = entropy
+ self._levy_area_approximation = levy_area_approximation
+ self._dt = dt
+ self._tol = tol
+ self._pool_size = pool_size
+ self._cache_size = cache_size
+ self._halfway_tree = halfway_tree
+
+ #####################################
+ # A miscellany of other things #
+ #####################################
+
+ # We keep a cache of recent queries, and their results. This is very important for speed, so that we don't
+ # recurse all the way up to the top every time we have a query.
+ if cache_size is None:
+ self._increment_and_space_time_levy_area_cache = {}
+ elif cache_size == 0:
+ self._increment_and_space_time_levy_area_cache = _EmptyDict()
+ else:
+ self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size)
+
+ # We keep track of the most recently queried interval, and start searching for the next interval from that
+ # element of the binary tree. This is because subsequent queries are likely to be near the most recent query.
+ self._last_interval = self
+
+ # Precompute these as we don't want to spend lots of time checking strings in hot loops.
+ self._have_H = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.space_time,
+ LEVY_AREA_APPROXIMATIONS.davie,
+ LEVY_AREA_APPROXIMATIONS.foster)
+ self._have_A = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.davie,
+ LEVY_AREA_APPROXIMATIONS.foster)
+
+ # If we like we can quantise what level we want to compute the Brownian motion to.
+ if math.isclose(tol, 0., rel_tol=1e-5):
+ def round_func(x):
+ return x
+ else:
+ ndigits = -int(math.log10(tol))
+
+ def round_func(x):
+ return round(x, ndigits)
+
+ self._round = round_func
+
+ # Initalise as _Interval.
+ # (Must come after _round but before _w_h)
+ super(BrownianInterval, self).__init__(start=t0,
+ end=t1,
+ parent=None,
+ is_left=None,
+ top=self)
+
+ # Set the global increment and space-time Levy area
+ generator = np.random.SeedSequence(entropy=entropy, pool_size=pool_size)
+ initial_W_seed, initial_H_seed, top_a_seed = generator.generate_state(3)
+ if W is None:
+ W = self._randn(initial_W_seed) * math.sqrt(t1 - t0)
+ else:
+ _assert_floating_tensor('W', W)
+ if H is None:
+ H = self._randn(initial_H_seed) * math.sqrt((t1 - t0) / 12)
+ else:
+ _assert_floating_tensor('H', H)
+ self._w_h = (W, H)
+ self._top_a_seed = top_a_seed
+
+ if not self._halfway_tree:
+ # We create a binary tree dependency between the points. If we don't do this then the forward pass is still
+ # efficient at O(N), but we end up with a dependency chain stretching along the interval [t0, t1], making
+ # the backward pass O(N^2). By setting up a dependency tree of depth relative to `dt` and `cache_size` we
+ # can instead make both directions O(N log N).
+ self._average_dt = 0
+ self._tree_dt = t1 - t0
+ self._num_evaluations = -100 # start off with a warmup period to get a decent estimate of the average
+ if dt is not None:
+ # Create the dependency tree based on the supplied hint `dt`.
+ self._create_dependency_tree(dt)
+ # If dt is None, then create the dependency tree based on observed statistics of query points. (In __call__)
+
+ # Effectively permanently store our increment and space-time Levy area in the cache.
+ def _increment_and_space_time_levy_area(self):
+ return self._w_h
+ yield # make it a generator
+
+ def _a_seed(self):
+ return self._top_a_seed
+
+ def _set_spawn_key_and_depth(self):
+ self._spawn_key = 0
+ self._depth = 0
+
+ def __call__(self, ta, tb=None, return_U=False, return_A=False):
+ if tb is None:
+ warnings.warn(f"{self.__class__.__name__} is optimised for interval-based queries, not point evaluation.")
+ ta, tb = self._start, ta
+ tb_name = 'ta'
+ else:
+ tb_name = 'tb'
+ ta = float(ta)
+ tb = float(tb)
+ if ta < self._start:
+ warnings.warn(f"Should have ta>=t0 but got ta={ta} and t0={self._start}.")
+ ta = self._start
+ if tb < self._start:
+ warnings.warn(f"Should have {tb_name}>=t0 but got {tb_name}={tb} and t0={self._start}.")
+ tb = self._start
+ if ta > self._end:
+ warnings.warn(f"Should have ta<=t1 but got ta={ta} and t1={self._end}.")
+ ta = self._end
+ if tb > self._end:
+ warnings.warn(f"Should have {tb_name}<=t1 but got {tb_name}={tb} and t1={self._end}.")
+ tb = self._end
+ if ta > tb:
+ raise RuntimeError(f"Query times ta={ta:.3f} and tb={tb:.3f} must respect ta <= tb.")
+
+ if ta == tb:
+ W = torch.zeros(self._size, dtype=self._dtype, device=self._device)
+ H = None
+ A = None
+ if self._have_H:
+ H = torch.zeros(self._size, dtype=self._dtype, device=self._device)
+ if self._have_A:
+ size = (*self._size, *self._size[-1:]) # not self._size[-1] as that may not exist
+ A = torch.zeros(size, dtype=self._dtype, device=self._device)
+ else:
+ if self._dt is None and not self._halfway_tree:
+ self._num_evaluations += 1
+ # We start off with "negative" num evaluations, to give us a small warm-up period at the start.
+ if self._num_evaluations > 0:
+ # Compute average step size so far
+ dt = tb - ta
+ self._average_dt = (dt + self._average_dt * (self._num_evaluations - 1)) / self._num_evaluations
+ if self._average_dt < 0.5 * self._tree_dt:
+ # If 'dt' wasn't specified, then check the average interval length against the size of the
+ # bottom of the dependency tree. If we're below halfway then refine the tree by splitting all
+ # the bottom pieces into two.
+ self._create_dependency_tree(dt)
+
+ # Find the intervals that correspond to the query. We start our search at the last interval we accessed in
+ # the binary tree, as it's likely that the next query will come nearby.
+ intervals = self._last_interval._loc(ta, tb)
+ # Ideally we'd keep track of intervals[0] on the backward pass. Practically speaking len(intervals) tends to
+ # be 1 or 2 almost always so this isn't a huge deal.
+ self._last_interval = intervals[-1]
+
+ W, H, A = intervals[0]._increment_and_levy_area()
+ if len(intervals) > 1:
+ # If we have multiple intervals then add up their increments and Levy areas.
+
+ for interval in intervals[1:]:
+ Wi, Hi, Ai = interval._increment_and_levy_area()
+ if self._have_H:
+ # Aggregate H:
+ # Given s < u < t, then
+ # H_{s,t} = (term1 + term2) / (t - s)
+ # where
+ # term1 = (t - u) * (H_{u, t} + W_{s, u} / 2)
+ # term2 = (u - s) * (H_{s, u} - W_{u, t} / 2)
+ term1 = (interval._end - interval._start) * (Hi + 0.5 * W)
+ term2 = (interval._start - ta) * (H - 0.5 * Wi)
+ H = (term1 + term2) / (interval._end - ta)
+ if self._have_A and len(self._size) not in (0, 1):
+ # If len(self._size) in (0, 1) then we treat our scalar / single dimension as a batch
+ # dimension, so we have zero Levy area. (And these unsqueezes will result in a tensor of shape
+ # (batch, batch) which is wrong.)
+
+ # Let B_{x, y} = \int_x^y W^1_{s,u} dW^2_u.
+ # Then
+ # B_{s, t} = \int_s^t W^1_{s,u} dW^2_u
+ # = \int_s^v W^1_{s,u} dW^2_u + \int_v^t W^1_{s,v} dW^2_u + \int_v^t W^1_{v,u} dW^2_u
+ # = B_{s, v} + W^1_{s, v} W^2_{v, t} + B_{v, t}
+ #
+ # A is now the antisymmetric part of B, which gives the formula below.
+ A = A + Ai + 0.5 * (W.unsqueeze(-1) * Wi.unsqueeze(-2) - Wi.unsqueeze(-1) * W.unsqueeze(-2))
+ W = W + Wi
+
+ U = None
+ if self._have_H:
+ U = _H_to_U(W, H, tb - ta)
+
+ if return_U:
+ if return_A:
+ return W, U, A
+ else:
+ return W, U
+ else:
+ if return_A:
+ return W, A
+ else:
+ return W
+
+ def _create_dependency_tree(self, dt):
+ # For safety we take a min with 100: if people take very large cache sizes then this would then break the
+ # logarithmic into linear, which causes RecursionErrors.
+ if self._cache_size is None: # cache_size=None corresponds to infinite cache.
+ cache_size = 100
+ else:
+ cache_size = min(self._cache_size, 100)
+
+ self._tree_dt = min(self._tree_dt, dt)
+ # Rationale: We are prepared to hold `cache_size` many things in memory, so when making steps of size `dt`
+ # then we can afford to have the intervals at the bottom of our binary tree be of size `dt * cache_size`.
+ # For safety we then make this a bit smaller by multiplying by 0.8.
+ piece_length = self._tree_dt * cache_size * 0.8
+
+ def _set_points(interval):
+ start = interval._start
+ end = interval._end
+ if end - start > piece_length:
+ midway = (end + start) / 2
+ interval._loc(start, midway)
+ _set_points(interval._left_child)
+ _set_points(interval._right_child)
+
+ _set_points(self)
+
+ def __repr__(self):
+ if self._dt is None:
+ dt = None
+ else:
+ dt = f"{self._dt:.3f}"
+ return (f"{self.__class__.__name__}("
+ f"t0={self._start:.3f}, "
+ f"t1={self._end:.3f}, "
+ f"size={self._size}, "
+ f"dtype={self._dtype}, "
+ f"device={repr(self._device)}, "
+ f"entropy={self._entropy}, "
+ f"dt={dt}, "
+ f"tol={self._tol}, "
+ f"pool_size={self._pool_size}, "
+ f"cache_size={self._cache_size}, "
+ f"levy_area_approximation={repr(self._levy_area_approximation)}"
+ f")")
+
+ def display_binary_tree(self):
+ stack = [(self, 0)]
+ out = []
+ while stack:
+ elem, depth = stack.pop()
+ out.append(" " * depth + f"({elem._start}, {elem._end})")
+ if elem._midway is not None:
+ stack.append((elem._right_child, depth + 1))
+ stack.append((elem._left_child, depth + 1))
+ print("\n".join(out))
+
+ @property
+ def shape(self):
+ return self._size
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @property
+ def device(self):
+ return self._device
+
+ @property
+ def entropy(self):
+ return self._entropy
+
+ @property
+ def levy_area_approximation(self):
+ return self._levy_area_approximation
+
+ @property
+ def dt(self):
+ return self._dt
+
+ @property
+ def tol(self):
+ return self._tol
+
+ @property
+ def pool_size(self):
+ return self._pool_size
+
+ @property
+ def cache_size(self):
+ return self._cache_size
+
+ @property
+ def halfway_tree(self):
+ return self._halfway_tree
+
+ def size(self):
+ return self._size
+
+
+class BrownianTree(brownian_base.BaseBrownian):
+ """Brownian tree with fixed entropy.
+
+ Useful when the map from entropy -> Brownian motion shouldn't depend on the
+ locations and order of the query points. (As the usual BrownianInterval
+ does - note that BrownianTree is slower as a result though.)
+
+ To use:
+ >>> bm = BrownianTree(t0=0.0, w0=torch.zeros(4, 1))
+ >>> bm(0., 0.5)
+ tensor([[ 0.0733],
+ [-0.5692],
+ [ 0.1872],
+ [-0.3889]], device='cuda:0')
+ """
+
+ def __init__(self, t0: Scalar,
+ w0: Tensor,
+ t1: Optional[Scalar] = None,
+ w1: Optional[Tensor] = None,
+ entropy: Optional[int] = None,
+ tol: float = 1e-6,
+ pool_size: int = 24,
+ cache_depth: int = 9,
+ safety: Optional[float] = None):
+ """Initialize the Brownian tree.
+
+ The random value generation process exploits the parallel random number paradigm and uses
+ `numpy.random.SeedSequence`. The default generator is PCG64 (used by `default_rng`).
+
+ Arguments:
+ t0: Initial time.
+ w0: Initial state.
+ t1: Terminal time.
+ w1: Terminal state.
+ entropy: Global seed, defaults to `None` for random entropy.
+ tol: Error tolerance before the binary search is terminated; the search depth ~ log2(tol).
+ pool_size: Size of the pooled entropy. This parameter affects the query speed significantly.
+ cache_depth: Unused; deprecated.
+ safety: Unused; deprecated.
+ """
+
+ if t1 is None:
+ t1 = t0 + 1
+ if w1 is None:
+ W = None
+ else:
+ W = w1 - w0
+ self._w0 = w0
+ self._interval = BrownianInterval(t0=t0,
+ t1=t1,
+ size=w0.shape,
+ dtype=w0.dtype,
+ device=w0.device,
+ entropy=entropy,
+ tol=tol,
+ pool_size=pool_size,
+ halfway_tree=True,
+ W=W)
+ super(BrownianTree, self).__init__()
+
+ def __call__(self, t, tb=None, return_U=False, return_A=False):
+ # Deliberately called t rather than ta, for backward compatibility
+ out = self._interval(t, tb, return_U=return_U, return_A=return_A)
+ if tb is None and not return_U and not return_A:
+ out = out + self._w0
+ return out
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(interval={self._interval})"
+
+ @property
+ def dtype(self):
+ return self._interval.dtype
+
+ @property
+ def device(self):
+ return self._interval.device
+
+ @property
+ def shape(self):
+ return self._interval.shape
+
+ @property
+ def levy_area_approximation(self):
+ return self._interval.levy_area_approximation
+
+
+def brownian_interval_like(y: Tensor,
+ t0: Optional[Scalar] = 0.,
+ t1: Optional[Scalar] = 1.,
+ size: Optional[Tuple[int, ...]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ **kwargs):
+ """Returns a BrownianInterval object with the same size, device, and dtype as a given tensor."""
+ size = y.shape if size is None else size
+ dtype = y.dtype if dtype is None else dtype
+ device = y.device if device is None else device
+ return brownian_interval.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, **kwargs)
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/inference_stableaudio.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/inference_stableaudio.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7b49fb36bd6905ff4984b704e7c090ba9907f2f
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/inference_stableaudio.py
@@ -0,0 +1,161 @@
+import time
+import json
+import os
+import argparse
+
+import torch
+import torch_npu
+import soundfile as sf
+from safetensors.torch import load_file
+
+from stableaudio import (
+ StableAudioPipeline,
+ AutoencoderOobleck,
+)
+from mindiesd import CacheConfig, CacheAgent
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--prompt_file",
+ type=str,
+ default="./prompts/prompts.txt",
+ help="The prompts file to guide audio generation.",
+ )
+ parser.add_argument(
+ "--negative_prompt",
+ type=str,
+ default="",
+ help="The prompt or prompts to guide what to not include in audio generation.",
+ )
+ parser.add_argument(
+ "--num_inference_steps",
+ type=int,
+ default=100,
+ help="The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="./stable-audio-open-1.0",
+ help="The path of stable-audio-open-1.0.",
+ )
+ parser.add_argument(
+ "--audio_end_in_s",
+ nargs='+',
+ default=[10],
+ help="Audio end index in seconds.",
+ )
+ parser.add_argument(
+ "--device",
+ type=int,
+ default=0,
+ help="NPU device id.",
+ )
+ parser.add_argument(
+ "--save_dir",
+ type=str,
+ default="./results",
+ help="Path to save result audio files.",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=-1,
+ help="Random seed, default 1.",
+ )
+ parser.add_argument(
+ "--use_cache",
+ action="store_true",
+ help="turn on dit cache or not.",
+ )
+ parser = add_attentioncache_args(parser)
+ return parser.parse_args()
+
+
+def add_attentioncache_args(parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(title="Attention Cache args")
+ group.add_argument("--use_attentioncache", action='store_true')
+ group.add_argument("--attentioncache_ratio", type=float, default=1.4)
+ group.add_argument("--attentioncache_interval", type=int, default=5)
+ group.add_argument("--start_step", type=int, default=60)
+ group.add_argument("--end_step", type=int, default=97)
+
+ return parser
+
+
+def main():
+ args = parse_arguments()
+ save_dir = args.save_dir
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ npu_stream = torch_npu.npu.Stream()
+ torch_npu.npu.set_device(args.device)
+ if args.seed != -1:
+ torch.manual_seed(args.seed)
+ latents = torch.randn(1, 64, 1024, dtype=torch.float16, device="cpu").to("npu")
+
+ with open(os.path.join(args.model, "vae", "config.json")) as f:
+ vae_config = json.load(f)
+ vae = AutoencoderOobleck.from_config(vae_config)
+ vae.load_state_dict(load_file(os.path.join(args.model, "vae", "diffusion_pytorch_model.safetensors")))
+
+ pipe = StableAudioPipeline.from_pretrained(args.model, vae=vae)
+ pipe.to(torch.float16).to("npu")
+
+ transformer = wan_t2v.model
+ if args.use_attentioncache:
+ config = CacheConfig(
+ method="attention_cache",
+ blocks_count=len(transformer.transformer_blocks),
+ steps_count=args.num_inference_steps,
+ step_start=args.start_step,
+ step_interval=args.attentioncache_interval,
+ step_end=args.end_step
+ )
+ else:
+ config = CacheConfig(
+ method="attention_cache",
+ blocks_count=len(transformer.transformer_blocks),
+ steps_count=args.num_inference_steps
+ )
+ cache = CacheAgent(config)
+ for block in transformer.transformer_blocks:
+ block.cache = cache
+
+ total_time = 0
+ prompts_num = 0
+ average_time = 0
+ skip = 2
+ with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f:
+ for i, prompt in enumerate(f):
+ with torch.no_grad():
+ audio_end_in_s = float(args.audio_end_in_s[i]) if (len(args.audio_end_in_s) > i) else 10.0
+ npu_stream.synchronize()
+ begin = time.time()
+ audio = pipe(
+ prompt=prompt,
+ negative_prompt=args.negative_prompt,
+ num_inference_steps=args.num_inference_steps,
+ latents=latents,
+ audio_end_in_s=audio_end_in_s,
+ use_cache=args.use_cache,
+ ).audios
+ npu_stream.synchronize()
+ end = time.time()
+ if i > skip - 1:
+ total_time += end - begin
+ prompts_num = i + 1
+ output = audio[0].T.float().cpu().numpy()
+ file_path = os.path.join(args.save_dir, f"audio_by_prompt{prompts_num}.wav")
+ sf.write(file_path, output, pipe.vae.sampling_rate)
+ if prompts_num > skip:
+ average_time = total_time / (prompts_num - skip)
+ else:
+ raise ValueError("Infer average time skip first two prompts, ensure that prompts.txt \
+ contains more than three prompts")
+ print(f"Infer average time: {average_time:.3f}s\n")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/prompts/prompts.txt b/MindIE/MultiModal/Stable-Audio-Open-1.0/prompts/prompts.txt
new file mode 100644
index 0000000000000000000000000000000000000000..523a9bcbcf3c2b09423dbbb24df08a2e26eac78e
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/prompts/prompts.txt
@@ -0,0 +1,6 @@
+Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, evolving, 135BPM. LOOP.
+Uplifting acoustic loop. 120 BPM.
+Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM.
+Warm arpeggios on an analog synthesizer with a gradually rising filter cutoff and a reverb tail.
+Blackbird song, summer, dusk in the forest.
+Rock beat played in a treated studio, session drumming on an acoustic kit.
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/requirements.txt b/MindIE/MultiModal/Stable-Audio-Open-1.0/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6353806d3f71e2b739caf8df040a9e852530ffc2
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/requirements.txt
@@ -0,0 +1,6 @@
+torch==2.1.0
+torchsde==0.2.6
+diffusers==0.30.0
+transformers==4.40.0
+soundfile==0.12.1
+torch_npu==2.1.0.post6
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f484d3fcebf98ab394508963b30f278134d7c236
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/__init__.py
@@ -0,0 +1,4 @@
+from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck
+from .pipeline import StableAudioPipeline, StableAudioProjectionModel
+from .models import StableAudioDiTModel, ModelMixin
+from .schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler, SchedulerMixin
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2f044203d9a52527a303066545740aaea80cc49
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/__init__.py
@@ -0,0 +1,4 @@
+from .attention import (
+ Attention,
+ StableAudioAttnProcessor2_0,
+)
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/attention.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..aec5aa7ca009691d8e2b5c15d46ab0afa5fab4d5
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/attention.py
@@ -0,0 +1,373 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch_npu
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import logging
+from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
+from mindiesd import attention_forward, rotary_position_embedding
+from .linear import QKVLinear
+
+logger = logging.get_logger(__name__)
+
+
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ kv_heads (`int`, *optional*, defaults to `None`):
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
+ Query Attention (MQA) otherwise GQA is used.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ out_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ ):
+ super().__init__()
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.dim_head = dim_head
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ self.spatial_norm = None
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
+ else:
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_qkv = QKVLinear(
+ attention_dim=query_dim,
+ hidden_size=self.inner_dim,
+ qkv_bias=self.use_bias,
+ cross_attention_dim=cross_attention_dim,
+ cross_hidden_size=self.inner_kv_dim,
+ )
+
+ self.added_proj_bias = added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+
+ if qk_norm is not None and added_kv_proj_dim is not None:
+ if qk_norm == "fp32_layer_norm":
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor") -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+
+class StableAudioAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ query, key, value = attn.to_qkv(hidden_states, encoder_hidden_states)
+
+ head_dim = attn.dim_head
+ kv_heads = attn.inner_kv_dim // head_dim
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, kv_heads, head_dim)
+ value = value.view(batch_size, -1, kv_heads, head_dim)
+
+ if kv_heads != attn.heads:
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
+ heads_per_kv_head = attn.heads // kv_heads
+ key = key.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1)
+ key = key.view(batch_size, sequence_length, attn.heads, head_dim)
+ value = value.unsqueeze(3).repeat(1, 1, 1, heads_per_kv_head, 1)
+ value = value.view(batch_size, sequence_length, attn.heads, head_dim)
+
+ # Apply RoPE if needed
+ if rotary_emb is not None:
+ cos, sin = rotary_emb
+ query_to_rotate, query_unrotated = torch.chunk(query, 2, 3)
+ query_rotated = rotary_position_embedding(query_to_rotate, cos, sin, rotated_mode="rotated_half", fused=True)
+ query = torch.cat((query_rotated, query_unrotated), dim=-1)
+
+ if not attn.is_cross_attention:
+ key_to_rotate, key_unrotated = torch.chunk(key, 2, 3)
+ key_rotated = rotary_position_embedding(key_to_rotate, cos, sin, rotated_mode="rotated_half", fused=True)
+ key = torch.cat((key_rotated, key_unrotated), dim=-1)
+
+ hidden_states = attention_forward(
+ query, key, value,
+ opt_mode="manual",
+ op_type="fused_attn_score",
+ layout="BSND"
+ )
+
+ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/linear.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/linear.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e8ee84fc878be54988022ef49e5e8649b3f498c
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/layers/linear.py
@@ -0,0 +1,99 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn as nn
+import torch_npu
+
+
+class QKVLinear(nn.Module):
+ def __init__(self, attention_dim, hidden_size, qkv_bias=True, cross_attention_dim=None, cross_hidden_size=None,
+ device=None, dtype=None):
+ super(QKVLinear, self).__init__()
+ self.attention_dim = attention_dim
+ self.hidden_size = hidden_size
+
+ self.cross_attention_dim = cross_attention_dim
+ self.cross_hidden_size = self.hidden_size if cross_hidden_size is None else cross_hidden_size
+ self.qkv_bias = qkv_bias
+
+ factory_kwargs = {"device": device, "dtype": dtype}
+
+ if cross_attention_dim is None:
+ self.weight = nn.Parameter(torch.empty([self.attention_dim, 3 * self.hidden_size], **factory_kwargs))
+ if self.qkv_bias:
+ self.bias = nn.Parameter(torch.empty([3 * self.hidden_size], **factory_kwargs))
+ else:
+ self.q_weight = nn.Parameter(torch.empty([self.attention_dim, self.hidden_size], **factory_kwargs))
+ self.kv_weight = nn.Parameter(
+ torch.empty([self.cross_attention_dim, 2 * self.cross_hidden_size], **factory_kwargs))
+
+ if self.qkv_bias:
+ self.q_bias = nn.Parameter(torch.empty([self.hidden_size], **factory_kwargs))
+ self.kv_bias = nn.Parameter(torch.empty([2 * self.cross_hidden_size], **factory_kwargs))
+
+ def forward(self, hidden_states, encoder_hidden_states=None):
+
+ if self.cross_attention_dim is None:
+ if not self.qkv_bias:
+ qkv = torch.matmul(hidden_states, self.weight)
+ else:
+ qkv = torch.addmm(
+ self.bias,
+ hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)),
+ self.weight,
+ beta=1,
+ alpha=1
+ )
+
+ batch, seqlen, _ = hidden_states.shape
+ qkv_shape = (batch, seqlen, 3, -1)
+ qkv = qkv.view(qkv_shape)
+ q, k, v = qkv.unbind(2)
+
+ else:
+ if not self.qkv_bias:
+ q = torch.matmul(hidden_states, self.q_weight)
+ kv = torch.matmul(encoder_hidden_states, self.kv_weight)
+ else:
+ q = torch.addmm(
+ self.q_bias,
+ hidden_states.view(hidden_states.size(0) * hidden_states.size(1), hidden_states.size(2)),
+ self.q_weight,
+ beta=1,
+ alpha=1
+ )
+ kv = torch.addmm(
+ self.kv_bias,
+ encoder_hidden_states.view(
+ encoder_hidden_states.size(0) * encoder_hidden_states.size(1),
+ encoder_hidden_states.size(2)),
+ self.kv_weight,
+ beta=1,
+ alpha=1
+ )
+
+ batch, q_seqlen, _ = hidden_states.shape
+ q = q.view(batch, q_seqlen, -1)
+
+ batch, kv_seqlen, _ = encoder_hidden_states.shape
+ kv_shape = (batch, kv_seqlen, 2, -1)
+
+ kv = kv.view(kv_shape)
+ k, v = kv.unbind(2)
+
+ return q, k, v
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4170f5af4c06c05f2ac5f6059a165c3e497f7eb
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/__init__.py
@@ -0,0 +1 @@
+from .stable_audio_transformer import StableAudioDiTModel, ModelMixin
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/stable_audio_transformer.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/stable_audio_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb06b2adecc79d8118ad4ba3a51bfc01b551b63
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/models/stable_audio_transformer.py
@@ -0,0 +1,431 @@
+# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Union
+from collections import OrderedDict
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.attention import FeedForward
+from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import logging
+
+from ..layers.attention import (
+ Attention,
+ StableAudioAttnProcessor2_0,
+)
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableAudioGaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__
+ def __init__(
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
+ ):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.log = log
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ if set_W_to_weight:
+ # to delete later
+ del self.weight
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.weight = self.W
+ del self.W
+
+ def forward(self, x):
+ if self.log:
+ x = torch.log(x)
+
+ x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :]
+
+ if self.flip_sin_to_cos:
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+ else:
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
+
+
+class StableAudioDiTBlock(nn.Module):
+ r"""
+ Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip
+ connection and QKNorm
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for the query states.
+ num_key_value_attention_heads (`int`): The number of heads to use for the key and value states.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ num_key_value_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ upcast_attention: bool = False,
+ norm_eps: float = 1e-5,
+ ff_inner_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=False,
+ upcast_attention=upcast_attention,
+ out_bias=False,
+ processor=StableAudioAttnProcessor2_0(),
+ )
+
+ # 2. Cross-Attn
+ self.norm2 = nn.LayerNorm(dim, norm_eps, True)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ kv_heads=num_key_value_attention_heads,
+ dropout=dropout,
+ bias=False,
+ upcast_attention=upcast_attention,
+ out_bias=False,
+ processor=StableAudioAttnProcessor2_0(),
+ ) # is self-attn if encoder_hidden_states is none
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim, norm_eps, True)
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn="swiglu",
+ final_dropout=False,
+ inner_dim=ff_inner_dim,
+ bias=True,
+ )
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ self.cache = None
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ rotary_embedding: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.cache.apply(
+ self.attn1,
+ norm_hidden_states,
+ attention_mask=attention_mask,
+ rotary_emb=rotary_embedding,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention
+ norm_hidden_states = self.norm2(hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 3. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = ff_output + hidden_states
+
+ return hidden_states
+
+
+class StableAudioDiTModel(ModelMixin, ConfigMixin):
+ """
+ The Diffusion Transformer model introduced in Stable Audio.
+
+ Reference: https://github.com/Stability-AI/stable-audio-tools
+
+ Parameters:
+ sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 64): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states.
+ num_key_value_attention_heads (`int`, *optional*, defaults to 12):
+ The number of heads to use for the key and value states.
+ out_channels (`int`, defaults to 64): Number of output channels.
+ cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection.
+ time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection.
+ global_states_input_dim ( `int`, *optional*, defaults to 1536):
+ Input dimension of the global hidden states projection.
+ cross_attention_input_dim ( `int`, *optional*, defaults to 768):
+ Input dimension of the cross-attention projection
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 1024,
+ in_channels: int = 64,
+ num_layers: int = 24,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 24,
+ num_key_value_attention_heads: int = 12,
+ out_channels: int = 64,
+ cross_attention_dim: int = 768,
+ time_proj_dim: int = 256,
+ global_states_input_dim: int = 1536,
+ cross_attention_input_dim: int = 768,
+ ):
+ super().__init__()
+
+ self.cache_block_start = 11
+ self.cache_step_interval = 2
+ self.cache_num_blocks = 9
+ self.cache_step_start = 5
+
+ self.num_layers = num_layers
+
+ self.sample_size = sample_size
+ self.out_channels = out_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.init_dtype = self.dtype
+ self.time_proj = StableAudioGaussianFourierProjection(
+ embedding_size=time_proj_dim // 2,
+ flip_sin_to_cos=True,
+ log=False,
+ set_W_to_weight=False,
+ )
+
+ self.timestep_proj = nn.Sequential(
+ nn.Linear(time_proj_dim, self.inner_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(self.inner_dim, self.inner_dim, bias=True),
+ )
+
+ self.global_proj = nn.Sequential(
+ nn.Linear(global_states_input_dim, self.inner_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(self.inner_dim, self.inner_dim, bias=False),
+ )
+
+ self.cross_attention_proj = nn.Sequential(
+ nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(cross_attention_dim, cross_attention_dim, bias=False),
+ )
+
+ self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False)
+ self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ StableAudioDiTBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ num_key_value_attention_heads=num_key_value_attention_heads,
+ attention_head_dim=attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False)
+ self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ step_id,
+ hidden_states: torch.FloatTensor,
+ timestep: torch.LongTensor = None,
+ encoder_hidden_states: torch.FloatTensor = None,
+ global_hidden_states: torch.FloatTensor = None,
+ rotary_embedding: torch.FloatTensor = None,
+ return_dict: bool = True,
+ attention_mask: Optional[torch.LongTensor] = None,
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = False,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`StableAudioDiTModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`):
+ Input `hidden_states`.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`):
+ Global embeddings that will be prepended to the hidden states.
+ rotary_embedding (`torch.Tensor`):
+ The rotary embeddings to apply on query and key tensors during attention calculation.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
+ Mask to avoid performing attention on padding token indices, formed by concatenating the attention
+ masks
+ for the two text encoders together. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*):
+ Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating
+ the attention masks
+ for the two text encoders together. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states)
+ global_hidden_states = self.global_proj(global_hidden_states)
+ time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.init_dtype)))
+
+ global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1)
+
+ hidden_states = self.preprocess_conv(hidden_states) + hidden_states
+ # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim)
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = self.proj_in(hidden_states)
+
+ # prepend global states to hidden states
+ hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2)
+ if attention_mask is not None:
+ prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool)
+ attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1)
+
+ if not use_cache or (use_cache and step_id < self.cache_step_start):
+ hidden_states = self._transformer_blocks_forward(
+ hidden_states=hidden_states,
+ encoder_hidden_states=cross_attention_hidden_states,
+ rotary_embedding=rotary_embedding,
+ start_id=0,
+ end_id=self.num_layers,
+ )
+ else:
+ hidden_states = self._transformer_blocks_forward(
+ hidden_states=hidden_states,
+ encoder_hidden_states=cross_attention_hidden_states,
+ rotary_embedding=rotary_embedding,
+ start_id=0,
+ end_id=self.cache_block_start,
+ )
+
+ cache_end = np.minimum(self.cache_block_start + self.cache_num_blocks, self.num_layers)
+ hidden_states_pre_cache = hidden_states.clone()
+ if (step_id - self.cache_step_start) % self.cache_step_interval == 0:
+ hidden_states = self._transformer_blocks_forward(
+ hidden_states=hidden_states,
+ encoder_hidden_states=cross_attention_hidden_states,
+ rotary_embedding=rotary_embedding,
+ start_id=self.cache_block_start,
+ end_id=cache_end,
+ )
+ self.delta_cache = hidden_states - hidden_states_pre_cache
+ else:
+ hidden_states = hidden_states_pre_cache + self.delta_cache
+
+ if cache_end < self.num_layers:
+ hidden_states = self._transformer_blocks_forward(
+ hidden_states=hidden_states,
+ encoder_hidden_states=cross_attention_hidden_states,
+ rotary_embedding=rotary_embedding,
+ start_id=cache_end,
+ end_id=self.num_layers,
+ )
+
+ hidden_states = self.proj_out(hidden_states)
+
+ # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length)
+ # remove prepend length that has been added by global hidden states
+ hidden_states = hidden_states.transpose(1, 2)[:, :, 1:]
+ hidden_states = self.postprocess_conv(hidden_states) + hidden_states
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(sample=hidden_states)
+
+ def _transformer_blocks_forward(self, hidden_states, encoder_hidden_states, rotary_embedding, start_id, end_id):
+ for block in self.transformer_blocks[start_id: end_id]:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ rotary_embedding=rotary_embedding,
+ )
+ return hidden_states
+
+ def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) ->None:
+ for i in range(self.num_layers):
+ self_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_q.weight", None)
+ self_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_k.weight", None)
+ self_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn1.to_v.weight", None)
+ self_qkv_weight = torch.cat([self_q_weight, self_k_weight, self_v_weight], dim=0).transpose(0, 1).contiguous()
+ state_dict[f"transformer_blocks.{i}.attn1.to_qkv.weight"] = self_qkv_weight
+
+ cross_q_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_q.weight", None)
+ cross_k_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_k.weight", None)
+ cross_v_weight = state_dict.pop(f"transformer_blocks.{i}.attn2.to_v.weight", None)
+ cross_q_weight = cross_q_weight.transpose(0, 1).contiguous()
+ cross_kv_weight = torch.cat([cross_k_weight, cross_v_weight], dim=0).transpose(0, 1).contiguous()
+ state_dict[f"transformer_blocks.{i}.attn2.to_qkv.q_weight"] = cross_q_weight
+ state_dict[f"transformer_blocks.{i}.attn2.to_qkv.kv_weight"] = cross_kv_weight
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a283da8ecdf566d75e679e29b3a8e849445b4e6e
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/__init__.py
@@ -0,0 +1 @@
+from .pipeline_stable_audio import StableAudioPipeline, StableAudioProjectionModel
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/pipeline_stable_audio.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/pipeline_stable_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..3abd5ecc9d47ea40f1462b4ff7a488120193da98
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/pipeline/pipeline_stable_audio.py
@@ -0,0 +1,754 @@
+# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import (
+ T5EncoderModel,
+ T5Tokenizer,
+ T5TokenizerFast,
+)
+
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel
+from diffusers.models.embeddings import get_1d_rotary_pos_embed
+from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck
+from diffusers.utils import (
+ logging,
+ replace_example_docstring,
+)
+
+from ..models import StableAudioDiTModel
+from ..schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import scipy
+ >>> import torch
+ >>> import soundfile as sf
+ >>> from diffusers import StableAudioPipeline
+
+ >>> repo_id = "stabilityai/stable-audio-open-1.0"
+ >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> # define the prompts
+ >>> prompt = "The sound of a hammer hitting a wooden surface."
+ >>> negative_prompt = "Low quality."
+
+ >>> # set the seed for generator
+ >>> generator = torch.Generator("cuda").manual_seed(0)
+
+ >>> # run the generation
+ >>> audio = pipe(
+ ... prompt,
+ ... negative_prompt=negative_prompt,
+ ... num_inference_steps=200,
+ ... audio_end_in_s=10.0,
+ ... num_waveforms_per_prompt=3,
+ ... generator=generator,
+ ... ).audios
+
+ >>> output = audio[0].T.float().cpu().numpy()
+ >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate)
+ ```
+"""
+
+
+class StableAudioPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-audio generation using StableAudio.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ vae ([`AutoencoderOobleck`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.T5EncoderModel`]):
+ Frozen text-encoder. StableAudio uses the encoder of
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant.
+ projection_model ([`StableAudioProjectionModel`]):
+ A trained model used to linearly project the hidden-states from the text encoder model and the start and
+ end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to
+ give the input to the transformer model.
+ tokenizer ([`~transformers.T5Tokenizer`]):
+ Tokenizer to tokenize text for the frozen text-encoder.
+ transformer ([`StableAudioDiTModel`]):
+ A `StableAudioDiTModel` to denoise the encoded audio latents.
+ scheduler ([`CosineDPMSolverMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded audio latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae"
+
+ def __init__(
+ self,
+ vae: AutoencoderOobleck,
+ text_encoder: T5EncoderModel,
+ projection_model: StableAudioProjectionModel,
+ tokenizer: Union[T5Tokenizer, T5TokenizerFast],
+ transformer: StableAudioDiTModel,
+ scheduler: CosineDPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ projection_model=projection_model,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2
+
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ negative_attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # 1. Tokenize text
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ f"The following part of your input was truncated because {self.text_encoder.config.model_type} can "
+ f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_input_ids = text_input_ids.to(device)
+ attention_mask = attention_mask.to(device)
+
+ # 2. Text encoder forward
+ self.text_encoder.eval()
+ prompt_embeds = self.text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ if do_classifier_free_guidance and negative_prompt is not None:
+ uncond_tokens: List[str]
+ if type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # 1. Tokenize text
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ uncond_input_ids = uncond_input.input_ids.to(device)
+ negative_attention_mask = uncond_input.attention_mask.to(device)
+
+ # 2. Text encoder forward
+ self.text_encoder.eval()
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input_ids,
+ attention_mask=negative_attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if negative_attention_mask is not None:
+ # set the masked tokens to the null embed
+ negative_prompt_embeds = torch.where(
+ negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0
+ )
+
+ # 3. Project prompt_embeds and negative_prompt_embeds
+ if do_classifier_free_guidance and negative_prompt_embeds is not None:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the negative and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ if attention_mask is not None and negative_attention_mask is None:
+ negative_attention_mask = torch.ones_like(attention_mask)
+ elif attention_mask is None and negative_attention_mask is not None:
+ attention_mask = torch.ones_like(negative_attention_mask)
+
+ if attention_mask is not None:
+ attention_mask = torch.cat([negative_attention_mask, attention_mask])
+
+ prompt_embeds = self.projection_model(
+ text_hidden_states=prompt_embeds,
+ ).text_hidden_states
+ if attention_mask is not None:
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
+
+ return prompt_embeds
+
+ def encode_duration(
+ self,
+ audio_start_in_s,
+ audio_end_in_s,
+ device,
+ do_classifier_free_guidance,
+ batch_size,
+ ):
+ audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s]
+ audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s]
+
+ if len(audio_start_in_s) == 1:
+ audio_start_in_s = audio_start_in_s * batch_size
+ if len(audio_end_in_s) == 1:
+ audio_end_in_s = audio_end_in_s * batch_size
+
+ # Cast the inputs to floats
+ audio_start_in_s = [float(x) for x in audio_start_in_s]
+ audio_start_in_s = torch.tensor(audio_start_in_s).to(device)
+
+ audio_end_in_s = [float(x) for x in audio_end_in_s]
+ audio_end_in_s = torch.tensor(audio_end_in_s).to(device)
+
+ projection_output = self.projection_model(
+ start_seconds=audio_start_in_s,
+ end_seconds=audio_end_in_s,
+ )
+ seconds_start_hidden_states = projection_output.seconds_start_hidden_states
+ seconds_end_hidden_states = projection_output.seconds_end_hidden_states
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we repeat the audio hidden states to avoid doing two forward passes
+ if do_classifier_free_guidance:
+ seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0)
+ seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0)
+
+ return seconds_start_hidden_states, seconds_end_hidden_states
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ audio_start_in_s,
+ audio_end_in_s,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ attention_mask=None,
+ negative_attention_mask=None,
+ initial_audio_waveforms=None,
+ initial_audio_sampling_rate=None,
+ ):
+ if audio_end_in_s < audio_start_in_s:
+ raise ValueError(
+ f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but "
+ )
+
+ if (
+ audio_start_in_s < self.projection_model.config.min_value
+ or audio_start_in_s > self.projection_model.config.max_value
+ ):
+ raise ValueError(
+ f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
+ f"is {audio_start_in_s}."
+ )
+
+ if (
+ audio_end_in_s < self.projection_model.config.min_value
+ or audio_end_in_s > self.projection_model.config.max_value
+ ):
+ raise ValueError(
+ f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but "
+ f"is {audio_end_in_s}."
+ )
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and (prompt_embeds is None):
+ raise ValueError(
+ "Provide either `prompt`, or `prompt_embeds`. Cannot leave"
+ "`prompt` undefined without specifying `prompt_embeds`."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
+ raise ValueError(
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
+ )
+
+ if initial_audio_sampling_rate is None and initial_audio_waveforms is not None:
+ raise ValueError(
+ "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`."
+ )
+
+ if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate:
+ raise ValueError(
+ f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`."
+ "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. "
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_vae,
+ sample_size,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ initial_audio_waveforms=None,
+ num_waveforms_per_prompt=None,
+ audio_channels=None,
+ ):
+ shape = (batch_size, num_channels_vae, sample_size)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # encode the initial audio for use by the model
+ if initial_audio_waveforms is not None:
+ # check dimension
+ if initial_audio_waveforms.ndim == 2:
+ initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1)
+ elif initial_audio_waveforms.ndim != 3:
+ raise ValueError(
+ f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
+ )
+
+ audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
+ audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
+
+ # check num_channels
+ if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2:
+ initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1)
+ elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1:
+ initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True)
+
+ if initial_audio_waveforms.shape[:2] != audio_shape[:2]:
+ raise ValueError(
+ f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`"
+ )
+
+ # crop or pad
+ audio_length = initial_audio_waveforms.shape[-1]
+ if audio_length < audio_vae_length:
+ logger.warning(
+ f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded."
+ )
+ elif audio_length > audio_vae_length:
+ logger.warning(
+ f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped."
+ )
+
+ audio = initial_audio_waveforms.new_zeros(audio_shape)
+ audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length]
+
+ encoded_audio = self.vae.encode(audio).latent_dist.sample(generator)
+ encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1))
+ latents = encoded_audio + latents
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ audio_end_in_s: Optional[float] = None,
+ audio_start_in_s: Optional[float] = 0.0,
+ num_inference_steps: int = 100,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_waveforms_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ initial_audio_waveforms: Optional[torch.Tensor] = None,
+ initial_audio_sampling_rate: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ negative_attention_mask: Optional[torch.LongTensor] = None,
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ output_type: Optional[str] = "pt",
+ use_cache: Optional[bool] = False,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
+ audio_end_in_s (`float`, *optional*, defaults to 47.55):
+ Audio end index in seconds.
+ audio_start_in_s (`float`, *optional*, defaults to 0):
+ Audio start index in seconds.
+ num_inference_steps (`int`, *optional*, defaults to 100):
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
+ The number of waveforms to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ initial_audio_waveforms (`torch.Tensor`, *optional*):
+ Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape
+ `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size`
+ corresponds to the number of prompts passed to the model.
+ initial_audio_sampling_rate (`int`, *optional*):
+ Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs,
+ *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input
+ argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
+ `negative_prompt` input argument.
+ attention_mask (`torch.LongTensor`, *optional*):
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
+ be computed from `prompt` input argument.
+ negative_attention_mask (`torch.LongTensor`, *optional*):
+ Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ output_type (`str`, *optional*, defaults to `"pt"`):
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
+ model (LDM) output.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated audio.
+ """
+ # 0. Convert audio input length from seconds to latent length
+ downsample_ratio = self.vae.hop_length
+
+ max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate
+ if audio_end_in_s is None:
+ audio_end_in_s = max_audio_length_in_s
+
+ if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
+ raise ValueError(
+ f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
+ )
+
+ waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
+ waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate)
+ waveform_length = int(self.transformer.config.sample_size)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ audio_start_in_s,
+ audio_end_in_s,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ attention_mask,
+ negative_attention_mask,
+ initial_audio_waveforms,
+ initial_audio_sampling_rate,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ attention_mask,
+ negative_attention_mask,
+ )
+
+ # Encode duration
+ seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration(
+ audio_start_in_s,
+ audio_end_in_s,
+ device,
+ do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None),
+ batch_size,
+ )
+
+ # Create text_audio_duration_embeds and audio_duration_embeds
+ text_audio_duration_embeds = torch.cat(
+ [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1
+ )
+
+ audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2)
+
+ # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and
+ # to concatenate it to the embeddings
+ if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None:
+ negative_text_audio_duration_embeds = torch.zeros_like(
+ text_audio_duration_embeds, device=text_audio_duration_embeds.device
+ )
+ text_audio_duration_embeds = torch.cat(
+ [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0
+ )
+ audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0)
+
+ bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape
+ # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method
+ text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
+ text_audio_duration_embeds = text_audio_duration_embeds.view(
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
+ )
+
+ audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1)
+ audio_duration_embeds = audio_duration_embeds.view(
+ bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1]
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_vae = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_waveforms_per_prompt,
+ num_channels_vae,
+ waveform_length,
+ text_audio_duration_embeds.dtype,
+ device,
+ generator,
+ latents,
+ initial_audio_waveforms,
+ num_waveforms_per_prompt,
+ audio_channels=self.vae.config.audio_channels,
+ )
+
+ # 6. Prepare extra step kwargs
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare rotary positional embedding
+ rotary_embedding = get_1d_rotary_pos_embed(
+ self.rotary_embed_dim,
+ latents.shape[2] + audio_duration_embeds.shape[1],
+ use_real=True,
+ repeat_interleave_real=False,
+ )
+
+ cos = rotary_embedding[0][None, :, None, :].to(latents.device).to(torch.float16)
+ sin = rotary_embedding[1][None, :, None, :].to(latents.device).to(torch.float16)
+ rotary_embedding = (cos, sin)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.transformer(
+ i,
+ latent_model_input,
+ t.unsqueeze(0),
+ encoder_hidden_states=text_audio_duration_embeds,
+ global_hidden_states=audio_duration_embeds,
+ rotary_embedding=rotary_embedding,
+ return_dict=False,
+ use_cache=use_cache,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ # 9. Post-processing
+ if not output_type == "latent":
+ audio = self.vae.decode(latents).sample
+ else:
+ return AudioPipelineOutput(audios=latents)
+
+ audio = audio[:, :, waveform_start:waveform_end]
+
+ if output_type == "np":
+ audio = audio.cpu().float().numpy()
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (audio,)
+
+ return AudioPipelineOutput(audios=audio)
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/__init__.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bad3d9b150d1826ed727fb3fe27c446b443d2aa
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/__init__.py
@@ -0,0 +1 @@
+from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4e7805d1346cd3c6c7a9883127c037942170f27
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_cosine_dpmsolver_multistep.py
@@ -0,0 +1,574 @@
+# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
+from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler
+
+
+class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021).
+ This scheduler was used in Stable Audio Open [1].
+
+ [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ sigma_min (`float`, *optional*, defaults to 0.3):
+ Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
+ sigma_max (`float`, *optional*, defaults to 500):
+ Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
+ sigma_data (`float`, *optional*, defaults to 1.0):
+ The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
+ sigma_schedule (`str`, *optional*, defaults to `exponential`):
+ Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
+ (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
+ incorporated in this model: https://huggingface.co/stabilityai/cosxl.
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ solver_order (`int`, defaults to 2):
+ The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
+ prediction_type (`str`, defaults to `v_prediction`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ solver_type (`str`, defaults to `midpoint`):
+ Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
+ sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
+ lower_order_final (`bool`, defaults to `True`):
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
+ euler_at_final (`bool`, defaults to `False`):
+ Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
+ richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
+ steps, but sometimes may result in blurring.
+ final_sigmas_type (`str`, defaults to `"zero"`):
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.3,
+ sigma_max: float = 500,
+ sigma_data: float = 1.0,
+ sigma_schedule: str = "exponential",
+ num_train_timesteps: int = 1000,
+ solver_order: int = 2,
+ prediction_type: str = "v_prediction",
+ rho: float = 7.0,
+ solver_type: str = "midpoint",
+ lower_order_final: bool = True,
+ euler_at_final: bool = False,
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ ):
+ if solver_type not in ["midpoint", "heun"]:
+ if solver_type in ["logrho", "bh1", "bh2"]:
+ self.register_to_config(solver_type="midpoint")
+ else:
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
+
+ ramp = torch.linspace(0, 1, num_train_timesteps)
+ if sigma_schedule == "karras":
+ sigmas = self._compute_karras_sigmas(ramp)
+ elif sigma_schedule == "exponential":
+ sigmas = self._compute_exponential_sigmas(ramp)
+
+ self.timesteps = self.precondition_noise(sigmas)
+
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+
+ # setable values
+ self.num_inference_steps = None
+ self.model_outputs = [None] * solver_order
+ self.lower_order_nums = 0
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ return (self.config.sigma_max**2 + 1) ** 0.5
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
+ def precondition_inputs(self, sample, sigma):
+ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+ scaled_sample = sample * c_in
+ return scaled_sample
+
+ def precondition_noise(self, sigma):
+ if not isinstance(sigma, torch.Tensor):
+ sigma = torch.tensor([sigma])
+
+ return sigma.atan() / math.pi * 2
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
+ def precondition_outputs(self, sample, model_output, sigma):
+ sigma_data = self.config.sigma_data
+ c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
+
+ if self.config.prediction_type == "epsilon":
+ c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
+ elif self.config.prediction_type == "v_prediction":
+ c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
+ else:
+ raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.")
+
+ denoised = c_skip * sample + c_out * model_output
+
+ return denoised
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ sigma = self.sigmas[self.step_index]
+ sample = self.precondition_inputs(sample, sigma)
+
+ self.is_scale_input_called = True
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+
+ self.num_inference_steps = num_inference_steps
+
+ ramp = torch.linspace(0, 1, self.num_inference_steps)
+ if self.config.sigma_schedule == "karras":
+ sigmas = self._compute_karras_sigmas(ramp)
+ elif self.config.sigma_schedule == "exponential":
+ sigmas = self._compute_exponential_sigmas(ramp)
+
+ sigmas = sigmas.to(dtype=torch.float32, device=device)
+ self.timesteps = self.precondition_noise(sigmas)
+
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = self.config.sigma_min
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
+ )
+
+ self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)])
+
+ self.model_outputs = [
+ None,
+ ] * self.config.solver_order
+ self.lower_order_nums = 0
+
+ # add an index counter for schedulers that allow duplicated timesteps
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
+
+ # if a noise sampler is used, reinitialise it
+ self.noise_sampler = None
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
+ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+ sigma_min = sigma_min or self.config.sigma_min
+ sigma_max = sigma_max or self.config.sigma_max
+
+ rho = self.config.rho
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
+ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
+ """Implementation closely follows k-diffusion.
+
+ https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
+ """
+ sigma_min = sigma_min or self.config.sigma_min
+ sigma_max = sigma_max or self.config.sigma_max
+ sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ def _sigma_to_alpha_sigma_t(self, sigma):
+ alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
+ sigma_t = sigma
+
+ return alpha_t, sigma_t
+
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ sample: torch.Tensor = None,
+ ) -> torch.Tensor:
+ """
+ Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
+ designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
+ integral of the data prediction model.
+
+
+
+ The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
+ prediction and data prediction models.
+
+
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The converted model output.
+ """
+ sigma = self.sigmas[self.step_index]
+ x0_pred = self.precondition_outputs(sample, model_output, sigma)
+
+ return x0_pred
+
+ def dpm_solver_first_order_update(
+ self,
+ model_output: torch.Tensor,
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ One step for the first-order DPMSolver (equivalent to DDIM).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
+
+ h = lambda_t - lambda_s
+ if noise is None:
+ raise ValueError("noise must not be None")
+ x_t = (
+ (sigma_t / sigma_s * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+
+ return x_t
+
+ def multistep_dpm_solver_second_order_update(
+ self,
+ model_output_list: List[torch.Tensor],
+ sample: torch.Tensor = None,
+ noise: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ One step for the second-order multistep DPMSolver.
+
+ Args:
+ model_output_list (`List[torch.Tensor]`):
+ The direct outputs from learned diffusion model at current and latter timesteps.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+
+ Returns:
+ `torch.Tensor`:
+ The sample tensor at the previous timestep.
+ """
+ sigma_t, sigma_s0, sigma_s1 = (
+ self.sigmas[self.step_index + 1],
+ self.sigmas[self.step_index],
+ self.sigmas[self.step_index - 1],
+ )
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+ alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
+
+ m0, m1 = model_output_list[-1], model_output_list[-2]
+
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
+ r0 = h_0 / h
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
+
+ # sde-dpmsolver++
+ if noise is None:
+ raise ValueError("noise must not be None")
+ if self.config.solver_type == "midpoint":
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+ elif self.config.solver_type == "heun":
+ x_t = (
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
+ + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
+ + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
+ )
+
+ return x_t
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ index_candidates = (schedule_timesteps == timestep).nonzero()
+
+ if len(index_candidates) == 0:
+ step_index = len(self.timesteps) - 1
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ elif len(index_candidates) > 1:
+ step_index = index_candidates[1].item()
+ else:
+ step_index = index_candidates[0].item()
+
+ return step_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
+ def _init_step_index(self, timestep):
+ """
+ Initialize the step_index counter for the scheduler.
+ """
+
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: Union[int, torch.Tensor],
+ sample: torch.Tensor,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
+ the multistep DPMSolver.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Improve numerical stability for small number of steps
+ lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
+ self.config.euler_at_final
+ or (self.config.lower_order_final and len(self.timesteps) < 15)
+ or self.config.final_sigmas_type == "zero"
+ )
+ lower_order_second = (
+ (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
+ )
+
+ model_output = self.convert_model_output(model_output, sample=sample)
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.model_outputs[-1] = model_output
+
+ if self.noise_sampler is None:
+ seed = None
+ if generator is not None:
+ seed = (
+ [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
+ )
+ self.noise_sampler = BrownianTreeNoiseSampler(
+ model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
+ )
+ noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
+ model_output.device
+ )
+
+ if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
+ prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)
+ elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
+ prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.Tensor,
+ ) -> torch.Tensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timesteps.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timesteps.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_dpmsolver_sde.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_dpmsolver_sde.py
new file mode 100644
index 0000000000000000000000000000000000000000..8533b082f5c1dd5c07b52d896ec14f4b53f853e8
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_dpmsolver_sde.py
@@ -0,0 +1,71 @@
+# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from .scheduling_utils import BrownianTree
+
+
+class BatchedBrownianTree:
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
+
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
+ t0, t1, self.sign = self.sort(t0, t1)
+ w0 = kwargs.get("w0", torch.zeros_like(x))
+ if seed is None:
+ seed = torch.randint(0, 2**63 - 1, []).item()
+ self.batched = True
+ try:
+ if len(seed) == x.shape[0]:
+ raise ValueError(f"len(seed) should equal to x.shape[0], but got{len(seed)}")
+ w0 = w0[0]
+ except TypeError:
+ seed = [seed]
+ self.batched = False
+ self.trees = [BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+
+ @staticmethod
+ def sort(a, b):
+ return (a, b, 1) if a < b else (b, a, -1)
+
+ def __call__(self, t0, t1):
+ t0, t1, sign = self.sort(t0, t1)
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
+ return w if self.batched else w[0]
+
+
+class BrownianTreeNoiseSampler:
+ """A noise sampler backed by a torchsde.BrownianTree.
+
+ Args:
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
+ random samples.
+ sigma_min (float): The low end of the valid interval.
+ sigma_max (float): The high end of the valid interval.
+ seed (int or List[int]): The random seed. If a list of seeds is
+ supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
+ with its own seed.
+ transform (callable): A function that maps sigma to the sampler's
+ internal timestep.
+ """
+
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
+ self.transform = transform
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
+
+ def __call__(self, sigma, sigma_next):
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
\ No newline at end of file
diff --git a/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_utils.py b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4734b052f51c53c2399d20b95589e3cd122e8c9
--- /dev/null
+++ b/MindIE/MultiModal/Stable-Audio-Open-1.0/stableaudio/schedulers/scheduling_utils.py
@@ -0,0 +1,893 @@
+
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import warnings
+import trampoline
+
+import numpy as np
+import torch
+
+from torchsde._brownian import brownian_base
+from torchsde.settings import LEVY_AREA_APPROXIMATIONS
+from torchsde.types import Scalar, Optional, Tuple, Union, Tensor
+
+_rsqrt3 = 1 / math.sqrt(3)
+_r12 = 1 / 12
+
+
+def _randn(size, dtype, device, seed):
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+def _is_scalar(x):
+ return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1)
+
+
+def _assert_floating_tensor(name, tensor):
+ if not torch.is_tensor(tensor):
+ raise ValueError(f"{name}={tensor} should be a Tensor.")
+ if not tensor.is_floating_point():
+ raise ValueError(f"{name}={tensor} should be floating point.")
+
+
+def _check_tensor_info(*tensors, size, dtype, device):
+ """Check if sizes, dtypes, and devices of input tensors all match prescribed values."""
+ tensors = list(filter(torch.is_tensor, tensors))
+
+ if dtype is None and len(tensors) == 0:
+ dtype = torch.get_default_dtype()
+ if device is None and len(tensors) == 0:
+ device = torch.device("cpu")
+
+ sizes = [] if size is None else [size]
+ sizes += [t.shape for t in tensors]
+
+ dtypes = [] if dtype is None else [dtype]
+ dtypes += [t.dtype for t in tensors]
+
+ devices = [] if device is None else [device]
+ devices += [t.device for t in tensors]
+
+ if len(sizes) == 0:
+ raise ValueError("Must either specify `size` or pass in `W` or `H` to implicitly define the size.")
+
+ if not all(i == sizes[0] for i in sizes):
+ raise ValueError("Multiple sizes found. Make sure `size` and `W` or `H` are consistent.")
+ if not all(i == dtypes[0] for i in dtypes):
+ raise ValueError("Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.")
+ if not all(i == devices[0] for i in devices):
+ raise ValueError("Multiple devices found. Make sure `device` and `W` or `H` are consistent.")
+
+ # Make sure size is a tuple (not a torch.Size) for neat repr-printing purposes.
+ return tuple(sizes[0]), dtypes[0], devices[0]
+
+
+def _davie_foster_approximation(W, H, h, levy_area_approximation, get_noise):
+ if levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.none, LEVY_AREA_APPROXIMATIONS.space_time):
+ return None
+ elif W.ndimension() in (0, 1):
+ # If we have zero or one dimensions then treat the scalar / single dimension we have as batch, so that the
+ # Brownian motion is one dimensional and the Levy area is zero.
+ return torch.zeros_like(W)
+ else:
+ # Davie's approximation to the Levy area from space-time Levy area
+ A = H.unsqueeze(-1) * W.unsqueeze(-2) - W.unsqueeze(-1) * H.unsqueeze(-2)
+ noise = get_noise()
+ noise = noise - noise.transpose(-1, -2) # noise is skew symmetric of variance 2
+ if levy_area_approximation == LEVY_AREA_APPROXIMATIONS.foster:
+ # Foster's additional correction to Davie's approximation
+ tenth_h = 0.1 * h
+ H_squared = H ** 2
+ std = (tenth_h * (tenth_h + H_squared.unsqueeze(-1) + H_squared.unsqueeze(-2))).sqrt()
+ else: # davie approximation
+ std = math.sqrt(_r12 * h ** 2)
+ a_tilde = std * noise
+ A += a_tilde
+ return A
+
+
+def _H_to_U(W: torch.Tensor, H: torch.Tensor, h: float) -> torch.Tensor:
+ return h * (.5 * W + H)
+
+
+class _EmptyDict:
+ def __setitem__(self, key, value):
+ pass
+
+ def __getitem__(self, item):
+ raise KeyError
+
+
+class _LRUDict(dict):
+ def __init__(self, max_size):
+ super().__init__()
+ self._max_size = max_size
+ self._keys = []
+
+ def __setitem__(self, key, value):
+ if key in self:
+ self._keys.remove(key)
+ elif len(self) >= self._max_size:
+ del self[self._keys.pop(0)]
+ super().__setitem__(key, value)
+ self._keys.append(key)
+
+
+class _Interval:
+ # Intervals correspond to some subinterval of the overall interval [t0, t1].
+ # They are arranged as a binary tree: each node corresponds to an interval. If a node has children, they are left
+ # and right subintervals, which partition the parent interval.
+
+ __slots__ = (
+ # These are the things that every interval has
+ '_start',
+ '_end',
+ '_parent',
+ '_is_left',
+ '_top',
+ # These are the things that intervals which are parents also have
+ '_midway',
+ '_spawn_key',
+ '_depth',
+ '_W_seed',
+ '_H_seed',
+ '_left_a_seed',
+ '_right_a_seed',
+ '_left_child',
+ '_right_child')
+
+ def __init__(self, start, end, parent, is_left, top):
+ self._start = top._round(start) # the left hand edge of the interval
+ self._end = top._round(end) # the right hand edge of the interval
+ self._parent = parent # our parent interval
+ self._is_left = is_left # are we the left or right child of our parent
+ self._top = top # the top-level BrownianInterval, where we cache certain state
+ self._midway = None # The point at which we split between left and right subintervals
+
+ ########################################
+ # Calculate increments and levy area #
+ ########################################
+ #
+ # This is a little bit convoluted, so here's an explanation.
+ #
+ # The entry point is _increment_and_levy_area, below. This immediately calls _increment_and_space_time_levy_area,
+ # applies the space-time to full Levy area correction, and then returns.
+ #
+ # _increment_and_space_time_levy_area in turn calls a central LRU cache, as (later on) we'll need the increment and
+ # space-time Levy area of the parent interval to compute our own increment and space-time Levy area, and it's likely
+ # that our parent exists in the cache, as if we're being queried then our parent was probably queried recently as
+ # well.
+ # (The top-level BrownianInterval overrides _increment_and_space_time_levy_area to return its own increment and
+ # space-time Levy area, effectively holding them permanently in the cache.)
+ #
+ # If the request isn't found in the LRU cache then it computes it from its parent.
+ # Now it turns out that the size of our increment and space-time Levy area is really most naturally thought of as a
+ # property of our parent: it depends on our parent's increment, space-time Levy area, and whether we are the left or
+ # right interval within our parent. So _increment_and_space_time_levy_area in turn checks if we are on the
+ # left or right of our parent and does most of the computation using the parent's attributes.
+
+ def _increment_and_levy_area(self):
+ W, H = trampoline.trampoline(self._increment_and_space_time_levy_area())
+ A = _davie_foster_approximation(W, H, self._end - self._start, self._top._levy_area_approximation,
+ self._randn_levy)
+ return W, H, A
+
+ def _increment_and_space_time_levy_area(self):
+ try:
+ return self._top._increment_and_space_time_levy_area_cache[self]
+ except KeyError:
+ parent = self._parent
+
+ W, H = yield parent._increment_and_space_time_levy_area()
+ h_reciprocal = 1 / (parent._end - parent._start)
+ left_diff = parent._midway - parent._start
+ right_diff = parent._end - parent._midway
+
+ if self._top._have_H:
+ left_diff_squared = left_diff ** 2
+ right_diff_squared = right_diff ** 2
+ left_diff_cubed = left_diff * left_diff_squared
+ right_diff_cubed = right_diff * right_diff_squared
+
+ v = 0.5 * math.sqrt(left_diff * right_diff / (left_diff_cubed + right_diff_cubed))
+
+ a = v * left_diff_squared * h_reciprocal
+ b = v * right_diff_squared * h_reciprocal
+ c = v * _rsqrt3
+
+ X1 = parent._randn(parent._W_seed)
+ X2 = parent._randn(parent._H_seed)
+
+ third_coeff = 2 * (a * left_diff + b * right_diff) * h_reciprocal
+
+ if self._is_left:
+ first_coeff = left_diff * h_reciprocal
+ second_coeff = 6 * first_coeff * right_diff * h_reciprocal
+ out_W = first_coeff * W + second_coeff * H + third_coeff * X1
+ out_H = first_coeff ** 2 * H - a * X1 + c * right_diff * X2
+ else:
+ first_coeff = right_diff * h_reciprocal
+ second_coeff = 6 * first_coeff * left_diff * h_reciprocal
+ out_W = first_coeff * W - second_coeff * H - third_coeff * X1
+ out_H = first_coeff ** 2 * H - b * X1 - c * left_diff * X2
+ else:
+ # Don't compute space-time Levy area unless we need to
+
+ mean = left_diff * h_reciprocal * W
+ var = left_diff * right_diff * h_reciprocal
+ noise = parent._randn(parent._W_seed)
+ left_W = mean + math.sqrt(var) * noise
+
+ if self._is_left:
+ out_W = left_W
+ else:
+ out_W = W - left_W
+ out_H = None
+
+ self._top._increment_and_space_time_levy_area_cache[self] = (out_W, out_H)
+ return out_W, out_H
+
+ def _randn(self, seed):
+ # We generate random noise deterministically wrt some seed; this seed is determined by the generator.
+ # This means that if we drop out of the cache, then we'll create the same random noise next time, as we still
+ # have the generator.
+ size = self._top._size
+ return _randn(size, self._top._dtype, self._top._device, seed)
+
+ def _a_seed(self):
+ return self._parent._left_a_seed if self._is_left else self._parent._right_a_seed
+
+ def _randn_levy(self):
+ size = (*self._top._size, *self._top._size[-1:])
+ return _randn(size, self._top._dtype, self._top._device, self._a_seed())
+
+ ########################################
+ # Locate an interval in the hierarchy #
+ ########################################
+ #
+ # The other important piece of this construction is a way to locate any given interval within the binary tree
+ # hierarchy. (This is typically the slightly slower part, actually, so if you want to speed things up then this is
+ # the bit to target.)
+ #
+ # loc finds the interval [ta, tb] - and creates it in the appropriate place (as a child of some larger interval) if
+ # it doesn't already exist. As in principle we may request an interval that covers multiple existing intervals, then
+ # in fact the interval [ta, tb] is returned as an ordered list of existing subintervals.
+ #
+ # It calls _loc, which operates recursively. See _loc for more details on how the search works.
+
+ def _loc(self, ta, tb):
+ out = []
+ ta = self._top._round(ta)
+ tb = self._top._round(tb)
+ trampoline.trampoline(self._loc_inner(ta, tb, out))
+ return out
+
+ def _loc_inner(self, ta, tb, out):
+ # Expect to have ta < tb
+
+ # First, we (this interval) only have jurisdiction over [self._start, self._end]. So if we're asked for
+ # something outside of that then we pass the buck up to our parent, who is strictly larger.
+ if ta < self._start or tb > self._end:
+ raise trampoline.TailCall(self._parent._loc_inner(ta, tb, out))
+
+ # If it's us that's being asked for, then we add ourselves on to out and return.
+ if ta == self._start and tb == self._end:
+ out.append(self)
+ return
+
+ # If we've got this far then we know that it's an interval that's within our jurisdiction, and that it's not us.
+ # So next we check if it's up to us to figure out, or up to our children.
+ if self._midway is None:
+ # It's up to us. Create subintervals (_split) if appropriate.
+ if ta == self._start:
+ self._split(tb)
+ raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out))
+ # implies ta > self._start
+ self._split(ta)
+ # Query our (newly created) right_child: if tb == self._end then our right child will be the result, and it
+ # will tell us so. But if tb < self._end then our right_child will need to make another split of its own.
+ raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out))
+
+ # If we're here then we have children: self._midway is not None
+ if tb <= self._midway:
+ # Strictly our left_child's problem
+ raise trampoline.TailCall(self._left_child._loc_inner(ta, tb, out))
+ if ta >= self._midway:
+ # Strictly our right_child's problem
+ raise trampoline.TailCall(self._right_child._loc_inner(ta, tb, out))
+ # It's a problem for both of our children: the requested interval overlaps our midpoint. Call the left_child
+ # first (to append to out in the correct order), then call our right child.
+ # (Implies ta < self._midway < tb)
+ yield self._left_child._loc_inner(ta, self._midway, out)
+ raise trampoline.TailCall(self._right_child._loc_inner(self._midway, tb, out))
+
+ def _set_spawn_key_and_depth(self):
+ self._spawn_key = 2 * self._parent._spawn_key + (0 if self._is_left else 1)
+ self._depth = self._parent._depth + 1
+
+ def _split(self, midway):
+ if self._top._halfway_tree:
+ self._split_exact(0.5 * (self._end + self._start))
+ # self._midway is now the rounded halfway point.
+ if midway > self._midway:
+ self._right_child._split(midway)
+ elif midway < self._midway:
+ self._left_child._split(midway)
+ else:
+ self._split_exact(midway)
+
+ def _split_exact(self, midway): # Create two children
+ self._midway = self._top._round(midway)
+ # Use splittable PRNGs to generate noise.
+ self._set_spawn_key_and_depth()
+ generator = np.random.SeedSequence(entropy=self._top._entropy,
+ spawn_key=(self._spawn_key, self._depth),
+ pool_size=self._top._pool_size)
+ self._W_seed, self._H_seed, self._left_a_seed, self._right_a_seed = generator.generate_state(4)
+
+ self._left_child = _Interval(start=self._start,
+ end=midway,
+ parent=self,
+ is_left=True,
+ top=self._top)
+ self._right_child = _Interval(start=midway,
+ end=self._end,
+ parent=self,
+ is_left=False,
+ top=self._top)
+
+
+class BrownianInterval(brownian_base.BaseBrownian, _Interval):
+ """Brownian interval with fixed entropy.
+
+ Computes increments (and optionally Levy area).
+
+ To use:
+ >>> bm = BrownianInterval(t0=0.0, t1=1.0, size=(4, 1), device='cuda')
+ >>> bm(0., 0.5)
+ tensor([[ 0.0733],
+ [-0.5692],
+ [ 0.1872],
+ [-0.3889]], device='cuda:0')
+ """
+
+ __slots__ = (
+ # Inputs
+ '_size',
+ '_dtype',
+ '_device',
+ '_entropy',
+ '_levy_area_approximation',
+ '_dt',
+ '_tol',
+ '_pool_size',
+ '_cache_size',
+ '_halfway_tree',
+ # Quantisation
+ '_round',
+ # Caching, searching and computing values
+ '_increment_and_space_time_levy_area_cache',
+ '_last_interval',
+ '_have_H',
+ '_have_A',
+ '_w_h',
+ '_top_a_seed',
+ # Dependency tree creation
+ '_average_dt',
+ '_tree_dt',
+ '_num_evaluations'
+ )
+
+ def __init__(self,
+ t0: Optional[Scalar] = 0.,
+ t1: Optional[Scalar] = 1.,
+ size: Optional[Tuple[int, ...]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ entropy: Optional[int] = None,
+ dt: Optional[Scalar] = None,
+ tol: Scalar = 0.,
+ pool_size: int = 8,
+ cache_size: Optional[int] = 45,
+ halfway_tree: bool = False,
+ levy_area_approximation: str = LEVY_AREA_APPROXIMATIONS.none,
+ W: Optional[Tensor] = None,
+ H: Optional[Tensor] = None):
+ """Initialize the Brownian interval.
+
+ Args:
+ t0 (float or Tensor): Initial time.
+ t1 (float or Tensor): Terminal time.
+ size (tuple of int): The shape of each Brownian sample.
+ If zero dimensional represents a scalar Brownian motion.
+ If one dimensional represents a batch of scalar Brownian motions.
+ If >two dimensional the last dimension represents the size of a
+ a multidimensional Brownian motion, and all previous dimensions
+ represent batch dimensions.
+ dtype (torch.dtype): The dtype of each Brownian sample.
+ Defaults to the PyTorch default.
+ device (str or torch.device): The device of each Brownian sample.
+ Defaults to the CPU.
+ entropy (int): Global seed, defaults to `None` for random entropy.
+ levy_area_approximation (str): Whether to also approximate Levy
+ area. Defaults to 'none'. Valid options are 'none',
+ 'space-time', 'davie' or 'foster', corresponding to different
+ approximation types.
+ This is needed for some higher-order SDE solvers.
+ dt (float or Tensor): The expected average step size of the SDE
+ solver. Set it if you know it (e.g. when using a fixed-step
+ solver); else it will be estimated from the first few queries.
+ This is used to set up the data structure such that it is
+ efficient to query at these intervals.
+ tol (float or Tensor): What tolerance to resolve the Brownian motion
+ to. Must be non-negative. Defaults to zero, i.e. floating point
+ resolution. Usually worth setting in conjunction with
+ `halfway_tree`, below.
+ pool_size (int): Size of the pooled entropy. If you care about
+ statistical randomness then increasing this will help (but will
+ slow things down).
+ cache_size (int): How big a cache of recent calculations to use.
+ (As new calculations depend on old calculations, this speeds
+ things up dramatically, rather than recomputing things.)
+ Set this to `None` to use an infinite cache, which will be fast
+ but memory inefficient.
+ halfway_tree (bool): Whether the dependency tree (the internal data
+ structure) should be the dyadic tree. Defaults to `False`.
+ Normally, the sample path is determined by both `entropy`,
+ _and_ the locations and order of the query points. Setting this
+ to `True` will make it deterministic with respect to just
+ `entropy`; however this is much slower.
+ W (Tensor): The increment of the Brownian motion over the interval
+ [t0, t1]. Will be generated randomly if not provided.
+ H (Tensor): The space-time Levy area of the Brownian motion over the
+ interval [t0, t1]. Will be generated randomly if not provided.
+ """
+
+ #####################################
+ # Check and normalise inputs #
+ #####################################
+
+ if not _is_scalar(t0):
+ raise ValueError('Initial time t0 should be a float or 0-d torch.Tensor.')
+ if not _is_scalar(t1):
+ raise ValueError('Terminal time t1 should be a float or 0-d torch.Tensor.')
+ if dt is not None and not _is_scalar(dt):
+ raise ValueError('Expected average time step dt should be a float or 0-d torch.Tensor.')
+
+ if t0 > t1:
+ raise ValueError(f'Initial time {t0} should be less than terminal time {t1}.')
+ t0 = float(t0)
+ t1 = float(t1)
+ if dt is not None:
+ dt = float(dt)
+
+ if halfway_tree:
+ if tol <= 0.:
+ raise ValueError("`tol` should be positive.")
+ if dt is not None:
+ raise ValueError("`dt` is not used and should be set to `None` if `halfway_tree` is True.")
+ else:
+ if tol < 0.:
+ raise ValueError("`tol` should be non-negative.")
+
+ size, dtype, device = _check_tensor_info(W, H, size=size, dtype=dtype, device=device)
+
+ # Let numpy dictate randomness, so we have fewer seeds to set for reproducibility.
+ if entropy is None:
+ entropy = np.random.randint(0, 2 ** 31 - 1)
+
+ if levy_area_approximation not in LEVY_AREA_APPROXIMATIONS:
+ raise ValueError(f"`levy_area_approximation` must be one of {LEVY_AREA_APPROXIMATIONS}, but got "
+ f"'{levy_area_approximation}'.")
+
+ #####################################
+ # Record inputs #
+ #####################################
+
+ self._size = size
+ self._dtype = dtype
+ self._device = device
+ self._entropy = entropy
+ self._levy_area_approximation = levy_area_approximation
+ self._dt = dt
+ self._tol = tol
+ self._pool_size = pool_size
+ self._cache_size = cache_size
+ self._halfway_tree = halfway_tree
+
+ #####################################
+ # A miscellany of other things #
+ #####################################
+
+ # We keep a cache of recent queries, and their results. This is very important for speed, so that we don't
+ # recurse all the way up to the top every time we have a query.
+ if cache_size is None:
+ self._increment_and_space_time_levy_area_cache = {}
+ elif cache_size == 0:
+ self._increment_and_space_time_levy_area_cache = _EmptyDict()
+ else:
+ self._increment_and_space_time_levy_area_cache = _LRUDict(max_size=cache_size)
+
+ # We keep track of the most recently queried interval, and start searching for the next interval from that
+ # element of the binary tree. This is because subsequent queries are likely to be near the most recent query.
+ self._last_interval = self
+
+ # Precompute these as we don't want to spend lots of time checking strings in hot loops.
+ self._have_H = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.space_time,
+ LEVY_AREA_APPROXIMATIONS.davie,
+ LEVY_AREA_APPROXIMATIONS.foster)
+ self._have_A = self._levy_area_approximation in (LEVY_AREA_APPROXIMATIONS.davie,
+ LEVY_AREA_APPROXIMATIONS.foster)
+
+ # If we like we can quantise what level we want to compute the Brownian motion to.
+ if math.isclose(tol, 0., rel_tol=1e-5):
+ def round_func(x):
+ return x
+ else:
+ ndigits = -int(math.log10(tol))
+
+ def round_func(x):
+ return round(x, ndigits)
+
+ self._round = round_func
+
+ # Initalise as _Interval.
+ # (Must come after _round but before _w_h)
+ super(BrownianInterval, self).__init__(start=t0,
+ end=t1,
+ parent=None,
+ is_left=None,
+ top=self)
+
+ # Set the global increment and space-time Levy area
+ generator = np.random.SeedSequence(entropy=entropy, pool_size=pool_size)
+ initial_W_seed, initial_H_seed, top_a_seed = generator.generate_state(3)
+ if W is None:
+ W = self._randn(initial_W_seed) * math.sqrt(t1 - t0)
+ else:
+ _assert_floating_tensor('W', W)
+ if H is None:
+ H = self._randn(initial_H_seed) * math.sqrt((t1 - t0) / 12)
+ else:
+ _assert_floating_tensor('H', H)
+ self._w_h = (W, H)
+ self._top_a_seed = top_a_seed
+
+ if not self._halfway_tree:
+ # We create a binary tree dependency between the points. If we don't do this then the forward pass is still
+ # efficient at O(N), but we end up with a dependency chain stretching along the interval [t0, t1], making
+ # the backward pass O(N^2). By setting up a dependency tree of depth relative to `dt` and `cache_size` we
+ # can instead make both directions O(N log N).
+ self._average_dt = 0
+ self._tree_dt = t1 - t0
+ self._num_evaluations = -100 # start off with a warmup period to get a decent estimate of the average
+ if dt is not None:
+ # Create the dependency tree based on the supplied hint `dt`.
+ self._create_dependency_tree(dt)
+ # If dt is None, then create the dependency tree based on observed statistics of query points. (In __call__)
+
+ # Effectively permanently store our increment and space-time Levy area in the cache.
+ def _increment_and_space_time_levy_area(self):
+ return self._w_h
+ yield # make it a generator
+
+ def _a_seed(self):
+ return self._top_a_seed
+
+ def _set_spawn_key_and_depth(self):
+ self._spawn_key = 0
+ self._depth = 0
+
+ def __call__(self, ta, tb=None, return_U=False, return_A=False):
+ if tb is None:
+ warnings.warn(f"{self.__class__.__name__} is optimised for interval-based queries, not point evaluation.")
+ ta, tb = self._start, ta
+ tb_name = 'ta'
+ else:
+ tb_name = 'tb'
+ ta = float(ta)
+ tb = float(tb)
+ if ta < self._start:
+ warnings.warn(f"Should have ta>=t0 but got ta={ta} and t0={self._start}.")
+ ta = self._start
+ if tb < self._start:
+ warnings.warn(f"Should have {tb_name}>=t0 but got {tb_name}={tb} and t0={self._start}.")
+ tb = self._start
+ if ta > self._end:
+ warnings.warn(f"Should have ta<=t1 but got ta={ta} and t1={self._end}.")
+ ta = self._end
+ if tb > self._end:
+ warnings.warn(f"Should have {tb_name}<=t1 but got {tb_name}={tb} and t1={self._end}.")
+ tb = self._end
+ if ta > tb:
+ raise RuntimeError(f"Query times ta={ta:.3f} and tb={tb:.3f} must respect ta <= tb.")
+
+ if ta == tb:
+ W = torch.zeros(self._size, dtype=self._dtype, device=self._device)
+ H = None
+ A = None
+ if self._have_H:
+ H = torch.zeros(self._size, dtype=self._dtype, device=self._device)
+ if self._have_A:
+ size = (*self._size, *self._size[-1:]) # not self._size[-1] as that may not exist
+ A = torch.zeros(size, dtype=self._dtype, device=self._device)
+ else:
+ if self._dt is None and not self._halfway_tree:
+ self._num_evaluations += 1
+ # We start off with "negative" num evaluations, to give us a small warm-up period at the start.
+ if self._num_evaluations > 0:
+ # Compute average step size so far
+ dt = tb - ta
+ self._average_dt = (dt + self._average_dt * (self._num_evaluations - 1)) / self._num_evaluations
+ if self._average_dt < 0.5 * self._tree_dt:
+ # If 'dt' wasn't specified, then check the average interval length against the size of the
+ # bottom of the dependency tree. If we're below halfway then refine the tree by splitting all
+ # the bottom pieces into two.
+ self._create_dependency_tree(dt)
+
+ # Find the intervals that correspond to the query. We start our search at the last interval we accessed in
+ # the binary tree, as it's likely that the next query will come nearby.
+ intervals = self._last_interval._loc(ta, tb)
+ # Ideally we'd keep track of intervals[0] on the backward pass. Practically speaking len(intervals) tends to
+ # be 1 or 2 almost always so this isn't a huge deal.
+ self._last_interval = intervals[-1]
+
+ W, H, A = intervals[0]._increment_and_levy_area()
+ if len(intervals) > 1:
+ # If we have multiple intervals then add up their increments and Levy areas.
+
+ for interval in intervals[1:]:
+ Wi, Hi, Ai = interval._increment_and_levy_area()
+ if self._have_H:
+ # Aggregate H:
+ # Given s < u < t, then
+ # H_{s,t} = (term1 + term2) / (t - s)
+ # where
+ # term1 = (t - u) * (H_{u, t} + W_{s, u} / 2)
+ # term2 = (u - s) * (H_{s, u} - W_{u, t} / 2)
+ term1 = (interval._end - interval._start) * (Hi + 0.5 * W)
+ term2 = (interval._start - ta) * (H - 0.5 * Wi)
+ H = (term1 + term2) / (interval._end - ta)
+ if self._have_A and len(self._size) not in (0, 1):
+ # If len(self._size) in (0, 1) then we treat our scalar / single dimension as a batch
+ # dimension, so we have zero Levy area. (And these unsqueezes will result in a tensor of shape
+ # (batch, batch) which is wrong.)
+
+ # Let B_{x, y} = \int_x^y W^1_{s,u} dW^2_u.
+ # Then
+ # B_{s, t} = \int_s^t W^1_{s,u} dW^2_u
+ # = \int_s^v W^1_{s,u} dW^2_u + \int_v^t W^1_{s,v} dW^2_u + \int_v^t W^1_{v,u} dW^2_u
+ # = B_{s, v} + W^1_{s, v} W^2_{v, t} + B_{v, t}
+ #
+ # A is now the antisymmetric part of B, which gives the formula below.
+ A = A + Ai + 0.5 * (W.unsqueeze(-1) * Wi.unsqueeze(-2) - Wi.unsqueeze(-1) * W.unsqueeze(-2))
+ W = W + Wi
+
+ U = None
+ if self._have_H:
+ U = _H_to_U(W, H, tb - ta)
+
+ if return_U:
+ if return_A:
+ return W, U, A
+ else:
+ return W, U
+ else:
+ if return_A:
+ return W, A
+ else:
+ return W
+
+ def _create_dependency_tree(self, dt):
+ # For safety we take a min with 100: if people take very large cache sizes then this would then break the
+ # logarithmic into linear, which causes RecursionErrors.
+ if self._cache_size is None: # cache_size=None corresponds to infinite cache.
+ cache_size = 100
+ else:
+ cache_size = min(self._cache_size, 100)
+
+ self._tree_dt = min(self._tree_dt, dt)
+ # Rationale: We are prepared to hold `cache_size` many things in memory, so when making steps of size `dt`
+ # then we can afford to have the intervals at the bottom of our binary tree be of size `dt * cache_size`.
+ # For safety we then make this a bit smaller by multiplying by 0.8.
+ piece_length = self._tree_dt * cache_size * 0.8
+
+ def _set_points(interval):
+ start = interval._start
+ end = interval._end
+ if end - start > piece_length:
+ midway = (end + start) / 2
+ interval._loc(start, midway)
+ _set_points(interval._left_child)
+ _set_points(interval._right_child)
+
+ _set_points(self)
+
+ def __repr__(self):
+ if self._dt is None:
+ dt = None
+ else:
+ dt = f"{self._dt:.3f}"
+ return (f"{self.__class__.__name__}("
+ f"t0={self._start:.3f}, "
+ f"t1={self._end:.3f}, "
+ f"size={self._size}, "
+ f"dtype={self._dtype}, "
+ f"device={repr(self._device)}, "
+ f"entropy={self._entropy}, "
+ f"dt={dt}, "
+ f"tol={self._tol}, "
+ f"pool_size={self._pool_size}, "
+ f"cache_size={self._cache_size}, "
+ f"levy_area_approximation={repr(self._levy_area_approximation)}"
+ f")")
+
+ def display_binary_tree(self):
+ stack = [(self, 0)]
+ out = []
+ while stack:
+ elem, depth = stack.pop()
+ out.append(" " * depth + f"({elem._start}, {elem._end})")
+ if elem._midway is not None:
+ stack.append((elem._right_child, depth + 1))
+ stack.append((elem._left_child, depth + 1))
+ print("\n".join(out))
+
+ @property
+ def shape(self):
+ return self._size
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @property
+ def device(self):
+ return self._device
+
+ @property
+ def entropy(self):
+ return self._entropy
+
+ @property
+ def levy_area_approximation(self):
+ return self._levy_area_approximation
+
+ @property
+ def dt(self):
+ return self._dt
+
+ @property
+ def tol(self):
+ return self._tol
+
+ @property
+ def pool_size(self):
+ return self._pool_size
+
+ @property
+ def cache_size(self):
+ return self._cache_size
+
+ @property
+ def halfway_tree(self):
+ return self._halfway_tree
+
+ def size(self):
+ return self._size
+
+
+class BrownianTree(brownian_base.BaseBrownian):
+ """Brownian tree with fixed entropy.
+
+ Useful when the map from entropy -> Brownian motion shouldn't depend on the
+ locations and order of the query points. (As the usual BrownianInterval
+ does - note that BrownianTree is slower as a result though.)
+
+ To use:
+ >>> bm = BrownianTree(t0=0.0, w0=torch.zeros(4, 1))
+ >>> bm(0., 0.5)
+ tensor([[ 0.0733],
+ [-0.5692],
+ [ 0.1872],
+ [-0.3889]], device='cuda:0')
+ """
+
+ def __init__(self, t0: Scalar,
+ w0: Tensor,
+ t1: Optional[Scalar] = None,
+ w1: Optional[Tensor] = None,
+ entropy: Optional[int] = None,
+ tol: float = 1e-6,
+ pool_size: int = 24,
+ cache_depth: int = 9,
+ safety: Optional[float] = None):
+ """Initialize the Brownian tree.
+
+ The random value generation process exploits the parallel random number paradigm and uses
+ `numpy.random.SeedSequence`. The default generator is PCG64 (used by `default_rng`).
+
+ Arguments:
+ t0: Initial time.
+ w0: Initial state.
+ t1: Terminal time.
+ w1: Terminal state.
+ entropy: Global seed, defaults to `None` for random entropy.
+ tol: Error tolerance before the binary search is terminated; the search depth ~ log2(tol).
+ pool_size: Size of the pooled entropy. This parameter affects the query speed significantly.
+ cache_depth: Unused; deprecated.
+ safety: Unused; deprecated.
+ """
+
+ if t1 is None:
+ t1 = t0 + 1
+ if w1 is None:
+ W = None
+ else:
+ W = w1 - w0
+ self._w0 = w0
+ self._interval = BrownianInterval(t0=t0,
+ t1=t1,
+ size=w0.shape,
+ dtype=w0.dtype,
+ device=w0.device,
+ entropy=entropy,
+ tol=tol,
+ pool_size=pool_size,
+ halfway_tree=True,
+ W=W)
+ super(BrownianTree, self).__init__()
+
+ def __call__(self, t, tb=None, return_U=False, return_A=False):
+ # Deliberately called t rather than ta, for backward compatibility
+ out = self._interval(t, tb, return_U=return_U, return_A=return_A)
+ if tb is None and not return_U and not return_A:
+ out = out + self._w0
+ return out
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(interval={self._interval})"
+
+ @property
+ def dtype(self):
+ return self._interval.dtype
+
+ @property
+ def device(self):
+ return self._interval.device
+
+ @property
+ def shape(self):
+ return self._interval.shape
+
+ @property
+ def levy_area_approximation(self):
+ return self._interval.levy_area_approximation
+
+
+def brownian_interval_like(y: Tensor,
+ t0: Optional[Scalar] = 0.,
+ t1: Optional[Scalar] = 1.,
+ size: Optional[Tuple[int, ...]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ **kwargs):
+ """Returns a BrownianInterval object with the same size, device, and dtype as a given tensor."""
+ size = y.shape if size is None else size
+ dtype = y.dtype if dtype is None else dtype
+ device = y.device if device is None else device
+ return brownian_interval.BrownianInterval(t0=t0, t1=t1, size=size, dtype=dtype, device=device, **kwargs)
\ No newline at end of file