diff --git a/MindIE/MultiModal/DiT/README.md b/MindIE/MultiModal/DiT/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6846cbf389e1bffee9659e2464b350a36cdde8e9 --- /dev/null +++ b/MindIE/MultiModal/DiT/README.md @@ -0,0 +1,231 @@ +# DiT模型-推理指导 + +# 概述 + +DiT一种基于Transformer的扩散模型,全称为Diffusion Transformer,DiT遵循ViT的技术方法。有关DiT模型的更多信息,请参考[DiT github](https://github.com/facebookresearch/DiT)。 + +- 设备支持: +Atlas 800I A2推理设备 +Atlas 300I Duo推理卡 + +## 输入输出数据 + +image_num为需要生成的图片数量 + +batch_size = image_num * 2 + +latent_size = image_size // 8 + +- 输入数据 + + | 输入数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | ------------------------------------------ | ------------ | + | x | FLOAT32 | batch_size x 4 x latent_size x latent_size | NCHW | + | t | INT64 | batch_size | ND | + | y | INT64 | batch_size | ND | + + +- 输出数据 + + | 输出数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | ----------------------------------------- | ------------ | + | output | FLOAT32 | image_num x 4 x latent_size x latent_size | NCHW | + +# 推理环境准备 + +**表 1** 版本配套表 + +| 配套 | 版本 | 环境准备指导 | +| ------------------------------------ | ------- | ------------ | +| Python | 3.10.13 | - | +| PyTorch | 2.1.0 | - | +| 硬件:Atlas 300I Duo, Atlas 800I A2 | \ | \ | + +请以CANN版本选择对应的固件与驱动版本。 + +# 快速上手 + +## 获取源码 + +1. 获取源码,然后把当前目录下的几个文件移到DiT工程下 + + ```bash + git clone https://github.com/facebookresearch/DiT + mv background_runtime.py export_model.py models_npu.py sample_npu.py vision.patch timm_patch.py requirements.txt fid_test.py ./DiT + ``` + +2. 安装依赖 + + ```bash + pip3 install -r requirements.txt + + ``` + +3. 安装mindie包 + + ```bash + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +4. 代码修改 + + ``` + cd ./DiT + # 若环境没有patch工具,请自行安装 + python3 timm_patch.py + ``` + +## 准备数据集 + +本模型输入图片类别信息生成图片,无需数据集。 + +## 模型推理 + +1. 下载模型 + + DiT权重文件下载链接如下,按需下载: + + [DiT-XL-2-256x256下载链接](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt) + + [DiT-XL-2-512x512下载链接](https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt) + + vae权重文件下载链接如下,按需下载: + + ```bash + # ema + git clone https://huggingface.co/stabilityai/sd-vae-ft-ema + # mse + git clone https://huggingface.co/stabilityai/sd-vae-ft-mse + ``` + +2. 模型转换,该步骤会生成编译之后的pt模型 + + ```bash + # Atlas 300I Duo卡 + python3 export_model.py \ + --ckpt ./DiT-XL-2-512x512.pt \ + --vae_model ./sd-vae-ft-mse \ + --image_size 512 \ + --device 0 \ + --soc Duo \ + --output_dir ./models \ + --parallel + + # Atlas 800I A2 + python3 export_model.py \ + --ckpt ./DiT-XL-2-512x512.pt \ + --vae_model ./sd-vae-ft-mse \ + --image_size 512 \ + --device 0 \ + --soc A2 \ + --output_dir ./models + ``` + + 参数说明: + + - --ckpt:DiT-XL-2的权重路径 + - --vae_model:vae的权重路径 + - --image_size:分辨率,支持256和512。默认为512 + - --device:使用哪张卡 + - --soc:soc_version,只支持Duo和A2 + - --output_dir:pt模型输出目录 + - --parallel:【可选】模型使用并行进行推理 + +3. 开始推理 + + 1. 开启cpu高性能模式 + + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 执行推理,会在当前路径生成sample.png + + ```bash + # Atlas 300I Duo + python3 sample_npu.py \ + --vae mse \ + --image_size 512 \ + --ckpt ./DiT-XL-2-512x512.pt \ + --device 0 \ + --class_label 0 \ + --output_dir ./models \ + --parallel + + # Atlas 800I A2 + python3 sample_npu.py \ + --vae mse \ + --image_size 512 \ + --ckpt ./DiT-XL-2-512x512.pt \ + --device 0 \ + --class_label 0 \ + --output_dir ./models \ + --warmup + ``` + + 参数说明: + + - --vae:使用哪种vae模型,支持mse和ema + - --image_size:分辨率,支持256和512。默认为512 + - --ckpt:DiT-XL-2的权重路径 + - --device:使用哪张卡 + - --class_label:可在0~999中任意指定一个整数,代表image_net的种类 + - --output_dir:上一步骤指定的pt模型输出目录 + - --parallel:【可选】模型使用并行进行推理 + - --warmup:【可选】使用warmup可使得时间更准确。并行场景使用该选项会有问题,不建议使用 + +4. 精度验证 + + 下载数据集[ImageNet512x512](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)(VIRTUAL_imagenet512.npz)和[ImageNet256x256](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)(VIRTUAL_imagenet256_labeled.npz),放在任意路径 + + 然后执行以下命令: + + ```bash + # Atlas 300I Duo + python3 fid_test.py \ + --vae mse \ + --image_size 512 \ + --ckpt ./DiT-XL-2-512x512.pt \ + --device 0 \ + --output_dir ./models \ + --parallel \ + --results results + + # Atlas 800I A2 + python3 fid_test.py \ + --vae mse \ + --image_size 512 \ + --ckpt ./DiT-XL-2-512x512.pt \ + --device 0 \ + --output_dir ./models \ + --results results + ``` + + 参数说明: + + - --results:生成的1000张图片存放路径 + - image_size:分辨率,支持256和512。默认为512 + + 之后进行FID计算: + + ```bash + # 512分辨率使用VIRTUAL_imagenet512.npz数据集 + python3 -m pytorch_fid ./VIRTUAL_imagenet512.npz ./results + # 256分辨率使用VIRTUAL_imagenet256_labeled.npz数据集 + python3 -m pytorch_fid ./VIRTUAL_imagenet256_labeled.npz ./results + ``` + +# 模型推理性能&精度 + +性能参考下列数据。 + +| 分辨率 | 硬件形态 | 迭代次数 | 平均耗时 | +| ------ | -------- | -------- | -------- | +| 512 | Atlas 300I Duo | 250 | 19.6s | +| | Atlas 800I A2 (32G) | 250 | 10.49s | +| 256 | Atlas 300I Duo | 250 | 9.5s | +| | Atlas 800I A2 (32G) | 50 | 4.13s | diff --git a/MindIE/MultiModal/DiT/background_runtime.py b/MindIE/MultiModal/DiT/background_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..aa536b5b5474360199e2e78efbf9ebd475b60e19 --- /dev/null +++ b/MindIE/MultiModal/DiT/background_runtime.py @@ -0,0 +1,182 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +@dataclass +class RuntimeIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntime: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfo + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray]) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + + self.sync_pipe.send('') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfo, + device_id: int, + model_path: str, + ) -> None: + # The sub process function + # Create a runtime + mindietorch.set_device(device_id) + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model = torch.jit.load(model_path).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + # Tell the main function that we are ready + sync_pipe.send('') + + infer_num = 0 + preprocess_time = 0 + infer_time = 0 + forward_time = 0 + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + # Keep looping until recived a 'STOP' + while sync_pipe.recv() != 'STOP': + start = time.time() + x, t, y = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + x_npu = x.to(torch.float32).to(f"npu:{device_id}") + t_npu = t.to(torch.int64).to(f"npu:{device_id}") + y_npu = y.to(torch.int64).to(f"npu:{device_id}") + + preprocess_time += time.time() - start + + start2 = time.time() + with mindietorch.npu.stream(stream): + inf_start = time.time() + output_npu = model(x_npu, t_npu, y_npu) + stream.synchronize() + inf_end = time.time() + + output_cpu = output_npu.to('cpu') + forward_time += inf_end - inf_start + infer_time += time.time() - start2 + + for i, _ in enumerate(output_arrays): + output = output_cpu.numpy() + output_arrays[i][:] = output[i][:] + + infer_num += 1 + sync_pipe.send('') + + infer_num /= 50 + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfo) -> 'BackgroundRuntime': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntime + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MultiModal/DiT/export_model.py b/MindIE/MultiModal/DiT/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5c11b0562b80eb0156d8f654db9aa4dc330606 --- /dev/null +++ b/MindIE/MultiModal/DiT/export_model.py @@ -0,0 +1,198 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from argparse import Namespace +import torch +from models import DiT_models +from download import find_model +from diffusers.models import AutoencoderKL +import mindietorch +from mindietorch import _enums + +class DiTExport(torch.nn.Module): + def __init__(self, dit_model): + super().__init__() + self.dit_model = dit_model + + def forward(self, x, t, y): + return self.dit_model(x, t, y) + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model): + super().__init__() + self.vae_model = vae_model + + def forward(self, latents): + return self.vae_model.decode(latents)[0] + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") + parser.add_argument("--image_size", type=int, choices=[256, 512], default=512) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="Path of directory to save models" + ) + parser.add_argument( + "--ckpt", + type=str, + default="./DiT-XL-2-256x256.pt", + help="Path or name of the pre-trained model." + ) + parser.add_argument( + "--vae_model", + type=str, + default="./sd-vae-ft-ema", + help="Path or name of the vae pre-trained model." + ) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--parallel", action="store_true", help="Use parallel during inference") + parser.add_argument( + "--soc", + type=str, + default="Duo", + choices=["Duo", "A2"], + help="soc_version" + ) + return parser.parse_args() + +def export_dit(args, soc_version): + print(f"start trace dit_{args.image_size}---------->") + dit_path = os.path.join(args.output_dir, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + device = "cpu" + ckpt_path = args.ckpt + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ).to(device) + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + if args.parallel: + batch = 1 + traced_path = os.path.join(dit_path, f"dit_model_{args.image_size}_parallel.pt") + compiled_path = os.path.join(dit_path, f"dit_model_{args.image_size}_parallel_compiled.pt") + else: + batch = 2 + traced_path = os.path.join(dit_path, f"dit_model_{args.image_size}.pt") + compiled_path = os.path.join(dit_path, f"dit_model_{args.image_size}_compiled.pt") + dummy_input = ( + torch.ones([batch, 4, latent_size, latent_size], dtype=torch.float32), + torch.ones([batch,], dtype=torch.int64), + torch.ones([batch,], dtype=torch.int64), + ) + # trace模型 + if not os.path.exists(traced_path): + dit_model = DiTExport(model) + dit_model.eval() + torch.jit.trace(dit_model, dummy_input).save(traced_path) + + # compile模型 + print(f"start compile dit_{args.image_size}---------->") + inputs = [ + mindietorch.Input((batch, 4, latent_size, latent_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch,), + dtype=mindietorch.dtype.INT64) + ] + if not os.path.exists(compiled_path): + jit_model = torch.jit.load(traced_path).eval() + compiled_model = ( + mindietorch.compile(jit_model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0) + ) + torch.jit.save(compiled_model, compiled_path) + +def export_vae(args, soc_version): + if "ema" in args.vae_model: + kind = "ema" + elif "mse" in args.vae_model: + kind = "mse" + else: + print("unsupport vae weights name, must be sd-vae-ft-ema or sd-vae-ft-mse.") + return + print(f"start trace vae_{kind}_{args.image_size}---------->") + vae_path = os.path.join(args.output_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + device = "cpu" + vae = AutoencoderKL.from_pretrained(args.vae_model).to(device) + latent_size = args.image_size // 8 + batch = 1 + dummy_input = ( + torch.ones([batch, 4, latent_size, latent_size], dtype=torch.float32) + ) + traced_path = os.path.join(vae_path, f"vae_{kind}_{args.image_size}.pt") + compiled_path = os.path.join(vae_path, f"vae_{kind}_{args.image_size}_compiled.pt") + + # trace模型 + if not os.path.exists(traced_path): + vae_model = VaeExport(vae) + vae_model.eval() + torch.jit.trace(vae_model, dummy_input).save(traced_path) + # compile模型 + print(f"start compile vae_{kind}_{args.image_size}---------->") + inputs = [ + mindietorch.Input((batch, 4, latent_size, latent_size), + dtype=mindietorch.dtype.FLOAT) + ] + if not os.path.exists(compiled_path): + jit_model = torch.jit.load(traced_path).eval() + compiled_model = ( + mindietorch.compile(jit_model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0) + ) + torch.jit.save(compiled_model, compiled_path) + + +def main(): + args = parse_arguments() + device_id = args.device + mindietorch.set_device(device_id) + + if args.soc == "Duo": + soc_version = "Ascend310P3" + elif args.soc == "A2": + soc_version = "Ascend910B4" + else: + print("Unsupport soc_version") + return + export_dit(args, soc_version) + export_vae(args, soc_version) + mindietorch.finalize() + print("Done") + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/DiT/fid_test.py b/MindIE/MultiModal/DiT/fid_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5b360f7a233c62e32b052169ff4f3e58bc5edfd9 --- /dev/null +++ b/MindIE/MultiModal/DiT/fid_test.py @@ -0,0 +1,113 @@ +import torch +from torchvision.utils import save_image +from diffusion import create_diffusion +from download import find_model +import argparse +from argparse import Namespace +import mindietorch +from models_npu import DiT_models +import os + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") + parser.add_argument("--image_size", type=int, choices=[256, 512], default=512) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=1.5) + parser.add_argument("--num-sampling-steps", type=int, default=250) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--ckpt", + type=str, + default="./DiT-XL-2-256x256.pt", + help="Path or name of the pre-trained model." + ) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--parallel", action="store_true", help="Use parallel during inference") + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="Path of directory to save models" + ) + parser.add_argument( + "--results", + type=str, + default="./results", + help="Path of directory to save all class images" + ) + return parser.parse_args() + +def main(args): + results_path = args.results + if not os.path.exists(results_path): + os.makedirs(results_path, mode=0o640) + torch.set_grad_enabled(False) + device = "cpu" + device_id = args.device + + if args.parallel: + mindie_model_path = f"{args.output_dir}/dit/dit_model_{args.image_size}_parallel_compiled.pt" + else: + mindie_model_path = f"{args.output_dir}/dit/dit_model_{args.image_size}_compiled.pt" + vae_compiled_model_path = f"{args.output_dir}/vae/vae_{args.vae}_{args.image_size}_compiled.pt" + vae_compiled_model = torch.jit.load(vae_compiled_model_path).eval() + dit_compiled_model = torch.jit.load(mindie_model_path).eval() + + if args.ckpt is None: + assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." + assert args.image_size in [256, 512] + assert args.num_classes == 1000 + + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ).to(device) + ckpt_path = args.ckpt + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + model.eval() + + diffusion = create_diffusion(str(args.num_sampling_steps)) + + mindietorch.set_device(device_id) + model.set_npu_model_stream(args.parallel, device_id, args.image_size, mindie_model_path, dit_compiled_model) + + all_class = 1000 + for i in range(all_class): + torch.manual_seed(args.seed) + class_labels = [i] + + # Create sampling noise + n = len(class_labels) + z = torch.randn(n, 4, latent_size, latent_size, device=torch.device('cpu')).to(device) + y = torch.tensor(class_labels, device=device) + + # Setup classifier-free guidance + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) + + samples = diffusion.p_sample_loop( + model.forward_with_cfg, + z.shape, + z, + clip_denoised=False, + model_kwargs=model_kwargs, + progress=True, + device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + samples = vae_compiled_model((samples / 0.18215).to(f"npu:{device_id}")).to('cpu') # 0.18215 is scale factor + + save_image(samples, f"{results_path}/sample_{i}.png", nrow=4, normalize=True, value_range=(-1, 1)) + if args.parallel: + model.end_asyn() + mindietorch.finalize() + +if __name__ == "__main__": + args = parse_arguments() + main(args) \ No newline at end of file diff --git a/MindIE/MultiModal/DiT/models_npu.py b/MindIE/MultiModal/DiT/models_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..796fcc9a47ba7da94670fa16254b07ca744c0b7c --- /dev/null +++ b/MindIE/MultiModal/DiT/models_npu.py @@ -0,0 +1,62 @@ +import numpy as np +import torch +import mindietorch +from models import DiT +from background_runtime import BackgroundRuntime, RuntimeIOInfo + +class MindIEDiT(DiT): + def forward_with_cfg(self, x, t, y, cfg_scale): + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + + device = self.device + if self.parallel: + combined, combined_2 = combined.chunk(2) + t, t_2 = t.chunk(2) + y, y_2 = y.chunk(2) + self.bg.infer_asyn([ + combined_2.numpy().astype(np.float32), + t_2.numpy().astype(np.int64), + y_2.numpy().astype(np.int64) + ]) + with mindietorch.npu.stream(self.stream): + model_out_npu = self.model_npu(combined.to(f"npu:{device}"), + t.to(f"npu:{device}"), + y.to(f"npu:{device}")) + self.stream.synchronize() + model_out = model_out_npu.to("cpu") + + if self.parallel: + model_out_2 = torch.from_numpy(self.bg.wait_and_get_outputs()[0]) + model_out = torch.cat([model_out, model_out_2]) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + def set_npu_model_stream(self, parallel, device, image_size, mindie_model_path, dit_compiled_model): + latent_size = image_size // 8 + self.device, device_2 = device, device + 1 + self.stream = mindietorch.npu.Stream(f"npu:{self.device}") + self.parallel = parallel + self.model_npu = dit_compiled_model + if parallel: + runtime_info = RuntimeIOInfo( + input_shapes=[(1, 4, latent_size, latent_size), (1,), (1,)], + input_dtypes=[np.float32, np.int64, np.int64], + output_shapes=[(1, 8, latent_size, latent_size)], + output_dtypes=[np.float32] + ) + self.bg = BackgroundRuntime(device_2, mindie_model_path, runtime_info) + print('success init') + + def end_asyn(self): + self.bg.stop() + +def DiT_XL_2(**kwargs): + return MindIEDiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + +DiT_models = { + 'DiT-XL/2': DiT_XL_2 +} \ No newline at end of file diff --git a/MindIE/MultiModal/DiT/requirements.txt b/MindIE/MultiModal/DiT/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f0c8f2635a24178615aba3c756eeb5de326c767 --- /dev/null +++ b/MindIE/MultiModal/DiT/requirements.txt @@ -0,0 +1,7 @@ +torch==2.1.0 +torchvision==0.16.0 +timm==0.9.12 +diffusers==0.26.3 +accelerate==0.21.0 +scipy==1.11.1 +pytorch-fid \ No newline at end of file diff --git a/MindIE/MultiModal/DiT/sample_npu.py b/MindIE/MultiModal/DiT/sample_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..6c1959a59a5df1ac320c8446eaeea03540696801 --- /dev/null +++ b/MindIE/MultiModal/DiT/sample_npu.py @@ -0,0 +1,149 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchvision.utils import save_image +from diffusion import create_diffusion +from download import find_model +import argparse +from argparse import Namespace +import mindietorch +import time +from models_npu import DiT_models + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + choices=list(DiT_models.keys()), + default="DiT-XL/2" + ) + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse") + parser.add_argument("--image_size", type=int, choices=[256, 512], default=512) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--cfg-scale", type=float, default=1.5) + parser.add_argument("--num-sampling-steps", type=int, default=250) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--ckpt", + type=str, + default="./DiT-XL-2-256x256.pt", + help="Path or name of the pre-trained model." + ) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--parallel", action="store_true", help="Use parallel during inference") + parser.add_argument("--class_label", type=int, default=0) + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="Path of directory to save models" + ) + parser.add_argument("--warmup", action="store_true", help="Use warmup") + return parser.parse_args() + +def warm_up(args, dit_compiled_model, vae_compiled_model): + batch = 1 if args.parallel else 2 + latent_size = args.image_size // 8 + x1 = torch.ones([batch, 4, latent_size, latent_size], dtype=torch.float32) + x2 = torch.ones([batch,], dtype=torch.int64) + x3 = torch.ones([batch,], dtype=torch.int64) + x4 = torch.ones([1, 4, latent_size, latent_size], dtype=torch.float32) + count = 5 + stream = mindietorch.npu.Stream(f"npu:{args.device}") + for _ in range(count): + with mindietorch.npu.stream(stream): + dit_out_npu = dit_compiled_model(x1.to(f"npu:{args.device}"), + x2.to(f"npu:{args.device}"), + x3.to(f"npu:{args.device}")) + stream.synchronize() + dit_out_cpu = dit_out_npu.to("cpu") + + with mindietorch.npu.stream(stream): + vae_out_npu = vae_compiled_model(x4.to(f"npu:{args.device}")) + stream.synchronize() + vae_out_cpu = vae_out_npu.to("cpu") + +def main(args): + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cpu" + device_id = args.device + + if args.parallel: + mindie_model_path = f"{args.output_dir}/dit/dit_model_{args.image_size}_parallel_compiled.pt" + else: + mindie_model_path = f"{args.output_dir}/dit/dit_model_{args.image_size}_compiled.pt" + vae_compiled_model_path = f"{args.output_dir}/vae/vae_{args.vae}_{args.image_size}_compiled.pt" + vae_compiled_model = torch.jit.load(vae_compiled_model_path).eval() + dit_compiled_model = torch.jit.load(mindie_model_path).eval() + + if args.ckpt is None: + assert args.model == "DiT-XL/2", "Only DiT-XL/2 models are available for auto-download." + assert args.image_size in [256, 512] + assert args.num_classes == 1000 + + latent_size = args.image_size // 8 + model = DiT_models[args.model]( + input_size=latent_size, + num_classes=args.num_classes + ).to(device) + ckpt_path = args.ckpt + state_dict = find_model(ckpt_path) + model.load_state_dict(state_dict) + model.eval() + + diffusion = create_diffusion(str(args.num_sampling_steps)) + class_labels = [args.class_label] + + # Create sampling noise + n = len(class_labels) + z = torch.randn(n, 4, latent_size, latent_size, device=device) + y = torch.tensor(class_labels, device=device) + + # Setup classifier-free guidance + z = torch.cat([z, z], 0) + y_null = torch.tensor([1000] * n, device=device) + y = torch.cat([y, y_null], 0) + model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) + + mindietorch.set_device(device_id) + if args.warmup: + warm_up(args, dit_compiled_model, vae_compiled_model) + model.set_npu_model_stream(args.parallel, device_id, args.image_size, mindie_model_path, dit_compiled_model) + start = time.time() + samples = diffusion.p_sample_loop( + model.forward_with_cfg, + z.shape, + z, + clip_denoised=False, + model_kwargs=model_kwargs, + progress=True, + device=device + ) + samples, _ = samples.chunk(2, dim=0) # Remove null class samples + samples = vae_compiled_model((samples / 0.18215).to(f"npu:{device_id}")).to('cpu') # 0.18215 is scale factor + end = time.time() + print(f"sample time is: {(end-start):.2f}s") + if args.parallel: + model.end_asyn() + + save_image(samples, "sample.png", nrow=4, normalize=True, value_range=(-1, 1)) + mindietorch.finalize() + +if __name__ == "__main__": + args = parse_arguments() + main(args) + diff --git a/MindIE/MultiModal/DiT/timm_patch.py b/MindIE/MultiModal/DiT/timm_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..eeea5ab899400e7105535b4de3d8cae7b40865f8 --- /dev/null +++ b/MindIE/MultiModal/DiT/timm_patch.py @@ -0,0 +1,27 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import timm + +def main(): + timm_path = timm.__path__ + timm_version = timm.__version__ + + assert timm_version is not '0.9.12', "expectation timm==0.9.12" + os.system(f'patch -p0 {timm_path[0]}/models/vision_transformer.py vision.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/DiT/vision.patch b/MindIE/MultiModal/DiT/vision.patch new file mode 100644 index 0000000000000000000000000000000000000000..c0e52c83220ccef616a9ff972713e10b4ceda41c --- /dev/null +++ b/MindIE/MultiModal/DiT/vision.patch @@ -0,0 +1,37 @@ +--- vision_transformer.py 2024-05-16 14:54:32.831213500 +0800 ++++ vision_transformer_new.py 2024-05-16 15:01:23.736957600 +0800 +@@ -87,17 +87,23 @@ + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + +- if self.fused_attn: +- x = F.scaled_dot_product_attention( +- q, k, v, +- dropout_p=self.attn_drop.p if self.training else 0., +- ) +- else: +- q = q * self.scale +- attn = q @ k.transpose(-2, -1) +- attn = attn.softmax(dim=-1) +- attn = self.attn_drop(attn) +- x = attn @ v ++ # if self.fused_attn: ++ # x = F.scaled_dot_product_attention( ++ # q, k, v, ++ # dropout_p=self.attn_drop.p if self.training else 0., ++ # ) ++ # else: ++ # q = q * self.scale ++ # attn = q @ k.transpose(-2, -1) ++ # attn = attn.softmax(dim=-1) ++ # attn = self.attn_drop(attn) ++ # x = attn @ v ++ ++ q = q * self.scale ++ attn = q @ k.transpose(-2, -1) ++ attn = attn.softmax(dim=-1) ++ attn = self.attn_drop(attn) ++ x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) diff --git a/MindIE/MultiModal/IP-Adapter/README.md b/MindIE/MultiModal/IP-Adapter/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4461199b37f6729f6276178ca43fcb92e3658768 --- /dev/null +++ b/MindIE/MultiModal/IP-Adapter/README.md @@ -0,0 +1,212 @@ +# IP-Adapter模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + +在Stable Diffusion研究中,如何有效地将文本提示和图像提示整合到预训练的文生图模型中一直一个挑战。IPAdapter通过引入一个轻量级的适配器模块创新地解决了这个问题,请查看[IP-Adapter](https://github.com/tencent-ailab/IP-Adapter)。 + +- 参考实现: + ```bash + # IP-Adapter + https://github.com/tencent-ailab/IP-Adapter + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 +Atlas 300I Duo推理卡:支持的卡数为1 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 +- + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ |--------| ------------------------------------------------------------ | + | Python | 3.10.13 | - | + | torch| 2.1.0 | - | + | 硬件 | Atlas 300I Duo | - | + +请以CANN版本选择对应的固件与驱动版本。 + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + +# 快速上手 + +## 获取源码 + +1. 获取源码,然后把当前目录下的几个文件移到IP-Adapter工程下 + + ```bash + git clone https://github.com/tencent-ailab/IP-Adapter + mv attention_processor.patch clip.patch export_ts.py inference.py stable_diffusion_patch.py stable_diffusion_pipeline.py requirements.txt ./IP-Adapter + ``` + +2. 按照requirements.txt要求的版本安装相关依赖,避免导出模型失败。 + + ```bash + cd IP-Adapter + pip3 install -r requirements.txt + ``` + +3. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +4. 代码修改 + + 执行命令: + + ```bash + python3 stable_diffusion_patch.py + ``` + + +## 准备数据集 + +1. 获取原始数据集。 + + 本模型输入文本信息和图片生成图片,无需数据集。 + +## 模型推理 + +1. 模型转换。【可选】 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 0. 获取权重(可选) + + 可提前下载权重,放在任意路径,以避免执行后面步骤时可能会出现下载失败。 + + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) + + [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) + + git clone https://huggingface.co/h94/IP-Adapter IP-Adapter-weights + + 1. 导出pt模型并进行编译。(可选) + + 设置模型名称或路径 + ```bash + # v1.5 + base_model_path="runwayml/stable-diffusion-v1-5" + + # vae + vae_model_path="stabilityai/sd-vae-ft-mse" + + # image_encoder + image_encoder_path="IP-Adapter-weights/models/image_encoder" + + # ip_ckpt + ip_ckpt="IP-Adapter-weights/models/ip-adapter_sd15.bin" + ``` + + 执行命令: + + ```bash + # 导出pt模型 + python3 export_ts.py \ + --base_model_path ${base_model_path} \ + --vae_model_path ${vae_model_path} \ + --image_encoder_path ${image_encoder_path} \ + --batch_size 1 \ + --output_dir ./models \ + --device 0 \ + --soc Duo + ``` + + 参数说明: + - --base_model_path:SD的模型名称或本地模型目录的路径 + - --vae_model_path:VAE的模型名称或本地模型目录的路径 + - --image_encoder_path:image_encoder的模型名称或本地模型目录的路径 + - --output_dir: 导出的模型输出目录 + - batch_size:目前只支持batch为1 + - --device:使用的NPU芯片,默认是0 + - soc:soc_version。默认为Duo,可支持A2 + + + +2. 开始推理验证。 + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + numactl -C 0-23 python3 inference.py \ + --base_model_path ${base_model_path} \ + --vae_model_path ${vae_model_path} \ + --image_encoder_path ${image_encoder_path} \ + --ip_ckpt ${ip_ckpt} \ + --output_dir ./models \ + --device 0 \ + --image_path ./assets/images/woman.png \ + --save_image_path ./test.png \ + --prompt "A girl" + ``` + + 参数说明: + - --base_model_path:SD的模型名称或本地模型目录的路径 + - --vae_model_path:VAE的模型名称或本地模型目录的路径 + - --image_encoder_path:image_encoder的模型名称或本地模型目录的路径 + - --ip_ckpt:ipadpter的模型名称或本地模型目录的路径 + - --output_dir: 导出的模型输出目录 + - --device:使用的NPU芯片,默认是0 + - --image_path: 输入的图片路径 + - --save_image_path:输出的图片路径 + - --prompt:文本提示词 + + + + +# 模型推理性能 + +性能参考下列数据。 + +### IP-Adapter + +| 硬件形态 | 迭代次数 | 平均耗时 | +| :------: |:----:|:----:| +| Atlas 300I Duo | 50 | 4.09s | diff --git a/MindIE/MultiModal/IP-Adapter/attention_processor.patch b/MindIE/MultiModal/IP-Adapter/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..bd15281c5a3acf9752eec8a239323f66f1beadb7 --- /dev/null +++ b/MindIE/MultiModal/IP-Adapter/attention_processor.patch @@ -0,0 +1,18 @@ +--- attention_processor.py 2024-07-02 07:42:32.312000000 +0000 ++++ attention_processor.py 2024-07-02 07:44:55.100000000 +0000 +@@ -205,10 +205,11 @@ + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- if processor is None: +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # if processor is None: ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor = AttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( diff --git a/MindIE/MultiModal/IP-Adapter/clip.patch b/MindIE/MultiModal/IP-Adapter/clip.patch new file mode 100644 index 0000000000000000000000000000000000000000..e3e4719b66f771ebb660f25151c33d140566c3f3 --- /dev/null +++ b/MindIE/MultiModal/IP-Adapter/clip.patch @@ -0,0 +1,10 @@ +22a23 +> import numpy as np +760c761,762 +< mask.triu_(1) # zero out the lower diagonal +--- +> # mask.triu_(1) # zero out the lower diagonal +> mask = torch.from_numpy(np.triu(mask.numpy(), 1)) +1324a1327 +> + diff --git a/MindIE/MultiModal/IP-Adapter/export_ts.py b/MindIE/MultiModal/IP-Adapter/export_ts.py new file mode 100644 index 0000000000000000000000000000000000000000..d2222322c142833f0c15e8adce90fbd62d4bfd77 --- /dev/null +++ b/MindIE/MultiModal/IP-Adapter/export_ts.py @@ -0,0 +1,275 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from argparse import Namespace + +import torch +import torch.nn as nn +from diffusers import DDIMScheduler +from diffusers import StableDiffusionPipeline, AutoencoderKL +from transformers import CLIPVisionModelWithProjection +import mindietorch +from mindietorch import _enums + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--base_model_path", + type=str, + default="./stable-diffusion-v1-5", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--vae_model_path", + type=str, + default="./sd-vae-ft-mse", + help="vae_model_path.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default="./image_encoder", + help="image_encoder_path.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "--soc", + choices=["Duo", "A2"], + default="Duo", + help="soc version.", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device", + ) + + return parser.parse_args() + +class ImageEncoderExport(torch.nn.Module): + def __init__(self, image_encoder_model): + super().__init__() + self.image_encoder_model = image_encoder_model + + def forward(self, x): + return self.image_encoder_model(x)[0] + +def export_image_encoder(sd_pipeline, args): + print("Exporting the image encoder...") + image_path = os.path.join(args.output_dir, "image_encoder") + if not os.path.exists(image_path): + os.makedirs(image_path, mode=0o640) + batch_size = args.batch_size + image_encoder_pt_path = os.path.join(image_path, f"image_encoder_bs{batch_size}.pt") + image_encoder_compiled_path = os.path.join(image_path, f"image_encoder_bs{batch_size}_compiled.ts") + + image_encoder_model = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path).to('cpu') + if not os.path.exists(image_encoder_pt_path): + dummy_input = torch.ones([1, 3, 224, 224], dtype=torch.float32) + image_export = ImageEncoderExport(image_encoder_model) + torch.jit.trace(image_export, dummy_input).save(image_encoder_pt_path) + if not os.path.exists(image_encoder_compiled_path): + model = torch.jit.load(image_encoder_pt_path).eval() + compiled_image_model = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((1, 3, 224, 224), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(compiled_image_model, image_encoder_compiled_path) + +class ClipExport(torch.nn.Module): + def __init__(self, clip_model): + super().__init__() + self.clip_model = clip_model + + def forward(self, x): + return self.clip_model(x)[0] + +def export_clip(sd_pipeline, args): + print("Exporting the text encoder...") + clip_path = os.path.join(args.output_dir, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + batch_size = args.batch_size + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip_compiled_path = os.path.join(clip_path, f"clip_bs{batch_size}_compiled.ts") + + clip_model = sd_pipeline.text_encoder + max_position_embeddings = clip_model.config.max_position_embeddings + + if not os.path.exists(clip_pt_path): + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + clip_export = ClipExport(clip_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + if not os.path.exists(clip_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + compiled_clip_model = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, max_position_embeddings), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(compiled_clip_model, clip_compiled_path) + +class UnetExportInit(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward(self, sample, timestep, encoder_hidden_states): + return self.unet_model(sample, timestep, encoder_hidden_states)[0] + +def export_unet_init(sd_pipeline, args): + print("Exporting the image information creater...") + unet_path = os.path.join(args.output_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + batch_size = args.batch_size * 2 + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compiled.ts") + + unet_model = sd_pipeline.unet + clip_model = sd_pipeline.text_encoder + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size = clip_model.config.hidden_size + max_position_embeddings = 81 + + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), + ) + unet = UnetExportInit(unet_model).eval() + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + compiled_unet_model = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, + in_channels, + sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model, scaling_factor): + super().__init__() + self.vae_model = vae_model + self.scaling_factor = scaling_factor + + def forward(self, latents): + latents = 1 / self.scaling_factor * latents + image = self.vae_model.decode(latents)[0] + image = (image / 2 + 0.5) + return image.permute(0, 2, 3, 1) + +def export_vae(sd_pipeline, args): + print("Exporting the image decoder...") + vae_path = os.path.join(args.output_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + batch_size = args.batch_size + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + vae_compiled_path = os.path.join(vae_path, f"vae_bs{batch_size}_compiled.ts") + + vae_model = sd_pipeline.vae + unet_model = sd_pipeline.unet + scaling_factor = vae_model.config.scaling_factor + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.out_channels + + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size]) + vae_export = VaeExport(vae_model,scaling_factor) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + if not os.path.exists(vae_compiled_path): + model = torch.jit.load(vae_pt_path).eval() + compiled_vae_model = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, + in_channels, + sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(compiled_vae_model, vae_compiled_path) + +def export(args): + vae = AutoencoderKL.from_pretrained(args.vae_model_path).to("cpu") + pipeline = StableDiffusionPipeline.from_pretrained(args.base_model_path, vae=vae).to("cpu") + mindietorch.set_device(args.device) + export_image_encoder(pipeline, args) + export_clip(pipeline, args) + export_vae(pipeline, args) + export_unet_init(pipeline, args) + mindietorch.finalize() + +def main(args): + export(args) + print("Done!") + +if __name__ == "__main__": + args = parse_arguments() + if args.soc == "Duo": + soc_version = "Ascend310P3" + elif args.soc == "A2": + soc_version = "Ascend910B4" + main(args) \ No newline at end of file diff --git a/MindIE/MultiModal/IP-Adapter/inference.py b/MindIE/MultiModal/IP-Adapter/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..74f1e17b26e84312512d0d4791771b48ce542f4e --- /dev/null +++ b/MindIE/MultiModal/IP-Adapter/inference.py @@ -0,0 +1,308 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List +import torch +from diffusers import StableDiffusionPipeline, AutoencoderKL, DDIMScheduler +from PIL import Image +from safetensors import safe_open +import argparse +from stable_diffusion_pipeline import AIEStableDiffusionPipeline +from ip_adapter.attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor +from ip_adapter.utils import get_generator +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +import mindietorch +import time + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_model_path", + type=str, + default="./stable-diffusion-v1-5", + help="SD base_model_path", + ) + parser.add_argument( + "--vae_model_path", + type=str, + default="./sd-vae-ft-mse", + help="vae_model_path", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default="./image_encoder", + help="image_encoder_path", + ) + parser.add_argument( + "--ip_ckpt", + type=str, + default="./ip-adapter_sd15.bin", + help="SD1.5 ip_ckpt", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device", + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "--image_path", + type=str, + default="./assets/images/woman.png", + help="Path of image.", + ) + parser.add_argument( + "--save_image_path", + type=str, + default="./test.png", + help="Path of image.", + ) + parser.add_argument( + "--prompt", + type=str, + default="A girl", + help="prompt.", + ) + return parser.parse_args() + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + +class IPAdapter: + def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, args=None): + self.device = device + self.image_encoder_path = image_encoder_path + self.ip_ckpt = ip_ckpt + self.num_tokens = num_tokens + + self.pipe = sd_pipe.to(self.device) + self.pipe.output_dir = args.output_dir + self.device_0 = args.device + self.pipe.device_0 = self.device_0 + self.pipe.device_1 = self.device_0 + 1 + self.pipe.compile_aie_model() + self.set_ip_adapter() + + # load image encoder + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( + self.device, dtype=torch.float32 + ) + self.clip_image_processor = CLIPImageProcessor() + # image proj model + self.image_proj_model = self.init_proj() + + self.load_ip_adapter() + self.image_encoder_compiled = torch.jit.load(f"{args.output_dir}/image_encoder/image_encoder_bs1_compiled.ts") + mindietorch.set_device(self.device_0) + + def init_proj(self): + image_proj_model = ImageProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float32) + return image_proj_model + + def set_ip_adapter(self): + unet = self.pipe.unet + attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + attn_procs[name] = IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.num_tokens, + ).to(self.device, dtype=torch.float32) + unet.set_attn_processor(attn_procs) + if hasattr(self.pipe, "controlnet"): + if isinstance(self.pipe.controlnet, MultiControlNetModel): + for controlnet in self.pipe.controlnet.nets: + controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + else: + self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) + + def load_ip_adapter(self): + if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(self.ip_ckpt, map_location="cpu") + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) + + @torch.inference_mode() + def get_image_embeds(self, pil_image=None, clip_image_embeds=None): + if pil_image is not None: + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image_embeds = self.image_encoder_compiled( + clip_image.to(dtype=torch.float32).to(f"npu:{self.device_0}")).to("cpu") + else: + clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32) + image_prompt_embeds = self.image_proj_model(clip_image_embeds) + uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) + return image_prompt_embeds, uncond_image_prompt_embeds + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + def generate( + self, + pil_image=None, + clip_image_embeds=None, + prompt=None, + negative_prompt=None, + scale=1.0, + num_samples=4, + seed=None, + guidance_scale=7.5, + num_inference_steps=30, + **kwargs, + ): + self.set_scale(scale) + + if pil_image is not None: + num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) + else: + num_prompts = clip_image_embeds.size(0) + + if prompt is None: + prompt = "best quality, high quality" + if negative_prompt is None: + negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + + if not isinstance(prompt, List): + prompt = [prompt] * num_prompts + if not isinstance(negative_prompt, List): + negative_prompt = [negative_prompt] * num_prompts + + image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( + pil_image=pil_image, clip_image_embeds=clip_image_embeds + ) + bs_embed, seq_len, _ = image_prompt_embeds.shape + image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) + image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) + uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) + + with torch.inference_mode(): + prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( + prompt, + device=self.device, + num_images_per_prompt=num_samples, + do_classifier_free_guidance=True, + negative_prompt=negative_prompt, + ) + prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) + negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) + + generator = get_generator(seed, self.device) + + images = self.pipe.ascendie_infer( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps,\ + ) + + return images + +def load_pipe(args): + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + vae = AutoencoderKL.from_pretrained(args.vae_model_path) + pipe = AIEStableDiffusionPipeline.from_pretrained( + args.base_model_path, + scheduler=noise_scheduler, + vae=vae + ) + return pipe + +def main(args): + device = 'cpu' + pipe = load_pipe(args) + image_path = args.image_path + save_image_path = args.save_image_path + image = Image.open(image_path) + image = image.resize((256, 256)) + ip_model = IPAdapter(pipe, args.image_encoder_path, args.ip_ckpt, device, args=args) + prompt = args.prompt + print(f"start warm up------>") + for _ in range(5): + images = ip_model.generate(pil_image=image, num_samples=1, num_inference_steps=50, seed=42, prompt=prompt) + start = time.time() + images = ip_model.generate(pil_image=image, num_samples=1, num_inference_steps=50, seed=42, prompt=prompt) + print(f"use time is: {time.time() - start}s") + + image = images[0][0] + image.save(save_image_path) + mindietorch.finalize() + +if __name__ == "__main__": + args = parse_arguments() + main(args) \ No newline at end of file diff --git a/MindIE/MultiModal/IP-Adapter/requirements.txt b/MindIE/MultiModal/IP-Adapter/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7f86c3e1addd7d12eba75cd728e37945799dc067 --- /dev/null +++ b/MindIE/MultiModal/IP-Adapter/requirements.txt @@ -0,0 +1,4 @@ +torch==2.1.0 +diffusers==0.26.3 +transformers==4.26.1 +open_clip_torch==2.20.0 \ No newline at end of file diff --git a/MindIE/MultiModal/IP-Adapter/stable_diffusion_patch.py b/MindIE/MultiModal/IP-Adapter/stable_diffusion_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..63415abee790147eed068f8f9e5cdfc924496106 --- /dev/null +++ b/MindIE/MultiModal/IP-Adapter/stable_diffusion_patch.py @@ -0,0 +1,35 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers +import transformers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch') + + transformers_path = transformers.__path__ + transformers_version = transformers.__version__ + + assert transformers_version is '4.26.1', "expectation transformers==4.26.1" + os.system(f'patch -p0 {transformers_path[0]}/models/clip/modeling_clip.py clip.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/IP-Adapter/stable_diffusion_pipeline.py b/MindIE/MultiModal/IP-Adapter/stable_diffusion_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..96fc389bcfd17d9e3430538607991e49904409d2 --- /dev/null +++ b/MindIE/MultiModal/IP-Adapter/stable_diffusion_pipeline.py @@ -0,0 +1,365 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np +import torch +import mindietorch +from mindietorch import _enums +from diffusers import StableDiffusionPipeline +from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMScheduler, SASolverScheduler + +clip_time = 0 +unet_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 +scheduler_time = 0 + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + +class AIEStableDiffusionPipeline(StableDiffusionPipeline): + device_0 = None + device_1 = None + runtime = None + use_parallel_inferencing = False + unet_bg = None + output_dir = None + is_init = False + + def compile_aie_model(self): + if self.is_init: + return + in_channels = self.unet.config.out_channels + sample_size = self.unet.config.sample_size + encoder_hidden_size = self.text_encoder.config.hidden_size + max_position_embeddings = 81 + + batch_size = 1 + size = batch_size * 2 + clip_compiled_path = os.path.join(self.output_dir, f"clip/clip_bs{batch_size}_compiled.ts") + self.compiled_clip_model = torch.jit.load(clip_compiled_path).eval() + + vae_compiled_path = os.path.join(self.output_dir, f"vae/vae_bs{batch_size}_compiled.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + unet_compiled_path = os.path.join(self.output_dir, f"unet/unet_bs{size}_compiled.ts") + self.compiled_unet_model = torch.jit.load(unet_compiled_path).eval() + + self.is_init = True + + @torch.no_grad() + def ascendie_infer( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + do_classifier_free_guidance = guidance_scale > 1.0 and self.unet.config.time_cond_proj_dim is None + + if not self.is_init: + self.compile_aie_model() + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + if i == 50: + break + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + stream = mindietorch.npu.Stream(f"npu:{self.device_0}") + with mindietorch.npu.stream(stream): + latent_model_input_npu = latent_model_input.to(f"npu:{self.device_0}") + t_npu = t[None].to(f"npu:{self.device_0}") + prompt_embeds_npu = prompt_embeds.to(f"npu:{self.device_0}") + noise_pred = self.compiled_unet_model(latent_model_input_npu, t_npu, prompt_embeds_npu) + stream.synchronize() + noise_pred = noise_pred.to("cpu") + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + image = self.compiled_vae_model(latents.to(f"npu:{self.device_0}")).to("cpu") + + image = image.clamp(0, 1).float().numpy() + + has_nsfw_concept = False + + if output_type == "pil": + image = self.numpy_to_pil(image) + return (image, has_nsfw_concept) + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.compiled_clip_model(text_input_ids.to(f"npu:{self.device_0}")).to("cpu") + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.compiled_clip_model(uncond_input.input_ids.to(f"npu:{self.device_0}")).to("cpu") + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds \ No newline at end of file diff --git a/MindIE/MultiModal/OpenSora-1.0/.isort.cfg b/MindIE/MultiModal/OpenSora-1.0/.isort.cfg new file mode 100644 index 0000000000000000000000000000000000000000..ccbf575fdbfacd185cf880431ad81462e0ae8fdf --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/.isort.cfg @@ -0,0 +1,7 @@ +[settings] +line_length = 120 +multi_line_output=3 +include_trailing_comma = true +ignore_comments = true +profile = black +honor_noqa = true diff --git a/MindIE/MultiModal/OpenSora-1.0/.pre-commit-config.yaml b/MindIE/MultiModal/OpenSora-1.0/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7eb6c127ae04d0ad36df126bdc067d00afdd2c66 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/.pre-commit-config.yaml @@ -0,0 +1,32 @@ +repos: + + - repo: https://github.com/PyCQA/autoflake + rev: v2.2.1 + hooks: + - id: autoflake + name: autoflake (python) + args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports'] + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: sort all imports (python) + + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.9.1 + hooks: + - id: black + name: black formatter + args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310'] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-yaml + - id: check-merge-conflict + - id: check-case-conflict + - id: trailing-whitespace + - id: end-of-file-fixer + - id: mixed-line-ending + args: ['--fix=lf'] diff --git a/MindIE/MultiModal/OpenSora-1.0/CONTRIBUTING.md b/MindIE/MultiModal/OpenSora-1.0/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..b2ef579cbb287bf85192e1bf7de4974c2aeed981 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/CONTRIBUTING.md @@ -0,0 +1,91 @@ +# Contributing + +The Open-Sora project welcomes any constructive contribution from the community and the team is more than willing to work on problems you have encountered to make it a better project. + +## Development Environment Setup + +To contribute to Open-Sora, we would like to first guide you to set up a proper development environment so that you can better implement your code. You can install this library from source with the `editable` flag (`-e`, for development mode) so that your change to the source code will be reflected in runtime without re-installation. + +You can refer to the [Installation Section](./README.md#installation) and replace `pip install -v .` with `pip install -v -e .`. + + +### Code Style + +We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below. + +```shell +# these commands are executed under the Open-Sora directory +pip install pre-commit +pre-commit install +``` + +Code format checking will be automatically executed when you commit your changes. + + +## Contribution Guide + +You need to follow these steps below to make contribution to the main repository via pull request. You can learn about the details of pull request [here](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests). + +### 1. Fork the Official Repository + +Firstly, you need to visit the [Open-Sora repository](https://github.com/hpcaitech/Open-Sora) and fork into your own account. The `fork` button is at the right top corner of the web page alongside with buttons such as `watch` and `star`. + +Now, you can clone your own forked repository into your local environment. + +```shell +git clone https://github.com//Open-Sora.git +``` + +### 2. Configure Git + +You need to set the official repository as your upstream so that you can synchronize with the latest update in the official repository. You can learn about upstream [here](https://www.atlassian.com/git/tutorials/git-forks-and-upstreams). + +Then add the original repository as upstream + +```shell +cd Open-Sora +git remote add upstream https://github.com/hpcaitech/Open-Sora.git +``` + +you can use the following command to verify that the remote is set. You should see both `origin` and `upstream` in the output. + +```shell +git remote -v +``` + +### 3. Synchronize with Official Repository + +Before you make changes to the codebase, it is always good to fetch the latest updates in the official repository. In order to do so, you can use the commands below. + +```shell +git fetch upstream +git checkout main +git merge upstream/main +git push origin main +``` + +### 5. Create a New Branch + +You should not make changes to the `main` branch of your forked repository as this might make upstream synchronization difficult. You can create a new branch with the appropriate name. General branch name format should start with `hotfix/` and `feature/`. `hotfix` is for bug fix and `feature` is for addition of a new feature. + + +```shell +git checkout -b +``` + +### 6. Implementation and Code Commit + +Now you can implement your code change in the source code. Remember that you installed the system in development, thus you do not need to uninstall and install to make the code take effect. The code change will be reflected in every new PyThon execution. +You can commit and push the changes to your local repository. The changes should be kept logical, modular and atomic. + +```shell +git add -A +git commit -m "" +git push -u origin +``` + +### 7. Open a Pull Request + +You can now create a pull request on the GitHub webpage of your repository. The source branch is `` of your repository and the target branch should be `main` of `hpcaitech/Open-Sora`. After creating this pull request, you should be able to see it [here](https://github.com/hpcaitech/Open-Sora/pulls). + +The Open-Sora team will review your code change and merge your code if applicable. diff --git a/MindIE/MultiModal/OpenSora-1.0/LICENSE b/MindIE/MultiModal/OpenSora-1.0/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7327c123dd164dc24fc361a8eaf37c62125c3aa2 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/LICENSE @@ -0,0 +1,681 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ========================================================================= + This project is inspired by the listed projects and is subject to the following licenses: + + 1. Latte (https://github.com/Vchitect/Latte/blob/main/LICENSE) + + Copyright 2024 Latte + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + 2. PixArt-alpha (https://github.com/PixArt-alpha/PixArt-alpha/blob/master/LICENSE) + + Copyright (C) 2024 PixArt-alpha/PixArt-alpha + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + + 3. dpm-solver (https://github.com/LuChengTHU/dpm-solver/blob/main/LICENSE) + + MIT License + + Copyright (c) 2022 Cheng Lu + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + 4. DiT (https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt) + + Attribution-NonCommercial 4.0 International + + ======================================================================= + + Creative Commons Corporation ("Creative Commons") is not a law firm and + does not provide legal services or legal advice. Distribution of + Creative Commons public licenses does not create a lawyer-client or + other relationship. Creative Commons makes its licenses and related + information available on an "as-is" basis. Creative Commons gives no + warranties regarding its licenses, any material licensed under their + terms and conditions, or any related information. Creative Commons + disclaims all liability for damages resulting from their use to the + fullest extent possible. + + Using Creative Commons Public Licenses + + Creative Commons public licenses provide a standard set of terms and + conditions that creators and other rights holders may use to share + original works of authorship and other material subject to copyright + and certain other rights specified in the public license below. The + following considerations are for informational purposes only, are not + exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + + ======================================================================= + + Creative Commons Attribution-NonCommercial 4.0 International Public + License + + By exercising the Licensed Rights (defined below), You accept and agree + to be bound by the terms and conditions of this Creative Commons + Attribution-NonCommercial 4.0 International Public License ("Public + License"). To the extent this Public License may be interpreted as a + contract, You are granted the Licensed Rights in consideration of Your + acceptance of these terms and conditions, and the Licensor grants You + such rights in consideration of benefits the Licensor receives from + making the Licensed Material available under these terms and + conditions. + + Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + Section 3 -- License Conditions. + + Your exercise of the Licensed Rights is expressly made subject to the + following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + Section 4 -- Sui Generis Database Rights. + + Where the Licensed Rights include Sui Generis Database Rights that + apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + + For the avoidance of doubt, this Section 4 supplements and does not + replace Your obligations under this Public License where the Licensed + Rights include other Copyright and Similar Rights. + + Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + + ======================================================================= + + Creative Commons is not a party to its public + licenses. Notwithstanding, Creative Commons may elect to apply one of + its public licenses to material it publishes and in those instances + will be considered the “Licensor.” The text of the Creative Commons + public licenses is dedicated to the public domain under the CC0 Public + Domain Dedication. Except for the limited purpose of indicating that + material is shared under a Creative Commons public license or as + otherwise permitted by the Creative Commons policies published at + creativecommons.org/policies, Creative Commons does not authorize the + use of the trademark "Creative Commons" or any other trademark or logo + of Creative Commons without its prior written consent including, + without limitation, in connection with any unauthorized modifications + to any of its public licenses or any other arrangements, + understandings, or agreements concerning use of licensed material. For + the avoidance of doubt, this paragraph does not form part of the + public licenses. + + Creative Commons may be contacted at creativecommons.org. + + 5. OpenDiT (https://github.com/NUS-HPC-AI-Lab/OpenDiT/blob/master/LICENSE) + + Copyright OpenDiT + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + diff --git a/MindIE/MultiModal/OpenSora-1.0/README.md b/MindIE/MultiModal/OpenSora-1.0/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c53e0b904cf7125bddb65a90f67bbe7e990c4342 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/README.md @@ -0,0 +1,128 @@ +# OpenSora模型-推理指导 + +# 概述 + +Open Sora采用动态掩码策略等技术细节复现Sora,并已实现可变长宽比、可变分辨率和可变时长等功能。 + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 +Atlas 300I Duo推理卡:支持的卡数为1 + +## 输入输出数据 + +输入一个prompt,输入一个2s长的视频 + +# 推理环境准备 + +**表 1** 版本配套表 + +| 配套 | 版本 | 环境准备指导 | +| ------------------------------------ | ------- | ------------ | +| Python | 3.10.13 | - | +| PyTorch | 2.1.0 | - | +| 硬件:Atlas 300I Duo ,Atlas 800I A2 | \ | \ | + +请以CANN版本选择对应的固件与驱动版本。 + +# 快速上手 + +## 获取源码 + +1. 安装依赖 + + ```bash + pip3 install -r requirements.txt + ``` + +2. 安装mindie包 + + ```bash + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +## 准备数据集 + +本模型输入prompt生成视频,无需数据集。 + +## 模型推理 + +1. 下载模型 + + ST-DIT权重文件下载链接如下,按需下载: + + [ST-DIT-256x256下载链接](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x256x256.pth) + + [ST-DIT-512x512下载链接](https://huggingface.co/hpcai-tech/Open-Sora/blob/main/OpenSora-v1-HQ-16x512x512.pth) + + vae权重文件下载链接如下,按需下载: + + ```bash + # ema + git clone https://huggingface.co/stabilityai/sd-vae-ft-ema + ``` + + encoder权重文件 + + ```bash + https://huggingface.co/DeepFloyd/t5-v1_1-xxl/tree/main + ``` + +2. 模型转换,该步骤会生成编译之后的pt模型 + + ```bash + python3 export_model.py \ + --output_dir ./models \ + --encoder_model_path ./DeepFloyd--t5-v1_1-xxl \ + --dit_model_path ./OpenSora-v1-HQ-16x512x512.pth \ + --vae_model_path ./sd-vae-ft-ema \ + --resolution 16x512x512 \ + --device_id 0 + ``` + + 参数说明: + + - --encoder_model_path:encoder的权重路径 + - --dit_model_path:dit的权重路径 + - --vae_model_path:vae的权重路径 + - --resolution:分辨率。支持256和512 + - --device_id:NPU芯片 + - --output_dir:pt模型输出目录 + +3. 开始推理 + + 1. 开启cpu高性能模式 + + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 执行推理,会在当前路径生成sample.mp4 + + ```bash + python inference.py \ + ./configs/opensora/inference/16x256x256.py \ + --ckpt-path ./OpenSora-v1-HQ-16x512x512.pth \ + --prompt-path ./assets/texts/t2v_samples.txt \ + --use_mindie 1 \ + --device_id 0 + ``` + + 参数说明: + + - --ckpt-path:STDIT的权重路径 + - --prompt-path:prompt数据集的路径 + - --use_mindie:是否使用MindIE推理。1代表是,0代表否 + - --device_id:使用哪张卡 + +# 模型推理性能 + +性能参考下列数据。 + +| 分辨率 | 硬件形态 | 平均耗时 | +| ------ | -------- | -------- | +| 512 | Atlas 800I A2(8*32G) | 110.8s | +| 256 | Atlas 300I Duo | 22.2s | diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/images/imagenet/train/n01440764/n01440764_10026.JPEG b/MindIE/MultiModal/OpenSora-1.0/assets/images/imagenet/train/n01440764/n01440764_10026.JPEG new file mode 100644 index 0000000000000000000000000000000000000000..b985769e1c1f09585e67291a4926537186a40e49 Binary files /dev/null and b/MindIE/MultiModal/OpenSora-1.0/assets/images/imagenet/train/n01440764/n01440764_10026.JPEG differ diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/images/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG b/MindIE/MultiModal/OpenSora-1.0/assets/images/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG new file mode 100644 index 0000000000000000000000000000000000000000..1b332471a78cbb3e362a0871d8e2dfad14320910 Binary files /dev/null and b/MindIE/MultiModal/OpenSora-1.0/assets/images/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG differ diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/imagenet_id.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/imagenet_id.txt new file mode 100644 index 0000000000000000000000000000000000000000..9085aa0034c05cc60e40b1f14be1bb4a2a171d2f --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/imagenet_id.txt @@ -0,0 +1,8 @@ +207 +360 +387 +974 +88 +979 +417 +279 diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/imagenet_labels.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/imagenet_labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..6493fdbf907465063a2cee904fe3994a90d420cd --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/imagenet_labels.txt @@ -0,0 +1,8 @@ +golden retriever +otter +lesser panda +geyser +macaw +valley +balloon +golden panda diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2i_samples.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2i_samples.txt new file mode 100644 index 0000000000000000000000000000000000000000..9b729527cee2d4da1d28415e42c52c6627217d10 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2i_samples.txt @@ -0,0 +1,8 @@ +A small cactus with a happy face in the Sahara desert. +Bright scene, aerial view,ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens. +Nature vs human nature, surreal, UHD, 8k, hyper details, rich colors, photograph. +Poster of a mechanical cat, techical Schematics viewed from front. +Luffy from ONEPIECE, handsome face, fantasy. +Real beautiful woman. +A alpaca made of colorful building blocks, cyberpunk. +artistic diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_latte.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_latte.txt new file mode 100644 index 0000000000000000000000000000000000000000..a61359ca41325db817d8eb2c6a255d997f0382ca --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_latte.txt @@ -0,0 +1,7 @@ +Yellow and black tropical fish dart through the sea. +An epic tornado attacking above aglowing city at night. +Slow pan upward of blazing oak fire in an indoor fireplace. +a cat wearing sunglasses and working as a lifeguard at pool. +Sunset over the sea. +A dog in astronaut suit and sunglasses floating in space. +A astronaut in flying in space, 4k, high resolution diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_samples.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_samples.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f40f58206847e4037afe5bd87ac5d3e4d248a6e --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_samples.txt @@ -0,0 +1,10 @@ +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. \ No newline at end of file diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_samples_bk.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_samples_bk.txt new file mode 100644 index 0000000000000000000000000000000000000000..312db4603e3692f98dc9883236bee7d031276022 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_samples_bk.txt @@ -0,0 +1,10 @@ +A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. +The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty. +A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene. In the foreground, a few cars can be seen driving along a winding road that cuts through the mountains. The cars are small compared to the vastness of the landscape, emphasizing the grandeur of the surroundings. The overall style of the video is a mix of adventure and tranquility, with the hot air balloons adding a touch of whimsy to the otherwise serene mountain landscape. The video is likely shot during the day, as the lighting is bright and even, casting soft shadows on the snow-covered mountains. +The vibrant beauty of a sunflower field. The sunflowers, with their bright yellow petals and dark brown centers, are in full bloom, creating a stunning contrast against the green leaves and stems. The sunflowers are arranged in neat rows, creating a sense of order and symmetry. The sun is shining brightly, casting a warm glow on the flowers and highlighting their intricate details. The video is shot from a low angle, looking up at the sunflowers, which adds a sense of grandeur and awe to the scene. The sunflowers are the main focus of the video, with no other objects or people present. The video is a celebration of nature's beauty and the simple joy of a sunny day in the countryside. +A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene. The video is shot from a slightly elevated angle, providing a comprehensive view of the turtle's surroundings. The overall style of the video is calm and peaceful, capturing the beauty and tranquility of the underwater world. +A vibrant underwater scene. A group of blue fish, with yellow fins, are swimming around a coral reef. The coral reef is a mix of brown and green, providing a natural habitat for the fish. The water is a deep blue, indicating a depth of around 30 feet. The fish are swimming in a circular pattern around the coral reef, indicating a sense of motion and activity. The overall scene is a beautiful representation of marine life. +A bustling city street at night, filled with the glow of car headlights and the ambient light of streetlights. The scene is a blur of motion, with cars speeding by and pedestrians navigating the crosswalks. The cityscape is a mix of towering buildings and illuminated signs, creating a vibrant and dynamic atmosphere. The perspective of the video is from a high angle, providing a bird's eye view of the street and its surroundings. The overall style of the video is dynamic and energetic, capturing the essence of urban life at night. +A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road. +The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements. +A serene night scene in a forested area. The first frame shows a tranquil lake reflecting the star-filled sky above. The second frame reveals a beautiful sunset, casting a warm glow over the landscape. The third frame showcases the night sky, filled with stars and a vibrant Milky Way galaxy. The video is a time-lapse, capturing the transition from day to night, with the lake and forest serving as a constant backdrop. The style of the video is naturalistic, emphasizing the beauty of the night sky and the peacefulness of the forest. diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_sora.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_sora.txt new file mode 100644 index 0000000000000000000000000000000000000000..eeb887b1863e590e45054fb766694c1275cee987 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/t2v_sora.txt @@ -0,0 +1,48 @@ +A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. +Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. +A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors. +Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway. +Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle. The art style is 3D and realistic, with a focus on lighting and texture. The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image. +A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures. +This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird’s head is tilted slightly to the side, giving the impression of it looking regal and majestic. The background is blurred, drawing attention to the bird’s striking appearance. +Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee. +A young man at his 20s is sitting on a piece of cloud in the sky, reading a book. +Historical footage of California during the gold rush. +A close up view of a glass sphere that has a zen garden within it. There is a small dwarf in the sphere who is raking the zen garden and creating patterns in the sand. +Extreme close up of a 24 year old woman’s eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field, vivid colors, cinematic +A cartoon kangaroo disco dances. +A beautiful homemade video showing the people of Lagos, Nigeria in the year 2056. Shot with a mobile phone camera. +A petri dish with a bamboo forest growing within it that has tiny red pandas running around. +The camera rotates around a large stack of vintage televisions all showing different programs — 1950s sci-fi movies, horror movies, news, static, a 1970s sitcom, etc, set inside a large New York museum gallery. +3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest. +The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds. +Reflections in the window of a train traveling through the Tokyo suburbs. +A drone camera circles around a beautiful historic church built on a rocky outcropping along the Amalfi Coast, the view showcases historic and magnificent architectural details and tiered pathways and patios, waves are seen crashing against the rocks below as the view overlooks the horizon of the coastal waters and hilly landscapes of the Amalfi Coast Italy, several distant people are seen walking and enjoying vistas on patios of the dramatic ocean views, the warm glow of the afternoon sun creates a magical and romantic feeling to the scene, the view is stunning captured with beautiful photography. +A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect. +A flock of paper airplanes flutters through a dense jungle, weaving around trees as if they were migrating birds. +A cat waking up its sleeping owner demanding breakfast. The owner tries to ignore the cat, but the cat tries new tactics and finally the owner pulls out a secret stash of treats from under the pillow to hold the cat off a little longer. +Borneo wildlife on the Kinabatangan River +A Chinese Lunar New Year celebration video with Chinese Dragon. +Tour of an art gallery with many beautiful works of art in different styles. +Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes. +A stop motion animation of a flower growing out of the windowsill of a suburban house. +The story of a robot’s life in a cyberpunk setting. +An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film. +A beautiful silhouette animation shows a wolf howling at the moon, feeling lonely, until it finds its pack. +New York City submerged like Atlantis. Fish, whales, sea turtles and sharks swim through the streets of New York. +A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in. +Step-printing scene of a person running, cinematic film shot in 35mm. +Five gray wolf pups frolicking and chasing each other around a remote gravel road, surrounded by grass. The pups run and leap, chasing each other, and nipping at each other, playing. +Basketball through hoop then explodes. +Archeologists discover a generic plastic chair in the desert, excavating and dusting it with great care. +A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table, expression is one of pure joy and happiness, with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker, the grandmother wears a light blue blouse adorned with floral patterns, several happy friends and family sitting at the table can be seen celebrating, out of focus. The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood. +The camera directly faces colorful buildings in Burano Italy. An adorable dalmation looks through a window on a building on the ground floor. Many people are walking and cycling along the canal streets in front of the buildings. +An adorable happy otter confidently stands on a surfboard wearing a yellow lifejacket, riding along turquoise tropical waters near lush tropical islands, 3D digital render art style. +This close-up shot of a chameleon showcases its striking color changing capabilities. The background is blurred, drawing attention to the animal’s striking appearance. +A corgi vlogging itself in tropical Maui. +A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates a warm contrast, accentuating the cat’s orange fur. The shot is clear and sharp, with a shallow depth of field. +Aerial view of Santorini during the blue hour, showcasing the stunning architecture of white Cycladic buildings with blue domes. The caldera views are breathtaking, and the lighting creates a beautiful, serene atmosphere. +Tiltshift of a construction site filled with workers, equipment, and heavy machinery. +A giant, towering cloud in the shape of a man looms over the earth. The cloud man shoots lighting bolts down to the earth. +A Samoyed and a Golden Retriever dog are playfully romping through a futuristic neon city at night. The neon lights emitted from the nearby buildings glistens off of their fur. +The Glenfinnan Viaduct is a historic railway bridge in Scotland, UK, that crosses over the west highland line between the towns of Mallaig and Fort William. It is a stunning sight as a steam train leaves the bridge, traveling over the arch-covered viaduct. The landscape is dotted with lush greenery and rocky mountains, creating a picturesque backdrop for the train journey. The sky is blue and the sun is shining, making for a beautiful day to explore this majestic spot. diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/ucf101_id.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/ucf101_id.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8371f00609f33a59378dd2f6bb4385a7df8bd63 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/ucf101_id.txt @@ -0,0 +1,6 @@ +0 +1 +2 +3 +4 +5 diff --git a/MindIE/MultiModal/OpenSora-1.0/assets/texts/ucf101_labels.txt b/MindIE/MultiModal/OpenSora-1.0/assets/texts/ucf101_labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..264dbfd8837a4b89b81d05b06c48b567dfa1d150 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/assets/texts/ucf101_labels.txt @@ -0,0 +1,6 @@ +Apply Eye Makeup +Apply Lipstick +Archery +Baby Crawling +Balance Beam +Band Marching diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/16x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb1d796824c0b459b569e44d5ab66543814d748 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/16x256x256.py @@ -0,0 +1,31 @@ +num_frames = 16 +fps = 8 +image_size = (256, 256) + +# Define model +model = dict( + type="DiT-XL/2", + condition="text", + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/ucf101_labels.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/1x256x256-class.py b/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/1x256x256-class.py new file mode 100644 index 0000000000000000000000000000000000000000..24d1c8af390a408bf3d43ef4cd9c87d18d3fea2b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/1x256x256-class.py @@ -0,0 +1,31 @@ +num_frames = 1 +fps = 1 +image_size = (256, 256) + +# Define model +model = dict( + type="DiT-XL/2", + no_temporal_pos_emb=True, + condition="label_1000", + from_pretrained="DiT-XL-2-256x256.pt", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="classes", + num_classes=1000, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/imagenet_id.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/1x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/1x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..31a5b9f1f2f315b19b528b2c4b98cfeb8b213c58 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/dit/inference/1x256x256.py @@ -0,0 +1,32 @@ +num_frames = 1 +fps = 1 +image_size = (256, 256) + +# Define model +model = dict( + type="DiT-XL/2", + no_temporal_pos_emb=True, + condition="text", + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/imagenet_labels.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/dit/train/16x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/dit/train/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..af8ee8768af253ee124e2679706ea4320bb97def --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/dit/train/16x256x256.py @@ -0,0 +1,50 @@ +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = False +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="DiT-XL/2", + from_pretrained="DiT-XL-2-256x256.pt", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/dit/train/1x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/dit/train/1x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..667e0a835652d25c41fbd1d7947e65291972f49c --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/dit/train/1x256x256.py @@ -0,0 +1,50 @@ +num_frames = 1 +frame_interval = 1 +image_size = (256, 256) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = True +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = False +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="DiT-XL/2", + no_temporal_pos_emb=True, + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 128 +lr = 1e-4 # according to DiT repo +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/latte/inference/16x256x256-class.py b/MindIE/MultiModal/OpenSora-1.0/configs/latte/inference/16x256x256-class.py new file mode 100644 index 0000000000000000000000000000000000000000..c46f4bc362f60effbb80c74e4cea3662d39302a1 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/latte/inference/16x256x256-class.py @@ -0,0 +1,30 @@ +num_frames = 16 +fps = 8 +image_size = (256, 256) + +# Define model +model = dict( + type="Latte-XL/2", + condition="label_101", + from_pretrained="Latte-XL-2-256x256-ucf101.pt", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="classes", + num_classes=101, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/ucf101_id.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/latte/inference/16x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/latte/inference/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..cb502371d39b9324084bcda151d0a168e69fafaf --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/latte/inference/16x256x256.py @@ -0,0 +1,31 @@ +num_frames = 16 +fps = 8 +image_size = (256, 256) + +# Define model +model = dict( + type="Latte-XL/2", + condition="text", + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=4.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/ucf101_labels.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/latte/train/16x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/latte/train/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf6bd4126c8517d526c2af1b75d5af8a1660df0 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/latte/train/16x256x256.py @@ -0,0 +1,49 @@ +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="Latte-XL/2", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="clip", + from_pretrained="openai/clip-vit-base-patch32", + model_max_length=77, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/16x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..f8608e475d96466a8df3bf2fa6cbee27cabbf6b9 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/16x256x256.py @@ -0,0 +1,30 @@ +num_frames = 16 +fps = 24 // 3 +image_size = (256, 256) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=0.5, + time_scale=1.0, + enable_flashattn=False, + enable_layernorm_kernel=False, + from_pretrained="PRETRAINED_MODEL", +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, +) +dtype = "fp32" + +# Others +batch_size = 1 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./outputs/samples/" +vae_path = "stabilityai/sd-vae-ft-ema" +t5_path = "./t5-v1_1-xxl" +use_mindie = 1 +device_id = 0 +output_dir="./models" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/16x512x512.py b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/16x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..561c1563c1a33c23e329084fe882865da4a74a1b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/16x512x512.py @@ -0,0 +1,37 @@ +num_frames = 16 +fps = 24 // 3 +image_size = (512, 512) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=1.0, + enable_flashattn=False, + enable_layernorm_kernel=False, + from_pretrained="PRETRAINED_MODEL" +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", # 待修改 + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", # 待修改 + model_max_length=120, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, +) +dtype = "fp32" + +# Others +batch_size = 1 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./outputs/samples/" +use_mindie = 0 +device_id = 0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/64x512x512.py b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/64x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..e15649a35c8205d162cdd6873808e8737f8afb25 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/inference/64x512x512.py @@ -0,0 +1,35 @@ +num_frames = 64 +fps = 24 // 2 +image_size = (512, 512) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=2 / 3, + enable_flashattn=True, + enable_layernorm_kernel=True, + from_pretrained="PRETRAINED_MODEL", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, +) +scheduler = dict( + type="iddpm", + num_sampling_steps=100, + cfg_scale=7.0, +) +dtype = "fp16" + +# Others +batch_size = 1 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/16x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a318f0c72f2786690c2631eb43884898684d8 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/16x256x256.py @@ -0,0 +1,53 @@ +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=0.5, + time_scale=1.0, + from_pretrained="PixArt-XL-2-512x512.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/16x512x512.py b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/16x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..885aad1fed966acddfa9ce609c65b24449cc9c05 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/16x512x512.py @@ -0,0 +1,54 @@ +num_frames = 16 +frame_interval = 3 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = False +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=1.0, + from_pretrained=None, + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 500 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/360x512x512.py b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/360x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6f75995b96152a80ad14e6a40f4b1e2482c1e9 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/360x512x512.py @@ -0,0 +1,55 @@ +num_frames = 360 +frame_interval = 1 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2-seq" +sp_size = 2 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=2 / 3, + from_pretrained=None, + enable_flashattn=True, + enable_layernorm_kernel=True, + enable_sequence_parallelism=True, # enable sq here +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 250 +load = None + +batch_size = 1 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/64x512x512-sp.py b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/64x512x512-sp.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b9062c987e7e90c75e5e1d2064fe8654e22b46 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/64x512x512-sp.py @@ -0,0 +1,54 @@ +num_frames = 64 +frame_interval = 2 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2-seq" +sp_size = 2 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=2 / 3, + from_pretrained=None, + enable_flashattn=True, + enable_layernorm_kernel=True, + enable_sequence_parallelism=True, # enable sq here +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 1 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/64x512x512.py b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/64x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcdcc08d250e0a1d23ece174c023975309d2ae1 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/opensora/train/64x512x512.py @@ -0,0 +1,54 @@ +num_frames = 64 +frame_interval = 2 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=2 / 3, + from_pretrained=None, + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=64, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 250 +load = None + +batch_size = 4 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/16x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc8ee653c0fa23b76c29f17e728e263c738c2ea --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/16x256x256.py @@ -0,0 +1,32 @@ +num_frames = 16 +fps = 8 +image_size = (256, 256) + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=0.5, + time_scale=1.0, + from_pretrained="outputs/098-F16S3-PixArt-XL-2/epoch7-global_step30000/model_ckpt.pt", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2v_samples.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x1024MS.py b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x1024MS.py new file mode 100644 index 0000000000000000000000000000000000000000..41cc97ad0402d54610302ff6c10a7a4630d1f15b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x1024MS.py @@ -0,0 +1,34 @@ +num_frames = 1 +fps = 1 +image_size = (1920, 512) +multi_resolution = True + +# Define model +model = dict( + type="PixArtMS-XL/2", + space_scale=2.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-1024-MS.pth", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2i_samples.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..11e06d777af8450b6610e4e99f29e10c548900c1 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x256x256.py @@ -0,0 +1,33 @@ +num_frames = 1 +fps = 1 +image_size = (256, 256) + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-256x256.pth", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2i_samples.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x512x512.py b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..5674259b5a36afc48384b8170fcb978a43717753 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/inference/1x512x512.py @@ -0,0 +1,33 @@ +num_frames = 1 +fps = 1 +image_size = (512, 512) + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-512x512.pth", +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, +) +scheduler = dict( + type="dpm-solver", + num_sampling_steps=20, + cfg_scale=7.0, +) +dtype = "fp16" + +# Others +batch_size = 2 +seed = 42 +prompt_path = "./assets/texts/t2i_samples.txt" +save_dir = "./outputs/samples/" diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/16x256x256.py b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/16x256x256.py new file mode 100644 index 0000000000000000000000000000000000000000..b47731e2d5fcb1418c23b68442ee1cae54425726 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/16x256x256.py @@ -0,0 +1,53 @@ +num_frames = 16 +frame_interval = 3 +image_size = (256, 256) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = False +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=0.5, + time_scale=1.0, + from_pretrained="PixArt-XL-2-512x512.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 8 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/1x512x512.py b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/1x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..619c9aafd03a68a36815b5bbc7d12d59c3ea40c6 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/1x512x512.py @@ -0,0 +1,54 @@ +num_frames = 1 +frame_interval = 1 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = True +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=1.0, + no_temporal_pos_emb=True, + from_pretrained="PixArt-XL-2-512x512.pth", + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 1000 +load = None + +batch_size = 32 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/64x512x512.py b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/64x512x512.py new file mode 100644 index 0000000000000000000000000000000000000000..628cf254fe3d379e4fe6661d62ddad6511003abc --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/configs/pixart/train/64x512x512.py @@ -0,0 +1,54 @@ +num_frames = 64 +frame_interval = 2 +image_size = (512, 512) + +# Define dataset +root = None +data_path = "CSV_PATH" +use_image_transform = False +num_workers = 4 + +# Define acceleration +dtype = "bf16" +grad_checkpoint = True +plugin = "zero2" +sp_size = 1 + +# Define model +model = dict( + type="PixArt-XL/2", + space_scale=1.0, + time_scale=2 / 3, + from_pretrained=None, + enable_flashattn=True, + enable_layernorm_kernel=True, +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, +) +scheduler = dict( + type="iddpm", + timestep_respacing="", +) + +# Others +seed = 42 +outputs = "outputs" +wandb = False + +epochs = 1000 +log_every = 10 +ckpt_every = 250 +load = None + +batch_size = 4 +lr = 2e-5 +grad_clip = 1.0 diff --git a/MindIE/MultiModal/OpenSora-1.0/docs/README_zh.md b/MindIE/MultiModal/OpenSora-1.0/docs/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/docs/acceleration.md b/MindIE/MultiModal/OpenSora-1.0/docs/acceleration.md new file mode 100644 index 0000000000000000000000000000000000000000..3a0a68eb7095c8d2d668ffe8edabc6a3dfc5628d --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/docs/acceleration.md @@ -0,0 +1,57 @@ +# Acceleration + +Open-Sora aims to provide a high-speed training framework for diffusion models. We can achieve **55%** training speed acceleration when training on **64 frames 512x512 videos**. Our framework support training **1min 1080p videos**. + +## Accelerated Transformer + +Open-Sora boosts the training speed by: + +- Kernal optimization including [flash attention](https://github.com/Dao-AILab/flash-attention), fused layernorm kernal, and the ones compiled by colossalAI. +- Hybrid parallelism including ZeRO. +- Gradient checkpointing for larger batch size. + +Our training speed on images is comparable to [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT), an project to accelerate DiT training. The training speed is measured on 8 H800 GPUs with batch size 128, image size 256x256. + +| Model | Throughput (img/s/GPU) | Throughput (tokens/s/GPU) | +| -------- | ---------------------- | ------------------------- | +| DiT | 100 | 26k | +| OpenDiT | 175 | 45k | +| OpenSora | 175 | 45k | + +## Efficient STDiT + +Our STDiT adopts spatial-temporal attention to model the video data. Compared with directly applying full attention on DiT, our STDiT is more efficient as the number of frames increases. Our current framework only supports sequence parallelism for very long sequence. + +The training speed is measured on 8 H800 GPUs with acceleration techniques applied, GC means gradient checkpointing. Both with T5 conditioning like PixArt. + +| Model | Setting | Throughput (sample/s/GPU) | Throughput (tokens/s/GPU) | +| ---------------- | -------------- | ------------------------- | ------------------------- | +| DiT | 16x256 (4k) | 7.20 | 29k | +| STDiT | 16x256 (4k) | 7.00 | 28k | +| DiT | 16x512 (16k) | 0.85 | 14k | +| STDiT | 16x512 (16k) | 1.45 | 23k | +| DiT (GC) | 64x512 (65k) | 0.08 | 5k | +| STDiT (GC) | 64x512 (65k) | 0.40 | 25k | +| STDiT (GC, sp=2) | 360x512 (370k) | 0.10 | 18k | + +With a 4x downsampling in the temporal dimension with Video-VAE, an 24fps video has 450 frames. The gap between the speed of STDiT (28k tokens/s) and DiT on images (up to 45k tokens/s) mainly comes from the T5 and VAE encoding, and temperal attention. + +## Accelerated Encoder (T5, VAE) + +During training, texts are encoded by T5, and videos are encoded by VAE. Typically there are two ways to accelerate the training: + +1. Preprocess text and video data in advance and save them to disk. +2. Encode text and video data during training, and accelerate the encoding process. + +For option 1, 120 tokens for one sample require 1M disk space, and a 64x64x64 latent requires 4M. Considering a training dataset with 10M video clips, the total disk space required is 50TB. Our storage system is not ready at this time for this scale of data. + +For option 2, we boost T5 speed and memory requirement. According to [OpenDiT](https://github.com/NUS-HPC-AI-Lab/OpenDiT), we find VAE consumes a large number of GPU memory. Thus we split batch size into smaller ones for VAE encoding. With both techniques, we can greatly accelerated the training speed. + +The training speed is measured on 8 H800 GPUs with STDiT. + +| Acceleration | Setting | Throughput (img/s/GPU) | Throughput (tokens/s/GPU) | +| ------------ | ------------- | ---------------------- | ------------------------- | +| Baseline | 16x256 (4k) | 6.16 | 25k | +| w. faster T5 | 16x256 (4k) | 7.00 | 29k | +| Baseline | 64x512 (65k) | 0.94 | 15k | +| w. both | 64x512 (65k) | 1.45 | 23k | diff --git a/MindIE/MultiModal/OpenSora-1.0/docs/commands.md b/MindIE/MultiModal/OpenSora-1.0/docs/commands.md new file mode 100644 index 0000000000000000000000000000000000000000..28ee285de143c9d5baf56d513728249f7aa82730 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/docs/commands.md @@ -0,0 +1,91 @@ +# Commands + +## Inference + +You can modify corresponding config files to change the inference settings. See more details [here](/docs/structure.md#inference-config-demos). + +### Inference with DiT pretrained on ImageNet + +The following command automatically downloads the pretrained weights on ImageNet and runs inference. + +```bash +python scripts/inference.py configs/dit/inference/1x256x256-class.py --ckpt-path DiT-XL-2-256x256.pt +``` + +### Inference with Latte pretrained on UCF101 + +The following command automatically downloads the pretrained weights on UCF101 and runs inference. + +```bash +python scripts/inference.py configs/latte/inference/16x256x256-class.py --ckpt-path Latte-XL-2-256x256-ucf101.pt +``` + +### Inference with PixArt-α pretrained weights + +Download T5 into `./pretrained_models` and run the following command. + +```bash +# 256x256 +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x256x256.py --ckpt-path PixArt-XL-2-256x256.pth + +# 512x512 +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x512x512.py --ckpt-path PixArt-XL-2-512x512.pth + +# 1024 multi-scale +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/pixart/inference/1x1024MS.py --ckpt-path PixArt-XL-2-1024MS.pth +``` + +### Inference with checkpoints saved during training + +During training, an experiment logging folder is created in `outputs` directory. Under each checpoint folder, e.g. `epoch12-global_step2000`, there is a `ema.pt` and the shared `model` folder. Run the following command to perform inference. + +```bash +# inference with ema model +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000/ema.pt + +# inference with model +torchrun --standalone --nproc_per_node 1 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000 + +# inference with sequence parallelism +# sequence parallelism is enabled automatically when nproc_per_node is larger than 1 +torchrun --standalone --nproc_per_node 2 scripts/inference.py configs/opensora/inference/16x256x256.py --ckpt-path outputs/001-STDiT-XL-2/epoch12-global_step2000 +``` + +The second command will automatically generate a `model_ckpt.pt` file in the checkpoint folder. + +### Inference Hyperparameters + +1. DPM-solver is good at fast inference for images. However, the video result is not satisfactory. You can use it for fast demo purpose. + +```python +type="dmp-solver" +num_sampling_steps=20 +``` + +1. You can use [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)'s finetuned VAE decoder on videos for inference (consumes more memory). However, we do not see significant improvement in the video result. To use it, download [the pretrained weights](https://huggingface.co/maxin-cn/Latte/tree/main/t2v_required_models/vae_temporal_decoder) into `./pretrained_models/vae_temporal_decoder` and modify the config file as follows. + +```python +vae = dict( + type="VideoAutoencoderKLTemporalDecoder", + from_pretrained="pretrained_models/vae_temporal_decoder", +) + +## Training + +To resume training, run the following command. ``--load`` different from ``--ckpt-path`` as it loads the optimizer and dataloader states. + +```bash +torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py configs/opensora/train/64x512x512.py --data-path YOUR_CSV_PATH --load YOUR_PRETRAINED_CKPT +``` + +To enable wandb logging, add `--wandb` to the command. + +```bash +WANDB_API_KEY=YOUR_WANDB_API_KEY torchrun --nnodes=1 --nproc_per_node=8 scripts/train.py configs/opensora/train/64x512x512.py --data-path YOUR_CSV_PATH --wandb True +``` + +You can modify corresponding config files to change the training settings. See more details [here](/docs/structure.md#training-config-demos). + +### Training Hyperparameters + +1. `dtype` is the data type for training. Only `fp16` and `bf16` are supported. ColossalAI automatically enables the mixed precision training for `fp16` and `bf16`. During training, we find `bf16` more stable. diff --git a/MindIE/MultiModal/OpenSora-1.0/docs/datasets.md b/MindIE/MultiModal/OpenSora-1.0/docs/datasets.md new file mode 100644 index 0000000000000000000000000000000000000000..c06835b03c67126dc5c1f11062340dcff60d9a21 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/docs/datasets.md @@ -0,0 +1,28 @@ +# Datasets + +## Datasets used for now + +### HD-VG-130M + +[HD-VG-130M](https://github.com/daooshee/HD-VG-130M?tab=readme-ov-file) comprises 130M text-video pairs. The caption is generated by BLIP-2. We find the cut and the text quality are relatively poor. It contains 20 splits. For OpenSora 1.0, we use the first split. We plan to use the whole dataset and re-process it. + +### Inter4k + +[Inter4k](https://github.com/alexandrosstergiou/Inter4K) is a dataset containing 1k video clips with 4K resolution. The dataset is proposed for super-resolution tasks. We use the dataset for HQ training. The videos are processed as mentioned [here](/README.md#data-processing). + +### Pexels.com + +[Pexels.com](https://www.pexels.com/) is a website that provides free stock photos and videos. We collect 19K video clips from this website for HQ training. The videos are processed as mentioned [here](/README.md#data-processing). + +## Datasets watching list + +We are also watching the following datasets and considering using them in the future, which depends on our disk space and the quality of the dataset. + +| Name | Size | Description | +| ----------------- | ------------ | ----------------------------- | +| Panda-70M | 70M videos | High quality video-text pairs | +| WebVid-10M | 10M videos | Low quality | +| InternVid-10M-FLT | 10M videos | | +| EGO4D | 3670 hours | | +| OpenDV-YouTube | 1700 hours | | +| VidProM | 6.69M videos | | diff --git a/MindIE/MultiModal/OpenSora-1.0/docs/report_v1.md b/MindIE/MultiModal/OpenSora-1.0/docs/report_v1.md new file mode 100644 index 0000000000000000000000000000000000000000..b3b8073cb2aa5fd257664c6947b9187683c35e6a --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/docs/report_v1.md @@ -0,0 +1,47 @@ +# Open-Sora v1 Report + +OpenAI's Sora is amazing at generating one minutes high quality videos. However, it reveals almost no information about its details. To make AI more "open", we are dedicated to build an open-source version of Sora. This report describes our first attempt to train a transformer-based video diffusion model. + +## Efficiency in choosing the architecture + +To lower the computational cost, we want to utilize existing VAE models. Sora uses spatial-temporal VAE to reduce the temporal dimensions. However, we found that there is no open-source high-quality spatial-temporal VAE model. [MAGVIT](https://github.com/google-research/magvit)'s 4x4x4 VAE is not open-sourced, while [VideoGPT](https://wilson1yan.github.io/videogpt/index.html)'s 2x4x4 VAE has a low quality in our experiments. Thus, we decided to use a 2D VAE (from [Stability-AI](https://huggingface.co/stabilityai/sd-vae-ft-mse-original)) in our first version. + +The video training involves a large amount of tokens. Considering 24fps 1min videos, we have 1440 frames. With VAE downsampling 4x and patch size downsampling 2x, we have 1440x1024≈1.5M tokens. Full attention on 1.5M tokens leads to a huge computational cost. Thus, we use spatial-temporal attention to reduce the cost following [Latte](https://github.com/Vchitect/Latte). + +As shown in the figure, we insert a temporal attention right after each spatial attention in STDiT (ST stands for spatial-temporal). This is similar to variant 3 in Latte's paper. However, we do not control a similar number of parameters for these variants. While Latte's paper claims their variant is better than variant 3, our experiments on 16x256x256 videos show that with same number of iterations, the performance ranks as: DiT (full) > STDiT (Sequential) > STDiT (Parallel) ≈ Latte. Thus, we choose STDiT (Sequential) out of efficiency. Speed benchmark is provided [here](/docs/acceleration.md#efficient-stdit). + +![Architecture Comparison](https://i0.imgs.ovh/2024/03/15/eLk9D.png) + +To focus on video generation, we hope to train the model based on a powerful image generation model. [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha) is an efficiently trained high-quality image generation model with T5-conditioned DiT structure. We initialize our model with PixArt-α and initialize the projection layer of inserted temporal attention with zero. This initialization preserves model's ability of image generation at beginning, while Latte's architecture cannot. The inserted attention increases the number of parameter from 580M to 724M. + +![Architecture](https://i0.imgs.ovh/2024/03/16/erC1d.png) + +Drawing from the success of PixArt-α and Stable Video Diffusion, we also adopt a progressive training strategy: 16x256x256 on 366K pretraining datasets, and then 16x256x256, 16x512x512, and 64x512x512 on 20K datasets. With scaled position embedding, this strategy greatly reduces the computational cost. + +We also try to use a 3D patch embedder in DiT. However, with 2x downsampling on temporal dimension, the generated videos have a low quality. Thus, we leave the downsampling to temporal VAE in our next version. For now, we sample at every 3 frames with 16 frames training and every 2 frames with 64 frames training. + +## Data is the key to high quality + +We find that the number and quality of data have a great impact on the quality of generated videos, even larger than the model architecture and training strategy. At this time, we only prepared the first split (366K video clips) from [HD-VG-130M](https://github.com/daooshee/HD-VG-130M). The quality of these videos varies greatly, and the captions are not that accurate. Thus, we further collect 20k relatively high quality videos from [Pexels](https://www.pexels.com/), which provides free license videos. We label the video with LLaVA, an image captioning model, with three frames and a designed prompt. With designed prompt, LLaVA can generate good quality of captions. + +![Caption](https://i0.imgs.ovh/2024/03/16/eXdvC.png) + +As we lay more emphasis on the quality of data, we prepare to collect more data and build a video preprocessing pipeline in our next version. + +## Training Details + +With a limited training budgets, we made only a few exploration. We find learning rate 1e-4 is too large and scales down to 2e-5. When training with a large batch size, we find `fp16` less stable than `bf16` and may lead to generation failure. Thus, we switch to `bf16` for training on 64x512x512. For other hyper-parameters, we follow previous works. + +## Loss curves + +16x256x256 Pretraining Loss Curve + +![16x256x256 Pretraining Loss Curve](https://i0.imgs.ovh/2024/03/16/erXQj.png) + +16x256x256 HQ Training Loss Curve + +![16x256x256 HQ Training Loss Curve](https://i0.imgs.ovh/2024/03/16/ernXv.png) + +16x512x512 HQ Training Loss Curve + +![16x512x512 HQ Training Loss Curve](https://i0.imgs.ovh/2024/03/16/erHBe.png) diff --git a/MindIE/MultiModal/OpenSora-1.0/docs/structure.md b/MindIE/MultiModal/OpenSora-1.0/docs/structure.md new file mode 100644 index 0000000000000000000000000000000000000000..0fc087ee45b7d82c2009c4dae198fa91b9896d1a --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/docs/structure.md @@ -0,0 +1,178 @@ +# Repo & Config Structure + +## Repo Structure + +```plaintext +Open-Sora +├── README.md +├── docs +│ ├── acceleration.md -> Acceleration & Speed benchmark +│ ├── command.md -> Commands for training & inference +│ ├── datasets.md -> Datasets used in this project +│ ├── structure.md -> This file +│ └── report_v1.md -> Report for Open-Sora v1 +├── scripts +│ ├── train.py -> diffusion training script +│ └── inference.py -> Report for Open-Sora v1 +├── configs -> Configs for training & inference +├── opensora +│ ├── __init__.py +│ ├── registry.py -> Registry helper +│   ├── acceleration -> Acceleration related code +│   ├── dataset -> Dataset related code +│   ├── models +│   │   ├── layers -> Common layers +│   │   ├── vae -> VAE as image encoder +│   │   ├── text_encoder -> Text encoder +│   │   │   ├── classes.py -> Class id encoder (inference only) +│   │   │   ├── clip.py -> CLIP encoder +│   │   │   └── t5.py -> T5 encoder +│   │   ├── dit +│   │   ├── latte +│   │   ├── pixart +│   │   └── stdit -> Our STDiT related code +│   ├── schedulers -> Diffusion shedulers +│   │   ├── iddpm -> IDDPM for training and inference +│   │ └── dpms -> DPM-Solver for fast inference +│ └── utils +└── tools -> Tools for data processing and more +``` + +## Configs + +Our config files follows [MMEgine](https://github.com/open-mmlab/mmengine). MMEngine will reads the config file (a `.py` file) and parse it into a dictionary-like object. + +```plaintext +Open-Sora +└── configs -> Configs for training & inference + ├── opensora -> STDiT related configs + │ ├── inference + │ │ ├── 16x256x256.py -> Sample videos 16 frames 256x256 + │ │ ├── 16x512x512.py -> Sample videos 16 frames 512x512 + │ │ └── 64x512x512.py -> Sample videos 64 frames 512x512 + │ └── train + │ ├── 16x256x256.py -> Train on videos 16 frames 256x256 + │ ├── 16x256x256.py -> Train on videos 16 frames 256x256 + │ └── 64x512x512.py -> Train on videos 64 frames 512x512 + ├── dit -> DiT related configs +    │   ├── inference +    │   │   ├── 1x256x256-class.py -> Sample images with ckpts from DiT +    │   │   ├── 1x256x256.py -> Sample images with clip condition +    │   │   └── 16x256x256.py -> Sample videos +    │   └── train +    │     ├── 1x256x256.py -> Train on images with clip condition +    │      └── 16x256x256.py -> Train on videos + ├── latte -> Latte related configs + └── pixart -> PixArt related configs +``` + +## Inference config demos + +To change the inference settings, you can directly modify the corresponding config file. Or you can pass arguments to overwrite the config file ([config_utils.py](/opensora/utils/config_utils.py)). To change sampling prompts, you should modify the `.txt` file passed to the `--prompt_path` argument. + +```plaintext +--prompt_path ./assets/texts/t2v_samples.txt -> prompt_path +--ckpt-path ./path/to/your/ckpt.pth -> model["from_pretrained"] +``` + +The explanation of each field is provided below. + +```python +# Define sampling size +num_frames = 64 # number of frames +fps = 24 // 2 # frames per second (divided by 2 for frame_interval=2) +image_size = (512, 512) # image size (height, width) + +# Define model +model = dict( + type="STDiT-XL/2", # Select model type (STDiT-XL/2, DiT-XL/2, etc.) + space_scale=1.0, # (Optional) Space positional encoding scale (new height / old height) + time_scale=2 / 3, # (Optional) Time positional encoding scale (new frame_interval / old frame_interval) + enable_flashattn=True, # (Optional) Speed up training and inference with flash attention + enable_layernorm_kernel=True, # (Optional) Speed up training and inference with fused kernel + from_pretrained="PRETRAINED_MODEL", # (Optional) Load from pretrained model + no_temporal_pos_emb=True, # (Optional) Disable temporal positional encoding (for image) +) +vae = dict( + type="VideoAutoencoderKL", # Select VAE type + from_pretrained="stabilityai/sd-vae-ft-ema", # Load from pretrained VAE + micro_batch_size=128, # VAE with micro batch size to save memory +) +text_encoder = dict( + type="t5", # Select text encoder type (t5, clip) + from_pretrained="./pretrained_models/t5_ckpts", # Load from pretrained text encoder + model_max_length=120, # Maximum length of input text +) +scheduler = dict( + type="iddpm", # Select scheduler type (iddpm, dpm-solver) + num_sampling_steps=100, # Number of sampling steps + cfg_scale=7.0, # hyper-parameter for classifier-free diffusion +) +dtype = "fp16" # Computation type (fp16, fp32, bf16) + +# Other settings +batch_size = 1 # batch size +seed = 42 # random seed +prompt_path = "./assets/texts/t2v_samples.txt" # path to prompt file +save_dir = "./samples" # path to save samples +``` + +## Training config demos + +```python +# Define sampling size +num_frames = 64 +frame_interval = 2 # sample every 2 frames +image_size = (512, 512) + +# Define dataset +root = None # root path to the dataset +data_path = "CSV_PATH" # path to the csv file +use_image_transform = False # True if training on images +num_workers = 4 # number of workers for dataloader + +# Define acceleration +dtype = "bf16" # Computation type (fp16, bf16) +grad_checkpoint = True # Use gradient checkpointing +plugin = "zero2" # Plugin for distributed training (zero2, zero2-seq) +sp_size = 1 # Sequence parallelism size (1 for no sequence parallelism) + +# Define model +model = dict( + type="STDiT-XL/2", + space_scale=1.0, + time_scale=2 / 3, + from_pretrained="YOUR_PRETRAINED_MODEL", + enable_flashattn=True, # Enable flash attention + enable_layernorm_kernel=True, # Enable layernorm kernel +) +vae = dict( + type="VideoAutoencoderKL", + from_pretrained="stabilityai/sd-vae-ft-ema", + micro_batch_size=128, +) +text_encoder = dict( + type="t5", + from_pretrained="./pretrained_models/t5_ckpts", + model_max_length=120, + shardformer=True, # Enable shardformer for T5 acceleration +) +scheduler = dict( + type="iddpm", + timestep_respacing="", # Default 1000 timesteps +) + +# Others +seed = 42 +outputs = "outputs" # path to save checkpoints +wandb = False # Use wandb for logging + +epochs = 1000 # number of epochs (just large enough, kill when satisfied) +log_every = 10 +ckpt_every = 250 +load = None # path to resume training + +batch_size = 4 +lr = 2e-5 +grad_clip = 1.0 # gradient clipping +``` diff --git a/MindIE/MultiModal/OpenSora-1.0/export_model.py b/MindIE/MultiModal/OpenSora-1.0/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a872b0bab28148ae4b93ed89b597bfdc90b3572e --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/export_model.py @@ -0,0 +1,230 @@ +import numpy as np +import torch +import argparse +import os +import mindietorch +from mindietorch import _enums +from opensora.models.stdit.stdit import STDiT_XL_2 +from opensora.models.vae.vae import VideoAutoencoderKL +from transformers import T5EncoderModel + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="save dir" + ) + parser.add_argument( + "--encoder_model_path", + type=str, + default="./DeepFloyd--t5-v1_1-xxl", + help="encoder model path" + ) + parser.add_argument( + "--dit_model_path", + type=str, + default="./OpenSora-v1-HQ-16x512x512.pth", + help="stdit model path" + ) + parser.add_argument( + "--vae_model_path", + type=str, + default="./sd-vae-ft-ema", + help="vae model path" + ) + parser.add_argument( + "--micro_batch_size", + type=int, + default=4, + help="vae micro_batch_size" + ) + parser.add_argument( + "--resolution", + type=str, + default='16x512x512', + choices=['16x256x256', '16x512x512'] + ) + parser.add_argument( + "--device_id", + type=int, + default=0, + help="npu device id" + ) + return parser.parse_args() + +class TextEncoderExport(torch.nn.Module): + def __init__(self, textencoder_model): + super(TextEncoderExport, self).__init__() + self.textencoder_model = textencoder_model + + def forward(self, input_ids, attention_mask): + return self.textencoder_model(input_ids=input_ids, + attention_mask=attention_mask, + return_dict=False)[0] + +def export_textencoder(args, save_dir, batch_size): + encoder_path = os.path.join(save_dir, "encoder") + if not os.path.exists(encoder_path): + os.makedirs(encoder_path, mode=0o640) + traced_path = os.path.join(encoder_path, "encoder.pt") + compiled_path = os.path.join(encoder_path, "encoder_compiled.pt") + model_path = args.encoder_model_path + max_lenth = 120 + if not os.path.exists(traced_path): + text_encoder = T5EncoderModel.from_pretrained(model_path, cache_dir="cache_dir", torch_dtype=torch.float).to('cpu') + dummy_input = ( + torch.ones([batch_size, max_lenth], dtype=torch.int64), + torch.ones([batch_size, max_lenth], dtype=torch.int64) + ) + encoder = TextEncoderExport(text_encoder) + encoder.eval() + torch.jit.trace(encoder, dummy_input).save(traced_path) + if not os.path.exists(compiled_path): + model = torch.jit.load(traced_path).eval() + compiled_model = mindietorch.compile( + model, + inputs=[mindietorch.Input((batch_size, max_lenth), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_lenth), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP32, + soc_version="Ascend910B4", + optimization_level=0 + ) + torch.jit.save(compiled_model, compiled_path) + +class STDiTExport(torch.nn.Module): + def __init__(self, dit_model): + super(STDiTExport, self).__init__() + self.dit_model = dit_model + + def forward(self, x, timestep, y, mask): + return self.dit_model(x, timestep, y, mask) + +def export_dit(args, save_dir, batch_size): + dit_path = os.path.join(save_dir, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + resolution = args.resolution + latent1 = int(resolution.split('x')[1]) + latent2 = int(resolution.split('x')[2]) + height, width = latent1 // 8, latent2 // 8 + traced_path = os.path.join(dit_path, f"dit_{latent1}_{latent2}.pt") + compiled_path = os.path.join(dit_path, f"dit_{latent1}_{latent2}_compiled.pt") + model_path = args.dit_model_path + + kwargs = { + 'space_scale': 0.5, + 'time_scale': 1.0, + 'enable_flashattn': False, + 'enable_layernorm_kernel': False, + 'input_size': [16, height, width], + 'in_channels': 4, + 'caption_channels': 4096, + 'model_max_length': 120, + 'dtype': torch.float32, + 'enable_sequence_parallelism': False + } + + video_lenth = kwargs['input_size'][0] + in_channels = kwargs['in_channels'] + model_max_length = kwargs['model_max_length'] + caption_channels = kwargs['caption_channels'] + if not os.path.exists(traced_path): + dit_model = STDiT_XL_2(from_pretrained=model_path, **kwargs) + dummy_input = ( + torch.ones([batch_size, in_channels, video_lenth, height, width], dtype=torch.float32), + torch.ones([batch_size,], dtype=torch.int64), + torch.ones([batch_size, 1, model_max_length, caption_channels], dtype=torch.float32), + torch.ones([1, model_max_length], dtype=torch.int64) + ) + dit = STDiTExport(dit_model) + dit.eval() + torch.jit.trace(dit, dummy_input).save(traced_path) + if not os.path.exists(compiled_path): + model = torch.jit.load(traced_path).eval() + compiled_model = mindietorch.compile( + model, + inputs=[mindietorch.Input((batch_size, in_channels, video_lenth, height, width), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, 1, model_max_length, caption_channels), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1, model_max_length), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version="Ascend910B4", + optimization_level=0 + ) + torch.jit.save(compiled_model, compiled_path) + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model): + super(VaeExport, self).__init__() + self.vae_model = vae_model + + def forward(self, latents): + return self.vae_model.decode(latents) + +def export_vae(args, save_dir, batch_size): + vae_path = os.path.join(save_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + resolution = args.resolution + latent1 = int(resolution.split('x')[1]) + latent2 = int(resolution.split('x')[2]) + height, width = latent1 // 8, latent2 // 8 + traced_path = os.path.join(vae_path, f"vae_{latent1}_{latent2}.pt") + compiled_path = os.path.join(vae_path, f"vae_{latent1}_{latent2}_compiled.pt") + model_path = args.vae_model_path + micro_batch_size = args.micro_batch_size + in_channels = 4 + video_lenth = 16 + + if not os.path.exists(traced_path): + vae_model = VideoAutoencoderKL(from_pretrained=model_path, micro_batch_size=micro_batch_size) + dummy_input = ( + torch.ones([batch_size, in_channels, video_lenth, height, width], dtype=torch.float32) + ) + vae = VaeExport(vae_model) + vae.eval() + torch.jit.trace(vae, dummy_input).save(traced_path) + if not os.path.exists(compiled_path): + model = torch.jit.load(traced_path).eval() + compiled_model = mindietorch.compile( + model, + inputs=[mindietorch.Input((batch_size, in_channels, video_lenth, height, width), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version="Ascend910B4", + optimization_level=0 + ) + torch.jit.save(compiled_model, compiled_path) + +def main(): + args = parse_arguments() + device_id = args.device_id + save_dir = args.output_dir + mindietorch.set_device(device_id) + batch_size = 1 + + export_textencoder(args, save_dir, batch_size) + export_dit(args, save_dir, batch_size*2) + export_vae(args, save_dir, batch_size) + print("export model done!") + mindietorch.finalize() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MultiModal/OpenSora-1.0/inference.py b/MindIE/MultiModal/OpenSora-1.0/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c31da68eb28a2f5c2436ce4a7d5f45bbd3998d12 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/inference.py @@ -0,0 +1,142 @@ +import os +import torch +import time +import torch.distributed as dist +from mmengine.runner import set_random_seed +from opensora.datasets import save_sample +from opensora.registry import MODELS, SCHEDULERS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.misc import to_torch_dtype +from opensora.models.text_encoder.t5 import T5Encoder +from opensora.models.vae.vae import VideoAutoencoderKL +import mindietorch + +def load_prompts(prompt_path): + with open(prompt_path, "r") as f: + prompts = [line.strip() for line in f.readlines()] + return prompts + + +def main(): + # ====================================================== + # 1. cfg and init distributed env + # ====================================================== + cfg = parse_configs(training=False) + print(cfg) + + image_size1, image_size2 = cfg.image_size[0], cfg.image_size[1] + device_id = 0 + vae_npu = None + use_mindie = False + output_dir = cfg.output_dir + absolute_path = os.path.abspath(output_dir) + if cfg.use_mindie == 0: + print("inference by CPU") + use_mindie = False + elif cfg.use_mindie == 1: + print("inference by MindIE") + use_mindie = True + device_id = cfg.device_id + mindietorch.set_device(device_id) + vae_npu = torch.jit.load(f"{output_dir}/vae/vae_{image_size1}_{image_size2}_compiled.pt") + + enable_sequence_parallelism = False + + # ====================================================== + # 2. runtime variables + # ====================================================== + torch.set_grad_enabled(False) + device = "cpu" + dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) + prompts = load_prompts(cfg.prompt_path) + + # ====================================================== + # 3. build model & load weights + # ====================================================== + # 3.1. build model + input_size = (cfg.num_frames, *cfg.image_size) + vae = VideoAutoencoderKL(cfg.vae_path) + latent_size = vae.get_latent_size(input_size) + text_encoder = T5Encoder( + from_pretrained=cfg.t5_path, + model_max_length=120, + device=device, + use_mindie=use_mindie, + device_id=device_id, + absolute_path=absolute_path + ) + model = build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + dtype=dtype, + enable_sequence_parallelism=enable_sequence_parallelism, + use_mindie=use_mindie, + device_id=device_id, + absolute_path=absolute_path + ) + text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance + + # 3.2. move to device & eval + vae = vae.to(device, dtype).eval() + model = model.to(device, dtype).eval() + + # 3.3. build scheduler + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # 3.4. support for multi-resolution + model_args = dict() + if cfg.multi_resolution: + image_size = cfg.image_size + hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + model_args["data_info"] = dict(ar=ar, hw=hw) + + # ====================================================== + # 4. inference + # ====================================================== + sample_idx = 0 + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + use_cache = 1 + cache_steps = [] + use_time = 0 + infer_num = 0 + for i in range(0, len(prompts), cfg.batch_size): + batch_prompts = prompts[i : i + cfg.batch_size] + infer_num += 1 + start_time = time.time() + samples = scheduler.sample( + model, + text_encoder, + z_size=(vae.out_channels, *latent_size), + prompts=batch_prompts, + device=device, + additional_args=model_args, + use_cache=use_cache, + cache_steps=cache_steps + ) + if use_mindie: + samples = vae_npu(samples.to(dtype).to(f"npu:{device_id}")).to('cpu') + else: + samples = vae.decode(samples.to(dtype)) + + if i > 4: + use_time += (time.time() - start_time) + + for idx, sample in enumerate(samples): + print(f"Prompt: {batch_prompts[idx]}") + save_path = os.path.join(save_dir, f"sample_{sample_idx}") + save_sample(sample, fps=cfg.fps, save_path=save_path) + sample_idx += 1 + if use_mindie: + mindietorch.finalize() + infer_num = infer_num - 5 + print(f"average time: {use_time / infer_num:.3f}s\n") + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3175b2df160a6b1215dc75eb1cc1a91bfc51ae0 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/__init__.py @@ -0,0 +1,4 @@ +from .acceleration import * +from .datasets import * +from .models import * +from .registry import * diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/checkpoint.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d832a0105ac278982feee34109bc585b4bf4d9d0 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/checkpoint.py @@ -0,0 +1,24 @@ +from collections.abc import Iterable + +import torch.nn as nn +from torch.utils.checkpoint import checkpoint, checkpoint_sequential + + +def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): + assert isinstance(model, nn.Module) + + def set_attr(module): + module.grad_checkpointing = True + module.fp32_attention = use_fp32_attention + module.grad_checkpointing_step = gc_step + + model.apply(set_attr) + + +def auto_grad_checkpoint(module, *args, **kwargs): + if getattr(module, "grad_checkpointing", False): + if not isinstance(module, Iterable): + return checkpoint(module, *args, **kwargs) + gc_step = module[0].grad_checkpointing_step + return checkpoint_sequential(module, gc_step, *args, **kwargs) + return module(*args, **kwargs) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/communications.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/communications.py new file mode 100644 index 0000000000000000000000000000000000000000..d0900d20841248a250b5aeb31755fac689474ff8 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/communications.py @@ -0,0 +1,188 @@ +import torch +import torch.distributed as dist + + +# ==================== +# All-To-All +# ==================== +def _all_to_all( + input_: torch.Tensor, + world_size: int, + group: dist.ProcessGroup, + scatter_dim: int, + gather_dim: int, +): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class _AllToAll(torch.autograd.Function): + """All-to-all communication. + + Args: + input_: input matrix + process_group: communication group + scatter_dim: scatter dimension + gather_dim: gather dimension + """ + + @staticmethod + def forward(ctx, input_, process_group, scatter_dim, gather_dim): + ctx.process_group = process_group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.world_size = dist.get_world_size(process_group) + output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim) + return output + + @staticmethod + def backward(ctx, grad_output): + grad_output = _all_to_all( + grad_output, + ctx.world_size, + ctx.process_group, + ctx.gather_dim, + ctx.scatter_dim, + ) + return ( + grad_output, + None, + None, + None, + ) + + +def all_to_all( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + scatter_dim: int = 2, + gather_dim: int = 1, +): + return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) + + +def _gather( + input_: torch.Tensor, + world_size: int, + group: dist.ProcessGroup, + gather_dim: int, +): + if gather_list is None: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + dist.gather(input_, gather_list, group=group, gather_dim=gather_dim) + return gather_list + + +# ==================== +# Gather-Split +# ==================== + + +def _split(input_, pg: dist.ProcessGroup, dim=-1): + # skip if only one rank involved + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, pg: dist.ProcessGroup, dim=-1): + # skip if only one rank involved + input_ = input_.contiguous() + world_size = dist.get_world_size(pg) + dist.get_rank(pg) + + if world_size == 1: + return input_ + + # all gather + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + assert input_.device.type == "cuda" + torch.distributed.all_gather(tensor_list, input_, group=pg) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _gather(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim, grad_scale): + ctx.mode = process_group + ctx.dim = dim + ctx.grad_scale = grad_scale + return _gather(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale == "up": + grad_output = grad_output * dist.get_world_size(ctx.mode) + elif ctx.grad_scale == "down": + grad_output = grad_output / dist.get_world_size(ctx.mode) + + return _split(grad_output, ctx.mode, ctx.dim), None, None, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + process_group: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, process_group, dim, grad_scale): + ctx.mode = process_group + ctx.dim = dim + ctx.grad_scale = grad_scale + return _split(input_, process_group, dim) + + @staticmethod + def backward(ctx, grad_output): + if ctx.grad_scale == "up": + grad_output = grad_output * dist.get_world_size(ctx.mode) + elif ctx.grad_scale == "down": + grad_output = grad_output / dist.get_world_size(ctx.mode) + return _gather(grad_output, ctx.mode, ctx.dim), None, None, None + + +def split_forward_gather_backward(input_, process_group, dim, grad_scale=1.0): + return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale) + + +def gather_forward_split_backward(input_, process_group, dim, grad_scale=None): + return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/parallel_states.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/parallel_states.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2893e33c86da4cb8a5170566917355af882825 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/parallel_states.py @@ -0,0 +1,19 @@ +import torch.distributed as dist + +_GLOBAL_PARALLEL_GROUPS = dict() + + +def set_data_parallel_group(group: dist.ProcessGroup): + _GLOBAL_PARALLEL_GROUPS["data"] = group + + +def get_data_parallel_group(): + return _GLOBAL_PARALLEL_GROUPS.get("data", None) + + +def set_sequence_parallel_group(group: dist.ProcessGroup): + _GLOBAL_PARALLEL_GROUPS["sequence"] = group + + +def get_sequence_parallel_group(): + return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/plugin.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..c657a9539d8fb1f0d65e8f452777a4bb73a84d4d --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/plugin.py @@ -0,0 +1,100 @@ +import random +from typing import Optional + +import numpy as np +import torch +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import ProcessGroupMesh +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +DP_AXIS, SP_AXIS = 0, 1 + + +class ZeroSeqParallelPlugin(LowLevelZeroPlugin): + def __init__( + self, + sp_size: int = 1, + stage: int = 2, + precision: str = "fp16", + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + reduce_bucket_size_in_m: int = 12, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + cpu_offload: bool = False, + master_weights: bool = True, + verbose: bool = False, + ) -> None: + super().__init__( + stage=stage, + precision=precision, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, + reduce_bucket_size_in_m=reduce_bucket_size_in_m, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + master_weights=master_weights, + verbose=verbose, + ) + self.sp_size = sp_size + assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size" + self.dp_size = self.world_size // sp_size + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) + self.dp_rank = self.pg_mesh.coordinate(DP_AXIS) + self.sp_rank = self.pg_mesh.coordinate(SP_AXIS) + + def __del__(self): + """Destroy the prcess groups in ProcessGroupMesh""" + self.pg_mesh.destroy_mesh_process_groups() + + def prepare_dataloader( + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, + ): + _kwargs = kwargs.copy() + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/modeling/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/modeling/t5.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/modeling/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfb80841c92a57628fba81425627053afc76a3b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/modeling/t5.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + @staticmethod + def from_native_module(module, *args, **kwargs): + assert module.__class__.__name__ == "FusedRMSNorm", ( + "Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm." + "Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48" + ) + + layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps) + layer_norm.weight.data.copy_(module.weight.data) + layer_norm = layer_norm.to(module.weight.device) + return layer_norm diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/policy/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/policy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/policy/t5_encoder.py b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/policy/t5_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..85c994ecc1a911da5f76b23819148cb1e17b16fa --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/acceleration/shardformer/policy/t5_encoder.py @@ -0,0 +1,67 @@ +from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func +from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward +from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription + + +class T5EncoderPolicy(Policy): + def config_sanity_check(self): + assert not self.shard_config.enable_tensor_parallelism + assert not self.shard_config.enable_flash_attention + + def preprocess(self): + return self.model + + def module_policy(self): + from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack + + policy = {} + + # check whether apex is installed + try: + from opensora.acceleration.shardformer.modeling.t5 import T5LayerNorm + + # recover hf from fused rms norm to T5 norm which is faster + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=T5LayerNorm, + ), + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm), + policy=policy, + target_key=T5LayerSelfAttention, + ) + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm), + policy=policy, + target_key=T5Stack, + ) + except (ImportError, ModuleNotFoundError): + pass + + # use jit operator + if self.shard_config.enable_jit_fused: + self.append_or_create_method_replacement( + description={ + "forward": get_jit_fused_T5_layer_ff_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerFF, + ) + self.append_or_create_method_replacement( + description={ + "forward": get_T5_layer_self_attention_forward(), + "dropout_add": get_jit_fused_dropout_add_func(), + }, + policy=policy, + target_key=T5LayerSelfAttention, + ) + + return policy + + def postprocess(self): + return self.model diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9b33954089543694925a8ef96b1341a01cb7c70 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/__init__.py @@ -0,0 +1,2 @@ +from .datasets import DatasetFromCSV, get_transforms_image, get_transforms_video +from .utils import prepare_dataloader, save_sample diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/datasets.py b/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9317271924d2c8a79778c2f69eb60d15343dda --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/datasets.py @@ -0,0 +1,114 @@ +import csv +import os + +import numpy as np +import torch +import torchvision +import torchvision.transforms as transforms +from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader + +from . import video_transforms +from .utils import center_crop_arr + + +def get_transforms_video(resolution=256): + transform_video = transforms.Compose( + [ + video_transforms.ToTensorVideo(), # TCHW + video_transforms.RandomHorizontalFlipVideo(), + video_transforms.UCFCenterCropVideo(resolution), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + return transform_video + + +def get_transforms_image(image_size=256): + transform = transforms.Compose( + [ + transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + return transform + + +class DatasetFromCSV(torch.utils.data.Dataset): + """load video according to the csv file. + + Args: + target_video_len (int): the number of video frames will be load. + align_transform (callable): Align different videos in a specified size. + temporal_sample (callable): Sample the target length of a video. + """ + + def __init__( + self, + csv_path, + num_frames=16, + frame_interval=1, + transform=None, + root=None, + ): + self.csv_path = csv_path + with open(csv_path, "r") as f: + reader = csv.reader(f) + self.samples = list(reader) + + ext = self.samples[0][0].split(".")[-1] + if ext.lower() in ("mp4", "avi", "mov", "mkv"): + self.is_video = True + else: + assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" + self.is_video = False + + self.transform = transform + + self.num_frames = num_frames + self.frame_interval = frame_interval + self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) + self.root = root + + def getitem(self, index): + sample = self.samples[index] + path = sample[0] + if self.root: + path = os.path.join(self.root, path) + text = sample[1] + + if self.is_video: + vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") + total_frames = len(vframes) + + # Sampling video frames + start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) + assert ( + end_frame_ind - start_frame_ind >= self.num_frames + ), f"{path} with index {index} has not enough frames." + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) + + video = vframes[frame_indice] + video = self.transform(video) # T C H W + else: + image = pil_loader(path) + image = self.transform(image) + video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) + + # TCHW -> CTHW + video = video.permute(1, 0, 2, 3) + + return {"video": video, "text": text} + + def __getitem__(self, index): + for _ in range(10): + try: + return self.getitem(index) + except Exception as e: + print(e) + index = np.random.randint(len(self)) + raise RuntimeError("Too many bad data.") + + def __len__(self): + return len(self.samples) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/utils.py b/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cd268ae610ddb557a633c4c15b09ef71df92bdec --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/utils.py @@ -0,0 +1,135 @@ +import random +from typing import Iterator, Optional + +import numpy as np +import torch +from PIL import Image +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_default_group +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler +from torchvision.io import write_video +from torchvision.utils import save_image + + +def save_sample(x, fps=8, save_path=None, normalize=True, value_range=(-1, 1)): + """ + Args: + x (Tensor): shape [C, T, H, W] + """ + assert x.ndim == 4 + + if x.shape[1] == 1: # T = 1: save as image + save_path += ".png" + x = x.squeeze(1) + save_image([x], save_path, normalize=normalize, value_range=value_range) + else: + save_path += ".mp4" + if normalize: + low, high = value_range + x.clamp_(min=low, max=high) + x.sub_(low).div_(max(high - low, 1e-5)) + + x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) + write_video(save_path, x, fps=fps, video_codec="h264") + print(f"Saved to {save_path}") + + +class StatefulDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + self.start_index: int = 0 + + def __iter__(self) -> Iterator: + iterator = super().__iter__() + indices = list(iterator) + indices = indices[self.start_index :] + return iter(indices) + + def __len__(self) -> int: + return self.num_samples - self.start_index + + def set_start_index(self, start_index: int) -> None: + self.start_index = start_index + + +def prepare_dataloader( + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + process_group: Optional[ProcessGroup] = None, + **kwargs, +): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `StatefulDistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + process_group = process_group or _get_default_group() + sampler = StatefulDistributedSampler( + dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/video_transforms.py b/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/video_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d1cec8c72a3f1a5fb7d71489b61973d0178580 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/datasets/video_transforms.py @@ -0,0 +1,501 @@ +# Copyright 2024 Vchitect/Latte + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Modified from Latte + +# - This file is adapted from https://github.com/Vchitect/Latte/blob/main/datasets/video_transforms.py + + +import numbers +import random + +import torch + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + + +def resize_scale(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + H, W = clip.size(-2), clip.size(-1) + scale_ = target_size[0] / min(H, W) + return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def center_crop_using_short_edge(clip): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + if h < w: + th, tw = h, h + i = 0 + j = int(round((w - tw) / 2.0)) + else: + th, tw = w, w + i = int(round((h - th) / 2.0)) + j = 0 + return crop(clip, i, j, th, tw) + + +def random_shift_crop(clip): + """ + Slide along the long edge, with the short edge as crop size + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + + if h <= w: + short_edge = h + else: + short_edge = w + + th, tw = short_edge, short_edge + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + # print(mean) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) + Returns: + flipped clip (torch.tensor): Size is (T, C, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) + + +class RandomCropVideo: + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: randomly cropped video clip. + size is (T, C, OH, OW) + """ + i, j, h, w = self.get_params(clip) + return crop(clip, i, j, h, w) + + def get_params(self, clip): + h, w = clip.shape[-2:] + th, tw = self.size + + if h < th or w < tw: + raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") + + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() + + return i, j, th, tw + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +class CenterCropResizeVideo: + """ + First use the short side for cropping length, + center crop video, then resize to the specified size + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop_using_short_edge(clip) + clip_center_crop_resize = resize( + clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode + ) + return clip_center_crop_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class UCFCenterCropVideo: + """ + First scale to the specified size in equal proportion to the short edge, + then center cropping + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized / center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + clip_center_crop = center_crop(clip_resize, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class KineticsRandomCropResizeVideo: + """ + Slide along the long edge, with the short edge as crop size. And resie to the desired size. + """ + + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + clip_random_crop = random_shift_crop(clip) + clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) + return clip_resize + + +class CenterCropVideo: + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: center cropped video clip. + size is (T, C, crop_size, crop_size) + """ + clip_center_crop = center_crop(clip, self.size) + return clip_center_crop + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class NormalizeVideo: + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) + """ + return normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo: + """ + Flip the video clip along the horizontal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (T, C, H, W) + Return: + clip (torch.tensor): Size is (T, C, H, W) + """ + if random.random() < self.p: + clip = hflip(clip) + return clip + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(p={self.p})" + + +# ------------------------------------------------------------ +# --------------------- Sampling --------------------------- +# ------------------------------------------------------------ +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index + + +if __name__ == "__main__": + import os + + import numpy as np + import torchvision.io as io + from torchvision import transforms + from torchvision.utils import save_image + + vframes, aframes, info = io.read_video(filename="./v_Archery_g01_c03.avi", pts_unit="sec", output_format="TCHW") + + trans = transforms.Compose( + [ + ToTensorVideo(), + RandomHorizontalFlipVideo(), + UCFCenterCropVideo(512), + # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + target_video_len = 32 + frame_interval = 1 + total_frames = len(vframes) + print(total_frames) + + temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) + + # Sampling video frames + start_frame_ind, end_frame_ind = temporal_sample(total_frames) + # print(start_frame_ind) + # print(end_frame_ind) + assert end_frame_ind - start_frame_ind >= target_video_len + frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) + print(frame_indice) + + select_vframes = vframes[frame_indice] + print(select_vframes.shape) + print(select_vframes.dtype) + + select_vframes_trans = trans(select_vframes) + print(select_vframes_trans.shape) + print(select_vframes_trans.dtype) + + select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) + print(select_vframes_trans_int.dtype) + print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) + + io.write_video("./test.avi", select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) + + for i in range(target_video_len): + save_image( + select_vframes_trans[i], os.path.join("./test000", "%04d.png" % i), normalize=True, value_range=(-1, 1) + ) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60253499b07d5c9f4e0848d1b76b26fa5d2ea048 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/__init__.py @@ -0,0 +1,6 @@ +from .dit import * +from .latte import * +from .pixart import * +from .stdit import * +from .text_encoder import * +from .vae import * diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/dit/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94548a363f00ee5bbd7c5b38eaf53d26a4919b11 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/dit/__init__.py @@ -0,0 +1 @@ +from .dit import DiT, DiT_XL_2, DiT_XL_2x2 diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/dit/dit.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/dit/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..a23dd7b56d74f6ea6575429b265a13cba88c64f1 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/dit/dit.py @@ -0,0 +1,284 @@ +# Modified from Meta DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange +from timm.models.vision_transformer import Mlp + +from opensora.acceleration.checkpoint import auto_grad_checkpoint +from opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + FinalLayer, + LabelEmbedder, + PatchEmbed3D, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + modulate, +) +from opensora.registry import MODELS +from opensora.utils.ckpt_utils import load_checkpoint + + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + enable_flashattn=False, + enable_layernorm_kernel=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.enable_flashattn = enable_flashattn + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = Attention( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flashattn=enable_flashattn, + ) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1, x, shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2, x, shift_mlp, scale_mlp)) + return x + + +@MODELS.register_module() +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=(16, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + learn_sigma=True, + condition="text", + no_temporal_pos_emb=False, + caption_channels=512, + model_max_length=77, + dtype=torch.float32, + enable_flashattn=False, + enable_layernorm_kernel=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + self.num_heads = num_heads + self.dtype = dtype + self.use_text_encoder = not condition.startswith("label") + if enable_flashattn: + assert dtype in [ + torch.float16, + torch.bfloat16, + ], f"Flash attention only supports float16 and bfloat16, but got {self.dtype}" + self.no_temporal_pos_emb = no_temporal_pos_emb + self.mlp_ratio = mlp_ratio + self.depth = depth + + self.register_buffer("pos_embed_spatial", self.get_spatial_pos_embed()) + self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, embed_dim=hidden_size) + if not self.use_text_encoder: + num_classes = int(condition.split("_")[-1]) + self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + else: + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=1, # pooled token + ) + self.t_embedder = TimestepEmbedder(hidden_size) + self.blocks = nn.ModuleList( + [ + DiTBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + enable_flashattn=enable_flashattn, + enable_layernorm_kernel=enable_layernorm_kernel, + ) + for _ in range(depth) + ] + ) + self.final_layer = FinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels) + + self.initialize_weights() + self.enable_flashattn = enable_flashattn + self.enable_layernorm_kernel = enable_layernorm_kernel + + def get_spatial_pos_embed(self): + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + self.input_size[1] // self.patch_size[1], + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def unpatchify(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def forward(self, x, t, y): + """ + Forward pass of DiT. + x: (B, C, T, H, W) tensor of inputs + t: (B,) tensor of diffusion timesteps + y: list of text + """ + # origin inputs should be float32, cast to specified dtype + x = x.to(self.dtype) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + self.pos_embed_spatial + if not self.no_temporal_pos_emb: + x = rearrange(x, "b t s d -> b s t d") + x = x + self.pos_embed_temporal + x = rearrange(x, "b s t d -> b (t s) d") + else: + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(t, dtype=x.dtype) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + if self.use_text_encoder: + y = y.squeeze(1).squeeze(1) + condition = t + y + + # blocks + for _, block in enumerate(self.blocks): + c = condition + x = auto_grad_checkpoint(block, x, c) # (B, N, D) + + # final process + x = self.final_layer(x, condition) # (B, N, num_patches * out_channels) + x = self.unpatchify(x) # (B, out_channels, T, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + if module.weight.requires_grad_: + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + # Zero-out text embedding layers: + if self.use_text_encoder: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + +@MODELS.register_module("DiT-XL/2") +def DiT_XL_2(from_pretrained=None, **kwargs): + model = DiT( + depth=28, + hidden_size=1152, + patch_size=(1, 2, 2), + num_heads=16, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model + + +@MODELS.register_module("DiT-XL/2x2") +def DiT_XL_2x2(from_pretrained=None, **kwargs): + model = DiT( + depth=28, + hidden_size=1152, + patch_size=(2, 2, 2), + num_heads=16, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/latte/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/latte/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d918ad01c676a2c2c0dc25f68aa008101773d3 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/latte/__init__.py @@ -0,0 +1 @@ +from .latte import Latte, Latte_XL_2, Latte_XL_2x2 diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/latte/latte.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/latte/latte.py new file mode 100644 index 0000000000000000000000000000000000000000..3f8f9685e00b72e601f662b49925d82a57f9e253 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/latte/latte.py @@ -0,0 +1,112 @@ +# Copyright 2024 Vchitect/Latte +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Modified from Latte +# +# +# This file is mofied from https://github.com/Vchitect/Latte/blob/main/models/latte.py +# +# With references to: +# Latte: https://github.com/Vchitect/Latte +# DiT: https://github.com/facebookresearch/DiT/tree/main + + +import torch +from einops import rearrange, repeat + +from opensora.acceleration.checkpoint import auto_grad_checkpoint +from opensora.models.dit import DiT +from opensora.registry import MODELS +from opensora.utils.ckpt_utils import load_checkpoint + + +@MODELS.register_module() +class Latte(DiT): + def forward(self, x, t, y): + """ + Forward pass of DiT. + x: (B, C, T, H, W) tensor of inputs + t: (B,) tensor of diffusion timesteps + y: list of text + """ + # origin inputs should be float32, cast to specified dtype + x = x.to(self.dtype) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + self.pos_embed_spatial + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(t, dtype=x.dtype) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + if self.use_text_encoder: + y = y.squeeze(1).squeeze(1) + condition = t + y + condition_spatial = repeat(condition, "b d -> (b t) d", t=self.num_temporal) + condition_temporal = repeat(condition, "b d -> (b s) d", s=self.num_spatial) + + # blocks + for i, block in enumerate(self.blocks): + if i % 2 == 0: + # spatial + x = rearrange(x, "b (t s) d -> (b t) s d", t=self.num_temporal, s=self.num_spatial) + c = condition_spatial + else: + # temporal + x = rearrange(x, "b (t s) d -> (b s) t d", t=self.num_temporal, s=self.num_spatial) + c = condition_temporal + if i == 1: + x = x + self.pos_embed_temporal + + x = auto_grad_checkpoint(block, x, c) # (B, N, D) + + if i % 2 == 0: + x = rearrange(x, "(b t) s d -> b (t s) d", t=self.num_temporal, s=self.num_spatial) + else: + x = rearrange(x, "(b s) t d -> b (t s) d", t=self.num_temporal, s=self.num_spatial) + + # final process + x = self.final_layer(x, condition) # (B, N, num_patches * out_channels) + x = self.unpatchify(x) # (B, out_channels, T, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + +@MODELS.register_module("Latte-XL/2") +def Latte_XL_2(from_pretrained=None, **kwargs): + model = Latte( + depth=28, + hidden_size=1152, + patch_size=(1, 2, 2), + num_heads=16, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model + + +@MODELS.register_module("Latte-XL/2x2") +def Latte_XL_2x2(from_pretrained=None, **kwargs): + model = Latte( + depth=28, + hidden_size=1152, + patch_size=(2, 2, 2), + num_heads=16, + **kwargs, + ) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/layers/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/layers/blocks.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/layers/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..730ac622a160a922721cf1408a28795b137d2e9c --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/layers/blocks.py @@ -0,0 +1,612 @@ +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# Latte: https://github.com/Vchitect/Latte +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- + +import math + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from timm.models.vision_transformer import Mlp + +from opensora.acceleration.communications import all_to_all, split_forward_gather_backward +from opensora.acceleration.parallel_states import get_sequence_parallel_group + +approx_gelu = lambda: nn.GELU(approximate="tanh") + + +def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool): + # if use_kernel: + # try: + # from apex.normalization import FusedLayerNorm + + # return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps) + # except ImportError: + # raise RuntimeError("FusedLayerNorm not available. Please install apex.") + # else: + # return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine) + return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine) + + +def modulate(norm_func, x, shift, scale): + # Suppose x is (B, N, D), shift is (B, D), scale is (B, D) + dtype = x.dtype + x = norm_func(x.to(torch.float32)).to(dtype) + x = x * (scale.unsqueeze(1) + 1) + shift.unsqueeze(1) + x = x.to(dtype) + return x + + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +# =============================================== +# General-purpose Layers +# =============================================== + + +class PatchEmbed3D(nn.Module): + """Video to Patch Embedding. + + Args: + patch_size (int): Patch token size. Default: (2,4,4). + in_chans (int): Number of input video channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, + patch_size=(2, 4, 4), + in_chans=3, + embed_dim=96, + norm_layer=None, + flatten=True, + ): + super().__init__() + self.patch_size = patch_size + self.flatten = flatten + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, D, H, W = x.size() + if W % self.patch_size[2] != 0: + x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) + if H % self.patch_size[1] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) + if D % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) + + x = self.proj(x) # (B C T H W) + if self.norm is not None: + D, Wh, Ww = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC + return x + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + enable_flashattn: bool = False, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.enable_flashattn = enable_flashattn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x) + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + if self.enable_flashattn: + qkv_permute_shape = (2, 0, 1, 3, 4) + else: + qkv_permute_shape = (2, 0, 3, 1, 4) + qkv = qkv.view(qkv_shape).permute(qkv_permute_shape) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + if self.enable_flashattn: + from flash_attn import flash_attn_func + + x = flash_attn_func( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + softmax_scale=self.scale, + ) + else: + dtype = q.dtype + q = q * self.scale + attn = q @ k.transpose(-2, -1) # translate attn to float32 + attn = attn.to(torch.float32) + attn = attn.softmax(dim=-1) + attn = attn.to(dtype) # cast back attn to original dtype + attn = self.attn_drop(attn) + x = attn @ v + + x_output_shape = (B, N, C) + if not self.enable_flashattn: + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SeqParallelAttention(Attention): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + enable_flashattn: bool = False, + ) -> None: + super().__init__( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + enable_flashattn=enable_flashattn, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape # for sequence parallel here, the N is a local sequence length + qkv = self.qkv(x) + qkv_shape = (B, N, 3, self.num_heads, self.head_dim) + + qkv = qkv.view(qkv_shape) + + sp_group = get_sequence_parallel_group() + + # apply all_to_all to gather sequence and split attention heads + # [B, SUB_N, 3, NUM_HEAD, HEAD_DIM] -> [B, N, 3, NUM_HEAD_PER_DEVICE, HEAD_DIM] + qkv = all_to_all(qkv, sp_group, scatter_dim=3, gather_dim=1) + + if self.enable_flashattn: + qkv_permute_shape = (2, 0, 1, 3, 4) # [3, B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM] + else: + qkv_permute_shape = (2, 0, 3, 1, 4) # [3, B, NUM_HEAD_PER_DEVICE, N, HEAD_DIM] + qkv = qkv.permute(qkv_permute_shape) + + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + if self.enable_flashattn: + from flash_attn import flash_attn_func + + x = flash_attn_func( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + softmax_scale=self.scale, + ) + else: + dtype = q.dtype + q = q * self.scale + attn = q @ k.transpose(-2, -1) # translate attn to float32 + attn = attn.to(torch.float32) + attn = attn.softmax(dim=-1) + attn = attn.to(dtype) # cast back attn to original dtype + attn = self.attn_drop(attn) + x = attn @ v + + if not self.enable_flashattn: + x = x.transpose(1, 2) + + # apply all to all to gather back attention heads and split sequence + # [B, N, NUM_HEAD_PER_DEVICE, HEAD_DIM] -> [B, SUB_N, NUM_HEAD, HEAD_DIM] + x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2) + + # reshape outputs back to [B, N, C] + x_output_shape = (B, N, C) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0): + super(MultiHeadCrossAttention, self).__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = nn.Linear(d_model, d_model) + self.kv_linear = nn.Linear(d_model, d_model * 2) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(d_model, d_model) + self.proj_drop = nn.Dropout(proj_drop) + + def torch_impl(self, q, k, v, mask, B, N, C): + q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) + + attn_mask = torch.zeros(B, N, k.shape[2], dtype=torch.float32, device=q.device) + for i, m in enumerate(mask): + attn_mask[i, :, m:] = -1e8 + + scale = 1 / math.sqrt(self.head_dim) + q = q * scale + attn = q @ k.transpose(-2, -1) + attn = attn.to(torch.float32) + if mask is not None: + attn = attn + attn_mask.unsqueeze(1) + attn = attn.softmax(-1) + attn = attn.to(v.dtype) + out = attn @ v + + x = out.transpose(1, 2).contiguous().view(B, N, C) + return x + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + B, N, C = x.shape + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + # attn_bias = None + # if mask is not None: + # attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + # x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + # x = x.view(B, -1, C) + x = self.torch_impl(q, k, v, mask, B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SeqParallelMultiHeadCrossAttention(MultiHeadCrossAttention): + def __init__( + self, + d_model, + num_heads, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__(d_model=d_model, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + sp_group = get_sequence_parallel_group() + sp_size = dist.get_world_size(sp_group) + B, SUB_N, C = x.shape + N = SUB_N * sp_size + + # shape: + # q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM] + q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(B, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + # apply all_to_all to gather sequence and split attention heads + q = all_to_all(q, sp_group, scatter_dim=2, gather_dim=1) + + k = split_forward_gather_backward(k, get_sequence_parallel_group(), dim=2, grad_scale="down") + v = split_forward_gather_backward(v, get_sequence_parallel_group(), dim=2, grad_scale="down") + + q = q.view(1, -1, self.num_heads // sp_size, self.head_dim) + k = k.view(1, -1, self.num_heads // sp_size, self.head_dim) + v = v.view(1, -1, self.num_heads // sp_size, self.head_dim) + + # compute attention + attn_bias = None + if mask is not None: + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) + + # apply all to all to gather back attention heads and scatter sequence + x = x.view(B, -1, self.num_heads // sp_size, self.head_dim) + x = all_to_all(x, sp_group, scatter_dim=1, gather_dim=2) + + # apply output projection + x = x.view(B, -1, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, num_patch, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final, x, shift, scale) + x = self.linear(x) + return x + + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + + def __init__(self, hidden_size, num_patch, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5) + self.out_channels = out_channels + + def forward(self, x, t): + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +# =============================================== +# Embedding Layers for Timesteps and Class Labels +# =============================================== + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half) + freqs = freqs.to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, dtype): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + if t_freq.dtype != dtype: + t_freq = t_freq.to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + return self.embedding_table(labels) + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + assert s.ndim == 2 + if s.shape[0] != bs: + s = s.repeat(bs // s.shape[0], 1) + assert s.shape[0] == bs + b, dims = s.shape[0], s.shape[1] + s = rearrange(s, "b d -> (b d)") + s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) + s_emb = self.mlp(s_freq) + s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return s_emb + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120): + super().__init__() + self.y_proj = Mlp( + in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0 + ) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +# =============================================== +# Sine/Cosine Positional Embedding Functions +# =============================================== +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if not isinstance(grid_size, tuple): + grid_size = (grid_size, grid_size) + + grid_h = np.arange(grid_size[0], dtype=np.float32) / scale + grid_w = np.arange(grid_size[1], dtype=np.float32) / scale + if base_size is not None: + grid_h *= base_size / grid_size[0] + grid_w *= base_size / grid_size[1] + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0): + pos = np.arange(0, length)[..., None] / scale + return get_1d_sincos_pos_embed_from_grid(embed_dim, pos) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/pixart/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/pixart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8320211a82bd0a4689b2afa9b600adeee6cfeb --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/pixart/__init__.py @@ -0,0 +1 @@ +from .pixart import PixArt, PixArt_XL_2 diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/pixart/pixart.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/pixart/pixart.py new file mode 100644 index 0000000000000000000000000000000000000000..849470ae438aebec8b700e429a03857d9125e76f --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/pixart/pixart.py @@ -0,0 +1,389 @@ +# Adapted from PixArt +# +# Copyright (C) 2023 PixArt-alpha/PixArt-alpha +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# DiT: https://github.com/facebookresearch/DiT/tree/main +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +# from .builder import MODELS +from opensora.acceleration.checkpoint import auto_grad_checkpoint +from opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + MultiHeadCrossAttention, + PatchEmbed3D, + SeqParallelAttention, + SeqParallelMultiHeadCrossAttention, + SizeEmbedder, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + t2i_modulate, +) +from opensora.registry import MODELS +from opensora.utils.ckpt_utils import load_checkpoint + + +class PixArtBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__( + self, + hidden_size, + num_heads, + mlp_ratio=4.0, + drop_path=0.0, + enable_flashattn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.enable_flashattn = enable_flashattn + self._enable_sequence_parallelism = enable_sequence_parallelism + + if enable_sequence_parallelism: + self.attn_cls = SeqParallelAttention + self.mha_cls = SeqParallelMultiHeadCrossAttention + else: + self.attn_cls = Attention + self.mha_cls = MultiHeadCrossAttention + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flashattn=enable_flashattn, + ) + self.cross_attn = self.mha_cls(hidden_size, num_heads) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward(self, x, y, t, mask=None): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +@MODELS.register_module() +class PixArt(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=(1, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path: float = 0.0, + no_temporal_pos_emb=False, + caption_channels=4096, + model_max_length=120, + dtype=torch.float32, + freeze=None, + space_scale=1.0, + time_scale=1.0, + enable_flashattn=False, + enable_layernorm_kernel=False, + ): + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + self.base_size = int(np.sqrt(self.num_spatial)) + self.num_heads = num_heads + self.dtype = dtype + self.no_temporal_pos_emb = no_temporal_pos_emb + self.depth = depth + self.mlp_ratio = mlp_ratio + self.enable_flashattn = enable_flashattn + self.enable_layernorm_kernel = enable_layernorm_kernel + self.space_scale = space_scale + self.time_scale = time_scale + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + + self.register_buffer("pos_embed", self.get_spatial_pos_embed()) + self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + PixArtBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + drop_path=drop_path[i], + enable_flashattn=enable_flashattn, + enable_layernorm_kernel=enable_layernorm_kernel, + ) + for i in range(depth) + ] + ) + self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels) + + self.initialize_weights() + if freeze is not None: + assert freeze in ["text"] + if freeze == "text": + self.freeze_text() + + def forward(self, x, timestep, y, mask=None): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + self.pos_embed + if not self.no_temporal_pos_emb: + x = rearrange(x, "b t s d -> b s t d") + x = x + self.pos_embed_temporal + x = rearrange(x, "b s t d -> b (t s) d") + else: + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(timestep, dtype=x.dtype) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + for block in self.blocks: + x = auto_grad_checkpoint(block, x, y, t0, y_lens) + + # final process + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + def unpatchify(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def get_spatial_pos_embed(self, grid_size=None): + if grid_size is None: + grid_size = self.input_size[1:] + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]), + scale=self.space_scale, + base_size=self.base_size, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + scale=self.time_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def freeze_text(self): + for n, p in self.named_parameters(): + if "cross_attn" in n: + p.requires_grad = False + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +@MODELS.register_module() +class PixArtMS(PixArt): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3" + self.csize_embedder = SizeEmbedder(self.hidden_size // 3) + self.ar_embedder = SizeEmbedder(self.hidden_size // 3) + + def forward(self, x, timestep, y, mask=None, data_info=None): + """ + Forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + + c_size = data_info["hw"] + ar = data_info["ar"] + pos_embed = self.get_spatial_pos_embed((x.shape[-2], x.shape[-1])).to(x.dtype) + + # embedding + x = self.x_embedder(x) # (B, N, D) + x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) + x = x + pos_embed.to(x.device) + if not self.no_temporal_pos_emb: + x = rearrange(x, "b t s d -> b s t d") + x = x + self.pos_embed_temporal + x = rearrange(x, "b s t d -> b (t s) d") + else: + x = rearrange(x, "b t s d -> b (t s) d") + + t = self.t_embedder(timestep, dtype=x.dtype) # (N, D) + B = x.shape[0] + csize = self.csize_embedder(c_size, B) + ar = self.ar_embedder(ar, B) + t = t + torch.cat([csize, ar], dim=1) + + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + for block in self.blocks: + x = block(x, y, t0, y_lens) + + # final process + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + + # cast to float32 for better accuracy + x = x.to(torch.float32) + return x + + +@MODELS.register_module("PixArt-XL/2") +def PixArt_XL_2(from_pretrained=None, **kwargs): + model = PixArt(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model + + +@MODELS.register_module("PixArtMS-XL/2") +def PixArtMS_XL_2(from_pretrained=None, **kwargs): + model = PixArtMS(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/stdit/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/stdit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca2cc91f8316e80a7e594e432431913b447207b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/stdit/__init__.py @@ -0,0 +1 @@ +from .stdit import STDiT diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/stdit/stdit.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/stdit/stdit.py new file mode 100644 index 0000000000000000000000000000000000000000..14fd67d1b332562437ee3b1d76abf6ac90a475ee --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/stdit/stdit.py @@ -0,0 +1,490 @@ +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import os +from einops import rearrange +from timm.models.layers import DropPath +from timm.models.vision_transformer import Mlp + +from opensora.acceleration.checkpoint import auto_grad_checkpoint +from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward +from opensora.acceleration.parallel_states import get_sequence_parallel_group +from opensora.models.layers.blocks import ( + Attention, + CaptionEmbedder, + MultiHeadCrossAttention, + PatchEmbed3D, + SeqParallelAttention, + SeqParallelMultiHeadCrossAttention, + T2IFinalLayer, + TimestepEmbedder, + approx_gelu, + get_1d_sincos_pos_embed, + get_2d_sincos_pos_embed, + get_layernorm, + t2i_modulate, +) +from opensora.registry import MODELS +from opensora.utils.ckpt_utils import load_checkpoint + + +class STDiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + d_s=None, + d_t=None, + mlp_ratio=4.0, + drop_path=0.0, + enable_flashattn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.enable_flashattn = enable_flashattn + self._enable_sequence_parallelism = enable_sequence_parallelism + + if enable_sequence_parallelism: + self.attn_cls = SeqParallelAttention + self.mha_cls = SeqParallelMultiHeadCrossAttention + else: + self.attn_cls = Attention + self.mha_cls = MultiHeadCrossAttention + + self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.attn = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flashattn=enable_flashattn, + ) + self.cross_attn = self.mha_cls(hidden_size, num_heads) + self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + + # temporal attention + self.d_s = d_s + self.d_t = d_t + + if self._enable_sequence_parallelism: + sp_size = dist.get_world_size(get_sequence_parallel_group()) + # make sure d_t is divisible by sp_size + assert d_t % sp_size == 0 + self.d_t = d_t // sp_size + + self.attn_temp = self.attn_cls( + hidden_size, + num_heads=num_heads, + qkv_bias=True, + enable_flashattn=self.enable_flashattn, + ) + + def forward(self, x, y, t, mask=None, tpe=None): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + t.reshape(B, 6, -1) + ).chunk(6, dim=1) + x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa) + + # spatial branch + x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s) + x_s = self.attn(x_s) + x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s) + x = x + self.drop_path(gate_msa * x_s) + + # temporal branch + x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s) + if tpe is not None: + x_t = x_t + tpe + x_t = self.attn_temp(x_t) + x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s) + x = x + self.drop_path(gate_msa * x_t) + + # cross attn + x = x + self.cross_attn(x, y, mask) + + # mlp + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +@MODELS.register_module() +class STDiT(nn.Module): + def __init__( + self, + input_size=(1, 32, 32), + in_channels=4, + patch_size=(1, 2, 2), + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path=0.0, + no_temporal_pos_emb=False, + caption_channels=4096, + model_max_length=120, + dtype=torch.float32, + space_scale=1.0, + time_scale=1.0, + freeze=None, + enable_flashattn=False, + enable_layernorm_kernel=False, + enable_sequence_parallelism=False, + use_mindie=False, + device_id=0, + absolute_path=None, + cache_start=7, # 256: 7, 512:4 + num_cache_layer=11 # 256: 11, 512:21 + ): + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.hidden_size = hidden_size + self.patch_size = patch_size + self.input_size = input_size + num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)]) + self.num_patches = num_patches + self.num_temporal = input_size[0] // patch_size[0] + self.num_spatial = num_patches // self.num_temporal + self.num_heads = num_heads + self.dtype = dtype + self.no_temporal_pos_emb = no_temporal_pos_emb + self.depth = depth + self.mlp_ratio = mlp_ratio + self.enable_flashattn = enable_flashattn + self.enable_layernorm_kernel = enable_layernorm_kernel + self.space_scale = space_scale + self.time_scale = time_scale + + self.register_buffer("pos_embed", self.get_spatial_pos_embed()) + self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed()) + + self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, + hidden_size=hidden_size, + uncond_prob=class_dropout_prob, + act_layer=approx_gelu, + token_num=model_max_length, + ) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] + self.blocks = nn.ModuleList( + [ + STDiTBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=self.mlp_ratio, + drop_path=drop_path[i], + enable_flashattn=self.enable_flashattn, + enable_layernorm_kernel=self.enable_layernorm_kernel, + enable_sequence_parallelism=enable_sequence_parallelism, + d_t=self.num_temporal, + d_s=self.num_spatial, + ) + for i in range(self.depth) + ] + ) + self.final_layer = T2IFinalLayer(hidden_size, np.prod(self.patch_size), self.out_channels) + + # init model + self.initialize_weights() + self.initialize_temporal() + if freeze is not None: + assert freeze in ["not_temporal", "text"] + if freeze == "not_temporal": + self.freeze_not_temporal() + elif freeze == "text": + self.freeze_text() + + # sequence parallel related configs + self.enable_sequence_parallelism = enable_sequence_parallelism + if enable_sequence_parallelism: + self.sp_rank = dist.get_rank(get_sequence_parallel_group()) + else: + self.sp_rank = None + self.use_mindie = use_mindie + self.absolute_path = absolute_path + + if self.use_mindie: + if os.path.exists(f"{self.absolute_path}/dit/dit_256_256_compiled.pt"): + self.model_npu = torch.jit.load(f"{self.absolute_path}/dit/dit_256_256_compiled.pt").eval() + if os.path.exists(f"{self.absolute_path}/dit/dit_256_256_0_compiled.pt"): + self.model_npu_cache = torch.jit.load(f"{self.absolute_path}/dit/dit_256_256_0_compiled.pt").eval() + if os.path.exists(f"{self.absolute_path}/dit/dit_256_256_1_compiled.pt"): + self.model_npu_skip = torch.jit.load(f"{self.absolute_path}/dit/dit_256_256_1_compiled.pt").eval() + self.device_id = device_id + + # delta cache configs + self.num_cache_layer = num_cache_layer + self.cache_start = cache_start + + def forward(self, x, timestep, y, mask=None, use_cache=0, current_if_cache=0, delta_cache=torch.tensor([])): + """ + Forward pass of STDiT. + Args: + x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W] + timestep (torch.Tensor): diffusion time steps; of shape [B] + y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C] + mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token] + + + Returns: + x (torch.Tensor): output latent representation; of shape [B, C, T, H, W] + """ + if self.use_mindie: + if not use_cache: + out = self.model_npu(x.to(f"npu:{self.device_id}"), + timestep.to(f"npu:{self.device_id}"), + y.to(f"npu:{self.device_id}"), + mask.to(f"npu:{self.device_id}")) + x = out.to('cpu') + elif not current_if_cache: + use_cache_tensor = torch.ones([1], dtype=torch.long) + cache_flag = torch.zeros([1], dtype=torch.long) + + # x:[2,4,16,32,32] + out = self.model_npu_cache(x.to(f"npu:{self.device_id}"), + timestep.to(f"npu:{self.device_id}"), + y.to(f"npu:{self.device_id}"), + mask.to(f"npu:{self.device_id}"), + use_cache_tensor.to(f"npu:{self.device_id}"), + cache_flag.to(f"npu:{self.device_id}"), + ) + x = out[0].to('cpu') # [2,8,16,32,32] + delta_cache = out[1] + else: + use_cache_tensor = torch.ones([1], dtype=torch.long) + skip_flag = torch.ones([1], dtype=torch.long) + out = self.model_npu_skip(x.to(f"npu:{self.device_id}"), + timestep.to(f"npu:{self.device_id}"), + y.to(f"npu:{self.device_id}"), + mask.to(f"npu:{self.device_id}"), + use_cache_tensor.to(f"npu:{self.device_id}"), + skip_flag.to(f"npu:{self.device_id}"), + delta_cache, ) + x = out.to('cpu') + else: + x = x.to(self.dtype) + timestep = timestep.to(self.dtype) + y = y.to(self.dtype) + + # embedding + x = self.x_embedder(x) # [B, N, C] + x = rearrange(x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial) + x = x + self.pos_embed + x = rearrange(x, "B T S C -> B (T S) C") + + # shard over the sequence dim if sp is enabled + if self.enable_sequence_parallelism: + x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down") + + t = self.t_embedder(timestep, dtype=x.dtype) # [B, C] + t0 = self.t_block(t) # [B, C] + y = self.y_embedder(y, self.training) # [B, 1, N_token, C] + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + + # blocks + if not use_cache: + x = self.forward_blocks(x, y, t0, y_lens, use_cache, current_if_cache, delta_cache) + elif not current_if_cache: + x, delta_cache = self.forward_blocks(x, y, t0, y_lens, use_cache, current_if_cache, delta_cache) + else: + x = self.forward_blocks(x, y, t0, y_lens, use_cache, current_if_cache, delta_cache) + + if self.enable_sequence_parallelism: + x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up") + # x.shape: [B, N, C] + + # final process + x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out] + x = self.unpatchify(x) # [B, C_out, T, H, W] + + # cast to float32 for better accuracy + x = x.to(torch.float32) + + if not use_cache: + return x + elif not current_if_cache: + return (x, delta_cache) + else: + return x + + # forward blocks in range [start_idx, end_idx) + # return input and output + def forward_blocks_range(self, x, y, t0, y_lens, start_idx, end_idx): + for block_idx, block in enumerate(self.blocks[start_idx: end_idx]): + if (start_idx == 0) & (block_idx == 0): + tpe = self.pos_embed_temporal + else: + tpe = None + x = block(x, y, t0, y_lens, tpe) + return x + + def forward_blocks(self, x, y, t0, y_lens, use_cache=0, current_if_cache=0, delta_cache=torch.tensor([])): + + if not use_cache: + x = self.forward_blocks_range(x, y, t0, y_lens, + 0, len(self.blocks)) + else: + cache_end = np.minimum(self.cache_start + self.num_cache_layer, len(self.blocks)) + # 1.0 infer [0, cache_start) + x = self.forward_blocks_range(x, y, t0, y_lens, + 0, self.cache_start) + if not current_if_cache: + # 2.0 infer [cache_start, cache_end) + tmp = self.forward_blocks_range(x, y, t0, y_lens, + self.cache_start, cache_end) + delta_cache = tmp - x + else: + tmp = x + delta_cache + + # 3.0 infer [cache_end, len(self.blocks)) + x = self.forward_blocks_range(tmp, y, t0, y_lens, + cache_end, len(self.blocks)) + if not use_cache: + return x + elif not current_if_cache: + return x, delta_cache + else: + return x + + def unpatchify(self, x): + """ + Args: + x (torch.Tensor): of shape [B, N, C] + + Return: + x (torch.Tensor): of shape [B, C_out, T, H, W] + """ + + N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + T_p, H_p, W_p = self.patch_size + x = rearrange( + x, + "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)", + N_t=N_t, + N_h=N_h, + N_w=N_w, + T_p=T_p, + H_p=H_p, + W_p=W_p, + C_out=self.out_channels, + ) + return x + + def unpatchify_old(self, x): + c = self.out_channels + t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)] + pt, ph, pw = self.patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = rearrange(x, "n t h w r p q c -> n c t r h p w q") + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + return imgs + + def get_spatial_pos_embed(self, grid_size=None): + if grid_size is None: + grid_size = self.input_size[1:] + pos_embed = get_2d_sincos_pos_embed( + self.hidden_size, + (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]), + scale=self.space_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def get_temporal_pos_embed(self): + pos_embed = get_1d_sincos_pos_embed( + self.hidden_size, + self.input_size[0] // self.patch_size[0], + scale=self.time_scale, + ) + pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) + return pos_embed + + def freeze_not_temporal(self): + for n, p in self.named_parameters(): + if "attn_temp" not in n: + p.requires_grad = False + + def freeze_text(self): + for n, p in self.named_parameters(): + if "cross_attn" in n: + p.requires_grad = False + + def initialize_temporal(self): + for block in self.blocks: + nn.init.constant_(block.attn_temp.proj.weight, 0) + nn.init.constant_(block.attn_temp.proj.bias, 0) + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + nn.init.normal_(self.t_block[1].weight, std=0.02) + + # Initialize caption embedding MLP: + nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) + nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) + + # Zero-out adaLN modulation layers in PixArt blocks: + for block in self.blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + +@MODELS.register_module("STDiT-XL/2") +def STDiT_XL_2(from_pretrained=None, use_mindie=False, device_id=0, absolute_path=None, **kwargs): + model = STDiT( + depth=28, + hidden_size=1152, + patch_size=(1, 2, 2), + num_heads=16, + use_mindie=use_mindie, + device_id=device_id, + absolute_path=absolute_path, + **kwargs) + if from_pretrained is not None: + load_checkpoint(model, from_pretrained) + return model diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc6a9995d9652099a51159907eb1ebb7cc219c2 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/__init__.py @@ -0,0 +1,3 @@ +from .classes import ClassEncoder +from .clip import ClipEncoder +from .t5 import T5Encoder diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/classes.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/classes.py new file mode 100644 index 0000000000000000000000000000000000000000..f02c9f299f9a611f62141d063a80f38cd1b34b45 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/classes.py @@ -0,0 +1,20 @@ +import torch + +from opensora.registry import MODELS + + +@MODELS.register_module("classes") +class ClassEncoder: + def __init__(self, num_classes, model_max_length=None, device="cuda", dtype=torch.float): + self.num_classes = num_classes + self.y_embedder = None + + self.model_max_length = model_max_length + self.output_dim = None + self.device = device + + def encode(self, text): + return dict(y=torch.tensor([int(t) for t in text]).to(self.device)) + + def null(self, n): + return torch.tensor([self.num_classes] * n).to(self.device) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/clip.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..c628d02bb1aab2c6ee74be1daa0ec824dda160ff --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/clip.py @@ -0,0 +1,114 @@ +# Copyright 2024 Vchitect/Latte +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.# Modified from Latte +# +# This file is adapted from the Latte project. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# Latte: https://github.com/Vchitect/Latte +# DiT: https://github.com/facebookresearch/DiT/tree/main +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +import transformers +from transformers import CLIPTextModel, CLIPTokenizer + +from opensora.registry import MODELS + +transformers.logging.set_verbosity_error() + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(path) + self.transformer = CLIPTextModel.from_pretrained(path) + self.device = device + self.max_length = max_length + self._freeze() + + def _freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + pooled_z = outputs.pooler_output + return z, pooled_z + + def encode(self, text): + return self(text) + + +@MODELS.register_module("clip") +class ClipEncoder: + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + + def __init__( + self, + from_pretrained, + model_max_length=77, + device="cuda", + dtype=torch.float, + ): + super().__init__() + assert from_pretrained is not None, "Please specify the path to the T5 model" + + self.text_encoder = FrozenCLIPEmbedder(path=from_pretrained, max_length=model_max_length).to(device, dtype) + self.y_embedder = None + + self.model_max_length = model_max_length + self.output_dim = self.text_encoder.transformer.config.hidden_size + + def encode(self, text): + _, pooled_embeddings = self.text_encoder.encode(text) + y = pooled_embeddings.unsqueeze(1).unsqueeze(1) + return dict(y=y) + + def null(self, n): + null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + return null_y + + def to(self, dtype): + self.text_encoder = self.text_encoder.to(dtype) + return self diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/t5.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..64d714f2995823dbea5b3bf341011136fba74ec7 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/text_encoder/t5.py @@ -0,0 +1,338 @@ +# Adapted from PixArt +# +# Copyright (C) 2023 PixArt-alpha/PixArt-alpha +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# T5: https://github.com/google-research/text-to-text-transfer-transformer +# -------------------------------------------------------- + + +import html +import os +import re +import urllib.parse as ul + +import ftfy +import torch +from bs4 import BeautifulSoup +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, T5EncoderModel + +from opensora.registry import MODELS +import mindietorch + + +class T5Embedder: + available_models = ["t5-v1_1-xxl"] + bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" + ) # noqa + + def __init__( + self, + device, + dir_or_name="t5-v1_1-xxl", + *, + local_cache=False, + cache_dir=None, + hf_token=None, + use_text_preprocessing=True, + t5_model_kwargs=None, + torch_dtype=None, + use_offload_folder=None, + model_max_length=120, + use_mindie=False, + device_id=0, + absolute_path=None + ): + self.device = torch.device(device) + self.torch_dtype = torch_dtype or torch.bfloat16 + if t5_model_kwargs is None: + t5_model_kwargs = {"low_cpu_mem_usage": True, "torch_dtype": self.torch_dtype} + if use_offload_folder is not None: + t5_model_kwargs["offload_folder"] = use_offload_folder + t5_model_kwargs["device_map"] = { + "shared": self.device, + "encoder.embed_tokens": self.device, + "encoder.block.0": self.device, + "encoder.block.1": self.device, + "encoder.block.2": self.device, + "encoder.block.3": self.device, + "encoder.block.4": self.device, + "encoder.block.5": self.device, + "encoder.block.6": self.device, + "encoder.block.7": self.device, + "encoder.block.8": self.device, + "encoder.block.9": self.device, + "encoder.block.10": self.device, + "encoder.block.11": self.device, + "encoder.block.12": "disk", + "encoder.block.13": "disk", + "encoder.block.14": "disk", + "encoder.block.15": "disk", + "encoder.block.16": "disk", + "encoder.block.17": "disk", + "encoder.block.18": "disk", + "encoder.block.19": "disk", + "encoder.block.20": "disk", + "encoder.block.21": "disk", + "encoder.block.22": "disk", + "encoder.block.23": "disk", + "encoder.final_layer_norm": "disk", + "encoder.dropout": "disk", + } + else: + t5_model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device} + + self.use_text_preprocessing = use_text_preprocessing + self.hf_token = hf_token + self.cache_dir = cache_dir or os.path.expanduser("~/.cache/IF_") + self.dir_or_name = dir_or_name + tokenizer_path, path = self.cache_dir, self.cache_dir + + print(tokenizer_path) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() + self.model_max_length = model_max_length + self.use_mindie = use_mindie + self.absolute_path = absolute_path + if self.use_mindie: + self.model_npu = torch.jit.load(f"{self.absolute_path}/encoder/encoder_compiled.pt").eval() + self.device_id = device_id + + def get_text_embeddings(self, texts): + texts = [self.text_preprocessing(text) for text in texts] + + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.model_max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + if self.use_mindie: + input_ids=text_tokens_and_mask["input_ids"].to('cpu') + attention_mask=text_tokens_and_mask["attention_mask"].to('cpu') + text_encoder_embs = self.model_npu(input_ids.to(f"npu:{self.device_id}"), + attention_mask.to(f"npu:{self.device_id}")).to('cpu') + else: + with torch.no_grad(): + text_encoder_embs = self.model( + input_ids=text_tokens_and_mask["input_ids"].to(self.device), + attention_mask=text_tokens_and_mask["attention_mask"].to(self.device), + )["last_hidden_state"].detach() + return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device) + + def text_preprocessing(self, text): + if self.use_text_preprocessing: + # The exact text cleaning as was in the training stage: + text = self.clean_caption(text) + text = self.clean_caption(text) + return text + else: + return text.lower().strip() + + @staticmethod + def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + def clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = self.basic_clean(caption) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + +@MODELS.register_module("t5") +class T5Encoder: + def __init__( + self, + from_pretrained=None, + model_max_length=120, + device="cuda", + dtype=torch.float, + local_cache=True, + shardformer=False, + use_mindie=False, + device_id=0, + absolute_path=None + ): + assert from_pretrained is not None, "Please specify the path to the T5 model" + + self.t5 = T5Embedder( + device=device, + torch_dtype=dtype, + local_cache=local_cache, + cache_dir=from_pretrained, + model_max_length=model_max_length, + use_mindie=use_mindie, + device_id=device_id, + absolute_path=absolute_path + ) + self.t5.model.to(dtype=dtype) + self.y_embedder = None + + self.model_max_length = model_max_length + self.output_dim = self.t5.model.config.d_model + + if shardformer: + self.shardformer_t5() + + def shardformer_t5(self): + from colossalai.shardformer import ShardConfig, ShardFormer + + from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy + from opensora.utils.misc import requires_grad + + shard_config = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=None, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_flash_attention=False, + enable_jit_fused=True, + enable_sequence_parallelism=False, + enable_sequence_overlap=False, + ) + shard_former = ShardFormer(shard_config=shard_config) + optim_model, _ = shard_former.optimize(self.t5.model, policy=T5EncoderPolicy()) + self.t5.model = optim_model.half() + + # ensure the weights are frozen + requires_grad(self.t5.model, False) + + def encode(self, text): + caption_embs, emb_masks = self.t5.get_text_embeddings(text) + caption_embs = caption_embs[:, None] + return dict(y=caption_embs, mask=emb_masks) + + def null(self, n): + null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + return null_y diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/vae/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63510b08b2036160c01d38b0ad3484757f6bcff7 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/vae/__init__.py @@ -0,0 +1 @@ +from .vae import VideoAutoencoderKL, VideoAutoencoderKLTemporalDecoder diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/models/vae/vae.py b/MindIE/MultiModal/OpenSora-1.0/opensora/models/vae/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..363bbfed43f8b191762c5d78138159ce6eb4c4a1 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/models/vae/vae.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder +from einops import rearrange + +from opensora.registry import MODELS + + +@MODELS.register_module() +class VideoAutoencoderKL(nn.Module): + def __init__(self, from_pretrained=None, micro_batch_size=None): + super().__init__() + self.module = AutoencoderKL.from_pretrained(from_pretrained) + self.out_channels = self.module.config.latent_channels + self.patch_size = (1, 8, 8) + self.micro_batch_size = micro_batch_size + + def encode(self, x): + # x: (B, C, T, H, W) + B = x.shape[0] + x = rearrange(x, "B C T H W -> (B T) C H W") + + if self.micro_batch_size is None: + x = self.module.encode(x).latent_dist.sample().mul_(0.18215) + else: + bs = self.micro_batch_size + x_out = [] + for i in range(0, x.shape[0], bs): + x_bs = x[i : i + bs] + x_bs = self.module.encode(x_bs).latent_dist.sample().mul_(0.18215) + x_out.append(x_bs) + x = torch.cat(x_out, dim=0) + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + return x + + def decode(self, x): + # x: (B, C, T, H, W) + B = x.shape[0] + x = rearrange(x, "B C T H W -> (B T) C H W") + if self.micro_batch_size is None: + x = self.module.decode(x / 0.18215).sample + else: + bs = self.micro_batch_size + x_out = [] + for i in range(0, x.shape[0], bs): + x_bs = x[i : i + bs] + x_bs = self.module.decode(x_bs / 0.18215).sample + x_out.append(x_bs) + x = torch.cat(x_out, dim=0) + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + return x + + def get_latent_size(self, input_size): + for i in range(3): + assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size" + input_size = [input_size[i] // self.patch_size[i] for i in range(3)] + return input_size + + +@MODELS.register_module() +class VideoAutoencoderKLTemporalDecoder(nn.Module): + def __init__(self, from_pretrained=None): + super().__init__() + self.module = AutoencoderKLTemporalDecoder.from_pretrained(from_pretrained) + self.out_channels = self.module.config.latent_channels + self.patch_size = (1, 8, 8) + + def encode(self, x): + raise NotImplementedError + + def decode(self, x): + B, _, T = x.shape[:3] + x = rearrange(x, "B C T H W -> (B T) C H W") + x = self.module.decode(x / 0.18215, num_frames=T).sample + x = rearrange(x, "(B T) C H W -> B C T H W", B=B) + return x + + def get_latent_size(self, input_size): + for i in range(3): + assert input_size[i] % self.patch_size[i] == 0, "Input size must be divisible by patch size" + input_size = [input_size[i] // self.patch_size[i] for i in range(3)] + return input_size diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/registry.py b/MindIE/MultiModal/OpenSora-1.0/opensora/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7797d36bd76fd482766562d030948ddb85c17aef --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/registry.py @@ -0,0 +1,39 @@ +from copy import deepcopy + +import torch.nn as nn +from mmengine.registry import Registry + + +def build_module(module, builder, **kwargs): + """Build module from config or return the module itself. + + Args: + module (Union[dict, nn.Module]): The module to build. + builder (Registry): The registry to build module. + *args, **kwargs: Arguments passed to build function. + + Returns: + Any: The built module. + """ + if isinstance(module, dict): + cfg = deepcopy(module) + for k, v in kwargs.items(): + cfg[k] = v + return builder.build(cfg) + elif isinstance(module, nn.Module): + return module + elif module is None: + return None + else: + raise TypeError(f"Only support dict and nn.Module, but got {type(module)}.") + + +MODELS = Registry( + "model", + locations=["opensora.models"], +) + +SCHEDULERS = Registry( + "scheduler", + locations=["opensora.schedulers"], +) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97ea76f92f8b99664e35c51172e35d66d704edc4 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/__init__.py @@ -0,0 +1,2 @@ +from .dpms import DPMS +from .iddpm import IDDPM diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/dpms/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/dpms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cebbcd21fe591ace55857ff5cb64f82293335e --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/dpms/__init__.py @@ -0,0 +1,50 @@ +from functools import partial + +import torch + +from opensora.registry import SCHEDULERS + +from .dpm_solver import DPMS + + +@SCHEDULERS.register_module("dpm-solver") +class DMP_SOLVER: + def __init__(self, num_sampling_steps=None, cfg_scale=4.0): + self.num_sampling_steps = num_sampling_steps + self.cfg_scale = cfg_scale + + def sample( + self, + model, + text_encoder, + z_size, + prompts, + device, + additional_args=None, + ): + n = len(prompts) + z = torch.randn(n, *z_size, device=device) + model_args = text_encoder.encode(prompts) + y = model_args.pop("y") + null_y = text_encoder.null(n) + if additional_args is not None: + model_args.update(additional_args) + + dpms = DPMS( + partial(forward_with_dpmsolver, model), + condition=y, + uncondition=null_y, + cfg_scale=self.cfg_scale, + model_kwargs=model_args, + ) + samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep") + return samples + + +def forward_with_dpmsolver(self, x, timestep, y, **kwargs): + """ + dpm solver donnot need variance prediction + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + model_out = self.forward(x, timestep, y, **kwargs) + return model_out.chunk(2, dim=1)[0] diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/dpms/dpm_solver.py b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/dpms/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..106e59ec9c2a22de935210ecfd8153bcf7ebb551 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/dpms/dpm_solver.py @@ -0,0 +1,1570 @@ +# MIT License +# +# Copyright (c) 2022 Cheng Lu +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# +# This file is adapted from the dpm-solver project +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# PixArt: https://github.com/PixArt-alpha/PixArt-alpha +# dpm-solver: https://github.com/LuChengTHU/dpm-solver +# -------------------------------------------------------- + +import math + +import numpy as np +import torch +from tqdm import tqdm + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class NoiseScheduleVP: + def __init__( + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support the linear VPSDE for the continuous time setting. The hyperparameters for the noise + schedule are the default settings in Yang Song's ScoreSDE: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ["discrete", "linear"]: + raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear'") + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.T = 1.0 + self.log_alpha_array = ( + self.numerical_clip_alpha(log_alphas) + .reshape( + ( + 1, + -1, + ) + ) + .to(dtype=dtype) + ) + self.total_N = self.log_alpha_array.shape[1] + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + else: + self.T = 1.0 + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + + def numerical_clip_alpha(self, log_alphas, clipped_lambda=-5.1): + """ + For some beta schedules such as cosine schedule, the log-SNR has numerical isssues. + We clip the log-SNR near t=T within -5.1 to ensure the stability. + Such a trick is very useful for diffusion models with the cosine schedule, such as i-DDPM, guided-diffusion and GLIDE. + """ + log_sigmas = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_alphas)) + lambs = log_alphas - log_sigmas + idx = torch.searchsorted(torch.flip(lambs, [0]), clipped_lambda) + if idx > 0: + log_alphas = log_alphas[:-idx] + return log_alphas + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device) + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) + return t.reshape((-1,)) + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim()) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -expand_dims(sigma_t, x.dim()) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1.0, + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + return torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + else: + raise ValueError( + f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'" + ) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * ( + K - 1 + ) + [1] + else: + orders = [ + 3, + ] * ( + K - 1 + ) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [ + 2, + ] * ( + K - 1 + ) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == "logSNR": + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ), + 0, + ).to(device) + ] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s + return (x_t, {"model_s": model_s}) if return_intermediate else x_t + + def singlestep_dpm_solver_second_update( + self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver" + ): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + if r1 is None: + r1 = 1.0 / 3.0 + if r2 is None: + r2 = 2.0 / 3.0 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}") + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.0)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.0)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + return ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + return ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + def singlestep_dpm_solver_update( + self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None + ): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1 + ) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2 + ) + else: + raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}") + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}") + + def dpm_solver_adaptive( + self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver" + ): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs + ) + elif order == 3: + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs + ) + else: + raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}") + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.0): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s) + nfe += order + print("adaptive solver nfe", nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + return xt.squeeze(0) if t.shape[0] == 1 else xt + + def inverse( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample( + x, + steps=steps, + t_start=t_0, + t_end=t_T, + order=order, + skip_type=skip_type, + method=method, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, + rtol=rtol, + return_intermediate=return_intermediate, + ) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == "adaptive": + x = self.dpm_solver_adaptive( + x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type + ) + elif method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, t, step, solver_type=solver_type + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1)): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device + ) + elif method == "singlestep_fixed": + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps( + skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device + ) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError(f"Got wrong method {method}") + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + return (x, intermediates) if return_intermediate else x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + return start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] + + +def DPMS( + model, + condition, + uncondition, + cfg_scale, + model_type="noise", + noise_schedule="linear", + guidance_type="classifier-free", + model_kwargs=None, + diffusion_steps=1000, +): + if model_kwargs is None: + model_kwargs = {} + betas = torch.tensor(get_named_beta_schedule(noise_schedule, diffusion_steps)) + + ## 1. Define the noise schedule. + noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas) + + ## 2. Convert your discrete-time `model` to the continuous-time + ## noise prediction model. Here is an example for a diffusion model + ## `model` with the noise prediction type ("noise") . + model_fn = model_wrapper( + model, + noise_schedule, + model_type=model_type, + model_kwargs=model_kwargs, + guidance_type=guidance_type, + condition=condition, + unconditional_condition=uncondition, + guidance_scale=cfg_scale, + ) + ## 3. Define dpm-solver and sample by multistep DPM-Solver. + return DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb80341a80abf34213ae8e6442219e3ed2372d5 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/__init__.py @@ -0,0 +1,107 @@ +from functools import partial + +import torch + +from opensora.registry import SCHEDULERS + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +@SCHEDULERS.register_module("iddpm") +class IDDPM(SpacedDiffusion): + def __init__( + self, + num_sampling_steps=None, + timestep_respacing=None, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=True, + rescale_learned_sigmas=False, + diffusion_steps=1000, + cfg_scale=4.0, + ): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if num_sampling_steps is not None: + assert timestep_respacing is None + timestep_respacing = str(num_sampling_steps) + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + super().__init__( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), + model_var_type=( + (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + # rescale_timesteps=rescale_timesteps, + ) + + self.cfg_scale = cfg_scale + + def sample( + self, + model, + text_encoder, + z_size, + prompts, + device, + additional_args=None, + use_cache=0, + cache_steps=[] + ): + n = len(prompts) + z = torch.randn(n, *z_size, device=device) + z = torch.cat([z, z], 0) + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if additional_args is not None: + model_args.update(additional_args) + + forward = partial(forward_with_cfg, model, cfg_scale=self.cfg_scale) + samples = self.p_sample_loop( + forward, + z.shape, + z, + clip_denoised=False, + model_kwargs=model_args, + progress=True, + device=device, + use_cache=use_cache, + cache_steps=cache_steps + ) + samples, _ = samples.chunk(2, dim=0) + return samples + + +def forward_with_cfg(model, x, timestep, y, cfg_scale, **kwargs): + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = model.forward(combined, timestep, y, **kwargs) + model_out = model_out["x"] if isinstance(model_out, dict) else model_out + if isinstance(model_out, tuple): + model_out_new, tmp = model_out + eps, rest = model_out_new[:, :3], model_out_new[:, 3:] + else: + eps, rest = model_out[:, :3], model_out[:, 3:] + + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + if isinstance(model_out, tuple): + return torch.cat([eps, rest], dim=1), tmp + else: + return torch.cat([eps, rest], dim=1) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/diffusion_utils.py b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c097ac59c6de771c4aeb8b9193aba48a4dfc7c7e --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/diffusion_utils.py @@ -0,0 +1,87 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + + +import numpy as np +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/gaussian_diffusion.py b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..7450e4bd6d2ff1a44793fb1895ca6317deec05a5 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/gaussian_diffusion.py @@ -0,0 +1,902 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + +import enum +import math + +import numpy as np +import torch as th +import torch + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__(self, *, betas, model_mean_type, model_var_type, loss_type): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = ( + np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + if len(self.posterior_variance) > 1 + else np.array([]) + ) + + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, + use_cache=0, current_if_cache=0, delta_cache=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] # x:[2, 4, 16, 32, 32] + + assert t.shape == (B,) + if not use_cache: + model_output = model(x, t, **model_kwargs) + elif not current_if_cache: + model_kwargs['use_cache'] = use_cache + model_kwargs['current_if_cache'] = current_if_cache + model_kwargs['delta_cache'] = delta_cache + model_output, delta_cache = model(x, t, **model_kwargs) # 会去调用respcae.py-->__call__ + else: + model_kwargs['use_cache'] = use_cache + model_kwargs['current_if_cache'] = current_if_cache + model_kwargs['delta_cache'] = delta_cache + model_output = model(x, t, **model_kwargs) + + if isinstance(model_output, tuple): + model_output, extra = model_output + + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + "delta_cache": delta_cache + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + use_cache=0, + current_if_cache=0, + delta_cache=None + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + use_cache=use_cache, + current_if_cache=current_if_cache, + delta_cache=delta_cache + ) + + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"], "delta_cache": out['delta_cache']} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + use_cache=0, + cache_steps=None + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + use_cache=use_cache, + cache_steps=cache_steps + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + use_cache=0, + cache_steps=None + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + delta_cache = torch.tensor([]) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + if i % 2 == 0: + current_if_cache = 1 + elif i % 2 == 1: + current_if_cache = 0 + + with th.no_grad(): + + if not use_cache: + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs + ) + elif not current_if_cache: + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + use_cache=use_cache, + current_if_cache=current_if_cache, + # delta_cache=delta_cache + ) + delta_cache = out['delta_cache'] + else: + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + use_cache=use_cache, + current_if_cache=current_if_cache, + delta_cache=delta_cache + ) + + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = th.randn_like(x) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/respace.py b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..d5ea16cedc356d0b93ed1f16cc802958b2af50ac --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/respace.py @@ -0,0 +1,127 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/timestep_sampler.py b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..52b6717d528f398cd08f34c347b7fb69f4d5a9a3 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/schedulers/iddpm/timestep_sampler.py @@ -0,0 +1,150 @@ +# Adapted from DiT + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DiT: https://github.com/facebookresearch/DiT/tree/main +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# -------------------------------------------------------- + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/utils/__init__.py b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/utils/ckpt_utils.py b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/ckpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..27adfba19cffaf2a0aa587c71188f3c5702ce3cc --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/ckpt_utils.py @@ -0,0 +1,216 @@ +import functools +import json +import logging +import operator +import os +from typing import Tuple + +import colossalai +import torch +import torch.distributed as dist +import torch.nn as nn +from colossalai.booster import Booster +from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.cluster import DistCoordinator +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torchvision.datasets.utils import download_url + +pretrained_models = { + "DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt", + "DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt", + "Latte-XL-2-256x256-ucf101.pt": "https://huggingface.co/maxin-cn/Latte/resolve/main/ucf101.pt", + "PixArt-XL-2-256x256.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth", + "PixArt-XL-2-SAM-256x256.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth", + "PixArt-XL-2-512x512.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth", + "PixArt-XL-2-1024-MS.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth", +} + + +def reparameter(ckpt, name=None): + if "DiT" in name: + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + del ckpt["pos_embed"] + elif "Latte" in name: + ckpt = ckpt["ema"] + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + del ckpt["pos_embed"] + del ckpt["temp_embed"] + elif "PixArt" in name: + ckpt = ckpt["state_dict"] + ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) + del ckpt["pos_embed"] + return ckpt + + +def find_model(model_name): + """ + Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints + model = download_model(model_name) + model = reparameter(model, model_name) + return model + else: # Load a custom DiT checkpoint: + assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}" + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + if "pos_embed_temporal" in checkpoint: + del checkpoint["pos_embed_temporal"] + if "pos_embed" in checkpoint: + del checkpoint["pos_embed"] + if "ema" in checkpoint: # supports checkpoints from train.py + checkpoint = checkpoint["ema"] + return checkpoint + + +def download_model(model_name): + """ + Downloads a pre-trained DiT model from the web. + """ + assert model_name in pretrained_models + local_path = f"pretrained_models/{model_name}" + if not os.path.isfile(local_path): + os.makedirs("pretrained_models", exist_ok=True) + web_path = pretrained_models[model_name] + download_url(web_path, "pretrained_models", model_name) + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +def load_from_sharded_state_dict(model, ckpt_path): + ckpt_io = GeneralCheckpointIO() + ckpt_io.load_model(model, os.path.join(ckpt_path, "model")) + +def model_sharding(model: torch.nn.Module): + global_rank = dist.get_rank() + world_size = dist.get_world_size() + for _, param in model.named_parameters(): + padding_size = (world_size - param.numel() % world_size) % world_size + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padding_param = param.data.view(-1) + splited_params = padding_param.split(padding_param.numel() // world_size) + splited_params = splited_params[global_rank] + param.data = splited_params + + +def load_json(file_path: str): + with open(file_path, "r") as f: + return json.load(f) + + +def save_json(data, file_path: str): + with open(file_path, "w") as f: + json.dump(data, f, indent=4) + + +def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor: + return tensor[: functools.reduce(operator.mul, original_shape)] + + +def model_gathering(model: torch.nn.Module, model_shape_dict: dict): + global_rank = dist.get_rank() + global_size = dist.get_world_size() + for name, param in model.named_parameters(): + all_params = [torch.empty_like(param.data) for _ in range(global_size)] + dist.all_gather(all_params, param.data, group=dist.group.WORLD) + if int(global_rank) == 0: + all_params = torch.cat(all_params) + param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name]) + dist.barrier() + + +def record_model_param_shape(model: torch.nn.Module) -> dict: + param_shape = {} + for name, param in model.named_parameters(): + param_shape[name] = param.shape + return param_shape + + +def save( + booster: Booster, + model: nn.Module, + ema: nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + global_step: int, + batch_size: int, + coordinator: DistCoordinator, + save_dir: str, + shape_dict: dict, +): + save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}") + os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "model"), shard=True) + # ema is not boosted, so we don't need to use booster.save_model + model_gathering(ema, shape_dict) + global_rank = dist.get_rank() + if int(global_rank) == 0: + torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt")) + model_sharding(ema) + + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) + if lr_scheduler is not None: + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "global_step": global_step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + dist.barrier() + + +def load( + booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str +) -> Tuple[int, int, int]: + booster.load_model(model, os.path.join(load_dir, "model")) + # ema is not boosted, so we don't use booster.load_model + # ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"))) + ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))) + booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) + if lr_scheduler is not None: + booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) + running_states = load_json(os.path.join(load_dir, "running_states.json")) + dist.barrier() + return running_states["epoch"], running_states["step"], running_states["sample_start_index"] + + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + format="[\033[34m%(asctime)s\033[0m] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")], + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def load_checkpoint(model, ckpt_path, save_as_pt=True): + if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): + state_dict = find_model(ckpt_path) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print(f"Missing keys: {missing_keys}") + print(f"Unexpected keys: {unexpected_keys}") + elif os.path.isdir(ckpt_path): + load_from_sharded_state_dict(model, ckpt_path) + if save_as_pt: + save_path = os.path.join(ckpt_path, "model_ckpt.pt") + torch.save(model.state_dict(), save_path) + print(f"Model checkpoint saved to {save_path}") + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/utils/config_utils.py b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..007ce523c1b646ede727c966cf5f513f6791650b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/config_utils.py @@ -0,0 +1,102 @@ +import argparse +import json +import os +from glob import glob + +from mmengine.config import Config +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(training=False): + parser = argparse.ArgumentParser() + + # model config + parser.add_argument("config", help="model config file path") + + parser.add_argument("--seed", default=42, type=int, help="generation seed") + parser.add_argument("--ckpt-path", type=str, help="path to model ckpt; will overwrite cfg.ckpt_path if specified") + parser.add_argument("--vae_path", type=str, help="path to vae model; will overwrite cfg.vae_path if specified") + parser.add_argument("--t5_path", type=str, help="path to t5 model; will overwrite cfg.t5_path if specified") + parser.add_argument("--batch-size", default=None, type=int, help="batch size") + parser.add_argument("--use_mindie", default=0, type=int, help="1 is use mindie; 0 is use cpu") + parser.add_argument("--device_id", default=0, type=int, help="npu device id") + parser.add_argument("--output_dir", type=str, default="./models", help="compiled models path") + + # ====================================================== + # Inference + # ====================================================== + + if not training: + # prompt + parser.add_argument("--prompt-path", default=None, type=str, help="path to prompt txt file") + parser.add_argument("--save-dir", default=None, type=str, help="path to save generated samples") + + # hyperparameters + parser.add_argument("--num-sampling-steps", default=None, type=int, help="sampling steps") + parser.add_argument("--cfg-scale", default=None, type=float, help="balance between cond & uncond") + else: + parser.add_argument("--wandb", default=None, type=bool, help="enable wandb") + parser.add_argument("--load", default=None, type=str, help="path to continue training") + parser.add_argument("--data-path", default=None, type=str, help="path to data csv") + + return parser.parse_args() + + +def merge_args(cfg, args, training=False): + if args.ckpt_path is not None: + cfg.model["from_pretrained"] = args.ckpt_path + args.ckpt_path = None + + if not training: + if args.cfg_scale is not None: + cfg.scheduler["cfg_scale"] = args.cfg_scale + args.cfg_scale = None + + if "multi_resolution" not in cfg: + cfg["multi_resolution"] = False + for k, v in vars(args).items(): + if k in cfg and v is not None: + cfg[k] = v + + return cfg + + +def parse_configs(training=False): + args = parse_args(training) + cfg = Config.fromfile(args.config) + cfg = merge_args(cfg, args, training) + return cfg + + +def create_experiment_workspace(cfg): + """ + This function creates a folder for experiment tracking. + + Args: + args: The parsed arguments. + + Returns: + exp_dir: The path to the experiment folder. + """ + # Make outputs folder (holds all experiment subfolders) + os.makedirs(cfg.outputs, exist_ok=True) + experiment_index = len(glob(f"{cfg.outputs}/*")) + + # Create an experiment folder + model_name = cfg.model["type"].replace("/", "-") + exp_name = f"{experiment_index:03d}-F{cfg.num_frames}S{cfg.frame_interval}-{model_name}" + exp_dir = f"{cfg.outputs}/{exp_name}" + os.makedirs(exp_dir, exist_ok=True) + return exp_name, exp_dir + + +def save_training_config(cfg, experiment_dir): + with open(f"{experiment_dir}/config.txt", "w") as f: + json.dump(cfg, f, indent=4) + + +def create_tensorboard_writer(exp_dir): + tensorboard_dir = f"{exp_dir}/tensorboard" + os.makedirs(tensorboard_dir, exist_ok=True) + writer = SummaryWriter(tensorboard_dir) + return writer diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/utils/misc.py b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d162526162ea0847400a9dc9d1ab39e45d8e5abf --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/misc.py @@ -0,0 +1,286 @@ +import collections +import importlib +import logging +import os +import time +from collections import OrderedDict +from collections.abc import Sequence +from itertools import repeat + +import numpy as np +import torch +import torch.distributed as dist + + +def print_rank(var_name, var_value, rank=0): + if dist.get_rank() == rank: + print(f"[Rank {rank}] {var_name}: {var_value}") + + +def print_0(*args, **kwargs): + if dist.get_rank() == 0: + print(*args, **kwargs) + + +def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +def format_numel_str(numel: int) -> str: + B = 1024**3 + M = 1024**2 + K = 1024 + if numel >= B: + return f"{numel / B:.2f} B" + elif numel >= M: + return f"{numel / M:.2f} M" + elif numel >= K: + return f"{numel / K:.2f} K" + else: + return f"{numel}" + + +def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: + dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) + tensor.div_(dist.get_world_size()) + return tensor + + +def get_model_numel(model: torch.nn.Module) -> (int, int): + num_params = 0 + num_params_trainable = 0 + for p in model.parameters(): + num_params += p.numel() + if p.requires_grad: + num_params_trainable += p.numel() + return num_params, num_params_trainable + + +def try_import(name): + """Try to import a module. + + Args: + name (str): Specifies what module to import in absolute or relative + terms (e.g. either pkg.mod or ..mod). + Returns: + ModuleType or None: If importing successfully, returns the imported + module, otherwise returns None. + """ + try: + return importlib.import_module(name) + except ImportError: + return None + + +def transpose(x): + """ + transpose a list of list + Args: + x (list[list]): + """ + ret = list(map(list, zip(*x))) + return ret + + +def get_timestamp(): + timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time())) + return timestamp + + +def format_time(seconds): + days = int(seconds / 3600 / 24) + seconds = seconds - days * 3600 * 24 + hours = int(seconds / 3600) + seconds = seconds - hours * 3600 + minutes = int(seconds / 60) + seconds = seconds - minutes * 60 + secondsf = int(seconds) + seconds = seconds - secondsf + millis = int(seconds * 1000) + + f = "" + i = 1 + if days > 0: + f += str(days) + "D" + i += 1 + if hours > 0 and i <= 2: + f += str(hours) + "h" + i += 1 + if minutes > 0 and i <= 2: + f += str(minutes) + "m" + i += 1 + if secondsf > 0 and i <= 2: + f += str(secondsf) + "s" + i += 1 + if millis > 0 and i <= 2: + f += str(millis) + "ms" + i += 1 + if f == "": + f = "0ms" + return f + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not isinstance(data, str): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f"type {type(data)} cannot be converted to tensor.") + + +def to_ndarray(data): + if isinstance(data, torch.Tensor): + return data.numpy() + elif isinstance(data, np.ndarray): + return data + elif isinstance(data, Sequence): + return np.array(data) + elif isinstance(data, int): + return np.ndarray([data], dtype=int) + elif isinstance(data, float): + return np.array([data], dtype=float) + else: + raise TypeError(f"type {type(data)} cannot be converted to ndarray.") + + +def to_torch_dtype(dtype): + if isinstance(dtype, torch.dtype): + return dtype + elif isinstance(dtype, str): + dtype_mapping = { + "float64": torch.float64, + "float32": torch.float32, + "float16": torch.float16, + "fp32": torch.float32, + "fp16": torch.float16, + "half": torch.float16, + "bf16": torch.bfloat16, + } + if dtype not in dtype_mapping: + raise ValueError + dtype = dtype_mapping[dtype] + return dtype + else: + raise ValueError + + +def count_params(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def convert_SyncBN_to_BN2d(model_cfg): + for k in model_cfg: + v = model_cfg[k] + if k == "norm_cfg" and v["type"] == "SyncBN": + v["type"] = "BN2d" + elif isinstance(v, dict): + convert_SyncBN_to_BN2d(v) + + +def get_topk(x, dim=4, k=5): + x = to_tensor(x) + inds = x[..., dim].topk(k)[1] + return x[inds] + + +def param_sigmoid(x, alpha): + ret = 1 / (1 + (-alpha * x).exp()) + return ret + + +def inverse_param_sigmoid(x, alpha, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) / alpha + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def count_columns(df, columns): + cnt_dict = OrderedDict() + num_samples = len(df) + + for col in columns: + d_i = df[col].value_counts().to_dict() + for k in d_i: + d_i[k] = (d_i[k], d_i[k] / num_samples) + cnt_dict[col] = d_i + + return cnt_dict + + +def build_logger(work_dir, cfgname): + log_file = cfgname + ".log" + log_path = os.path.join(work_dir, log_file) + + logger = logging.getLogger(cfgname) + logger.setLevel(logging.INFO) + # formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s') + formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + + handler1 = logging.FileHandler(log_path) + handler1.setFormatter(formatter) + + handler2 = logging.StreamHandler() + handler2.setFormatter(formatter) + + logger.addHandler(handler1) + logger.addHandler(handler2) + logger.propagate = False + + return logger diff --git a/MindIE/MultiModal/OpenSora-1.0/opensora/utils/train_utils.py b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f84604392b7536aaf35f27f6aec24970e782a62b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/opensora/utils/train_utils.py @@ -0,0 +1,31 @@ +from collections import OrderedDict + +import torch + + +@torch.no_grad() +def update_ema( + ema_model: torch.nn.Module, model: torch.nn.Module, optimizer=None, decay: float = 0.9999, sharded: bool = True +) -> None: + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + if name == "pos_embed": + continue + if param.requires_grad == False: + continue + if not sharded: + param_data = param.data + ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) + else: + if param.data.dtype != torch.float32: + param_id = id(param) + master_param = optimizer._param_store.working_to_master_param[param_id] + param_data = master_param.data + else: + param_data = param.data + ema_params[name].mul_(decay).add_(param_data, alpha=1 - decay) diff --git a/MindIE/MultiModal/OpenSora-1.0/requirements.txt b/MindIE/MultiModal/OpenSora-1.0/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0724a86ba2e917b752f9c4cc822bca0737bc5eda --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/requirements.txt @@ -0,0 +1,13 @@ +colossalai +accelerate +diffusers +ftfy +gdown +mmengine +pre-commit +pyav +tensorboard +timm +tqdm +transformers +wandb diff --git a/MindIE/MultiModal/OpenSora-1.0/run.sh b/MindIE/MultiModal/OpenSora-1.0/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..45f62c52fb914f3b47b391dbc1d188e93f05a359 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/run.sh @@ -0,0 +1,6 @@ +python inference.py \ +./configs/opensora/inference/16x256x256.py \ +--ckpt-path /home/mazhixin/OpenSora-v1-HQ-16x256x256.pth \ +--prompt-path ./assets/texts/t2v_samples.txt \ +--use_mindie 1 \ +--device_id 0 \ No newline at end of file diff --git a/MindIE/MultiModal/OpenSora-1.0/scripts/inference.py b/MindIE/MultiModal/OpenSora-1.0/scripts/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..900870be6aa4679bd6141c472ec83a245c7ab189 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/scripts/inference.py @@ -0,0 +1,112 @@ +import os + +import torch +import colossalai +import torch.distributed as dist +from mmengine.runner import set_random_seed + +from opensora.datasets import save_sample +from opensora.registry import MODELS, SCHEDULERS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.misc import to_torch_dtype +from opensora.acceleration.parallel_states import set_sequence_parallel_group +from colossalai.cluster import DistCoordinator + + +def load_prompts(prompt_path): + with open(prompt_path, "r") as f: + prompts = [line.strip() for line in f.readlines()] + return prompts + + +def main(): + # ====================================================== + # 1. cfg and init distributed env + # ====================================================== + cfg = parse_configs(training=False) + print(cfg) + + # init distributed + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + + if coordinator.world_size > 1: + set_sequence_parallel_group(dist.group.WORLD) + enable_sequence_parallelism = True + else: + enable_sequence_parallelism = False + + # ====================================================== + # 2. runtime variables + # ====================================================== + torch.set_grad_enabled(False) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = to_torch_dtype(cfg.dtype) + set_random_seed(seed=cfg.seed) + prompts = load_prompts(cfg.prompt_path) + + # ====================================================== + # 3. build model & load weights + # ====================================================== + # 3.1. build model + input_size = (cfg.num_frames, *cfg.image_size) + vae = build_module(cfg.vae, MODELS) + latent_size = vae.get_latent_size(input_size) + text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32 + model = build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + dtype=dtype, + enable_sequence_parallelism=enable_sequence_parallelism, + ) + text_encoder.y_embedder = model.y_embedder # hack for classifier-free guidance + + # 3.2. move to device & eval + vae = vae.to(device, dtype).eval() + model = model.to(device, dtype).eval() + + # 3.3. build scheduler + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # 3.4. support for multi-resolution + model_args = dict() + if cfg.multi_resolution: + image_size = cfg.image_size + hw = torch.tensor([image_size], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + ar = torch.tensor([[image_size[0] / image_size[1]]], device=device, dtype=dtype).repeat(cfg.batch_size, 1) + model_args["data_info"] = dict(ar=ar, hw=hw) + + # ====================================================== + # 4. inference + # ====================================================== + sample_idx = 0 + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + for i in range(0, len(prompts), cfg.batch_size): + batch_prompts = prompts[i : i + cfg.batch_size] + samples = scheduler.sample( + model, + text_encoder, + z_size=(vae.out_channels, *latent_size), + prompts=batch_prompts, + device=device, + additional_args=model_args, + ) + samples = vae.decode(samples.to(dtype)) + + if coordinator.is_master(): + for idx, sample in enumerate(samples): + print(f"Prompt: {batch_prompts[idx]}") + save_path = os.path.join(save_dir, f"sample_{sample_idx}") + save_sample(sample, fps=cfg.fps, save_path=save_path) + sample_idx += 1 + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/OpenSora-1.0/scripts/train.py b/MindIE/MultiModal/OpenSora-1.0/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9f611b7d3c07ef1a4af0678e245c8914d2d485fb --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/scripts/train.py @@ -0,0 +1,287 @@ +from copy import deepcopy + +import colossalai +import torch +import torch.distributed as dist +import wandb +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device +from tqdm import tqdm + +from opensora.acceleration.checkpoint import set_grad_checkpoint +from opensora.acceleration.parallel_states import ( + get_data_parallel_group, + set_data_parallel_group, + set_sequence_parallel_group, +) +from opensora.acceleration.plugin import ZeroSeqParallelPlugin +from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader +from opensora.registry import MODELS, SCHEDULERS, build_module +from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save +from opensora.utils.config_utils import ( + create_experiment_workspace, + create_tensorboard_writer, + parse_configs, + save_training_config, +) +from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype +from opensora.utils.train_utils import update_ema + + +def main(): + # ====================================================== + # 1. args & cfg + # ====================================================== + cfg = parse_configs(training=True) + print(cfg) + exp_name, exp_dir = create_experiment_workspace(cfg) + save_training_config(cfg._cfg_dict, exp_dir) + + # ====================================================== + # 2. runtime variables & colossalai launch + # ====================================================== + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" + + # 2.1. colossalai init distributed training + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + device = get_current_device() + dtype = to_torch_dtype(cfg.dtype) + + # 2.2. init logger, tensorboard & wandb + if not coordinator.is_master(): + logger = create_logger(None) + else: + logger = create_logger(exp_dir) + logger.info(f"Experiment directory created at {exp_dir}") + + writer = create_tensorboard_writer(exp_dir) + if cfg.wandb: + wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict) + + # 2.3. initialize ColossalAI booster + if cfg.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=cfg.dtype, + initial_scale=2**16, + max_norm=cfg.grad_clip, + ) + set_data_parallel_group(dist.group.WORLD) + elif cfg.plugin == "zero2-seq": + plugin = ZeroSeqParallelPlugin( + sp_size=cfg.sp_size, + stage=2, + precision=cfg.dtype, + initial_scale=2**16, + max_norm=cfg.grad_clip, + ) + set_sequence_parallel_group(plugin.sp_group) + set_data_parallel_group(plugin.dp_group) + else: + raise ValueError(f"Unknown plugin {cfg.plugin}") + booster = Booster(plugin=plugin) + + # ====================================================== + # 3. build dataset and dataloader + # ====================================================== + dataset = DatasetFromCSV( + cfg.data_path, + # TODO: change transforms + transform=( + get_transforms_video(cfg.image_size[0]) + if not cfg.use_image_transform + else get_transforms_image(cfg.image_size[0]) + ), + num_frames=cfg.num_frames, + frame_interval=cfg.frame_interval, + root=cfg.root, + ) + + # TODO: use plugin's prepare dataloader + # a batch contains: + # { + # "video": torch.Tensor, # [B, C, T, H, W], + # "text": List[str], + # } + dataloader = prepare_dataloader( + dataset, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + shuffle=True, + drop_last=True, + pin_memory=True, + process_group=get_data_parallel_group(), + ) + logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})") + + total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size + logger.info(f"Total batch size: {total_batch_size}") + + # ====================================================== + # 4. build model + # ====================================================== + # 4.1. build model + input_size = (cfg.num_frames, *cfg.image_size) + vae = build_module(cfg.vae, MODELS) + latent_size = vae.get_latent_size(input_size) + text_encoder = build_module(cfg.text_encoder, MODELS, device=device) # T5 must be fp32 + model = build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=vae.out_channels, + caption_channels=text_encoder.output_dim, + model_max_length=text_encoder.model_max_length, + dtype=dtype, + ) + model_numel, model_numel_trainable = get_model_numel(model) + logger.info( + f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" + ) + + # 4.2. create ema + ema = deepcopy(model).to(torch.float32).to(device) + requires_grad(ema, False) + ema_shape_dict = record_model_param_shape(ema) + + # 4.3. move to device + vae = vae.to(device, dtype) + model = model.to(device, dtype) + + # 4.4. build scheduler + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # 4.5. setup optimizer + optimizer = HybridAdam( + filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True + ) + lr_scheduler = None + + # 4.6. prepare for training + if cfg.grad_checkpoint: + set_grad_checkpoint(model) + model.train() + update_ema(ema, model, decay=0, sharded=False) + ema.eval() + + # ======================================================= + # 5. boost model for distributed training with colossalai + # ======================================================= + torch.set_default_dtype(dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader + ) + torch.set_default_dtype(torch.float) + num_steps_per_epoch = len(dataloader) + logger.info("Boost model for distributed training") + + # ======================================================= + # 6. training loop + # ======================================================= + start_epoch = start_step = log_step = sampler_start_idx = 0 + running_loss = 0.0 + + # 6.1. resume training + if cfg.load is not None: + logger.info("Loading checkpoint") + start_epoch, start_step, sampler_start_idx = load(booster, model, ema, optimizer, lr_scheduler, cfg.load) + logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") + logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") + + dataloader.sampler.set_start_index(sampler_start_idx) + model_sharding(ema) + + # 6.2. training loop + for epoch in range(start_epoch, cfg.epochs): + dataloader.sampler.set_epoch(epoch) + dataloader_iter = iter(dataloader) + logger.info(f"Beginning epoch {epoch}...") + + with tqdm( + range(start_step, num_steps_per_epoch), + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step, + ) as pbar: + for step in pbar: + batch = next(dataloader_iter) + x = batch["video"].to(device, dtype) # [B, C, T, H, W] + y = batch["text"] + + with torch.no_grad(): + # Prepare visual inputs + x = vae.encode(x) # [B, C, T, H/P, W/P] + # Prepare text inputs + model_args = text_encoder.encode(y) + + # Diffusion + t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device) + loss_dict = scheduler.training_losses(model, x, t, model_args) + + # Backward & update + loss = loss_dict["loss"].mean() + booster.backward(loss=loss, optimizer=optimizer) + optimizer.step() + optimizer.zero_grad() + + # Update EMA + update_ema(ema, model.module, optimizer=optimizer) + + # Log loss values: + all_reduce_mean(loss) + running_loss += loss.item() + global_step = epoch * num_steps_per_epoch + step + log_step += 1 + + # Log to tensorboard + if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: + avg_loss = running_loss / log_step + pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) + running_loss = 0 + log_step = 0 + writer.add_scalar("loss", loss.item(), global_step) + if cfg.wandb: + wandb.log( + { + "iter": global_step, + "num_samples": global_step * total_batch_size, + "epoch": epoch, + "loss": loss.item(), + "avg_loss": avg_loss, + }, + step=global_step, + ) + + # Save checkpoint + if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: + save( + booster, + model, + ema, + optimizer, + lr_scheduler, + epoch, + step + 1, + global_step + 1, + cfg.batch_size, + coordinator, + exp_dir, + ema_shape_dict, + ) + logger.info( + f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" + ) + + # the continue epochs are not resumed, so we need to reset the sampler start index and start step + dataloader.sampler.set_start_index(0) + start_step = 0 + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/OpenSora-1.0/setup.py b/MindIE/MultiModal/OpenSora-1.0/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..45049bbaeda5c2c7fd79e99d16683173dca153dd --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/setup.py @@ -0,0 +1,60 @@ +from typing import List + +from setuptools import find_packages, setup + + +def fetch_requirements(path) -> List[str]: + """ + This function reads the requirements file. + + Args: + path (str): the path to the requirements file. + + Returns: + The lines in the requirements file. + """ + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme() -> str: + """ + This function reads the README.md file in the current directory. + + Returns: + The lines in the README file. + """ + with open("README.md", encoding="utf-8") as f: + return f.read() + + +setup( + name="opensora", + version="1.0.0", + packages=find_packages( + exclude=( + "assets", + "configs", + "docs", + "outputs", + "pretrained_models", + "scripts", + "tests", + "tools", + "*.egg-info", + ) + ), + description="Democratizing Efficient Video Production for All", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], +) diff --git a/MindIE/MultiModal/OpenSora-1.0/tests/test_seq_parallel_attention.py b/MindIE/MultiModal/OpenSora-1.0/tests/test_seq_parallel_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..00966ad013fb330d0e3a767013a8c0e9cfb6b9d4 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tests/test_seq_parallel_attention.py @@ -0,0 +1,149 @@ +import colossalai +import torch +import torch.distributed as dist +from colossalai.testing import spawn + +from opensora.acceleration.communications import gather_forward_split_backward, split_forward_gather_backward +from opensora.acceleration.parallel_states import set_sequence_parallel_group +from opensora.models.layers.blocks import ( + Attention, + MultiHeadCrossAttention, + SeqParallelAttention, + SeqParallelMultiHeadCrossAttention, +) + + +def run_attention(rank, world_size): + # create model + torch.manual_seed(1024) + set_sequence_parallel_group(dist.group.WORLD) + + seq_parallel_attention = SeqParallelAttention(dim=256, num_heads=4, qkv_bias=True, enable_flashattn=False).cuda() + + torch.manual_seed(1024) + attention = Attention( + dim=256, + num_heads=4, + qkv_bias=True, + enable_flashattn=False, + ).cuda() + + # create inputs + torch.manual_seed(1024) + x = torch.randn(4, 64, 256).cuda() + seq_x = x.clone().detach() + + x.requires_grad = True + x.retain_grad() + seq_x.requires_grad = True + seq_x.retain_grad() + + sub_seq_x = split_forward_gather_backward(seq_x, dist.group.WORLD, dim=1, grad_scale="down") + + # run model + out = attention(x) + sub_seq_out = seq_parallel_attention(sub_seq_x) + seq_out = gather_forward_split_backward(sub_seq_out, dist.group.WORLD, dim=1, grad_scale="up") + + assert torch.allclose(seq_out, out, atol=1e-7), f"{seq_out}\nvs\n{out}" + + # run backward + seq_out.mean().backward() + out.mean().backward() + + # all reduce gradient for sp + for p in seq_parallel_attention.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, group=dist.group.WORLD) + p.grad.div_(world_size) + + # check grad + for p1, p2 in zip(seq_parallel_attention.parameters(), attention.parameters()): + assert torch.allclose(p1.grad, p2.grad, atol=1e-7), f"{p1.grad}\nvs\n{p2.grad}" + + # check input grad + assert torch.allclose(x.grad, seq_x.grad, atol=1e-7), f"{x.grad}\nvs\n{seq_x.grad}" + + +def run_cross_attention(rank, world_size): + # create model + torch.manual_seed(1024) + set_sequence_parallel_group(dist.group.WORLD) + seq_parallel_attention = SeqParallelMultiHeadCrossAttention( + d_model=256, + num_heads=4, + ).cuda().to(torch.bfloat16) + + torch.manual_seed(1024) + attention = MultiHeadCrossAttention( + d_model=256, + num_heads=4, + ).cuda().to(torch.bfloat16) + + # make sure the weights are the same + for p1, p2 in zip(seq_parallel_attention.parameters(), attention.parameters()): + p1.data.copy_(p2.data) + + # create inputs + torch.manual_seed(1024) + x = torch.randn(4, 64, 256).cuda().to(torch.bfloat16) + y = torch.randn(4, 32, 256).cuda().to(torch.bfloat16) + + mask = [2, 10, 8, 16] + mask = None + seq_x = x.clone().detach() + seq_y = y.clone().detach() + + # set grad + x.requires_grad = True + x.retain_grad() + seq_x.requires_grad = True + seq_x.retain_grad() + y.requires_grad = True + y.retain_grad() + seq_y.requires_grad = True + seq_y.retain_grad() + + # split by sequence + sub_seq_x = split_forward_gather_backward(seq_x, dist.group.WORLD, dim=1, grad_scale="down") + + # run model + out = attention(x, y, mask) + sub_seq_out = seq_parallel_attention(sub_seq_x, seq_y, mask) + seq_out = gather_forward_split_backward(sub_seq_out, dist.group.WORLD, dim=1, grad_scale="up") + + assert torch.allclose(seq_out, out, rtol=1e-5, atol=1e-6), f"\n{seq_out}\nvs\n{out}" + + # run backward + seq_out.mean().backward() + out.mean().backward() + + # all reduce gradient for sp + for name, p in seq_parallel_attention.named_parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, group=dist.group.WORLD) + p.grad.div_(world_size) + else: + print(f"grad of {name} is None") + + # # check grad + for p1, p2 in zip(seq_parallel_attention.named_parameters(), attention.named_parameters()): + assert torch.allclose(p1[1].grad, p2[1].grad, rtol=1e-3, atol=1e-4), f"\n{p1[0]}\nvs\n{p2[0]}:\n{p1[1].grad}\nvs\n{p2[1].grad}" + + # # check input grad + assert torch.allclose(x.grad, seq_x.grad, atol=1e-7), f"{x.grad}\nvs\n{seq_x.grad}" + assert torch.allclose(y.grad, seq_y.grad, atol=1e-7), f"{y.grad}\nvs\n{seq_y.grad}" + + +def run_dist(rank, world_size, port): + colossalai.launch({}, rank=rank, world_size=world_size, host="localhost", port=port) + # run_attention(rank, world_size) + run_cross_attention(rank, world_size) + + +def test_seq_parallel_attention(): + spawn(run_dist, nprocs=2) + + +if __name__ == "__main__": + test_seq_parallel_attention() diff --git a/MindIE/MultiModal/OpenSora-1.0/tests/test_t5_shardformer.py b/MindIE/MultiModal/OpenSora-1.0/tests/test_t5_shardformer.py new file mode 100644 index 0000000000000000000000000000000000000000..68040ab39e57d7b8508e7eb4c2d330d7492f30ea --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tests/test_t5_shardformer.py @@ -0,0 +1,71 @@ +import time +from copy import deepcopy + +import colossalai +import torch +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import spawn + +from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy +from opensora.models.text_encoder.t5 import T5Embedder + + +def run_t5_encoder(rank, world_size, port): + colossalai.launch({}, rank=rank, world_size=world_size, port=port, host="localhost") + + # t5 embedder + t5_path = "./pretrained_models/t5_ckpts" + hf_t5 = T5Embedder(device="cuda", local_cache=True, cache_dir=t5_path, torch_dtype=torch.float) + sf_t5 = deepcopy(hf_t5) + + # create huggingface model as normal + shard_config = ShardConfig( + tensor_parallel_process_group=None, + pipeline_stage_manager=None, + enable_tensor_parallelism=False, + enable_fused_normalization=False, + enable_flash_attention=False, + enable_jit_fused=True, + enable_sequence_parallelism=False, + enable_sequence_overlap=False, + ) + shard_former = ShardFormer(shard_config=shard_config) + sharded_model, _ = shard_former.optimize(sf_t5.model, policy=T5EncoderPolicy()) + sf_t5.model = sharded_model + + # test t5 embedder + texts = ["Who is the best player in the history of NBA?", "How to study computer science?"] + for i in range(5): + hf_embs, hf_masks = hf_t5.get_text_embeddings(texts) + sf_embs, sf_masks = sf_t5.get_text_embeddings(texts) + + # check accuracy + assert torch.allclose(hf_embs, sf_embs, rtol=1e-4, atol=1e-5), f"{hf_embs} \nvs\n{sf_embs}" + assert torch.allclose(hf_masks, sf_masks), f"{hf_masks} \nvs\n{sf_masks}" + + # measure perf + torch.cuda.synchronize() + hf_start = time.time() + for i in range(20): + hf_embs, hf_masks = hf_t5.get_text_embeddings(texts) + torch.cuda.synchronize() + hf_end = time.time() + + # convert sf to fp16 + hf_t5.model = hf_t5.model.half() + torch.cuda.synchronize() + sf_start = time.time() + for i in range(20): + hf_embs, hf_masks = hf_t5.get_text_embeddings(texts) + torch.cuda.synchronize() + sf_end = time.time() + + print(f"[Performance] native: {hf_end - hf_start}s, shardformer: {sf_end - sf_start} s") + + +def test_t5_encoder(): + spawn(run_t5_encoder) + + +if __name__ == "__main__": + test_t5_encoder() diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/__init__.py b/MindIE/MultiModal/OpenSora-1.0/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/caption/README.md b/MindIE/MultiModal/OpenSora-1.0/tools/caption/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e9289d0550d49fea9b5a0e6e3ce12f3501e78714 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/caption/README.md @@ -0,0 +1,25 @@ +# Video Captioning + +Human labeling of videos is expensive and time-consuming. We adopt powerful image captioning models to generate captions for videos. Although GPT-4V achieves a better performance, its 20s/sample speed is too slow for us. With batch inference, we can achieve a speed of 3s/sample with LLaVA, and the quality is comparable. LLaVA is the second best open-source model in [MMMU](https://mmmu-benchmark.github.io/) and accepts any resolution. + +![Caption](https://i0.imgs.ovh/2024/03/16/eXdvC.png) + +## GPT-4V Captioning + +Run the following command to generate captions for videos with GPT-4V: + +```bash +python -m tools.caption.caption_gpt4 FOLDER_WITH_VIDEOS output.csv --key $OPENAI_API_KEY +``` + +The cost is approximately $0.01 per video (3 frames per video). The output is a CSV file with path and caption. + +## LLaVA Captioning + +First, install LLaVA according to their [official instructions](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#install). We use the `liuhaotian/llava-v1.6-34b` model for captioning, which can be download [here](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b). Then, run the following command to generate captions for videos with LLaVA: + +```bash +CUDA_VISIBLE_DEVICES=0,1 python -m tools.caption.caption_llava samples output.csv +``` + +The Yi-34B requires 2 80GB GPUs and 3s/sample. The output is a CSV file with path and caption. diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/caption/__init__.py b/MindIE/MultiModal/OpenSora-1.0/tools/caption/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/caption/caption_gpt4.py b/MindIE/MultiModal/OpenSora-1.0/tools/caption/caption_gpt4.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c7590f15b05993d463defc9615c401090ed953 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/caption/caption_gpt4.py @@ -0,0 +1,69 @@ +import argparse +import csv +import os + +import requests +import tqdm + +from .utils import extract_frames, prompts, read_video_list + + +def get_caption(frame, prompt, api_key): + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + payload = { + "model": "gpt-4-vision-preview", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt, + }, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[0]}"}}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[1]}"}}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame[2]}"}}, + ], + } + ], + "max_tokens": 300, + } + response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=60) + caption = response.json()["choices"][0]["message"]["content"] + caption = caption.replace("\n", " ") + return caption + + +def main(args): + # ====================================================== + # 1. read video list + # ====================================================== + videos = read_video_list(args.video_folder, args.output_file) + f = open(args.output_file, "a") + writer = csv.writer(f) + + # ====================================================== + # 2. generate captions + # ====================================================== + for video in tqdm.tqdm(videos): + video_path = os.path.join(args.video_folder, video) + frame, length = extract_frames(video_path, base_64=True) + if len(frame) < 3: + continue + + prompt = prompts[args.prompt] + caption = get_caption(frame, prompt, args.key) + + writer.writerow((video, caption, length)) + f.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("video_folder", type=str) + parser.add_argument("output_file", type=str) + parser.add_argument("--prompt", type=str, default="three_frames") + parser.add_argument("--key", type=str) + args = parser.parse_args() + + main(args) diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/caption/caption_llava.py b/MindIE/MultiModal/OpenSora-1.0/tools/caption/caption_llava.py new file mode 100644 index 0000000000000000000000000000000000000000..8f4278c3bd7089b115123b41e53856e595156080 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/caption/caption_llava.py @@ -0,0 +1,352 @@ +import argparse +import csv +import os +import warnings + +import torch +from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX +from llava.conversation import conv_templates +from llava.mm_utils import get_anyres_image_grid_shape, get_model_name_from_path, process_images, tokenizer_image_token +from llava.model.builder import load_pretrained_model +from llava.model.llava_arch import unpad_image +from llava.utils import disable_torch_init +from tqdm import tqdm + +from .utils import extract_frames, prompts, read_video_list + +disable_torch_init() + + +def prepare_inputs_labels_for_multimodal( + self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None +): + # llava_arch.py + vision_tower = self.get_vision_tower() + if vision_tower is None or images is None or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if type(images) is list or images.ndim == 5: + if type(images) is list: + images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] + concat_images = torch.cat([image for image in images], dim=0) + image_features = self.encode_images(concat_images) + split_sizes = [image.shape[0] for image in images] + image_features = torch.split(image_features, split_sizes, dim=0) + mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") + image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") + if mm_patch_merge_type == "flat": + image_features = [x.flatten(0, 1) for x in image_features] + elif mm_patch_merge_type.startswith("spatial"): + new_image_features = [] + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = self.get_vision_tower().num_patches_per_side + assert height * width == base_image_feature.shape[0] + if image_aspect_ratio == "anyres": + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_sizes[image_idx], + self.config.image_grid_pinpoints, + self.get_vision_tower().config.image_size, + ) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + else: + raise NotImplementedError + if "unpad" in mm_patch_merge_type: + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + image_feature = torch.cat( + ( + image_feature, + self.model.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + else: + image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() + image_feature = image_feature.flatten(0, 3) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat( + (image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0 + ) + new_image_features.append(image_feature) + image_features = new_image_features + else: + raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") + else: + image_features = self.encode_images(images) + + # TODO: image start / end is not implemented here to support pretraining. + if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): + raise NotImplementedError + + # Let's just add dummy tensors if they do not exist, + # it is a headache to deal with None all the time. + # But it is not ideal, and if you have a better idea, + # please open an issue / submit a PR, thanks. + _labels = labels + _position_ids = position_ids + _attention_mask = attention_mask + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + else: + attention_mask = attention_mask.bool() + if position_ids is None: + position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) + if labels is None: + labels = torch.full_like(input_ids, IGNORE_INDEX) + + # remove the padding using attention_mask -- FIXME + input_ids = [ + cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) + ] + labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] + + new_input_embeds = [] + new_labels = [] + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + if num_images == 0: + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) + new_input_embeds.append(cur_input_embeds) + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + + image_token_indices = ( + [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] + ) + cur_input_ids_noim = [] + cur_labels = labels[batch_idx] + cur_labels_noim = [] + for i in range(len(image_token_indices) - 1): + cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) + split_sizes = [x.shape[0] for x in cur_labels_noim] + cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) + cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) + cur_new_input_embeds = [] + cur_new_labels = [] + + for i in range(num_images + 1): + cur_new_input_embeds.append(cur_input_embeds_no_im[i]) + cur_new_labels.append(cur_labels_noim[i]) + if i < num_images: + cur_image_features = image_features[cur_image_idx] + cur_image_idx += 1 + cur_new_input_embeds.append(cur_image_features) + cur_new_labels.append( + torch.full( + (cur_image_features.shape[0],), + IGNORE_INDEX, + device=cur_labels.device, + dtype=cur_labels.dtype, + ) + ) + + cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] + + cur_new_input_embeds = torch.cat(cur_new_input_embeds) + cur_new_labels = torch.cat(cur_new_labels) + + new_input_embeds.append(cur_new_input_embeds) + new_labels.append(cur_new_labels) + + # Truncate sequences to max length as image embeddings can make the sequence longer + tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) + if tokenizer_model_max_length is not None: + new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] + new_labels = [x[:tokenizer_model_max_length] for x in new_labels] + + # Combine them + max_len = max(x.shape[0] for x in new_input_embeds) + batch_size = len(new_input_embeds) + + new_input_embeds_padded = [] + new_labels_padded = torch.full( + (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device + ) + attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) + position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) + + for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): + cur_len = cur_new_embed.shape[0] + if getattr(self.config, "tokenizer_padding_side", "right") == "left": + new_input_embeds_padded.append( + torch.cat( + ( + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + cur_new_embed, + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, -cur_len:] = cur_new_labels + attention_mask[i, -cur_len:] = True + position_ids[i, -cur_len:] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + else: + new_input_embeds_padded.append( + torch.cat( + ( + cur_new_embed, + torch.zeros( + (max_len - cur_len, cur_new_embed.shape[1]), + dtype=cur_new_embed.dtype, + device=cur_new_embed.device, + ), + ), + dim=0, + ) + ) + if cur_len > 0: + new_labels_padded[i, :cur_len] = cur_new_labels + attention_mask[i, :cur_len] = True + position_ids[i, :cur_len] = torch.arange( + 0, cur_len, dtype=position_ids.dtype, device=position_ids.device + ) + + new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) + + if _labels is None: + new_labels = None + else: + new_labels = new_labels_padded + + if _attention_mask is None: + attention_mask = None + else: + attention_mask = attention_mask.to(dtype=_attention_mask.dtype) + + if _position_ids is None: + position_ids = None + + return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels + + +@torch.inference_mode() +def main(args): + # ====================================================== + # 1. read video list + # ====================================================== + videos = read_video_list(args.video_folder, args.output_file) + f = open(args.output_file, "a") + writer = csv.writer(f) + + # ====================================================== + # 2. load model and prepare prompts + # ====================================================== + model_path = "liuhaotian/llava-v1.6-34b" + query = prompts[args.prompt] + print(f"Prompt: {query}") + conv = conv_templates["chatml_direct"].copy() + conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + query) + prompt = conv.get_prompt() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # Pytorch non-meta copying warning fills out the console + tokenizer, model, image_processor, context_len = load_pretrained_model( + model_path=model_path, + model_base=None, + model_name=get_model_name_from_path(model_path), + ) + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") + input_ids = input_ids.unsqueeze(0).to(model.device) + + # ====================================================== + # 3. generate captions + # ====================================================== + bs = args.bs + for i in tqdm(range(0, len(videos), bs)): + # prepare a batch of inputs + video_files = videos[i : i + bs] + frames = [] + video_lengths = [] + for video_file in video_files: + frame, length = extract_frames(os.path.join(args.video_folder, video_file)) + if len(frame) < 3: + continue + frames.append(frame) + video_lengths.append(length) + if len(frames) == 0: + continue + + # encode the batch of inputs + samples = [] + for imgs in frames: + imgs_size = [img.size for img in imgs] + imgs = process_images(imgs, image_processor, model.config) + imgs = imgs.to(model.device, dtype=torch.float16) + with torch.inference_mode(): + _, _, _, _, inputs_embeds, _ = prepare_inputs_labels_for_multimodal( + model, input_ids, None, None, None, None, images=imgs, image_sizes=imgs_size + ) + samples.append(inputs_embeds) + + # padding + max_len = max([sample.shape[1] for sample in samples]) + attention_mask = torch.tensor( + [[0] * (max_len - samples[i].shape[1]) + [1] * samples[i].shape[1] for i in range(len(samples))] + ).to(model.device) + inputs_embeds = [ + torch.cat( + [ + torch.zeros( + (1, max_len - samples[i].shape[1], samples[i].shape[-1]), + device=model.device, + dtype=torch.float16, + ), + samples[i], + ], + dim=1, + ) + for i in range(len(samples)) + ] + inputs_embeds = torch.cat(inputs_embeds, dim=0) + + # generate outputs + output_ids = super(type(model), model).generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=True, + temperature=0.2, + max_new_tokens=512, + use_cache=True, + ) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + outputs = [output.replace("\n", " ").strip() for output in outputs] + + # save results + result = list(zip(video_files, outputs, video_lengths)) + for t in result: + writer.writerow(t) + + f.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("video_folder", type=str) + parser.add_argument("output_file", type=str) + parser.add_argument("--bs", type=int, default=32) + parser.add_argument("--prompt", type=str, default="three_frames") + args = parser.parse_args() + + main(args) diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/caption/utils.py b/MindIE/MultiModal/OpenSora-1.0/tools/caption/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3912f0ccd1811fae507d1da2169190324dc20f30 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/caption/utils.py @@ -0,0 +1,67 @@ +import base64 +import csv +import os + +import cv2 +from PIL import Image + +prompts = { + "naive": "Describe the video", + "three_frames": "A video is given by providing three frames in chronological order. Describe this video and its style to generate a description. Pay attention to all objects in the video. Do not describe each frame individually. Do not reply with words like 'first frame'. The description should be useful for AI to re-generate the video. The description should be less than six sentences. Here are some examples of good descriptions: 1. A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2. Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. 3. Drone view of waves crashing against the rugged cliffs along Big Sur's garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff's edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.", +} + + +def get_filelist(file_path): + Filelist = [] + VID_EXTENSIONS = ("mp4", "avi", "mov", "mkv") + for home, dirs, files in os.walk(file_path): + for filename in files: + ext = filename.split(".")[-1] + if ext in VID_EXTENSIONS: + Filelist.append(filename) + return Filelist + + +def get_video_length(cap): + return int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def extract_frames(video_path, points=(0.2, 0.5, 0.8), base_64=False): + cap = cv2.VideoCapture(video_path) + length = get_video_length(cap) + points = [int(length * point) for point in points] + frames = [] + if length < 3: + return frames, length + for point in points: + cap.set(cv2.CAP_PROP_POS_FRAMES, point) + ret, frame = cap.read() + if not base_64: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + else: + _, buffer = cv2.imencode(".jpg", frame) + frame = base64.b64encode(buffer).decode("utf-8") + frames.append(frame) + return frames, length + + +def read_video_list(video_folder, output_file): + processed_videos = [] + if os.path.exists(output_file): + with open(output_file, "r") as f: + reader = csv.reader(f) + samples = list(reader) + processed_videos = [sample[0] for sample in samples] + + # read video list + videos = get_filelist(video_folder) + print(f"Dataset contains {len(videos)} videos.") + videos = [video for video in videos if video not in processed_videos] + print(f"Processing {len(videos)} new videos.") + return videos diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/datasets/README.md b/MindIE/MultiModal/OpenSora-1.0/tools/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0118c8f180a25b43f3de0bd1b22f0ce8afd9ca84 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/datasets/README.md @@ -0,0 +1,48 @@ +# Dataset Download and Management + +## Dataset Format + +The training data should be provided in a CSV file with the following format: + +```csv +/absolute/path/to/image1.jpg, caption1, num_of_frames +/absolute/path/to/image2.jpg, caption2, num_of_frames +``` + +## HD-VG-130M + +This dataset comprises 130M text-video pairs. You can download the dataset and prepare it for training according to [the dataset repository's instructions](https://github.com/daooshee/HD-VG-130M). There is a README.md file in the Google Drive link that provides instructions on how to download and cut the videos. For this version, we directly use the dataset provided by the authors. + +## Demo Dataset + +You can use ImageNet and UCF101 for a quick demo. After downloading the datasets, you can use the following command to prepare the csv file for the dataset: + +```bash +# ImageNet +python -m tools.datasets.convert_dataset imagenet IMAGENET_FOLDER --split train +# UCF101 +python -m tools.datasets.convert_dataset ucf101 UCF101_FOLDER --split videos +``` + +## Manage datasets + +We provide `csvutils.py` to manage the CSV files. You can use the following commands to process the CSV files: + +```bash +# generate DATA_fmin_128_fmax_256.csv with frames between 128 and 256 +python -m tools.datasets.csvutil DATA.csv --fmin 128 --fmax 256 +# generate DATA_root.csv with absolute path +python -m tools.datasets.csvutil DATA.csv --root /absolute/path/to/dataset +# remove videos with no captions +python -m tools.datasets.csvutil DATA.csv --remove-empty-caption +# compute the number of frames for each video +python -m tools.datasets.csvutil DATA.csv --relength +# remove caption prefix +python -m tools.datasets.csvutil DATA.csv --remove-caption-prefix +``` + +To merge multiple CSV files, you can use the following command: + +```bash +cat *csv > combined.csv +``` diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/datasets/__init__.py b/MindIE/MultiModal/OpenSora-1.0/tools/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/datasets/convert_dataset.py b/MindIE/MultiModal/OpenSora-1.0/tools/datasets/convert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff904fc20cbe58e696c9d606fbfb871561ed45c --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/datasets/convert_dataset.py @@ -0,0 +1,66 @@ +import argparse +import csv +import os + +from torchvision.datasets import ImageNet + + +def get_filelist(file_path): + Filelist = [] + for home, dirs, files in os.walk(file_path): + for filename in files: + Filelist.append(os.path.join(home, filename)) + return Filelist + + +def split_by_capital(name): + # BoxingPunchingBag -> Boxing Punching Bag + new_name = "" + for i in range(len(name)): + if name[i].isupper() and i != 0: + new_name += " " + new_name += name[i] + return new_name + + +def process_imagenet(root, split): + root = os.path.expanduser(root) + data = ImageNet(root, split=split) + samples = [(path, data.classes[label][0]) for path, label in data.samples] + output = f"imagenet_{split}.csv" + + with open(output, "w") as f: + writer = csv.writer(f) + writer.writerows(samples) + + print(f"Saved {len(samples)} samples to {output}.") + + +def process_ucf101(root, split): + root = os.path.expanduser(root) + video_lists = get_filelist(os.path.join(root, split)) + classes = [x.split("/")[-2] for x in video_lists] + classes = [split_by_capital(x) for x in classes] + samples = list(zip(video_lists, classes)) + output = f"ucf101_{split}.csv" + + with open(output, "w") as f: + writer = csv.writer(f) + writer.writerows(samples) + + print(f"Saved {len(samples)} samples to {output}.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("dataset", type=str, choices=["imagenet", "ucf101"]) + parser.add_argument("root", type=str) + parser.add_argument("--split", type=str, default="train") + args = parser.parse_args() + + if args.dataset == "imagenet": + process_imagenet(args.root, args.split) + elif args.dataset == "ucf101": + process_ucf101(args.root, args.split) + else: + raise ValueError("Invalid dataset") diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/datasets/csvutil.py b/MindIE/MultiModal/OpenSora-1.0/tools/datasets/csvutil.py new file mode 100644 index 0000000000000000000000000000000000000000..4bbd22db24962ce2c66656445a043c35fbeed38b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/datasets/csvutil.py @@ -0,0 +1,96 @@ +import argparse +import csv +import os + +from tqdm import tqdm + +# path, name, #frames +PREFIX = [ + "The video shows", + "The video captures", + "The video features", + "The video depicts", + "The video presents", + "The video features", + "The video is ", + "In the video,", +] + + +def get_video_length(path): + import cv2 + + cap = cv2.VideoCapture(path) + return int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + +def main(args): + input_path = args.input + output_path = args.output + if output_path is None: + name = os.path.basename(input_path) + name, ext = os.path.splitext(name) + if args.fmin is not None: + name += f"_fmin_{args.fmin}" + if args.fmax is not None: + name += f"_fmax_{args.fmax}" + if args.remove_empty_caption: + name += "_rec" + if args.remove_caption_prefix: + name += "_rcp" + if args.root is not None: + name += f"_root" + if args.relength: + name += "_relength" + output_path = os.path.join(os.path.dirname(input_path), name + ext) + + with open(input_path, "r") as f: + reader = csv.reader(f) + data = list(reader) + print("Number of videos before filtering:", len(data)) + + data_new = [] + for i, row in tqdm(enumerate(data)): + path = row[0] + caption = row[1] + n_frames = int(row[2]) + if args.fmin is not None and n_frames < args.fmin: + continue + if args.fmax is not None and n_frames > args.fmax: + continue + if args.remove_empty_caption and len(caption) == 0: + continue + if args.remove_caption_prefix: + for prefix in PREFIX: + if caption.startswith(prefix): + caption = caption[len(prefix) :].strip() + if caption[0].islower(): + caption = caption[0].upper() + caption[1:] + row[1] = caption + break + if args.root is not None: + row[0] = os.path.join(args.root, path) + if args.relength: + n_frames = get_video_length(row[0]) + row[2] = n_frames + data_new.append(row) + + print("Number of videos after filtering:", len(data_new)) + with open(output_path, "w") as f: + writer = csv.writer(f) + writer.writerows(data_new) + print("Output saved to", output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("input", type=str) + parser.add_argument("--output", type=str, default=None) + parser.add_argument("--fmin", type=int, default=None) + parser.add_argument("--fmax", type=int, default=None) + parser.add_argument("--root", type=str, default=None) + parser.add_argument("--remove-empty-caption", action="store_true") + parser.add_argument("--remove-caption-prefix", action="store_true") + parser.add_argument("--relength", action="store_true") + args = parser.parse_args() + main(args) diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/intepolate/README.md b/MindIE/MultiModal/OpenSora-1.0/tools/intepolate/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cd406e140267824d8f30405e09e7dfcb591eb207 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/intepolate/README.md @@ -0,0 +1 @@ +# To be added diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/README.md b/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8052733739109237ea620a1983ba31b7473fee92 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/README.md @@ -0,0 +1,9 @@ +# Scene Detection and Video Split + +Raw videos from the Internet may be too long for training. +Thus, we detect scenes in raw videos and split them into short clips based on the scenes. +First prepare the video processing packages. +```bash +pip install scenedetect moviepy opencv-python +``` +Then run `scene_detect.py`. We provide efficient processing using `multiprocessing`. Don't forget to specify your own dataset path. diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/scene_detect.py b/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/scene_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..c46e59d5abce24575d62a3e3bdffb2aed49efa0b --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/scene_detect.py @@ -0,0 +1,138 @@ +import os +from multiprocessing import Pool + +from mmengine.logging import MMLogger +from scenedetect import ContentDetector, detect +from tqdm import tqdm + +from opensora.utils.misc import get_timestamp + +from .utils import check_mp4_integrity, clone_folder_structure, iterate_files, split_video + +# config +target_fps = 30 # int +shorter_size = 512 # int +min_seconds = 1 # float +max_seconds = 5 # float +assert max_seconds > min_seconds +cfg = dict( + target_fps=target_fps, + min_seconds=min_seconds, + max_seconds=max_seconds, + shorter_size=shorter_size, +) + + +def process_folder(root_src, root_dst): + # create logger + folder_path_log = os.path.dirname(root_dst) + log_name = os.path.basename(root_dst) + timestamp = get_timestamp() + log_path = os.path.join(folder_path_log, f"{log_name}_{timestamp}.log") + logger = MMLogger.get_instance(log_name, log_file=log_path) + + # clone folder structure + clone_folder_structure(root_src, root_dst) + + # all source videos + mp4_list = [x for x in iterate_files(root_src) if x.endswith(".mp4")] + mp4_list = sorted(mp4_list) + + for idx, sample_path in tqdm(enumerate(mp4_list)): + folder_src = os.path.dirname(sample_path) + folder_dst = os.path.join(root_dst, os.path.relpath(folder_src, root_src)) + + # check src video integrity + if not check_mp4_integrity(sample_path, logger=logger): + continue + + # detect scenes + scene_list = detect(sample_path, ContentDetector(), start_in_scene=True) + + # split scenes + save_path_list = split_video(sample_path, scene_list, save_dir=folder_dst, **cfg, logger=logger) + + # check integrity of generated clips + for x in save_path_list: + check_mp4_integrity(x, logger=logger) + + +def scene_detect(): + """detect & cut scenes using a single process + Expected dataset structure: + data/ + your_dataset/ + raw_videos/ + xxx.mp4 + yyy.mp4 + + This function results in: + data/ + your_dataset/ + raw_videos/ + xxx.mp4 + yyy.mp4 + zzz.mp4 + clips/ + xxx_scene-0.mp4 + yyy_scene-0.mp4 + yyy_scene-1.mp4 + """ + # TODO: specify your dataset root + root_src = f"./data/your_dataset/raw_videos" + root_dst = f"./data/your_dataset/clips" + + process_folder(root_src, root_dst) + + +def scene_detect_mp(): + """detect & cut scenes using multiple processes + Expected dataset structure: + data/ + your_dataset/ + raw_videos/ + split_0/ + xxx.mp4 + yyy.mp4 + split_1/ + xxx.mp4 + yyy.mp4 + + This function results in: + data/ + your_dataset/ + raw_videos/ + split_0/ + xxx.mp4 + yyy.mp4 + split_1/ + xxx.mp4 + yyy.mp4 + clips/ + split_0/ + xxx_scene-0.mp4 + yyy_scene-0.mp4 + split_1/ + xxx_scene-0.mp4 + yyy_scene-0.mp4 + yyy_scene-1.mp4 + """ + # TODO: specify your dataset root + root_src = f"./data/your_dataset/raw_videos" + root_dst = f"./data/your_dataset/clips" + + # TODO: specify your splits + splits = ["split_0", "split_1"] + + # process folders + root_src_list = [os.path.join(root_src, x) for x in splits] + root_dst_list = [os.path.join(root_dst, x) for x in splits] + + with Pool(processes=len(splits)) as pool: + pool.starmap(process_folder, list(zip(root_src_list, root_dst_list))) + + +if __name__ == "__main__": + # TODO: choose single process or multiprocessing + scene_detect() + # scene_detect_mp() diff --git a/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/utils.py b/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19eae31463bc1464b887877856bcf5c49ccde923 --- /dev/null +++ b/MindIE/MultiModal/OpenSora-1.0/tools/scenedetect/utils.py @@ -0,0 +1,145 @@ +import os +import subprocess + +import cv2 +from imageio_ffmpeg import get_ffmpeg_exe +from mmengine.logging import print_log +from moviepy.editor import VideoFileClip +from scenedetect import FrameTimecode + + +def iterate_files(folder_path): + for root, dirs, files in os.walk(folder_path): + # root contains the current directory path + # dirs contains the list of subdirectories in the current directory + # files contains the list of files in the current directory + + # Process files in the current directory + for file in files: + file_path = os.path.join(root, file) + # print("File:", file_path) + yield file_path + + # Process subdirectories and recursively call the function + for subdir in dirs: + subdir_path = os.path.join(root, subdir) + # print("Subdirectory:", subdir_path) + iterate_files(subdir_path) + + +def iterate_folders(folder_path): + for root, dirs, files in os.walk(folder_path): + for subdir in dirs: + subdir_path = os.path.join(root, subdir) + yield subdir_path + # print("Subdirectory:", subdir_path) + iterate_folders(subdir_path) + + +def clone_folder_structure(root_src, root_dst, verbose=False): + src_path_list = iterate_folders(root_src) + src_relpath_list = [os.path.relpath(x, root_src) for x in src_path_list] + + os.makedirs(root_dst, exist_ok=True) + dst_path_list = [os.path.join(root_dst, x) for x in src_relpath_list] + for folder_path in dst_path_list: + os.makedirs(folder_path, exist_ok=True) + if verbose: + print(f"Create folder: '{folder_path}'") + + +def count_files(root, suffix=".mp4"): + files_list = iterate_files(root) + cnt = len([x for x in files_list if x.endswith(suffix)]) + return cnt + + +def check_mp4_integrity(file_path, verbose=True, logger=None): + try: + VideoFileClip(file_path) + if verbose: + print_log(f"The MP4 file '{file_path}' is intact.", logger=logger) + return True + except Exception as e: + if verbose: + print_log(f"Error: {e}", logger=logger) + print_log(f"The MP4 file '{file_path}' is not intact.", logger=logger) + return False + + +def count_frames(video_path): + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + print(f"Error: Could not open video file '{video_path}'") + return + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + print(f"Total frames in the video '{video_path}': {total_frames}") + + cap.release() + + +def split_video( + sample_path, + scene_list, + save_dir, + target_fps=30, + min_seconds=1, + max_seconds=10, + shorter_size=512, + verbose=False, + logger=None, +): + FFMPEG_PATH = get_ffmpeg_exe() + + save_path_list = [] + for idx, scene in enumerate(scene_list): + s, t = scene # FrameTimecode + fps = s.framerate + max_duration = FrameTimecode(timecode="00:00:00", fps=fps) + max_duration.frame_num = round(fps * max_seconds) + duration = min(max_duration, t - s) + if duration.get_frames() < round(min_seconds * fps): + continue + + # save path + fname = os.path.basename(sample_path) + fname_wo_ext = os.path.splitext(fname)[0] + # TODO: fname pattern + save_path = os.path.join(save_dir, f"{fname_wo_ext}_scene-{idx}.mp4") + + # ffmpeg cmd + cmd = [FFMPEG_PATH] + + # Only show ffmpeg output for the first call, which will display any + # errors if it fails, and then break the loop. We only show error messages + # for the remaining calls. + # cmd += ['-v', 'error'] + + # input path + cmd += ["-i", sample_path] + + # clip to cut + cmd += ["-nostdin", "-y", "-ss", str(s.get_seconds()), "-t", str(duration.get_seconds())] + + # target fps + # cmd += ['-vf', 'select=mod(n\,2)'] + cmd += ["-r", f"{target_fps}"] + + # aspect ratio + cmd += ["-vf", f"scale='if(gt(iw,ih),-2,{shorter_size})':'if(gt(iw,ih),{shorter_size},-2)'"] + # cmd += ['-vf', f"scale='if(gt(iw,ih),{shorter_size},trunc(ow/a/2)*2)':-2"] + + cmd += ["-map", "0", save_path] + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + stdout, stderr = proc.communicate() + if verbose: + stdout = stdout.decode("utf-8") + print_log(stdout, logger=logger) + + save_path_list.append(sample_path) + print_log(f"Video clip saved to '{save_path}'", logger=logger) + + return save_path_list diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/README.md b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2450eae04cbb0c4eab7fcb48bc209a247ccab285 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/README.md @@ -0,0 +1,154 @@ +# SDWebUI-TorchAIE推理指导 +mindie_extension实现了一个SDWebUI界面的插件,用优化后的diffusers.Unet2DConditionModel替换原有的UNetModel进行推理,支持SD文生图和图生图功能。底层调用了MindIE的build编译优化功能,通过PASS改图、Batch并行等优化手段,提升了推理性能。 + + +# 概述 + + SDWebUI是一个基于Gradio库的WebUi界面,支持设置输入和参数用于SD模型的文生图、图生图等功能。有关SDWebUI的更多信息,请查看[Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)。 + +- 设备支持: +Atlas 800I A2推理设备 +Atlas 300I Duo推理卡 + +# 推理环境准备 + +该插件依赖torch2.1.0, python3.10环境 + +# 快速上手 + +## 环境准备 + +1. 按照requirements.txt要求的版本安装相关依赖,避免导出模型失败! + + ``` + pip install -r requirements.txt + ``` + +2. 安装mindie包和mindietorch包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie-rt/set_env.sh + # 安装mindietorch + tar -zxvf Ascend-mindie-torch_xxx.tar.gz + pip install mindietorch-1.0.rc1+torch2.1.0xxx.whl + ``` + +3. 代码修改,修改attention,用于trace正确的模型 + + ```bash + python sd_webui_patch.py + ``` + +## sd_webui部署 + +1. 拉取webui工程代码[stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) + + ```bash + git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git + ``` + +2. 拉取mindie_extension工程,放在stable-diffusion-webui/extensions路径下 + +3. 获取权重 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # v1.5,将该权重放在stable-diffusion-webui/extensions/mindie_extension/models路径下 + cd stable-diffusion-webui/extensions/torch_aie_extension/models + git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 + + # v2.1,将该权重放在stable-diffusion-webui/extensions/mindie_extension/models路径下 + git clone https://huggingface.co/runwayml/stable-diffusion-2-1-base + + # sdxl,将该权重放在stable-diffusion-webui/extensions/mindie_extension/models路径下 + git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 + ``` + +4. 将特定权重放在stable-diffusion-webui/models/Stable-diffusion路径下。注意:本插件支持的webui权重如下: + + ```bash + # v1.5 二选一即可,推荐safetensors + v1-5-pruned-emaonly.safetensors + v1-5-pruned-emaonly.ckpt + # v2.1 二选一即可,推荐safetensors + v2-1_512_ema-pruned.safetensors + v2-1_512_ema-pruned.ckpt + # SDXL + sd_xl_base_1.0.safetensors + ``` + +```bash +# 举例: +cp stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors ../../../models/Stable-diffusion +``` + +5. 拉取相关代码 + + ```bash + cd stable-diffusion-webui + mkdir repositories && cd repositories + git clone https://github.com/Stability-AI/stablediffusion stable-diffusion-stability-ai + git clone https://github.com/Stability-AI/generative-models.git + git clone https://github.com/crowsonkb/k-diffusion.git + git clone https://github.com/sczhou/CodeFormer.git + git clone https://github.com/salesforce/BLIP.git + git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets + ``` + +6. 下载clip-vit-large-patch14,放在自定义路径 + + ```bash + git lfs install + git clone https://huggingface.co/openai/clip--vit-large-patch14 + ``` + + 然后修改webui的源码: + + ①文件:stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/modules/encoders/modules.py + + 将该文件中涉及到的version="open/clip-vit-large-patch14"改为vesion=“下载的clip-vit-large-patch14路径” + + ②文件:stable-diffusion-webui/repositories/generative-models/sgm/modules/encoders/modules.py + + 将该文件中涉及到的version="open/clip-vit-large-patch14"改为vesion=“下载的clip-vit-large-patch14路径” + +7. 将mindie_extension工程的diff1.patch放到stable-diffusion-webui路径下 + + ```bash + mv diff_1.patch ../.. + patch -p0 < diff_1.patch + ``` + +## 运行功能 + +1. 执行命令启动webui +```bash +python launch.py --skip-torch-cuda-test --enable-insecure-extension-access --listen --log-startup --disable-safe-unpickle --no-half --skip-prepare-environment +``` +2. 使用该插件后,原始的webui界面中的某些配置受到限制,如下: + + 可配置参数: + + ``` + Sampling method + Sampling steps + CFG Scale + Seed + ``` + + 受限制参数: + + ``` + 使用SD1.5和SD2.1时,Width和Height都要设置为512 + 使用SDXL时,Width和Height都要设置为1024 + Batch count要固定为1 + Batch size要固定为1 + ``` + +3. 界面启动后,请先选择硬件配置,Duo或A2。然后选择MindIE_torch按钮,第一次启动服务时,点击MindIE_torch按钮后,会对于原始模型做一些处理,请耐心等待,直到服务端显示"You can generate image now!"字样后,再根据上述参数配置,点击generate生成结果。 + diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/attention_processor.patch b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..26f296526adcfaf629f3c47a311b88bb4aa002a2 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/attention_processor.patch @@ -0,0 +1,18 @@ +--- attention_processor.py 2024-02-22 19:06:56.596000000 +0800 ++++ attention_processor.py 2024-02-22 19:07:17.232000000 +0800 +@@ -205,10 +205,11 @@ + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- if processor is None: +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # if processor is None: ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor = AttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/config.py b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/config.py new file mode 100644 index 0000000000000000000000000000000000000000..013a1eabd5db77d3682b7da5e62662b292155cf1 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/config.py @@ -0,0 +1,9 @@ +class NpuConfig(object): + use_cpu = True + Duo = False + A2 = False + compiled_unet_model_1_5 = None + compiled_unet_model_2_1 = None + compiled_unet_model_xl = None + + diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/diff_1.patch b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/diff_1.patch new file mode 100644 index 0000000000000000000000000000000000000000..4346f26d520d7dae8e1bb770fa0dbb61f0dfcc9a --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/diff_1.patch @@ -0,0 +1,46 @@ +--- repositories/generative-models/sgm/modules/diffusionmodules/util.py 2024-03-09 12:04:41.836000000 +0000 ++++ repositories/generative-models/sgm/modules/diffusionmodules/util_temp.py 2024-03-09 12:05:41.828000000 +0000 +@@ -160,11 +160,12 @@ def checkpoint(func, inputs, params, fla + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ +- if flag: +- args = tuple(inputs) + tuple(params) +- return CheckpointFunction.apply(func, len(inputs), *args) +- else: +- return func(*inputs) ++ # if flag: ++ # args = tuple(inputs) + tuple(params) ++ # return CheckpointFunction.apply(func, len(inputs), *args) ++ # else: ++ # return func(*inputs) ++ return func(*inputs) + + + class CheckpointFunction(torch.autograd.Function): + +--- repositories/generative-models/sgm/modules/diffusionmodules/wrappers.py 2024-03-09 12:04:41.836000000 +0000 ++++ repositories/generative-models/sgm/modules/diffusionmodules/wrappers_temp.py 2024-03-09 12:05:41.828000000 +0000 +@@ -8,13 +8,14 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusi + class IdentityWrapper(nn.Module): + def __init__(self, diffusion_model, compile_model: bool = False): + super().__init__() +- compile = ( +- torch.compile +- if (version.parse(torch.__version__) >= version.parse("2.0.0")) +- and compile_model +- else lambda x: x +- ) +- self.diffusion_model = compile(diffusion_model) ++ # compile = ( ++ # torch.compile ++ # if (version.parse(torch.__version__) >= version.parse("2.0.0")) ++ # and compile_model ++ # else lambda x: x ++ # ) ++ # self.diffusion_model = compile(diffusion_model) ++ self.diffusion_model = diffusion_model + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/env_init_sd.py b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/env_init_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd86f0dd7a53c95eaf2df6440a22da6830d780c --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/env_init_sd.py @@ -0,0 +1,121 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import torch +import mindietorch +from mindietorch import _enums +from diffusers import StableDiffusionPipeline +from config import NpuConfig + +class UnetExport(torch.nn.Module): + def __init__(self, model): + super(UnetExport, self).__init__() + self.unet_model = model + + def forward(self, sample, timestep, encoder_hidden_states): + return self.unet_model(sample, timestep, encoder_hidden_states)[0] + +def export_unet(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int, soc_version: str) -> None: + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + unet_model = sd_pipeline.unet + clip_model = sd_pipeline.text_encoder + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size = clip_model.config.hidden_size + max_position_embeddings = clip_model.config.max_position_embeddings + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + ) + + traced_model = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + if not os.path.exists(traced_model): + print("Exporting the image information creater...") + unet = UnetExport(unet_model) + unet.eval() + model = torch.jit.trace(unet, dummy_input) + torch.jit.save(model, traced_model) + else: + model = torch.jit.load(traced_model).eval() + + compiled_model = os.path.join(unet_path, f"unet_bs{batch_size}_compiled.pt") + if not os.path.exists(compiled_model): + print("start compile unet model...") + unet_input_info = [ + mindietorch.Input( + (batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT + ), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input( + (batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT + ) + ] + compiled_unet_model = mindietorch.compile( + model, + inputs=unet_input_info, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + torch.jit.save(compiled_unet_model, compiled_model) + +def init_model(device): + mindietorch.set_device(device) + cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + save_dir = os.path.join(cur_dir_path, "models") + save_dir_1_5 = os.path.join(save_dir, "models-1-5") + save_dir_2_1 = os.path.join(save_dir, "models-2-1") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + if not os.path.exists(save_dir_1_5): + os.makedirs(save_dir_1_5) + if not os.path.exists(save_dir_2_1): + os.makedirs(save_dir_2_1) + + unet_path_1_5 = os.path.join(save_dir_1_5, "unet") + unet_path_2_1 = os.path.join(save_dir_2_1, "unet") + + weights_1_5 = os.path.join(save_dir, "stable-diffusion-v1-5") + weights_2_1 = os.path.join(save_dir, "stable-diffusion-2-1-base") + + batch_size = 1 + pipe_1_5 = StableDiffusionPipeline.from_pretrained(weights_1_5).to('cpu') + pipe_2_1 = StableDiffusionPipeline.from_pretrained(weights_2_1).to('cpu') + + if NpuConfig.Duo: + soc_version = "Ascend310P3" + elif NpuConfig.A2: + soc_version = "Ascend910B4" + + export_unet(pipe_1_5, save_dir_1_5, batch_size * 2, soc_version) + export_unet(pipe_2_1, save_dir_2_1, batch_size * 2, soc_version) + + mindietorch.finalize() + + return weights_1_5, weights_2_1, save_dir_1_5, save_dir_2_1 + \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/env_init_sdxl.py b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/env_init_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..eee6c59ccbe0033caf11a42599cd6ec40851ce1b --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/env_init_sdxl.py @@ -0,0 +1,122 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import torch +import mindietorch +from mindietorch import _enums +from omegaconf import ListConfig, OmegaConf +from ldm.util import instantiate_from_config +from modules import shared, paths +from config import NpuConfig +from diffusers import StableDiffusionXLPipeline + +class UnetExport(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.unet_model = model + + def forward(self, x, timesteps, context, y): + return self.unet_model(x, timesteps, context, y)[0] + +def export_unet(pipe: StableDiffusionXLPipeline, save_dir: str, batch_size: int, soc_version: str) -> None: + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_model = pipe.unet + clip_model = pipe.text_encoder + + unet_model.config.addition_embed_type = None + in_channels = unet_model.config.in_channels + sample_size = unet_model.config.sample_size + encoder_hidden_size = unet_model.config.cross_attention_dim + max_position_embeddings = clip_model.config.max_position_embeddings + adm_in_channels = unet_model.config.projection_class_embeddings_input_dim + + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([batch_size], dtype=torch.float32), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, adm_in_channels], dtype=torch.float32), + ) + + traced_model = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + if not os.path.exists(traced_model): + print("Exporting the SDXL image information creater...") + unet = UnetExport(unet_model) + unet.eval() + model = torch.jit.trace(unet, dummy_input) + torch.jit.save(model, traced_model) + else: + model = torch.jit.load(traced_model).eval() + + compiled_model = os.path.join(unet_path, f"unet_bs{batch_size}_compiled.pt") + if not os.path.exists(compiled_model): + print("start compile unet model...") + unet_input_info = [ + mindietorch.Input( + (batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT + ), + mindietorch.Input((batch_size,), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input( + (batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT + ), + mindietorch.Input((batch_size, adm_in_channels), dtype=mindietorch.dtype.FLOAT) + ] + compiled_unet_model = mindietorch.compile( + model, + inputs=unet_input_info, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + torch.jit.save(compiled_unet_model, compiled_model) + +def init_model_xl(device): + if NpuConfig.Duo: + soc_version = "Ascend310P3" + elif NpuConfig.A2: + soc_version = "Ascend910B4" + + mindietorch.set_device(device) + cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + save_dir = os.path.join(cur_dir_path, "models") + save_dir_sdxl = os.path.join(save_dir, "models-sdxl") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + if not os.path.exists(save_dir_sdxl): + os.makedirs(save_dir_sdxl) + weights_1_5 = os.path.join(save_dir, "stable-diffusion-xl-base-1.0") + + unet_sdxl = os.path.join(save_dir_sdxl, "unet") + batch_size = 2 + compiled_unet_model = os.path.join(unet_sdxl, f"unet_bs{batch_size}_compiled.pt") + if not os.path.exists(compiled_unet_model): + pipe_xl = StableDiffusionXLPipeline.from_pretrained(weights_xl).to('cpu') + export_unet(pipe_xl, save_dir_sdxl, batch_size, soc_version) + + mindietorch.finalize() + + return save_dir_sdxl + \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/replace_torch_aie.py b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/replace_torch_aie.py new file mode 100644 index 0000000000000000000000000000000000000000..d11651fb414fb0eaf52c5547eeeec868f1d84152 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/replace_torch_aie.py @@ -0,0 +1,107 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import math +import numpy as np +import torch +import mindietorch +from mindietorch import _enums +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModel_XL +from modules import shared +from diffusers import StableDiffusionPipeline +from config import NpuConfig +from env_init_sd import init_model +from env_init_sdxl import init_model_xl + +def replace_torch_aie(): + + device_0, device_1 = 0, None + mindietorch.set_device(device_0) + + model_base_1_5, model_base_2_1, save_dir_1_5, save_dir_2_1 = init_model(device_0) + save_dir_sdxl = init_model_xl(device_0) + print("You can generate image now!") + + def mindietorch_unet(self, x, timesteps=None, context=None, y=None, **kwargs): + if x.shape[-1] != 64: + return x + checkpoint = shared.opts.data['sd_model_checkpoint'] + if "v1-5-pruned-emaonly" in checkpoint: + unet_model = NpuConfig.compiled_unet_model_1_5 + model_base = model_base_1_5 + unet_path = os.path.join(save_dir_1_5, "unet") + elif "v2-1_512-ema-pruned" in checkpoint: + unet_model = NpuConfig.compiled_unet_model_2_1 + model_base = model_base_2_1 + unet_path = os.path.join(save_dir_2_1, "unet") + + if not unet_model: + batch_size = 2 + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compiled.pt") + + pipe = StableDiffusionPipeline.from_pretrained(model_base).to('cpu') + sample_size = pipe.unet.config.sample_size + in_channels = pipe.unet.config.in_channels + encoder_hidden_size = pipe.text_encoder.config.hidden_size + max_position_embeddings = pipe.text_encoder.config.max_position_embeddings + + if "v1-5-pruned-emaonly" in checkpoint: + NpuConfig.compiled_unet_model_1_5 = torch.jit.load(unet_compiled_path).eval() + elif "v2-1_512-ema-pruned" in checkpoint: + NpuConfig.compiled_unet_model_2_1 = torch.jit.load(unet_compiled_path).eval() + + if "v1-5-pruned-emaonly" in checkpoint: + noise_pred = NpuConfig.compiled_unet_model_1_5( + x.to(f"npu:{device_0}"), + timesteps[0][None].type(torch.int64).to(f"npu:{device_0}"), + context.to(f"npu:{device_0}") + ).to("cpu") + elif "v2-1_512-ema-pruned" in checkpoint: + noise_pred = NpuConfig.compiled_unet_model_2_1( + x.to(f"npu:{device_0}"), + timesteps[0][None].type(torch.int64).to(f"npu:{device_0}"), + context.to(f"npu:{device_0}") + ).to("cpu") + + return noise_pred + UNetModel.forward = mindietorch_unet + + def mindietorch_unet_xl(self, x, timesteps=None, context=None, y=None, **kwargs): + checkpoint = shared.opts.data['sd_model_checkpoint'] + if x.shape[-1] != 128: + print("The width and height should be 1024!") + return x + assert "sd_xl_base_1.0" in checkpoint, "Please select correct weight: sd_xl_base_1.0.safetensors" + + unet_model = NpuConfig.compiled_unet_model_xl + unet_path = os.path.join(save_dir_sdxl, "unet") + batch_size = 2 + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compiled.pt") + + if not unet_model: + NpuConfig.compiled_unet_model_xl = torch.jit.load(unet_compiled_path).eval() + noise_pred = NpuConfig.compiled_unet_model_xl( + x.to(f"npu:{device_0}"), + timesteps.to(f"npu:{device_0}"), + context.to(f"npu:{device_0}"), + y.to(f"npu:{device_0}") + ).to("cpu") + return noise_pred + UNetModel_XL.forward = mindietorch_unet_xl + + \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/requirements.txt b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..17d719391ef546a453f1cdf88efd1cb70ad612a6 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/requirements.txt @@ -0,0 +1,38 @@ +torch==2.1.0 +diffusers==0.26.3 +transformers==4.26.1 +GitPython==3.1.32 +Pillow==9.5.0 +accelerate==0.21.0 +basicsr==1.4.2 +blendmodes==2022 +clean-fid==0.1.35 +einops==0.4.1 +fastapi==0.94.0 +gfpgan==1.3.8 +gradio==3.41.2 +httpcore==0.15.0 +inflection==0.5.1 +jsonmerge==1.8.0 +kornia==0.6.7 +lark==1.1.2 +numpy==1.23.5 +omegaconf==2.2.3 +open-clip-torch==2.20.0 +piexif==1.1.3 +psutil==5.9.5 +pytorch_lightning==1.9.4 +realesrgan==0.3.0 +resize-right==0.0.2 +safetensors==0.3.1 +scikit-image==0.21.0 +timm==0.9.2 +tomesd==0.1.3 +torchdiffeq==0.2.3 +torchsde==0.2.6 +httpx==0.24.1 +tb-nightly +clip +dctorch +facexlib==0.3.0 +pydantic==1.10.14 diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/scripts/mindie_plugin.py b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/scripts/mindie_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..500844e7c133ffbafc207677710b99af8b4dba8b --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/scripts/mindie_plugin.py @@ -0,0 +1,54 @@ +import logging +import gradio as gr +from basicsr.utils import get_root_logger +from modules import scripts +from replace_torch_aie import replace_torch_aie +from config import NpuConfig + +def listen_change(choice): + if choice == 'MindIE_torch': + print("switch to MindIE_torch") + replace_torch_aie() + return + +class TorchAscendIEPlugin(scripts.Script): + + def __init__(self) -> None: + super().__init__() + self.logger = get_root_logger() + self.logger.info("import MindIEPlugin") + self.logger.setLevel(logging.INFO) + + def title(self): + return "webui-npu-extension" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_txt2img): + with gr.Group(): + with gr.Accordion("npu-extensions", open = True): + device_radio = gr.Radio(choices = ['None', 'Duo', 'A2'], value = "None", label="Ascend device: Duo is 310P3; A2 is 910B4") + device_radio.change(self.listen_device_change, inputs=device_radio) + npu_radio = gr.Radio(choices = ['None', 'MindIE_torch'], value = "None", label="Inference Engine choices") + npu_radio.change(listen_change, inputs = npu_radio) + + def listen_device_change(self, choice): + if choice == 'None': + print("do not use npu, use cpu default.") + NpuConfig.use_cpu = True + NpuConfig.Duo = False + NpuConfig.A2 = False + return + elif choice == 'Duo': + print("use Duo...") + NpuConfig.use_cpu = False + NpuConfig.Duo = True + NpuConfig.A2 = False + return + elif choice == 'A2': + print("use A2...") + NpuConfig.use_cpu = False + NpuConfig.Duo = False + NpuConfig.A2 = True + return \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/mindietorch_extension/sd_webui_patch.py b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/sd_webui_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..ad394fcaa1e3fdc86c2d037bdc1f7a6c7a80b3e4 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/mindietorch_extension/sd_webui_patch.py @@ -0,0 +1,30 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import transformers +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + assert diffusers_version is not "0.26.3", "expectation diffusers==0.26.3" + os.system( + f"patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch" + ) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/README.md b/MindIE/MultiModal/SD-WebUI/onnx_extension/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1aed2cf81ecf33915d828b958a1f595d7cb5fcf8 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/README.md @@ -0,0 +1,100 @@ +# SDWebUI-TorchAIE推理指导 +torch_aie_extension实现了一个SDWebUI界面的插件,用优化后的diffusers.Unet2DConditionModel替换原有的UNetModel进行推理,支持SD文生图和图生图功能。 + +# 概述 + + SDWebUI是一个基于Gradio库的WebUi界面,支持设置输入和参数用于SD模型的文生图、图生图等功能。有关SDWebUI的更多信息,请查看[Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)。 + +- 设备支持: +Atlas 800I A2推理设备 +Atlas 300I Duo推理卡 + +# 推理环境准备 + +该插件依赖torch2.1.0, python3.10环境 + +# 快速上手 +## sd_webui部署 +1. 拉取代码[stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) + + ```bash + git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git + ``` + +2. 拉取onnx_extension工程,放在stable-diffusion-webui/extensions路径下 + +3. 获取权重 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # v2.1,将该权重放在stable-diffusion-webui/extensions/onnx_extension/models路径下 + cd stable-diffusion-webui/extensions/onnx_extension/models + git clone https://huggingface.co/stabilityai/stable-diffusion-2-1-base + + # 将stable-diffusion-2-1-base下的v2-1_512-ema-pruned.safetensors复制到stable-diffusion-webui/models/Stable-diffusion路径下 + cp stable-diffusion-2-1-base/v2-1_512-ema-pruned.safetensors ../../../models/Stable-diffusion + cd ../../.. + ``` + +4. 在webui工程路径下执行命令启动webui,自动安装需要的环境 + + ```bash + python launch.py --skip-torch-cuda-test --port 22 --enable-insecure-extension-access --listen --log-startup --disable-safe-unpickle --no-half + ``` + +## 插件部署 +1. 按照requirements.txt要求的版本安装相关依赖,避免导出模型失败! +```bash + pip install -r requirements.txt +``` +2. 安装昇腾推理工具 + + 请访问[mist代码仓](https://gitee.com/ascend/msit/tree/master/msit/),根据readme文档进行工具安装。可只安装需要的组件:debug surgeon,其他组件为可选安装。 + + 请访问[ais_bench](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench),根据readme文件进行工具安装,建议使用whl包进行安装。 + +2. 代码修改,修改clip和cross_attention,用于导出正确的模型 +```bash + python sd_webui_patch.py +``` +3. 安装aie包和torch_aie包,配置AIE目录下的环境变量 +```bash + chmod +x ./Ascend-cann-aie_xxx.run + ./Ascend-cann-aie_xxx.run --install + source set_env.sh +``` + +## 运行功能 +1. 执行命令启动webui +```bash +python launch.py --skip-torch-cuda-test --port 22 --enable-insecure-extension-access --listen --log-startup --disable-safe-unpickle --no-half --skip-prepare-environment +``` +2. 请优先选择device,Duo或A2 +3. 文生图:选择ONNX按钮,输入文本,设置相关参数,点击generate生成结果 +4. 图生图:选择ONNX按钮,输入图像、文本,设置相关参数,点击generate生成结果 +5. 运用并行加速:点击Use_Parallel_Inferencing按钮选择 + +# 备注 + +1. 使用昇腾插件后,原始的webui界面中的某些配置受到限制,如下: + + 可配置参数: + + ``` + Sampling method + Sampling steps + CFG Scale + Seed + ``` + + 受限制参数: + + ``` + Width和Height要固定为512 + Batch count要固定为1 + Batch size要固定为1 + ``` + +2. 点击ONNX按钮,在第一次启动服务后,会做模型的处理,该处理会耗时10分钟左右,当后台输出"You can generate image now!"字样时,可进行图生成等操作。 diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/background_session.py b/MindIE/MultiModal/SD-WebUI/onnx_extension/background_session.py new file mode 100644 index 0000000000000000000000000000000000000000..040945c631f3cd4c2bc4c45ead9776ba0f487e5d --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/background_session.py @@ -0,0 +1,184 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +from dataclasses import dataclass +from typing import List, Optional +import multiprocessing.connection as connection +import numpy as np +from ais_bench.infer.interface import InferSession + + +@dataclass +class SessionIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +@dataclass +class BackgroundInferSessionOptions: + device_id: int + model_path: str + io_info: SessionIOInfo + acl_json_path: Optional[str] = None + debug: Optional[bool] = False + loop: Optional[int] = 1 + + +class BackgroundInferSession: + def __init__( + self, + device_id: int, + model_path: str, + io_info: SessionIOInfo, + acl_json_path: Optional[str] = None, + debug: Optional[bool] = False, + loop: Optional[int] = 1 + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [np.frombuffer(b, dtype=t).reshape(s) for ( + b, s, t) in zip(input_spaces, io_info.input_shapes, io_info.input_dtypes)] + self.output_arrays = [np.frombuffer(b, dtype=t).reshape(s) for ( + b, s, t) in zip(output_spaces, io_info.output_shapes, io_info.output_dtypes)] + + mp.set_start_method('forkserver', force=True) + self.p = mp.Process( + target=self.run_session, + args=[sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path, + acl_json_path, debug, loop] + ) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + def infer_asyn(self, feeds: List[np.ndarray]) -> None: + for i in range(len(self.input_arrays)): + self.input_arrays[i][:] = feeds[i][:] + + self.sync_pipe.send('') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + # This function should work as same as InferSession.infer() + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @classmethod + def clone( + cls, + session: InferSession, + device_id: int) -> 'BackgroundInferSession': + # Get shapes, datatypes, and model path from an existed InferSession, + # then use them to create a BackgroundInferSession + io_info = cls.get_io_info_from_session(session) + + return cls(device_id, session.model_path, io_info) + + @staticmethod + def get_io_info_from_session(session: InferSession) -> SessionIOInfo: + # Map aclruntime datatype to numpy datatype + np_types = (np.float32, np.float16, np.int8, np.int32, + np.uint8, '', np.int16, np.uint16, np.uint32, + np.int64, np.uint64) + + # Get input shapes and datatypes + inputs = session.get_inputs() + input_shapes = [t.shape for t in inputs] + input_dtypes = [np_types[t.datatype] for t in inputs] + + # Get output shapes and datatypes + outputs = session.get_outputs() + output_shapes = [t.shape for t in outputs] + output_dtypes = [np_types[t.datatype] for t in outputs] + + return SessionIOInfo(input_shapes, input_dtypes, + output_shapes, output_dtypes) + + @staticmethod + def create_shared_buffers(shapes: List[tuple], dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + @staticmethod + def run_session( + sync_pipe: connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: SessionIOInfo, + device_id: int, + model_path: str, + acl_json_path: Optional[str] = None, + debug: Optional[bool] = False, + loop: Optional[int] = 1 + ) -> None: + # The sub process function + + # Create an InferSession + session = InferSession( + device_id, + model_path, + acl_json_path, + debug, + loop + ) + + # Build numpy arrays on the shared buffers + input_arrays = [np.frombuffer(b, dtype=t).reshape(s) for ( + b, s, t) in zip(input_spaces, io_info.input_shapes, io_info.input_dtypes)] + + output_arrays = [np.frombuffer(b, dtype=t).reshape(s) for ( + b, s, t) in zip(output_spaces, io_info.output_shapes, io_info.output_dtypes)] + + # Tell the main function that we are ready + sync_pipe.send('') + + # Keep looping until recived a 'STOP' + while sync_pipe.recv() != 'STOP': + output = session.infer(input_arrays) + for i in range(len(output_arrays)): + output_arrays[i][:] = output[i][:] + + sync_pipe.send('') diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/clip.patch b/MindIE/MultiModal/SD-WebUI/onnx_extension/clip.patch new file mode 100644 index 0000000000000000000000000000000000000000..e3e4719b66f771ebb660f25151c33d140566c3f3 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/clip.patch @@ -0,0 +1,10 @@ +22a23 +> import numpy as np +760c761,762 +< mask.triu_(1) # zero out the lower diagonal +--- +> # mask.triu_(1) # zero out the lower diagonal +> mask = torch.from_numpy(np.triu(mask.numpy(), 1)) +1324a1327 +> + diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/config.py b/MindIE/MultiModal/SD-WebUI/onnx_extension/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f214c44cbb624a8c1fab8e3530a9d16e902a7624 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/config.py @@ -0,0 +1,9 @@ +class NpuConfig(object): + use_cpu = True + Duo = False + A2 = False + use_parallel_inferencing = False + unet_session = False + clip_session = False + vae_session = False + unet_session_bg = False diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/cross_attention.patch b/MindIE/MultiModal/SD-WebUI/onnx_extension/cross_attention.patch new file mode 100644 index 0000000000000000000000000000000000000000..b2fbe0d511f4e8678ed229ab952ddeb3fceea355 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/cross_attention.patch @@ -0,0 +1,14 @@ +--- cross_attention.py 2023-12-12 03:15:11.776000000 +0000 ++++ cross_attention.py 2023-12-12 03:15:25.400000000 +0000 +@@ -101,8 +101,9 @@ class CrossAttention(nn.Module): + # set attention processor + # We use the AttnProcessor2_0 by default when torch2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention +- if processor is None: +- processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() ++ #if processor is None: ++ # processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() ++ processor = CrossAttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/export2onnx.py b/MindIE/MultiModal/SD-WebUI/onnx_extension/export2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..1ccde926fcaf5a9f242d113bbf1f17b1b03d0e58 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/export2onnx.py @@ -0,0 +1,153 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from argparse import Namespace + +import torch +from diffusers import StableDiffusionPipeline + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save ONNX models.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + + return parser.parse_args() + + +def export_clip(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size:int) -> None: + print("Exporting the text encoder...") + clip_path = os.path.join(save_dir, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o744) + if os.path.exists(os.path.join(clip_path, "clip.onnx")): + return + clip_model = sd_pipeline.text_encoder + + max_position_embeddings = clip_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + torch.onnx.export( + clip_model, + dummy_input, + os.path.join(clip_path, "clip.onnx"), + input_names=["prompt"], + output_names=["text_embeddings"], + opset_version=11, + ) + + +def export_unet(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o744) + if os.path.exists(os.path.join(unet_path, "unet.onnx")): + return + + unet_model = sd_pipeline.unet + clip_model = sd_pipeline.text_encoder + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size = clip_model.config.hidden_size + max_position_embeddings = clip_model.config.max_position_embeddings + + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + ) + + torch.onnx.export( + unet_model, + dummy_input, + os.path.join(unet_path, f"unet.onnx"), + input_names=["latent_model_input", "t", "encoder_hidden_states"], + output_names=["sample"], + opset_version=11, + ) + + +def export_vae(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int) -> None: + print("Exporting the image decoder...") + + vae_path = os.path.join(save_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o744) + if os.path.exists(os.path.join(vae_path, "vae.onnx")): + return + vae_model = sd_pipeline.vae + unet_model = sd_pipeline.unet + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.out_channels + + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size]) + + torch.onnx.export( + vae_model.decoder, + dummy_input, + os.path.join(vae_path, "vae.onnx"), + input_names=["latents"], + output_names=["image"], + opset_version=11, + ) + + +def export_onnx(model_path: str, save_dir: str, batch_size:int, parallel: bool=False) -> None: + pipeline = StableDiffusionPipeline.from_pretrained(model_path).to("cpu") + + export_clip(pipeline, save_dir, batch_size) + + if parallel: + export_unet(pipeline, save_dir, batch_size) + else: + export_unet(pipeline, save_dir, batch_size * 2) + + export_vae(pipeline, save_dir, batch_size) + + +def main(): + args = parse_arguments() + cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + model = os.path.join(cur_dir_path, "models/stable-diffusion-2-1-base") + + parallel = False + output_dir = os.path.join(cur_dir_path, "models/SD2.1/models_bs1") + export_onnx(model, output_dir, args.batch_size, parallel) + parallel = True + output_dir = os.path.join(cur_dir_path, "models/SD2.1/models_bs1_parallel") + export_onnx(model, output_dir, args.batch_size, parallel) + print("Done.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/modify_onnx.py b/MindIE/MultiModal/SD-WebUI/onnx_extension/modify_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..43298a8822ec560622865e9d25ed07cca136459d --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/modify_onnx.py @@ -0,0 +1,468 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import argparse + +import numpy as np +from auto_optimizer import OnnxGraph +from config import NpuConfig + + +def del_add(model): + init = [n.name for n in model.get_nodes('Initializer')] + for node in model.get_nodes('Add'): + if 'attn' in node.name and node.inputs[1] in init: + value = model[node.inputs[1]].value + if (value == 0).all(): + model.remove(node.name) + + +def add_flash_attention(model, fa_name, soc_type): + for node in model.get_nodes('Mul'): + name = node.name + if soc_type == 1: + flag = 'attn' in name + else: + flag = 'attn1' in name + if flag: + matmul = model[name[:-3] + 'to_q/MatMul'] + reshape = model[name[:-3] + 'Reshape'] + if soc_type == 2 and model[reshape.inputs[1]].value[1] != 4096: + continue + softmax_node = model.get_next_nodes(node.outputs[0])[0] + if soc_type == 1: + # move mul to q + softmax_node.inputs[0] = node.inputs[0] + node.inputs[0] = matmul.outputs[0] + reshape.inputs[0] = node.outputs[0] + + # add flashattention + new_node = model.add_node(name[:-3] + fa_name, fa_name) + inputs = [None, None, None] + # input 0: q + if soc_type == 1: + matmul_node = model.get_prev_node(softmax_node.inputs[0]) + if soc_type == 2: + matmul_node = model.get_prev_node(node.inputs[0]) + inputs[0] = matmul_node.inputs[0] + # input 1: k + transpose_node = model.get_prev_node(matmul_node.inputs[1]) + inputs[1] = transpose_node.inputs[0] + # input 2: v + cast_node = model.get_next_nodes(softmax_node.outputs[0])[0] + last_node = model.get_next_nodes(cast_node.outputs[0])[0] + inputs[2] = last_node.inputs[1] + # output + outputs = last_node.outputs + # update link + new_node.inputs = inputs + new_node.outputs = outputs + + model.remove(matmul_node.name, {}) + model.remove(transpose_node.name, {}) + model.remove(softmax_node.name, {}) + model.remove(cast_node.name, {}) + model.remove(last_node.name, {}) + model.update_map() + for node in model.get_nodes(fa_name): + for _ in range(soc_type): + for i in range(3): + prev_node = model.get_prev_node(node.inputs[i]) + model.remove(prev_node.name) + next_node = model.get_next_nodes(node.outputs[0])[0] + model.remove(next_node.name) + if soc_type == 2: + name = node.name.replace(fa_name, 'Cast') + cast = model.add_node(name, 'Cast', attrs={'to': 1}) + model.insert_node(node.name, cast) + + +def change_input_type(model): + model.remove('t') + model.add_input('t', 'int32', [1]) + model.inputs[1], model.inputs[2] = model.inputs[2], model.inputs[1] + + +def get_index(model, init, name): + if name in init: + return model[name].value + else: + return name + + +def replace_slice(model, fast): + # find pairs of slice + slice_pair = [] + for node in model.get_nodes('Slice'): + if node.name[-2:] == '_1': + slice_pair.append((model[node.name[:-2]], model[node.name])) + # replace + init = [n.name for n in model.get_nodes('Initializer')] + for pair in slice_pair: + next_node = model.get_next_nodes(pair[0].outputs[0])[0] + if fast and next_node.op_type == 'Mul': + name = pair[0].name[:-5] + 'SliceTransGeluMul' + model.add_node(name, 'SliceTransGeluMul', inputs=[pair[0].inputs[0]], outputs=next_node.outputs) + model.remove(next_node.name, {}) + else: + name = pair[0].name[:-5] + 'Split' + data = pair[0].inputs[0] + start_0 = get_index(model, init, pair[0].inputs[1]) + end_0 = get_index(model, init, pair[0].inputs[2]) + start_1 = get_index(model, init, pair[1].inputs[1]) + end_1 = get_index(model, init, pair[1].inputs[2]) + if start_1 == end_0: + outputs = pair[0].outputs + pair[1].outputs + elif start_0 == end_1: + outputs = pair[1].outputs + pair[0].outputs + + axes = pair[0].inputs[3] + axis = model[axes].value[0] + model.add_node(name, 'Split', inputs=[data], outputs=outputs, attrs={'axis': axis}) + model.remove(pair[0].name, {}) + model.remove(pair[1].name, {}) + model.update_map() + + +def build_index(h, w, sy=2, sx=2): + # random select one from a 2x2 block + hsy = h // sy + wsx = w // sx + rand_idx = np.random.randint(sy * sx, size=(hsy, wsx)) + + idx = np.ones((hsy, wsx, sy * sx), dtype=np.int64) + for i in range(hsy): + for j in range(wsx): + idx[i, j][rand_idx[i, j]] = 0 + idx = idx.reshape(hsy, wsx, sy, sx).transpose(0, 2, 1, 3) + idx_rand = idx.reshape(-1).argsort() + index_a = np.sort(idx_rand[hsy * wsx:]) + index_b = np.sort(idx_rand[:hsy * wsx]) + return index_a, index_b + + +def get_block(model): + # find self-attention block + norms = [] + for node in model.get_nodes('Add'): + next_nodes = model.get_next_nodes(node.outputs[0]) + if len(next_nodes) != 3: + continue + op_type = set(n.op_type for n in next_nodes) + if len(op_type) == 1 and 'MatMul' in op_type: + if model[node.inputs[1]].value.shape[0] == 320: + norms.append(node) + return norms + + +def find_nodes(model, node): + prev_node = model.get_prev_node(node.inputs[0]) + while prev_node.op_type != 'Sub': + prev_node = model.get_prev_node(prev_node.inputs[0]) + inp = prev_node.inputs[0] + next_nodes = model.get_next_nodes(inp) + for next_node in next_nodes: + if next_node.op_type == 'Add': + if next_node.inputs[0] == inp: + out = next_node.inputs[1] + else: + out = next_node.inputs[0] + return inp, out + + +def build_tome_block(model, name, inputs, inputs_un): + # link merge to attn + for node in model.get_next_nodes(inputs[1]): + ind = 0 + for inp in node.inputs: + if inp == inputs[1]: + node.inputs[ind] = name + 'Concat_output' + ind += 1 + # norm block + model.add_node( + name + 'Mul', + 'Mul', + inputs=[inputs[0], inputs[0]], + outputs=[name + 'Mul_output'] + ) + model.add_node( + name + 'ReduceSum', + 'ReduceSum', + inputs=[name + 'Mul_output'], + outputs=[name + 'ReduceSum_output'], + attrs={'axes': [-1], 'keepdims': 1} + ) + model.add_node( + name + 'Sqrt', + 'Sqrt', + inputs=[name + 'ReduceSum_output'], + outputs=[name + 'Sqrt_output'] + ) + model.add_node( + name + 'Div', + 'Div', + inputs=[inputs[0], name + 'Sqrt_output'], + outputs=[name + 'Div_output'] + ) + # compute similarity + model.add_node( + name + 'Gather_0', + 'Gather', + inputs=[name + 'Div_output', 'tome/Gather_index_a'], + outputs=[name + 'Gather_0_output'], + attrs={'axis': 1} + ) + model.add_node( + name + 'Gather_1', + 'Gather', + inputs=[name + 'Div_output', 'tome/Gather_index_b'], + outputs=[name + 'Gather_1_output'], + attrs={'axis': 1} + ) + model.add_node( + name + 'Transpose', + 'Transpose', + inputs=[name + 'Gather_1_output'], + outputs=[name + 'Transpose_output'], + attrs={'perm': [0, 2, 1]} + ) + model.add_node( + name + 'MatMul', + 'MatMul', + inputs=[name + 'Gather_0_output', name + 'Transpose_output'], + outputs=[name + 'MatMul_output'] + ) + model.add_node( + name + 'FindMax', + 'FindMax', + inputs=[name + 'MatMul_output'], + outputs=[name + 'FindMax_output_0', name + 'FindMax_output_1'], + attrs={} + ) + model.add_node( + name + 'TopK', + 'TopK', + inputs=[name + 'FindMax_output_0', 'tome/Topk_k'], + outputs=[name + 'TopK_output_0', name + 'TopK_output_1'], + attrs={'axis': -1, 'largest': 1} + ) + # split token + model.add_node( + name + 'Gather_2', + 'Gather', + inputs=[inputs[1], 'tome/Gather_index_a'], + outputs=[name + 'Gather_2_output'], + attrs={'axis': 1} + ) + model.add_node( + name + 'Gather_3', + 'Gather', + inputs=[inputs[1], 'tome/Gather_index_b'], + outputs=[name + 'Gather_3_output'], + attrs={'axis': 1} + ) + model.add_node( + name + 'Cast_0', + 'Cast', + inputs=[name + 'Gather_2_output'], + outputs=[name + 'Cast_0_output'], + attrs={'to': 1} + ) + model.add_node( + name + 'Cast_1', + 'Cast', + inputs=[name + 'Gather_3_output'], + outputs=[name + 'Cast_1_output'], + attrs={'to': 1} + ) + # tome merge + merge_inputs = [ + name + 'Cast_0_output', + name + 'Cast_1_output', + name + 'TopK_output_1', + name + 'FindMax_output_1' + ] + merge_outputs = [ + name + 'TomeMerged_output_0', + name + 'TomeMerged_output_1', + name + 'TomeMerged_output_2' + ] + model.add_node( + name + 'TomeMerged', + 'TomeMerged', + inputs=merge_inputs, + outputs=merge_outputs + ) + model.add_node( + name + 'ReduceSum_1', + 'ReduceSum', + inputs=[name + 'TomeMerged_output_1'], + outputs=[name + 'ReduceSum_1_output'], + attrs={'axes': [1], 'keepdims': 0} + ) + model.add_node( + name + 'ReduceSum_2', + 'ReduceSum', + inputs=[name + 'TomeMerged_output_2'], + outputs=[name + 'ReduceSum_2_output'], + attrs={'axes': [1], 'keepdims': 0} + ) + model.add_node( + name + 'Unsqueeze', + 'Unsqueeze', + inputs=[name + 'ReduceSum_2_output'], + outputs=[name + 'Unsqueeze_output'], + attrs={'axes': [2]} + ) + model.add_node( + name + 'Div_1', + 'Div', + inputs=[name + 'ReduceSum_1_output', name + 'Unsqueeze_output'], + outputs=[name + 'Div_1_output'] + ) + model.add_node( + name + 'Concat', + 'Concat', + inputs=[name + 'TomeMerged_output_0', name + 'Div_1_output'], + outputs=[name + 'Concat_output'], + attrs={'axis': 1} + ) + # link unmerge to norm + for node in model.get_next_nodes(inputs_un[0]): + ind = 0 + for inp in node.inputs: + if inp == inputs_un[0]: + node.inputs[ind] = name + 'TomeUngerme_output' + ind += 1 + # add unmerge node + unmerge_inputs = inputs_un + [name + 'TopK_output_1', name + 'FindMax_output_1'] + model.add_node( + name + 'tome/TomeUnmerge', + 'TomeUnmerge', + inputs=unmerge_inputs, + outputs=[name + 'TomeUngerme_output'] + ) + model.update_map() + + +def insert_tome_block(model, max_num): + bs = model['latent_model_input'].shape[0] + h, w = model['latent_model_input'].shape[2:] + index_a, index_b = build_index(h, w) + # add initializer + model.add_initializer('tome/Gather_index_a', index_a) + model.add_initializer('tome/Gather_index_b', index_b) + bs_index_a = np.tile(index_a.reshape(1, -1), [bs, 1]) + bs_index_b = np.tile(index_b.reshape(1, -1), [bs, 1]) + model.add_initializer('tome/index_a', bs_index_a) + model.add_initializer('tome/index_b', bs_index_b) + model.add_initializer('tome/Topk_k', np.array([3072])) + # get reshape nodes + reshapes = model.get_nodes('Reshape') + # find inputs + norm_outs = get_block(model)[:max_num] + for node in norm_outs: + name = node.name.rsplit('/', 2)[0] + '/attn1/' + norm_input, sa_output = find_nodes(model, node) + inputs_0 = [norm_input] + node.outputs + inputs_1 = [sa_output] + ['tome/index_a', 'tome/index_b'] + # add tome block + build_tome_block(model, name.replace('attn', 'tome'), inputs_0, inputs_1) + # change shape of reshape + for reshape in reshapes: + if name in reshape.name: + shape = model[reshape.inputs[1]].value.copy() + ind = 0 + for size in shape: + if size == 4096: + shape[ind] = '-1' + ind += 1 + model[reshape.inputs[1]].value = shape + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="models/unet/unet.onnx", + help="Path of the unet onnx model.", + ) + parser.add_argument( + "--new_model", + type=str, + default="models/unet/unet_md.onnx", + help="Path to save the modified model", + ) + parser.add_argument( + "--FA_soc", + choices=["None", "Duo", "A2"], + default="None", + help="Type of FA operator.", + ) + parser.add_argument( + "--TOME_num", + type=int, + default=5, + help="Number of TOME used in the model", + ) + parser.add_argument( + "--faster_gelu", + default=True, + action="store_true", + help="Use specific gelu operation" + ) + return parser.parse_args() + + +def main(): + cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(cur_dir_path, "models/SD2.1/models_bs1/unet/unet.onnx") + new_model = os.path.join(cur_dir_path, "models/SD2.1/models_bs1/unet/unet_md.onnx") + if not os.path.exists(new_model): + model = OnnxGraph.parse(model_path) + del_add(model) + if NpuConfig.Duo: + add_flash_attention(model, 'FlashAttentionTik', soc_type=1) + elif NpuConfig.A2: + add_flash_attention(model, 'UnpadFlashAttentionMix', soc_type=2) + if args.TOME_num: + insert_tome_block(model, args.TOME_num) + change_input_type(model) + replace_slice(model, args.faster_gelu) + model.remove_unused_nodes() + model.save(new_model) + + model_parallel_path = os.path.join(cur_dir_path, "models/SD2.1/models_bs1_parallel/unet/unet.onnx") + new_model = os.path.join(cur_dir_path, "models/SD2.1/models_bs1_parallel/unet/unet_md.onnx") + if not os.path.exists(new_model): + model_parallel = OnnxGraph.parse(model_parallel_path) + del_add(model_parallel) + if NpuConfig.Duo: + add_flash_attention(model_parallel, 'FlashAttentionTik', soc_type=1) + elif NpuConfig.A2: + add_flash_attention(model_parallel, 'UnpadFlashAttentionMix', soc_type=2) + if args.TOME_num: + insert_tome_block(model_parallel, args.TOME_num) + change_input_type(model_parallel) + replace_slice(model_parallel, args.faster_gelu) + model_parallel.remove_unused_nodes() + new_model = os.path.join(cur_dir_path, "models/SD2.1/models_bs1_parallel/unet/unet_md.onnx") + model_parallel.save(new_model) + + +if __name__ == '__main__': + args = parse_arguments() + main() + \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/replace_onnx.py b/MindIE/MultiModal/SD-WebUI/onnx_extension/replace_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..270089842a12575783daff760b7a5dd216dd3155 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/replace_onnx.py @@ -0,0 +1,104 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import time +import math +import numpy as np +import torch +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from ldm.models.autoencoder import AutoencoderKL +from ldm.modules.encoders.modules import FrozenOpenCLIPEmbedder +from ldm.modules.diffusionmodules.util import timestep_embedding +from diffusers import StableDiffusionPipeline +from config import NpuConfig +from ais_bench.infer.interface import InferSession +from background_session import BackgroundInferSession +from modules import shared + +def env_om(cur_dir_path): + os.system(os.path.join(cur_dir_path, "setup.sh")) + +def replace_onnx(): + cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + env_om(cur_dir_path) + + device_0, device_1 = 2, None + if NpuConfig.unet_session_bg: + NpuConfig.unet_session_bg.stop() + NpuConfig.unet_session_bg = False + print("You can generate image now!") + + def unet_onnx(self, x, timesteps=None, context=None, y=None, **kwargs): + if x.shape[-1] != 64: # 64 = 512 // 8 + return x + + checkpoint = shared.opts.data['sd_model_checkpoint'] + if NpuConfig.use_parallel_inferencing: + device_1 = 3 + unet_path = os.path.join(cur_dir_path, "models/SD2.1/models_bs1_parallel/unet/unet.om") + else: + unet_path = os.path.join(cur_dir_path, "models/SD2.1/models_bs1/unet/unet.om") + + if not NpuConfig.unet_session: + NpuConfig.unet_session = InferSession(device_0, unet_path) + if NpuConfig.use_parallel_inferencing: + NpuConfig.unet_session_bg = BackgroundInferSession.clone(NpuConfig.unet_session, device_1) + if NpuConfig.use_parallel_inferencing: + context, context_2 = context.chunk(2) + x, x_2 = x.chunk(2) + NpuConfig.unet_session_bg.infer_asyn( + [ + x_2.cpu().numpy(), + timesteps[0][None].cpu().numpy().astype(np.int32), + context_2.cpu().numpy() + ] + ) + x = x.cpu().numpy() + t = timesteps[0][None].cpu().numpy().astype(np.int32) + context = context.cpu().numpy() + noise_pred = torch.from_numpy( + NpuConfig.unet_session.infer( + [ + x, + t, + context + ] + )[0] + ) + if NpuConfig.use_parallel_inferencing: + noise_pred_text = torch.from_numpy( + NpuConfig.unet_session_bg.wait_and_get_outputs()[0] + ) + noise_pred = torch.cat([noise_pred, noise_pred_text]) + + return noise_pred + UNetModel.forward = unet_onnx + + def clip_onnx(self, text): + clip_path = os.path.join(cur_dir_path, "models/SD2.1/models_bs1/clip/clip.om") + if not NpuConfig.clip_session: + NpuConfig.clip_session = InferSession(device_0, clip_path) + x = torch.from_numpy(NpuConfig.clip_session.infer([text.numpy()])[0]) + return x + FrozenOpenCLIPEmbedder.encode_with_transformer = clip_onnx + + def vae_onnx(self, z): + vae_path = os.path.join(cur_dir_path, "models/SD2.1/models_bs1/vae/vae.om") + z = self.post_quant_conv(z) + if not NpuConfig.vae_session: + NpuConfig.vae_session = InferSession(device_0, vae_path) + dec = torch.from_numpy(NpuConfig.vae_session.infer([z.numpy()])[0]) + return dec + AutoencoderKL.decode = vae_onnx \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/requirements.txt b/MindIE/MultiModal/SD-WebUI/onnx_extension/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2813fba74e04c4adae41738b9f362703e7ec7fc7 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/requirements.txt @@ -0,0 +1,33 @@ +torch==2.1.0 +diffusers==0.14.0 +transformers==4.26.1 +GitPython==3.1.32 +Pillow==9.5.0 +accelerate==0.21.0 +basicsr==1.4.2 +blendmodes==2022 +clean-fid==0.1.35 +einops==0.4.1 +fastapi==0.94.0 +gfpgan==1.3.8 +gradio==3.41.2 +httpcore==0.15.0 +inflection==0.5.1 +jsonmerge==1.8.0 +kornia==0.6.7 +lark==1.1.2 +numpy==1.23.5 +omegaconf==2.2.3 +open-clip-torch==2.20.0 +piexif==1.1.3 +psutil==5.9.5 +pytorch_lightning==1.9.4 +realesrgan==0.3.0 +resize-right==0.0.2 +safetensors==0.3.1 +scikit-image==0.21.0 +timm==0.9.2 +tomesd==0.1.3 +torchdiffeq==0.2.3 +torchsde==0.2.6 +httpx==0.24.1 \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/scripts/onnx_plugin.py b/MindIE/MultiModal/SD-WebUI/onnx_extension/scripts/onnx_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..5282390e7468b4333ef4f9d0558926416877d7f6 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/scripts/onnx_plugin.py @@ -0,0 +1,75 @@ +import logging +import gradio as gr +from basicsr.utils import get_root_logger +from modules import scripts +from replace_onnx import replace_onnx +from config import NpuConfig + +def listen_change(choice): + if choice == 'ONNX': + print("switch to ONNX") + replace_onnx() + return + else: + print("do nothing...") + + +class AscendIEPlugin(scripts.Script): + + def __init__(self) -> None: + super().__init__() + self.logger = get_root_logger() + self.logger.info("import AscendIEPlugin") + self.logger.setLevel(logging.INFO) + + def title(self): + return "webui-npu-extension" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_txt2img): + with gr.Group(): + with gr.Accordion("npu-extensions", open = True): + device_radio = gr.Radio(choices = ['None', 'Duo', 'A2'], value = "None", label="Ascend device: Duo is 310P3; A2 is 910B4") + device_radio.change(self.listen_device_change, inputs=device_radio) + + npu_radio = gr.Radio(choices = ['None', 'ONNX'], value = "None", label="Ascend boosting choices") + npu_radio.change(listen_change, inputs=npu_radio) + + parallel_inferencing_checkbox = gr.Checkbox(label = 'Use_Parallel_Inferencing', info="Do you want to use parallel inferencing?") + parallel_inferencing_checkbox.change(self.listen_parallel_status, inputs = parallel_inferencing_checkbox) + + def listen_parallel_status(self, status): + if status: + self.logger.info("Start to use parallel inferencing") + NpuConfig.use_parallel_inferencing = True + NpuConfig.unet_session = False + else: + self.logger.info("Stop using parallel inferencing") + NpuConfig.use_parallel_inferencing = False + NpuConfig.unet_session = False + if NpuConfig.unet_session_bg: + print("stop unet_session_bg") + NpuConfig.unet_session_bg.stop() + NpuConfig.unet_session_bg = False + + def listen_device_change(self, choice): + if choice == 'None': + print("do not use npu, use cpu default.") + NpuConfig.use_cpu = True + NpuConfig.Duo = False + NpuConfig.A2 = False + return + elif choice == 'Duo': + print("use Duo...") + NpuConfig.use_cpu = False + NpuConfig.Duo = True + NpuConfig.A2 = False + return + elif choice == 'A2': + print("use A2...") + NpuConfig.use_cpu = False + NpuConfig.Duo = False + NpuConfig.A2 = True + return diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/sd_webui_patch.py b/MindIE/MultiModal/SD-WebUI/onnx_extension/sd_webui_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..eb9860b8c72888bfe4d539415ec08aae8696d393 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/sd_webui_patch.py @@ -0,0 +1,38 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import transformers +import diffusers + + +def main(): + transformers_path = transformers.__path__ + transformers_version = transformers.__version__ + + assert transformers_version is not "4.26.1", "expectation transformers==4.26.1" + os.system( + f"patch -p0 {transformers_path[0]}/models/clip/modeling_clip.py clip.patch" + ) + + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + assert diffusers_version is not "0.14.0", "expectation diffusers==0.14.0" + os.system( + f"patch -p0 {diffusers_path[0]}/models/cross_attention.py cross_attention.patch" + ) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/SD-WebUI/onnx_extension/setup.sh b/MindIE/MultiModal/SD-WebUI/onnx_extension/setup.sh new file mode 100644 index 0000000000000000000000000000000000000000..7b232d0df4295cd4a8a6303ff1ee05060214c99e --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/onnx_extension/setup.sh @@ -0,0 +1,80 @@ +#! /bin/bash + +soc_version=$(echo $(npu-smi info) | cut -d "|" -f12 | cut -d " " -f3) + +current_path=$(cd $(dirname $0);pwd) + +python ${current_path}/export2onnx.py +python ${current_path}/modify_onnx.py + +model_clip="models/SD2.1/models_bs1/clip" +model_unet="models/SD2.1/models_bs1/unet" +model_vae="models/SD2.1/models_bs1/vae" + +# clip +if [ ! -e "${current_path}/${model_clip}/clip.om" ]; then + atc --framework=5 \ + --model=${current_path}/${model_clip}/clip.onnx \ + --output=${current_path}/${model_clip}/clip \ + --input_format=ND \ + --log=error \ + --soc_version=Ascend${soc_version} +fi + +# unet +if [ ! -e "${current_path}/${model_unet}/unet.om" ]; then + atc --framework=5 \ + --model=${current_path}/${model_unet}/unet_md.onnx \ + --output=${current_path}/${model_unet}/unet \ + --input_format=NCHW \ + --log=error \ + --optypelist_for_implmode="Gelu,Sigmoid" \ + --op_select_implmode=high_performance \ + --soc_version=Ascend${soc_version} +fi + +# vae +if [ ! -e "${current_path}/${model_vae}/vae.om" ]; then + atc --framework=5 \ + --model=${current_path}/${model_vae}/vae.onnx \ + --output=${current_path}/${model_vae}/vae \ + --input_format=NCHW \ + --log=error \ + --soc_version=Ascend${soc_version} +fi + +model_parallel_clip="models/SD2.1/models_bs1_parallel/clip" +model_parallel_unet="models/SD2.1/models_bs1_parallel/unet" +model_parallel_vae="models/SD2.1/models_bs1_parallel/vae" + +# clip +if [ ! -e "${current_path}/${model_parallel_clip}/clip.om" ]; then + atc --framework=5 \ + --model=${current_path}/${model_parallel_clip}/clip.onnx \ + --output=${current_path}/${model_parallel_clip}/clip \ + --input_format=ND \ + --log=error \ + --soc_version=Ascend${soc_version} +fi + +# unet +if [ ! -e "${current_path}/${model_parallel_unet}/unet.om" ]; then + atc --framework=5 \ + --model=${current_path}/${model_parallel_unet}/unet_md.onnx \ + --output=${current_path}/${model_parallel_unet}/unet \ + --input_format=NCHW \ + --log=error \ + --optypelist_for_implmode="Gelu,Sigmoid" \ + --op_select_implmode=high_performance \ + --soc_version=Ascend${soc_version} +fi + +# vae +if [ ! -e "${current_path}/${model_parallel_vae}/vae.om" ]; then + atc --framework=5 \ + --model=${current_path}/${model_parallel_vae}/vae.onnx \ + --output=${current_path}/${model_parallel_vae}/vae \ + --input_format=NCHW \ + --log=error \ + --soc_version=Ascend${soc_version} +fi \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/README.md b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3f751c88b1c4f6243c66142fcf1de861f1b83d72 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/README.md @@ -0,0 +1,90 @@ +# SDWebUI-TorchAIE推理指导 +torch_aie_extension实现了一个SDWebUI界面的插件,用优化后的diffusers.Unet2DConditionModel替换原有的UNetModel进行推理,支持SD文生图和图生图功能。底层调用了MindIE的build编译优化功能,通过PASS改图、Batch并行等优化手段,提升了推理性能。 + + +# 概述 + + SDWebUI是一个基于Gradio库的WebUi界面,支持设置输入和参数用于SD模型的文生图、图生图等功能。有关SDWebUI的更多信息,请查看[Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)。 + +# 推理环境准备 + +该插件依赖torch2.1.0, python3.10环境 + +# 快速上手 + +## 环境准备 + +1. 按照requirements.txt要求的版本安装相关依赖,避免导出模型失败! + + ``` + pip install -r requirements.txt + ``` + +2. 安装mindie包和mindietorch包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/aie/set_env.sh + # 安装mindietorch + tar -zxvf Ascend-mindie-torch_xxx.tar.gz + pip install mindietorch-1.0.rc1+torch2.1.0xxx.whl + ``` + +3. 代码修改,修改clip和cross_attention,用于trace正确的模型 + + ```bash + python sd_webui_patch.py + ``` + +## sd_webui部署 + +1. 拉取webui工程代码[stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) + + ```bash + git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git + ``` + +2. 拉取torch_aie_extension工程,放在stable-diffusion-webui/extensions路径下 + +3. 获取权重 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # v1.5,将该权重放在stable-diffusion-webui/extensions/torch_aie_extension/models路径下 + cd stable-diffusion-webui/extensions/torch_aie_extension/models + git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 + ``` + +4. 将特定权重放在stable-diffusion-webui/models/Stable-diffusion路径下。注意:本插件支持的webui权重如下: + + ```bash + # 二选一即可,推荐safetensors + v1-5-pruned-emaonly.safetensors + v1-5-pruned-emaonly.ckpt + ``` + + ```bash + # 举例: + cp stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors ../../../models/Stable-diffusion + ``` + +5. 在stable-diffusion-webui工程路径下执行命令启动webui,自动安装需要的环境 + + ```bash + python launch.py --skip-torch-cuda-test --port 22 --enable-insecure-extension-access --listen --log-startup --disable-safe-unpickle --no-half + ``` + +## 运行功能 +1. 执行命令启动webui +```bash +python launch.py --skip-torch-cuda-test --port 22 --enable-insecure-extension-access --listen --log-startup --disable-safe-unpickle --no-half --skip-prepare-environment +``` +2. 文生图:选择torch_aie按钮,输入文本,设置相关参数,点击generate生成结果 + +3. 图生图:选择torch_aie按钮,输入图像、文本,设置相关参数,点击generate生成结果 + +4. 运用并行加速:点击Use_Parallel_Inferencing按钮选择 diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/clip.patch b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/clip.patch new file mode 100644 index 0000000000000000000000000000000000000000..e3e4719b66f771ebb660f25151c33d140566c3f3 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/clip.patch @@ -0,0 +1,10 @@ +22a23 +> import numpy as np +760c761,762 +< mask.triu_(1) # zero out the lower diagonal +--- +> # mask.triu_(1) # zero out the lower diagonal +> mask = torch.from_numpy(np.triu(mask.numpy(), 1)) +1324a1327 +> + diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/config.py b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/config.py new file mode 100644 index 0000000000000000000000000000000000000000..61dab178daca80ef392b1e9af8665abbad5c4dc4 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/config.py @@ -0,0 +1,4 @@ +class NpuConfig(object): + compiled_unet_model = None + use_parallel_inferencing = False + unet_bg = None \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/cross_attention.patch b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/cross_attention.patch new file mode 100644 index 0000000000000000000000000000000000000000..b2fbe0d511f4e8678ed229ab952ddeb3fceea355 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/cross_attention.patch @@ -0,0 +1,14 @@ +--- cross_attention.py 2023-12-12 03:15:11.776000000 +0000 ++++ cross_attention.py 2023-12-12 03:15:25.400000000 +0000 +@@ -101,8 +101,9 @@ class CrossAttention(nn.Module): + # set attention processor + # We use the AttnProcessor2_0 by default when torch2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention +- if processor is None: +- processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() ++ #if processor is None: ++ # processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() ++ processor = CrossAttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/pt_background_runtime_np.py b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/pt_background_runtime_np.py new file mode 100644 index 0000000000000000000000000000000000000000..94cb0ede1fdae3dbac67eaeba9aff4d4f75a5e1e --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/pt_background_runtime_np.py @@ -0,0 +1,196 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import numpy as np +import torch +import mindietorch +import time +from typing import List +from dataclasses import dataclass + + +@dataclass +class RuntimeIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntime: + def __init__(self, device_id: int, model_path: str, io_info: RuntimeIOInfo): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers( + io_info.input_shapes, io_info.input_dtypes + ) + output_spaces = self.create_shared_buffers( + io_info.output_shapes, io_info.output_dtypes + ) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) + for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes + ) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) + for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes + ) + ] + + mp.set_start_method("spawn", force=True) + self.p = mp.Process( + target=self.run_infer, + args=[ + sync_pipe_peer, + input_spaces, + output_spaces, + io_info, + device_id, + model_path, + ], + ) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers( + shapes: List[tuple], dtypes: List[type] + ) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray]) -> None: + for i, _ in enumerate(self.input_arrays): + print(f"bg input shape: {self.input_arrays[i].shape}") + print(f"feeds shape: {feeds[i].shape}") + self.input_arrays[i][:] = feeds[i][:] + + self.sync_pipe.send("") + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send("STOP") + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfo, + device_id: int, + model_path: str, + ) -> None: + # The sub process function + + # Create a runtime + mindietorch.set_device(device_id) + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model = torch.jit.load(model_path).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) + for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes + ) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) + for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes + ) + ] + + # Tell the main function that we are ready + sync_pipe.send("") + + infer_num = 0 + preprocess_time = 0 + infer_time = 0 + + # Keep looping until recived a 'STOP' + while sync_pipe.recv() != "STOP": + start = time.time() + sample, timestep, encoder_hidden_states = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + sample_npu = sample.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to( + f"npu:{device_id}" + ) + preprocess_time += time.time() - start + + start2 = time.time() + output_npu = model(sample_npu, timestep_npu, encoder_hidden_states_npu).to( + "cpu" + ) + infer_time += time.time() - start2 + + for i, _ in enumerate(output_arrays): + output = output_npu.numpy() + output_arrays[i][:] = output[i][:] + + infer_num += 1 + sync_pipe.send("") + + infer_num /= 50 + print( + f"" + f"bg preprocess_time time: {preprocess_time / infer_num:.3f}s\n" + f"bg infer time: {infer_time / infer_num:.3f}s\n" + ) + + @classmethod + def clone( + cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfo + ) -> "BackgroundRuntime": + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntime + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/replace_torch_aie.py b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/replace_torch_aie.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae25739f7a9866ef04286120ef1527905287d5b --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/replace_torch_aie.py @@ -0,0 +1,114 @@ +import os +import sys +import time +import math +import numpy as np +import torch +import mindietorch +from mindietorch import _enums +from ldm.modules.diffusionmodules.openaimodel import UNetModel +from ldm.modules.diffusionmodules.util import timestep_embedding +from diffusers import StableDiffusionPipeline +from config import NpuConfig +from pt_background_runtime_np import BackgroundRuntime, RuntimeIOInfo + + +class UnetExport(torch.nn.Module): + def __init__(self, model): + super(UnetExport, self).__init__() + self.unet_model = model + + def forward(self, sample, timestep, encoder_hidden_states): + return self.unet_model(sample, timestep, encoder_hidden_states)[0] + + +def replace_unet_torch_aie(): + cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + save_dir = os.path.join(cur_dir_path, "models") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + device_0, device_1 = 2, None + mindietorch.set_device(device_0) + model_base = os.path.join(save_dir, "stable-diffusion-v1-5") + if NpuConfig.use_parallel_inferencing: + batch_size = 1 + device_1 = 3 + unet_path = os.path.join(save_dir, "unet_aie_compile_bs1.pt") + else: + batch_size = 2 + unet_path = os.path.join(save_dir, "unet_aie_compile_bs2.pt") + + def torch_aie_unet(self, x, timesteps=None, context=None, y=None, **kwargs): + if not NpuConfig.compiled_unet_model: + if not os.path.exists(unet_path): + pipe = StableDiffusionPipeline.from_pretrained(model_base).to("cpu") + in_channels = pipe.unet.config.out_channels + sample_size = pipe.unet.config.sample_size + encoder_hidden_size = pipe.text_encoder.config.hidden_size + max_position_embeddings = ( + pipe.text_encoder.config.max_position_embeddings + ) + dummy_input = ( + torch.ones( + [batch_size, in_channels, sample_size, sample_size], + dtype=torch.float32, + ), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], + dtype=torch.float32, + ), + ) + unet = UnetExport(pipe.unet) + model = torch.jit.trace(unet, dummy_input) + unet_input_info = [ + mindietorch.Input( + (batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT, + ), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input( + (batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT, + ), + ] + compiled_unet_model = mindietorch.compile( + model, + inputs=unet_input_info, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version="Ascend910B3", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=1, + ) + torch.jit.save(compiled_unet_model, unet_path) + NpuConfig.compiled_unet_model = compiled_unet_model + else: + NpuConfig.compiled_unet_model = torch.jit.load(unet_path).eval() + if NpuConfig.use_parallel_inferencing: + NpuConfig.unet_bg = BackgroundRuntime.clone( + device_1, unet_path, runtime_info + ) + + if NpuConfig.use_parallel_inferencing: + context, context_2 = context.chunk(2) + x, x_2 = x.chunk(2) + NpuConfig.unet_bg.infer_asyn( + x_2.numpy(), + timesteps[0][None].numpy().astype(np.int64), + context_2.numpy(), + ) + noise_pred = NpuConfig.compiled_unet_model( + x.to(f"npu:{device_0}"), + timesteps[0][None].type(torch.int64).to(f"npu:{device_0}"), + context.to(f"npu:{device_0}"), + ).to("cpu") + if NpuConfig.use_parallel_inferencing: + noise_pred_text = torch.from_numpy( + NpuConfig.unet_bg.wait_and_get_outputs()[0] + ) + noise_pred = torch.cat([noise_pred, noise_pred_text]) + return noise_pred + + UNetModel.forward = torch_aie_unet diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/requirements.txt b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2813fba74e04c4adae41738b9f362703e7ec7fc7 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/requirements.txt @@ -0,0 +1,33 @@ +torch==2.1.0 +diffusers==0.14.0 +transformers==4.26.1 +GitPython==3.1.32 +Pillow==9.5.0 +accelerate==0.21.0 +basicsr==1.4.2 +blendmodes==2022 +clean-fid==0.1.35 +einops==0.4.1 +fastapi==0.94.0 +gfpgan==1.3.8 +gradio==3.41.2 +httpcore==0.15.0 +inflection==0.5.1 +jsonmerge==1.8.0 +kornia==0.6.7 +lark==1.1.2 +numpy==1.23.5 +omegaconf==2.2.3 +open-clip-torch==2.20.0 +piexif==1.1.3 +psutil==5.9.5 +pytorch_lightning==1.9.4 +realesrgan==0.3.0 +resize-right==0.0.2 +safetensors==0.3.1 +scikit-image==0.21.0 +timm==0.9.2 +tomesd==0.1.3 +torchdiffeq==0.2.3 +torchsde==0.2.6 +httpx==0.24.1 \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/scripts/torch_aie_plugin.py b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/scripts/torch_aie_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..00d35810ecd6186543aaa349ca241fde32b23148 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/scripts/torch_aie_plugin.py @@ -0,0 +1,42 @@ +import logging +import gradio as gr +from basicsr.utils import get_root_logger +from modules import scripts +from replace_torch_aie import replace_unet_torch_aie +from config import NpuConfig + +def listen_change(choice): + if choice == 'torch_aie': + print("switch to torch_aie") + replace_unet_torch_aie() + return + +class TorchAscendIEPlugin(scripts.Script): + + def __init__(self) -> None: + super().__init__() + self.logger = get_root_logger() + self.logger.info("import TorchAscendIEPlugin") + self.logger.setLevel(logging.INFO) + + def title(self): + return "webui-npu-extension" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_txt2img): + with gr.Group(): + with gr.Accordion("npu-extensions", open = True): + npu_radio = gr.Radio(choices = ['torch_aie'], value = "None") + npu_radio.change(listen_change, inputs = npu_radio) + parallel_inferencing_checkbox = gr.Checkbox(label = 'Use_Parallel_Inferencing', info="Do you want to use parallel inferencing?") + parallel_inferencing_checkbox.change(self.listen_parallel_status, inputs = parallel_inferencing_checkbox) + + def listen_parallel_status(self, status): + if status: + self.logger.info("Start to use parallel inferencing") + NpuConfig.use_parallel_inferencing = True + else: + self.logger.info("Stop using parallel inferencing") + NpuConfig.use_parallel_inferencing = False \ No newline at end of file diff --git a/MindIE/MultiModal/SD-WebUI/torch_aie_extension/sd_webui_patch.py b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/sd_webui_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..eb9860b8c72888bfe4d539415ec08aae8696d393 --- /dev/null +++ b/MindIE/MultiModal/SD-WebUI/torch_aie_extension/sd_webui_patch.py @@ -0,0 +1,38 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import transformers +import diffusers + + +def main(): + transformers_path = transformers.__path__ + transformers_version = transformers.__version__ + + assert transformers_version is not "4.26.1", "expectation transformers==4.26.1" + os.system( + f"patch -p0 {transformers_path[0]}/models/clip/modeling_clip.py clip.patch" + ) + + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + assert diffusers_version is not "0.14.0", "expectation diffusers==0.14.0" + os.system( + f"patch -p0 {diffusers_path[0]}/models/cross_attention.py cross_attention.patch" + ) + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/README.md b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/README.md new file mode 100644 index 0000000000000000000000000000000000000000..25bc51996be7706a2c87931cf5292b3c7d161289 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/README.md @@ -0,0 +1,186 @@ +# stable-audio-open-1.0模型-diffusers方式推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + +# 概述 + + [此处获得](https://huggingface.co/stabilityai/stable-audio-open-1.0) + +- 参考实现: + ```bash + # StableAudioOpen1.0 + https://huggingface.co/stabilityai/stable-audio-open-1.0 + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 +Atlas 300I Duo推理卡:支持的卡数为1 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + +# 快速上手 +## 获取源码 +1. 安装依赖。 + ```bash + pip3 install -r requirements.txt + apt-get update + apt-get install libsndfile1 + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + source /usr/local/Ascend/ascend-toolkit/set_env.sh + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + +- 执行命令: + ```bash + python3 diffusers_aie_patch.py + python3 brownian_interval_patch.py + ``` + +4. MindieTorch配套Torch_NPU使用 + + MindieTorch采用dlopen的方式动态加载Torch_NPU,需要手动编译libtorch_npu_bridge.so,并将其放在libtorch_aie.so同一路径下,或者将其路径设置到LD_LIBRARY_PATH环境变量中,具体参考: + ```bash + https://www.hiascend.com/document/detail/zh/mindie/10RC2/mindietorch/Torchdev/mindie_torch0017.html + ``` + +## 模型推理 + +1. 模型转换。 + + 1. 提前下载权重,放到代码同级目录下。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # 下载stable-audio-open-1.0权重 + git clone https://huggingface.co/stabilityai/stable-audio-open-1.0 + ``` + + 2. 导出pt模型并进行编译。 + + (1) 设置模型权重的路径 + ```bash + # stable-audio-open-1.0 (执行时下载权重) + model_base="stabilityai/stable-audio-open-1.0" + + # stable-audio-open-1.0 (使用上一步下载的权重) + model_base="./stable-audio-open-1.0" + ``` + + (2) 执行命令查看芯片名称($\{chip\_name\})。 + + ``` + npu-smi info + ``` + + (3) 执行export命令 + + ```bash + python3 export_ts.py --model ${model_base} --output_dir ./models --soc Ascend${chip_name} --device 0 + ``` + + 参数说明: + - --model:模型权重路径 + - --output_dir: 存放导出模型的路径 + - --soc:处理器型号。 + - --device:推理设备ID + + 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 + +2. 开始推理验证。 + + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + numactl -C 0-23 python3 stable_audio_open_aie_pipeline.py \ + --model ${model_base} \ + --output_dir ./models \ + --prompt_file ./prompts.txt \ + --num_inference_steps 100 \ + --audio_end_in_s 10 10 47 \ + --num_waveforms_per_prompt 1 \ + --guidance_scale 7 \ + --save_dir ./results \ + --device 0 + ``` + + 参数说明: + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --num_inference_steps: 语音生成迭代次数。 + - --audio_end_in_s:生成语音的时长,如不输入则默认生成10s。 + - --num_waveforms_per_prompt:一个提示词生成的语音数量。 + - --guidance_scale:音频生成质量与准确度系数。 + - --save_dir:生成语音的存放目录。 + - --device:推理设备ID。 + + 执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 + + + +# 模型推理性能&精度 +性能参考下列数据。 + +### Stable-Audio-Open-1.0 + +| 硬件形态 | 迭代次数 | 平均耗时| +| :------: |:----:|:----:| +| Atlas 800I A2(8*32G) | 100 | 5.895s | \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/attention_processor.patch b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..66a6016d8d4f4a9e113402cded7a14d2f58f1a6a --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/attention_processor.patch @@ -0,0 +1,36 @@ +--- attention_processor.py 2024-09-26 09:41:41.531223700 +0800 ++++ attention_processor_patch.py 2024-10-16 15:56:30.103608300 +0800 +@@ -2312,8 +2312,12 @@ + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads +- key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) +- value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) ++ key = key.unsqueeze(2) ++ value = value.unsqueeze(2) ++ key = key.repeat(1, 1, heads_per_kv_head, 1, 1) ++ value = value.repeat(1, 1, heads_per_kv_head, 1, 1) ++ key = key.view(batch_size, attn.heads, sequence_length, head_dim) ++ value = value.view(batch_size, attn.heads, sequence_length, head_dim) + + if attn.norm_q is not None: + query = attn.norm_q(query) +@@ -2344,9 +2348,15 @@ + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 +- hidden_states = F.scaled_dot_product_attention( +- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False +- ) ++ import mindietorch ++ scale_factor = query.size(-1)**-0.5 ++ hidden_states = torch.ops.aie.flash_attention(query,key,value, ++ num_head=attn.heads, ++ attn_mask=attention_mask, ++ pse=None, ++ scale=scale_factor, ++ layout="BNSD", ++ type="PFA") + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/brownian_interval.patch b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/brownian_interval.patch new file mode 100644 index 0000000000000000000000000000000000000000..168e1e26edf6d8aacf47565cc5276735949a8d0f --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/brownian_interval.patch @@ -0,0 +1,11 @@ +--- brownian_interval.py 2024-09-06 18:31:27.596848500 +0800 ++++ brownian_interval_patch.py 2024-09-06 18:39:44.069857000 +0800 +@@ -226,7 +226,7 @@ + else: + # Don't compute space-time Levy area unless we need to + +- mean = left_diff * W * h_reciprocal ++ mean = left_diff * h_reciprocal * W + var = left_diff * right_diff * h_reciprocal + noise = parent._randn(parent._W_seed) + left_W = mean + math.sqrt(var) * noise diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/brownian_interval_patch.py b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/brownian_interval_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..da4ec84a8738b4aa8c8a2b7fb139dc457374a122 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/brownian_interval_patch.py @@ -0,0 +1,14 @@ +import os +import torchsde + + +def main(): + torchsde_path = torchsde.__path__ + torchsde_version = torchsde.__version__ + + assert torchsde_version is not '0.2.6', "expectation torchsde_version==0.2.6" + os.system(f'patch -p0 {torchsde_path[0]}/_brownian/brownian_interval.py brownian_interval.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/diffusers_aie_patch.py b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/diffusers_aie_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..659eaca188a0903afef9f7b52535a907392c24ae --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/diffusers_aie_patch.py @@ -0,0 +1,16 @@ +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.30.0', "expectation diffusers==0.30.0" + os.system(f'patch -p0 {diffusers_path[0]}/models/embeddings.py embeddings.patch') + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch') + os.system(f'patch -p0 {diffusers_path[0]}/models/transformers/stable_audio_transformer.py stable_audio_transformer.patch') + os.system(f'patch -p0 {diffusers_path[0]}/pipelines/stable_audio/pipeline_stable_audio.py pipeline_stable_audio.patch') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/embeddings.patch b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/embeddings.patch new file mode 100644 index 0000000000000000000000000000000000000000..ffc5390f2f13fe82ef3a0a0634079f82be58213b --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/embeddings.patch @@ -0,0 +1,13 @@ +--- embeddings.py 2024-09-27 17:00:05.872952500 +0800 ++++ embeddings_patch.py 2024-09-27 17:03:21.806110200 +0800 +@@ -524,7 +524,9 @@ + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Use for example in Stable Audio +- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] ++ x_ = x.reshape(*x.shape[:-1], 2, -1) ++ x_real = x_[...,0,:] ++ x_imag = x_[...,1,:] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/export_ts.py b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/export_ts.py new file mode 100644 index 0000000000000000000000000000000000000000..7401c69016d2786dd05eff27c1f04292bb45bd44 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/export_ts.py @@ -0,0 +1,88 @@ +import os +import torch +import mindietorch +from mindietorch import _enums +import argparse +from argparse import Namespace +from diffusers.models.transformers.stable_audio_transformer import StableAudioDiTModel + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt&ts models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-audio-open-1.0", + help="The path of the pretrained stable-audio-open-1.0.", + ) + parser.add_argument( + "--soc", + help="soc_version.", + ) + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device.", + ) + return parser.parse_args() + +def export(args) -> None: + print("Exporting the dit...") + audio_dit = StableAudioDiTModel.from_pretrained(args.model+"/transformer").to("cpu") + audio_dit.to(torch.float32).eval() + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir, mode=0o640) + + #trace + dit_pt_path = os.path.join(args.output_dir, f"dit.pt") + if not os.path.exists(dit_pt_path): + prompt = torch.randn([2, 64, 1024], dtype=torch.float32) + t = torch.ones([1], dtype=torch.float32) + encoder_hidden_states = torch.randn([2, 130, 768], dtype=torch.float32) + global_hidden_states = torch.randn([2, 1, 1536], dtype=torch.float32) + rotary_embedding = torch.randn([2, 1025, 32], dtype=torch.float32) + dummy_input = (prompt, t, encoder_hidden_states, global_hidden_states, rotary_embedding) + torch.jit.trace(audio_dit, dummy_input).save(dit_pt_path) + + #compile + dit_compiled_path = os.path.join(args.output_dir, f"dit_compile.ts") + if not os.path.exists(dit_compiled_path): + dit = torch.jit.load(dit_pt_path).eval() + compiled_dit = ( + mindietorch.compile(dit, + inputs=[mindietorch.Input((2, 64, 1024), + dtype=mindietorch.dtype.FLOAT16), + mindietorch.Input((1,), + dtype=mindietorch.dtype.FLOAT16), + mindietorch.Input((2, 130, 768), + dtype=mindietorch.dtype.FLOAT16), + mindietorch.Input((2, 1, 1536), + dtype=mindietorch.dtype.FLOAT16), + mindietorch.Input((2, 1025, 32), + dtype=mindietorch.dtype.FLOAT16)], + allow_tensor_replace_int=False, + require_full_compilation=False, + truncate_long_and_double=False, + soc_version=args.soc, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_dit, dit_compiled_path) + +def main(args): + mindietorch.set_device(args.device) + export(args) + print("Done.") + mindietorch.finalize() + +if __name__ == "__main__": + args = parse_arguments() + main(args) \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/pipeline_stable_audio.patch b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/pipeline_stable_audio.patch new file mode 100644 index 0000000000000000000000000000000000000000..e82f24496ab73ac8038a32a3fa70785d0a195bc8 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/pipeline_stable_audio.patch @@ -0,0 +1,19 @@ +--- pipeline_stable_audio.py 2024-09-24 21:57:27.340788500 +0800 ++++ pipeline_stable_audio_patch.py 2024-09-24 22:01:08.883637400 +0800 +@@ -702,14 +702,14 @@ + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual +- noise_pred = self.transformer( ++ noise_pred = self.transformer.forward( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, +- )[0] ++ ) + + # perform guidance + if do_classifier_free_guidance: diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_brownian_interval.patch b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_brownian_interval.patch new file mode 100644 index 0000000000000000000000000000000000000000..d9d94e58016f1ae0d0ca7a347759abe2d25907e7 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_brownian_interval.patch @@ -0,0 +1,11 @@ +--- brownian_interval.py 2024-09-10 20:28:33.814433800 +0800 ++++ brownian_interval_patch.py 2024-09-10 20:59:44.334826200 +0800 +@@ -28,8 +28,8 @@ + + + def _randn(size, dtype, device, seed): +- generator = torch.Generator(device).manual_seed(int(seed)) +- return torch.randn(size, dtype=dtype, device=device, generator=generator) ++ torch.manual_seed(int(seed)) ++ return torch.randn(size, dtype=dtype, device="cpu").to(device) + diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_brownian_interval_patch.py b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_brownian_interval_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb3b09711929a24c841cfd9591c794c52fcd288 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_brownian_interval_patch.py @@ -0,0 +1,14 @@ +import os +import torchsde + + +def main(): + torchsde_path = torchsde.__path__ + torchsde_version = torchsde.__version__ + + assert torchsde_version is not '0.2.6', "expectation torchsde_version==0.2.6" + os.system(f'patch -p0 {torchsde_path[0]}/_brownian/brownian_interval.py precision_brownian_interval.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_scheduling_cosine_dpmsolver_multistep.patch b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_scheduling_cosine_dpmsolver_multistep.patch new file mode 100644 index 0000000000000000000000000000000000000000..789545a1caeebd49344989e12cb48a7d4c9d40bc --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_scheduling_cosine_dpmsolver_multistep.patch @@ -0,0 +1,11 @@ +--- scheduling_cosine_dpmsolver_multistep.py 2024-09-12 09:03:22.789896000 +0800 ++++ scheduling_cosine_dpmsolver_multistep_patch.py 2024-09-12 09:04:46.648920700 +0800 +@@ -512,7 +512,7 @@ + [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() + ) + self.noise_sampler = BrownianTreeNoiseSampler( +- model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed ++ model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=0 + ) + noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( + model_output.device diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_scheduling_cosine_dpmsolver_multistep_pacth.py b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_scheduling_cosine_dpmsolver_multistep_pacth.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d9aedcd9749179ff36a45ff67391ffaf5d1403 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/precision_scheduling_cosine_dpmsolver_multistep_pacth.py @@ -0,0 +1,14 @@ +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.30.0', "expectation diffusers==0.30.0" + os.system(f'patch -p0 {diffusers_path[0]}/schedulers/scheduling_cosine_dpmsolver_multistep.py precision_scheduling_cosine_dpmsolver_multistep.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/prompts.txt b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..e1c7734ef9c418f15b6c67c338c81f2cb39b1e7e --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/prompts.txt @@ -0,0 +1,3 @@ +Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, evolving, 135BPM. LOOP. +Uplifting acoustic loop. 120 BPM. +Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM. \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/requirements.txt b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6353806d3f71e2b739caf8df040a9e852530ffc2 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/requirements.txt @@ -0,0 +1,6 @@ +torch==2.1.0 +torchsde==0.2.6 +diffusers==0.30.0 +transformers==4.40.0 +soundfile==0.12.1 +torch_npu==2.1.0.post6 \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_open_aie_pipeline.py b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_open_aie_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..acc69b700acaf78dcab1477a7b518e501965136e --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_open_aie_pipeline.py @@ -0,0 +1,183 @@ +import torch +import torch_npu +import mindietorch +import sys +import time +import json +import os +import argparse +import copy +import soundfile as sf +from safetensors.torch import load_file +from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck +from diffusers import StableAudioPipeline +from transformers import T5TokenizerFast +from transformers import T5EncoderModel +from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel +from diffusers.models.transformers.stable_audio_transformer import StableAudioDiTModel +from diffusers.schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler +from typing import Optional + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="The prompts file to guide audio generation.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="", + help="The prompt or prompts to guide what to not include in audio generation.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.", + ) + parser.add_argument( + "--model", + type=str, + default="./stable-audio-open-1.0", + help="The path of stable-audio-open-1.0.", + ) + parser.add_argument( + "--audio_end_in_s", + nargs='+', + default=[10], + help="Audio end index in seconds.", + ) + parser.add_argument( + "--num_waveforms_per_prompt", + type=int, + default=1, + help="The number of waveforms to generate per prompt.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7, + help="A higher guidance scale value encourages the model to generate audio that is closely linked to the text `prompt` at the expense of lower sound quality.", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device id.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result audio files.", + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt&ts models.", + ) + return parser.parse_args() + +class AieDiTModel(): + def __init__( + self, + config:None, + args:None, + ): + super().__init__() + self.config = config + dit_compiled_path = os.path.join(args.output_dir+"/dit_compile.ts") + if os.path.exists(dit_compiled_path): + self.compiled_dit = torch.jit.load(dit_compiled_path).eval() + else: + print("%s have no dit_compile.ts, please run export_ts.py first, program is exiting..."%(args.output_dir)) + sys.exit() + + def forward( + self, + hidden_states: torch.FloatTensor, + timestep: torch.FloatTensor = None, + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, + return_dict: Optional[bool] = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + ): + rotary_embedding = torch.stack([torch.tensor(re,dtype=torch.float16) for re in rotary_embedding]).to("npu") + timestep = torch.tensor(timestep, dtype=torch.float16) + output = self.compiled_dit(hidden_states, timestep, encoder_hidden_states, global_hidden_states,rotary_embedding).to(torch.float16) + return output + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch_npu.npu.set_device(args.device) + torch.manual_seed(1) + latents = torch.randn(1, 64, 1024,dtype=torch.float16,device="cpu").to("npu") + with open(args.model + "/vae/config.json", "r", encoding="utf-8") as reader: + data = reader.read() + json_data = json.loads(data) + init_dict = {key: json_data[key] for key in json_data} + vae = AutoencoderOobleck(**init_dict) + vae.load_state_dict(load_file(args.model + "/vae/diffusion_pytorch_model.safetensors"), strict=False) + + tokenizer = T5TokenizerFast.from_pretrained(args.model + "/tokenizer") + text_encoder = T5EncoderModel.from_pretrained(args.model + "/text_encoder") + projection_model = StableAudioProjectionModel.from_pretrained(args.model + "/projection_model") + audio_dit0 = StableAudioDiTModel.from_pretrained(args.model + "/transformer") + config = copy.deepcopy(audio_dit0.config) + del audio_dit0 + audio_dit = AieDiTModel(config,args) + scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(args.model + "/scheduler") + + npu_stream = torch_npu.npu.Stream() + vae = vae.to("npu").to(torch.float16).eval() + text_encoder = text_encoder.to("npu").to(torch.float16).eval() + projection_model = projection_model.to("npu").to(torch.float16).eval() + + pipe = StableAudioPipeline(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, + projection_model=projection_model, transformer=audio_dit, scheduler=scheduler) + + total_time = 0 + prompts_num = 0 + average_time = 0 + skip = 2 + with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f: + for i, prompt in enumerate(f): + with torch.no_grad(): + npu_stream.synchronize() + audio_end_in_s = float(args.audio_end_in_s[i]) if (len(args.audio_end_in_s) > i) else 10.0 + begin = time.time() + audio = pipe( + prompt=prompt.strip(), + negative_prompt=args.negative_prompt, + num_inference_steps=args.num_inference_steps, + latents=latents, + audio_end_in_s=audio_end_in_s, + num_waveforms_per_prompt=args.num_waveforms_per_prompt, + ).audios + npu_stream.synchronize() + end = time.time() + if i > skip-1: + total_time += end - begin + prompts_num = i+1 + output = audio[0].T.float().cpu().numpy() + sf.write(args.save_dir+"/audio_by_prompt"+str(prompts_num)+".wav", output, pipe.vae.sampling_rate) + if prompts_num>skip: + average_time = total_time/(prompts_num-skip) + else: + print("Infer average time skip first two prompts, make sure prompts.txt has three more prompts") + print(f"Infer average time: {average_time:.3f}s\n") + mindietorch.finalize() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_open_pipeline.py b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_open_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..434860e4ae3c8391902b10a0228508293298c2a5 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_open_pipeline.py @@ -0,0 +1,139 @@ +import torch +import torch_npu +import sys +import time +import json +import os +import argparse +import soundfile as sf +from safetensors.torch import load_file +from diffusers.models.autoencoders.autoencoder_oobleck import AutoencoderOobleck +from diffusers import StableAudioPipeline +from transformers import T5TokenizerFast +from transformers import T5EncoderModel +from diffusers.pipelines.stable_audio.modeling_stable_audio import StableAudioProjectionModel +from diffusers.models.transformers.stable_audio_transformer import StableAudioDiTModel +from diffusers.schedulers.scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="The prompts file to guide audio generation.", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="", + help="The prompt or prompts to guide what to not include in audio generation.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.", + ) + parser.add_argument( + "--stable_audio_open_dir", + type=str, + default="./stable-audio-open-1.0", + help="The path of stable-audio-open-1.0.", + ) + parser.add_argument( + "--audio_end_in_s", + nargs='+', + default=[10], + help="Audio end index in seconds.", + ) + parser.add_argument( + "--num_waveforms_per_prompt", + type=int, + default=1, + help="The number of waveforms to generate per prompt.", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7, + help="A higher guidance scale value encourages the model to generate audio that is closely linked to the text `prompt` at the expense of lower sound quality.", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device id.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result audio files.", + ) + return parser.parse_args() + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch_npu.npu.set_device(args.device) + torch.manual_seed(1) + latents = torch.randn(1, 64, 1024,dtype=torch.float16,device="cpu") + with open(args.stable_audio_open_dir + "/vae/config.json", "r", encoding="utf-8") as reader: + data = reader.read() + json_data = json.loads(data) + init_dict = {key: json_data[key] for key in json_data} + vae = AutoencoderOobleck(**init_dict) + vae.load_state_dict(load_file(args.stable_audio_open_dir + "/vae/diffusion_pytorch_model.safetensors"), strict=False) + + tokenizer = T5TokenizerFast.from_pretrained(args.stable_audio_open_dir + "/tokenizer") + text_encoder = T5EncoderModel.from_pretrained(args.stable_audio_open_dir + "/text_encoder") + projection_model = StableAudioProjectionModel.from_pretrained(args.stable_audio_open_dir + "/projection_model") + audio_dit = StableAudioDiTModel.from_pretrained(args.stable_audio_open_dir + "/transformer") + scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(args.stable_audio_open_dir + "/scheduler") + + npu_stream = torch_npu.npu.Stream() + vae = vae.to("npu").to(torch.float16).eval() + text_encoder = text_encoder.to("npu").to(torch.float16).eval() + projection_model = projection_model.to("npu").to(torch.float16).eval() + audio_dit = audio_dit.to("npu").to(torch.float16).eval() + + pipe = StableAudioPipeline(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, + projection_model=projection_model, transformer=audio_dit, scheduler=scheduler) + pipe.to("npu") + total_time = 0 + prompts_num = 0 + average_time = 0 + skip = 2 + with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f: + for i, prompt in enumerate(f): + with torch.no_grad(): + npu_stream.synchronize() + audio_end_in_s = float(args.audio_end_in_s[i]) if (len(args.audio_end_in_s) > i) else 10.0 + begin = time.time() + audio = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + num_inference_steps=args.num_inference_steps, + latents=latents.to("npu"), + audio_end_in_s=audio_end_in_s, + num_waveforms_per_prompt=args.num_waveforms_per_prompt, + ).audios + npu_stream.synchronize() + end = time.time() + if i > skip-1: + total_time += end - begin + prompts_num = i+1 + output = audio[0].T.float().cpu().numpy() + sf.write(args.save_dir+"/audio_by_prompt"+str(prompts_num)+".wav", output, pipe.vae.sampling_rate) + if prompts_num>skip: + average_time = total_time/(prompts_num-skip) + else: + print("Infer average time skip first two prompts, make sure prompts.txt has three more prompts") + print(f"Infer average time: {average_time:.3f}s\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_transformer.patch b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_transformer.patch new file mode 100644 index 0000000000000000000000000000000000000000..dbe5e44af7934ebbc52d541c60e78148fdef9350 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/diffusers/stable_audio_transformer.patch @@ -0,0 +1,19 @@ +--- stable_audio_transformer.py 2024-09-24 22:13:12.323307700 +0800 ++++ stable_audio_transformer_patch.py 2024-09-24 22:15:08.051444400 +0800 +@@ -356,7 +356,7 @@ + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, +- return_dict: bool = True, ++ return_dict: bool = False, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: +@@ -453,6 +453,6 @@ + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if not return_dict: +- return (hidden_states,) ++ return hidden_states + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/README.md b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6a265d10996f881ac0001f61b4b07b318420f6e7 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/README.md @@ -0,0 +1,161 @@ +# stable-audio-open-1.0模型-stable-audio-tools方式推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + +# 概述 + + [此处获得](https://huggingface.co/stabilityai/stable-audio-open-1.0) + +- 参考实现: + ```bash + # StableAudioOpen1.0 + https://huggingface.co/stabilityai/stable-audio-open-1.0 + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 +Atlas 300I Duo推理卡:支持的卡数为1 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + +# 快速上手 +## 获取源码 +1. 安装依赖。 + ```bash + pip3 install -r requirements.txt + apt-get update + apt-get install libsndfile1 + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + source /usr/local/Ascend/ascend-toolkit/set_env.sh + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + + 执行命令: + ```bash + python3 brownian_interval_patch.py + python3 conditioners_patch.py + python3 pretrained_patch.py + python3 transformer_patch.py + ``` + +## 模型推理 + +1. 模型准备。 + 1. 获取模型权重 + + 可提前下载权重,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # 下载stable-audio-open-1.0权重 + git clone https://huggingface.co/stabilityai/stable-audio-open-1.0 + ``` + + 2. 设置模型权重的路径。 + ```bash + # stable-audio-open-1.0 (执行时下载权重) + model_base="stabilityai/stable-audio-open-1.0" + + # stable-audio-open-1.0 (使用上一步下载的权重) + model_base="./stable-audio-open-1.0" + ``` + + 3. 获取T5模型权重 + + 推理过程中会自动从huggingface下载T5-base的模型权重,若希望以加载本地T5-base模型权重方式进行推理,请将`model_base`路径下的`tokenizer`和`text_encoder`文件夹复制到推理代码的执行路径中。 + + +2. 开始推理验证。 + + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + numactl -C 0-23 python3 stable_audio_open_tools_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --num_inference_steps 100 \ + --seconds_total 10 10 47 \ + --save_dir ./results \ + --seed -1 \ + --device 0 + ``` + + 参数说明: + - --model:模型权重路径。 + - --prompt_file:提示词文件。 + - --num_inference_steps: 语音生成迭代次数。 + - --seconds_total:生成语音的时长,与prompts.txt中的prompt对应,如不输入则默认生成10s。 + - --save_dir:生成语音的存放目录。 + - --seed: 用于固定生成语音的随机种子,默认值-1表示使用随机种子 + - --device:推理设备ID。 + + 执行完成后在`./results`目录下生成推理语音,语音生成顺序与文本中prompt顺序保持一致,并在终端显示推理时间。 + + + +# 模型推理性能&精度 +性能参考下列数据。 + +### Stable-Audio-Open-1.0 + +| 硬件形态 | 迭代次数 | 平均耗时| +| :------: |:----:|:----:| +| Atlas 800I A2(8*32G) | 100 | 7.886s | \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/brownian_interval.patch b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/brownian_interval.patch new file mode 100644 index 0000000000000000000000000000000000000000..168e1e26edf6d8aacf47565cc5276735949a8d0f --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/brownian_interval.patch @@ -0,0 +1,11 @@ +--- brownian_interval.py 2024-09-06 18:31:27.596848500 +0800 ++++ brownian_interval_patch.py 2024-09-06 18:39:44.069857000 +0800 +@@ -226,7 +226,7 @@ + else: + # Don't compute space-time Levy area unless we need to + +- mean = left_diff * W * h_reciprocal ++ mean = left_diff * h_reciprocal * W + var = left_diff * right_diff * h_reciprocal + noise = parent._randn(parent._W_seed) + left_W = mean + math.sqrt(var) * noise diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/brownian_interval_patch.py b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/brownian_interval_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..da4ec84a8738b4aa8c8a2b7fb139dc457374a122 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/brownian_interval_patch.py @@ -0,0 +1,14 @@ +import os +import torchsde + + +def main(): + torchsde_path = torchsde.__path__ + torchsde_version = torchsde.__version__ + + assert torchsde_version is not '0.2.6', "expectation torchsde_version==0.2.6" + os.system(f'patch -p0 {torchsde_path[0]}/_brownian/brownian_interval.py brownian_interval.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/conditioners.patch b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/conditioners.patch new file mode 100644 index 0000000000000000000000000000000000000000..c61a74932a4584ef1cb4dabe31f8c370811f1bf1 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/conditioners.patch @@ -0,0 +1,24 @@ +--- conditioners.py 2024-09-30 15:31:32.480360700 +0800 ++++ conditioners_patch.py 2024-09-30 18:20:43.344830200 +0800 +@@ -280,10 +280,17 @@ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: +- # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) +- # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) +- self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) +- model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) ++ import os ++ tokenizer_path = os.path.join(os.getcwd(), "tokenizer") ++ text_encoder_path = os.path.join(os.getcwd(), "text_encoder") ++ if os.path.exists(tokenizer_path) and os.path.exists(text_encoder_path): ++ print("From local import T5-base . . .") ++ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) ++ model = T5EncoderModel.from_pretrained(text_encoder_path).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) ++ else: ++ print("From HuggingFace download T5-base . . .") ++ self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) ++ model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + finally: + logging.disable(previous_level) + diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/conditioners_patch.py b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/conditioners_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..71db741779f701c76d73093bb735523fe01200d1 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/conditioners_patch.py @@ -0,0 +1,12 @@ +import os +import stable_audio_tools + + +def main(): + stable_audio_tools_path = stable_audio_tools.__path__ + + os.system(f'patch -p0 {stable_audio_tools_path[0]}/models/conditioners.py conditioners.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_brownian_interval.patch b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_brownian_interval.patch new file mode 100644 index 0000000000000000000000000000000000000000..d9d94e58016f1ae0d0ca7a347759abe2d25907e7 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_brownian_interval.patch @@ -0,0 +1,11 @@ +--- brownian_interval.py 2024-09-10 20:28:33.814433800 +0800 ++++ brownian_interval_patch.py 2024-09-10 20:59:44.334826200 +0800 +@@ -28,8 +28,8 @@ + + + def _randn(size, dtype, device, seed): +- generator = torch.Generator(device).manual_seed(int(seed)) +- return torch.randn(size, dtype=dtype, device=device, generator=generator) ++ torch.manual_seed(int(seed)) ++ return torch.randn(size, dtype=dtype, device="cpu").to(device) + diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_brownian_interval_path.py b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_brownian_interval_path.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb3b09711929a24c841cfd9591c794c52fcd288 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_brownian_interval_path.py @@ -0,0 +1,14 @@ +import os +import torchsde + + +def main(): + torchsde_path = torchsde.__path__ + torchsde_version = torchsde.__version__ + + assert torchsde_version is not '0.2.6', "expectation torchsde_version==0.2.6" + os.system(f'patch -p0 {torchsde_path[0]}/_brownian/brownian_interval.py precision_brownian_interval.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_generation.patch b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_generation.patch new file mode 100644 index 0000000000000000000000000000000000000000..b419e35190cfb8d51b19c980243057ea8e5bc8b9 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_generation.patch @@ -0,0 +1,11 @@ +--- generation.py 2024-10-10 19:37:20.875115700 +0800 ++++ generation_patch.py 2024-10-10 19:38:24.088056700 +0800 +@@ -139,7 +139,7 @@ + print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed +- noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) ++ noise = torch.randn([batch_size, model.io_channels, sample_size], device="cpu").to(device) + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_generation_patch.py b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_generation_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..c278ca17e2a0ac2f807c83a1f0e75ccc6dc80747 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/precision_generation_patch.py @@ -0,0 +1,12 @@ +import os +import stable_audio_tools + + +def main(): + stable_audio_tools_path = stable_audio_tools.__path__ + + os.system(f'patch -p0 {stable_audio_tools_path[0]}/inference/generation.py precision_generation.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/pretrained.patch b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/pretrained.patch new file mode 100644 index 0000000000000000000000000000000000000000..f51e6a1d90f1ff875f2fee8a1fde06a21b7f1eb3 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/pretrained.patch @@ -0,0 +1,29 @@ +--- pretrained.py 2024-09-30 15:31:40.672485200 +0800 ++++ pretrained_patch.py 2024-10-07 14:54:18.756960100 +0800 +@@ -1,4 +1,5 @@ + import json ++import os + + from .factory import create_model_from_config + from .utils import load_ckpt_state_dict +@@ -7,7 +8,7 @@ + + def get_pretrained_model(name: str): + +- model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model') ++ model_config_path = os.path.join(name, "model_config.json") + + with open(model_config_path) as f: + model_config = json.load(f) +@@ -15,10 +16,7 @@ + model = create_model_from_config(model_config) + + # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file +- try: +- model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') +- except Exception as e: +- model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') ++ model_ckpt_path = os.path.join(name, "model.safetensors") + + model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/pretrained_patch.py b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/pretrained_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..4abdad47b0b49a4263f674bc5d9a17768602ac66 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/pretrained_patch.py @@ -0,0 +1,12 @@ +import os +import stable_audio_tools + + +def main(): + stable_audio_tools_path = stable_audio_tools.__path__ + + os.system(f'patch -p0 {stable_audio_tools_path[0]}/models/pretrained.py pretrained.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/prompts.txt b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..e1c7734ef9c418f15b6c67c338c81f2cb39b1e7e --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/prompts.txt @@ -0,0 +1,3 @@ +Berlin techno, rave, drum machine, kick, ARP synthesizer, dark, moody, hypnotic, evolving, 135BPM. LOOP. +Uplifting acoustic loop. 120 BPM. +Disco, Driving Drum Machine, Synthesizer, Bass, Piano, Guitars, Instrumental, Clubby, Euphoric, Chicago, New York, 115 BPM. \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/requirements.txt b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..df8e40939ae837bb785620795a23d154ba3fa37d --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/requirements.txt @@ -0,0 +1,5 @@ +torch==2.1.0 +torchaudio==2.1.0 +stable_audio_tools==0.0.16 +transformers==4.40.0 +torch_npu==2.1.0.post6 \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/stable_audio_open_tools_pipeline.py b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/stable_audio_open_tools_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0a136fceb09a826cee214963dd7cc067e9b9267d --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/stable_audio_open_tools_pipeline.py @@ -0,0 +1,120 @@ +import torch +import torch_npu +import time +import os +import argparse +from safetensors.torch import load_file +import torchaudio +from einops import rearrange +from stable_audio_tools import get_pretrained_model +from stable_audio_tools.inference.generation import generate_diffusion_cond + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="The prompts file to guide audio generation.", + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=100, + help="The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference.", + ) + parser.add_argument( + "--model", + type=str, + default="./stable-audio-open-1.0", + help="The path of stable-audio-open-1.0.", + ) + parser.add_argument( + "--seconds_total", + nargs='+', + default=[10], + help="Audio end index in seconds.", + ) + parser.add_argument( + "--device", + type=int, + default=0, + help="NPU device id.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result audio files.", + ) + parser.add_argument( + "--seed", + type=int, + default=-1, + help="The random seed to use for generation, or default -1 to use a random seed.", + ) + return parser.parse_args() + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch_npu.npu.set_device(args.device) + npu_stream = torch_npu.npu.Stream() + + model, model_config = get_pretrained_model(args.model) + sample_rate = model_config["sample_rate"] + sample_size = model_config["sample_size"] + + model = model.to("npu").to(torch.float16).eval() + + conditioning = [{ + "prompt":"", + "seconds_start": 0, + "seconds_total": 0, + }] + total_time = 0 + prompts_num = 0 + average_time = 0 + skip = 2 + with os.fdopen(os.open(args.prompt_file, os.O_RDONLY), "r") as f: + for i, prompt in enumerate(f): + with torch.no_grad(): + conditioning[0]["prompt"] = prompt.strip() + conditioning[0]["seconds_total"] = float(args.seconds_total[i]) if (len(args.seconds_total) > i) else 10.0 + + npu_stream.synchronize() + begin = time.time() + output = generate_diffusion_cond( + model, + steps=args.num_inference_steps, + cfg_scale=7, + conditioning=conditioning, + sample_size=sample_size, + sigma_min=0.3, + sigma_max=500, + sampler_type="dpmpp-3m-sde", + device="npu", + seed=args.seed, + ) + npu_stream.synchronize() + end = time.time() + if i > skip-1: + total_time += end - begin + prompts_num = i+1 + waveform_start = int(conditioning[0]["seconds_start"] * sample_rate) + waveform_end = int(conditioning[0]["seconds_total"] * sample_rate) + output = output[:, :, waveform_start:waveform_end] + output = rearrange(output, "b d n -> d (b n)") + output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1,1).mul(32767).to(torch.int16).cpu() + torchaudio.save(args.save_dir + "/audio_by_prompt" + str(prompts_num) + ".wav", output, sample_rate) + if prompts_num > skip: + average_time = total_time / (prompts_num-skip) + else: + print("Infer average time skip first two prompts, make sure prompts.txt has three more prompts") + print(f"Infer average time: {average_time:.3f}s\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/transformer.patch b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/transformer.patch new file mode 100644 index 0000000000000000000000000000000000000000..9984393d93b7362e82dc2144ec18234fa1468bfa --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/transformer.patch @@ -0,0 +1,93 @@ +--- transformer.py 2024-10-09 15:09:11.860795900 +0800 ++++ transformer_patch.py 2024-10-09 15:15:51.763587200 +0800 +@@ -500,34 +500,74 @@ + out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask) + + else: +- # Fall back to custom implementation + +- if h != kv_h: ++ batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device ++ kv_heads = k.shape[1] ++ # Recommended for multi-query single-key-value attention by Tri Dao ++ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) ++ ++ if heads != kv_heads: + # Repeat interleave kv_heads to match q_heads +- heads_per_kv_head = h // kv_h ++ heads_per_kv_head = heads // kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + +- scale = 1. / (q.shape[-1] ** 0.5) ++ if k.ndim == 3: ++ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) ++ ++ if v.ndim == 3: ++ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + +- kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' ++ causal = self.causal if causal is None else causal + +- dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale ++ if q_len == 1 and causal: ++ causal = False + +- i, j, dtype = *dots.shape[-2:], dots.dtype ++ if mask is not None: ++ assert mask.ndim == 4 ++ mask = mask.expand(batch, heads, q_len, k_len) ++ ++ # handle kv cache - this should be bypassable in updated flash attention 2 ++ ++ if k_len > q_len and causal: ++ causal_mask = self.create_causal_mask(q_len, k_len, device = device) ++ if mask is None: ++ mask = ~causal_mask ++ else: ++ mask = mask & ~causal_mask ++ causal = False + +- mask_value = -torch.finfo(dots.dtype).max ++ # manually handle causal mask, if another mask was given + +- if final_attn_mask is not None: +- dots = dots.masked_fill(~final_attn_mask, mask_value) ++ row_is_entirely_masked = None + +- if causal: +- causal_mask = self.create_causal_mask(i, j, device = device) +- dots = dots.masked_fill(causal_mask, mask_value) ++ if mask is not None and causal: ++ causal_mask = self.create_causal_mask(q_len, k_len, device = device) ++ mask = mask & ~causal_mask ++ ++ # protect against an entire row being masked out ++ ++ row_is_entirely_masked = ~mask.any(dim = -1) ++ mask[..., 0] = mask[..., 0] | row_is_entirely_masked ++ ++ causal = False ++ ++ if mask is not None: ++ mask =~ mask ++ mask = mask.to(torch.bool) ++ import torch_npu ++ out = torch_npu.npu_prompt_flash_attention(q, k, v, ++ atten_mask=mask, ++ input_layout='BNSD', ++ scale_value=q.shape[-1]**-0.5, ++ pre_tokens=65535, ++ next_tokens=65535, ++ num_heads=h, ++ ) + +- attn = F.softmax(dots, dim=-1, dtype=torch.float32) +- attn = attn.type(dtype) ++ # for a row that is entirely masked out, should zero out the output of that row token + +- out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) ++ if row_is_entirely_masked is not None: ++ out = out.masked_fill(row_is_entirely_masked[..., None], 0.) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') diff --git a/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/transformer_patch.py b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/transformer_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..c92517411cc03f23ae82eb27d3bab58caafc9b43 --- /dev/null +++ b/MindIE/MultiModal/StableAudioOpen-1.0/stable-audio-tools/transformer_patch.py @@ -0,0 +1,12 @@ +import os +import stable_audio_tools + + +def main(): + stable_audio_tools_path = stable_audio_tools.__path__ + + os.system(f'patch -p0 {stable_audio_tools_path[0]}/models/transformer.py transformer.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-3/README.md b/MindIE/MultiModal/StableDiffusion-3/README.md new file mode 100644 index 0000000000000000000000000000000000000000..541a02a6570e4a347108d96009c632203ec5333d --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/README.md @@ -0,0 +1,435 @@ +# stable-diffusion3模型-推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + +# 概述 + + SD3 由一组用于潜在扩散的专家管道组成: 在第一步中,使用基础模型生成(噪声)潜伏, 然后使用专门用于最终降噪步骤的细化模型[此处获得](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/) + +- 参考实现: + ```bash + # StableDiffusion3 + https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1或2 +Atlas 300I Duo推理卡:支持的卡数为1,可双芯并行 + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | prompt | 1 x 77 | INT64| ND| + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ----------- | + | output1 | 1 x 3 x 1024 x 1024 | FLOAT32 | NCHW | + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ----- | ----- |-----| + | Python | 3.10.2 | - | + | torch | 2.1.0 | - | + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + +# 快速上手 +## 获取源码 +1. 安装依赖。 + ```bash + pip3 install -r requirements.txt + + # 若要使用hpsv2验证精度,则还需要按照以下步骤安装hpsv2 + git clone https://github.com/tgxs002/HPSv2.git + cd HPSv2 + pip3 install -e . + ``` + - 注意:当前sd3推理暂不支持mindie与torch_npu混用,请确保实际推理环境中没有安装torch_npu + +2. 安装mindie包 + + ```bash + # 安装mindie + source /usr/local/Ascend/ascend-toolkit/set_env.sh + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改(可选) +(1)若需要开启DiTCache、序列压缩等优化,需要执行以下代码修改操作: +- 若环境没有patch工具,请自行安装: + ```bash + apt update + apt install patch + ``` +- 执行命令: + ```bash + python3 attention_patch.py + python3 attention_processor_patch.py + python3 transformer_sd3_patch.py + ``` + +## 模型推理 + +1. 模型转换。 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 0. 获取权重(可选) + + 可提前下载权重,放到代码同级目录下,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # 下载sd3权重 + git clone https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers + ``` + + 1. 导出pt模型并进行编译。 + (1) 设置模型权重的路径 + ```bash + # sd3 (执行时下载权重) + model_base="stabilityai/stable-diffusion-3-medium-diffusers" + + # sd3 (使用上一步下载的权重) + model_base="./stable-diffusion-3-medium-diffusers" + ``` + (2) 创建文件夹./models存放导出的模型 + ```bash + mkdir ./models + ``` + (3) 执行命令查看芯片名称($\{chip\_name\})。 + + ``` + npu-smi info + ``` + + (4) 执行export命令 + + ```bash + # Atlas 800I A2,非并行,未加DiTCache优化 + python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend${chip_name} --device_type A2 --device 0 + # Atlas 800I A2,非并行。开启DiTCache优化 + python3 export_model.py --model ${model_base} --output_dir ./models --batch_size 1 --soc Ascend${chip_name} --device_type A2 --device 0 --use_cache + + # Atlas 300I Duo,并行 + python3 export_model.py --model ${model_base} --output_dir ./models --parallel --batch_size 1 --soc Ascend${chip_name} --device_type Duo --device 0 + ``` + 参数说明: + - --model:模型权重路径 + - --output_dir: 存放导出模型的路径 + - --parallel: 【可选】导出适用于并行方案的模型 + - --batch_size: 设置batch_size, 默认值为1, 当前仅支持batch_size=1的场景 + - --soc:处理器型号。 + - --device_type: 设备形态,当前支持A2、Duo两种形态。 + - --device:推理设备ID + - --use_cache:开启DiTCache优化,不配置则不开启 + 注意:trace+compile耗时较长且占用较多的CPU资源,请勿在执行export命令时运行其他占用CPU内存的任务,避免程序意外退出。 + +2. 开始推理验证。 + + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + # 不使用DiTCache,单卡推理,适用Atlas 800I A2场景 + numactl -C 0-23 python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 不使用DiTCache,使用双卡并行推理,适用Atlas 300I Duo场景 + numactl -C 0-23 python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0,1 \ + --save_dir ./results_parallel \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用DiTCache,单卡推理,适用Atlas 800I A2场景 + numactl -C 0-23 python3 stable_diffusion3_pipeline_cache.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --use_cache + ``` + + 参数说明: + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。 + - --save_dir:生成图片的存放目录。 + - --batch_size:模型batch size。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + - --height:生成图像高度,当前只支持1024 + - --width:生成图像宽度,当前只支持1024 + - --use_cache:开启DiTCache优化,不配置则不开启 + + 非并行策略,执行完成后在`./results`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 并行策略,同时使用双卡并行策略,执行完成后在`./results_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + 注意:当前MindIE-Torch和torch_npu的synchronizing stream不兼容,为避免出错,建议在运行推理前先卸载torch_npu。 + +## 精度验证 + + 由于生成的图片存在随机性,提供两种精度验证方法: + 1. CLIP-score(文图匹配度量):评估图片和输入文本的相关性,分数的取值范围为[-1, 1],越高越好。使用Parti数据集进行验证。 + 2. HPSv2(图片美学度量):评估生成图片的人类偏好评分,分数的取值范围为[0, 1],越高越好。使用HPSv2数据集进行验证 + + 注意,由于要生成的图片数量较多,进行完整的精度验证需要耗费很长的时间。 + + 1. 下载Parti数据集和hpsv2数据集 + + ```bash + # 下载Parti数据集 + wget https://raw.githubusercontent.com/google-research/parti/main/PartiPrompts.tsv --no-check-certificate + ``` + hpsv2数据集下载链接:https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/hpsv2_benchmark_prompts.json + + 2. 下载模型权重 + + ```bash + # Clip Score和HPSv2均需要使用的权重 + # 安装git-lfs + apt install git-lfs + git lfs install + + # Clip Score权重 + git clone https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K + + # HPSv2权重 + wget https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt --no-check-certificate + ``` + 也可手动下载[权重](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/open_clip_pytorch_model.bin) + 将权重放到`CLIP-ViT-H-14-laion2B-s32B-b79K`目录下,手动下载[HPSv2权重](https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt)放到当前路径 + + 3. 使用推理脚本读取Parti数据集,生成图片 + + ```bash + # 不使用并行 + python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用DitCache + python3 stable_diffusion3_pipeline_cache.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --use_cache + + # 使用双卡并行策略 + python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0,1 \ + --save_dir ./results_PartiPrompts_parallel \ + --steps 28 \ + --output_dir ./models \ + --use_cache \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + ``` + + 参数说明: + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - --num_images_per_prompt: 每个prompt生成的图片数量。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - --max_num_prompts:限制prompt数量为前X个,0表示不限制。 + - --save_dir:生成图片的存放目录。 + - --batch_size:模型batch size。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + + 不使用并行策略,执行完成后在`./results_PartiPrompts`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 使用双卡并行策略,执行完成后在`./results_PartiPrompts_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + +4. 使用推理脚本读取hpsv2数据集,生成图片 + + ```bash + # 不使用并行 + python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file_type hpsv2 \ + --num_images_per_prompt 1 \ + --info_file_save_path ./image_info_hpsv2.json \ + --device 0 \ + --save_dir ./results_hpsv2 \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用DitCache + python3 stable_diffusion3_pipeline_cache.py \ + --model ${model_base} \ + --prompt_file_type hpsv2 \ + --num_images_per_prompt 1 \ + --info_file_save_path ./image_info_hpsv2.json \ + --device 0 \ + --save_dir ./results_hpsv2 \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --use_cache + + # 使用双卡并行策略 + python3 stable_diffusion3_pipeline.py \ + --model ${model_base} \ + --prompt_file_type hpsv2 \ + --num_images_per_prompt 1 \ + --info_file_save_path ./image_info_hpsv2.json \ + --device 0,1 \ + --save_dir ./results_hpsv2_parallel \ + --steps 28 \ + --output_dir ./models \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + ``` + 参数说明: + - --info_file_save_path:生成图片信息的json文件路径。 + + 不使用并行策略,执行完成后在`./results_hpsv2`目录下生成推理图片,在当前目录生成一个`image_info_hpsv2.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 使用双卡并行策略,执行完成后在`./results_hpsv2_parallel`目录下生成推理图片,在当前目录生成一个`image_info_hpsv2.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + +5. 计算精度指标 + 1. CLIP-score + ```bash + python3 clip_score.py \ + --device=cpu \ + --image_info="image_info.json" \ + --model_name="ViT-H-14" \ + --model_weights_path="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --device: 推理设备,默认为"cpu",如果是cuda设备可设置为"cuda"。 + - --image_info: 上一步生成的`image_info.json`文件。 + - --model_name: Clip模型名称。 + - --model_weights_path: Clip模型权重文件路径。 + + clip_score.py脚本可参考[SDXL](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/clip_score.py),执行完成后会在屏幕打印出精度计算结果。 + + 2. HPSv2 + ```bash + python3 hpsv2_score.py \ + --image_info="image_info_hpsv2.json" \ + --HPSv2_checkpoint="./HPS_v2_compressed.pt" \ + --clip_checkpoint="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --image_info: 上一步生成的`image_info_hpsv2.json`文件。 + - --HPSv2_checkpoint: HPSv2模型权重文件路径。 + - --clip_checkpointh: Clip模型权重文件路径。 + + hpsv2_score.py脚本可参考[SDXL](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/MindIE-Torch/built-in/foundation/stable_diffusion_xl/hpsv2_score.py),执行完成后会在屏幕打印出精度计算结果。 + +# 模型推理性能&精度 + +调用ACL接口推理计算,性能参考下列数据。 + +### StableDiffusion3 +| 硬件形态 | cpu规格 | batch size | 迭代次数 | 优化手段 | 平均耗时 | 精度 | +| :------: | :------: | :------: |:----:| :------: |:-----:|:----------------:| +| Atlas 800I A2(8*32G) | 64核(arm) | 1 | 28 | w/o UnetCache | 6.15s | clip score 0.380 | + +性能测试需要独占npu和cpu \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-3/attention.patch b/MindIE/MultiModal/StableDiffusion-3/attention.patch new file mode 100644 index 0000000000000000000000000000000000000000..ce183bd9141104a16b4e2e14173b1782f3b10bcb --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/attention.patch @@ -0,0 +1,19 @@ +--- attention.py 2024-09-04 09:22:31.768000000 +0000 ++++ attention.py 2024-09-04 09:17:12.680000000 +0000 +@@ -100,7 +100,7 @@ + processing of `context` conditions. + """ + +- def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): ++ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, layer_idx=0): + super().__init__() + + self.context_pre_only = context_pre_only +@@ -134,6 +134,7 @@ + context_pre_only=context_pre_only, + bias=True, + processor=processor, ++ layer_idx=layer_idx + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-3/attention_patch.py b/MindIE/MultiModal/StableDiffusion-3/attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..2d289cb99ed59c8ea3ed16db71c9333b619a44a6 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/attention_patch.py @@ -0,0 +1,32 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention.py", + "attention.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-3/attention_processor.patch b/MindIE/MultiModal/StableDiffusion-3/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..864158fdaadaf4f6021c3c2c72ed96140b5cc75f --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/attention_processor.patch @@ -0,0 +1,86 @@ +--- attention_processor.py 2024-09-04 09:22:16.048000000 +0000 ++++ attention_processor.py 2024-09-10 06:38:37.076000000 +0000 +@@ -115,6 +115,7 @@ + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, ++ layer_idx: int = 0, + out_dim: int = None, + context_pre_only=None, + ): +@@ -132,13 +133,14 @@ + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only ++ self.layer_idx = layer_idx + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk +- self.scale = dim_head**-0.5 if self.scale_qk else 1.0 ++ self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation +@@ -561,6 +563,7 @@ + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, ++ layer_idx=self.layer_idx, + **cross_attention_kwargs, + ) + +@@ -1095,6 +1098,7 @@ + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, ++ layer_idx=0, + *args, + **kwargs, + ) -> torch.FloatTensor: +@@ -1112,7 +1116,19 @@ + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. ++ # use todo + query = attn.to_q(hidden_states) ++ if layer_idx <= 11: ++ hidden_dim = hidden_states.shape[-1] ++ cur_h = int(math.sqrt(hidden_states.shape[1])) ++ cur_w = cur_h ++ hidden_states = hidden_states.transpose(1, 2).view(batch_size, hidden_dim, cur_h, cur_w) ++ new_h = int(cur_h / 2.2) ++ new_w = int(cur_w / 2.2) ++ item = F.interpolate(hidden_states, size=(new_h, new_w), ++ mode='bilinear') ++ item = item.permute(0, 2, 3, 1) ++ hidden_states = item.reshape(batch_size, new_h * new_w, -1) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + +@@ -1128,14 +1144,16 @@ + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads +- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) +- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) +- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ++ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ++ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ++ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ++ scale_factor = 1 / math.sqrt(query.size(-1)) ++ ++ import mindietorch ++ hidden_states = torch.ops.aie.flash_attention(query, key, value, attn.heads, attention_mask, None, scale_factor, ++ "BNSD", "PFA").transpose(1, 2) + +- hidden_states = hidden_states = F.scaled_dot_product_attention( +- query, key, value, dropout_p=0.0, is_causal=False +- ) +- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ++ hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Split the attention outputs. diff --git a/MindIE/MultiModal/StableDiffusion-3/attention_processor_patch.py b/MindIE/MultiModal/StableDiffusion-3/attention_processor_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..97585e6af71fd0df9f9af7ba99eff923e3352b15 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/attention_processor_patch.py @@ -0,0 +1,32 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/attention_processor.py", + "attention_processor.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-3/background_runtime.py b/MindIE/MultiModal/StableDiffusion-3/background_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..482d82309cb1dac6e75a4aa8839a1af6b46d4784 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/background_runtime.py @@ -0,0 +1,191 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +NUM_LAYERS = 28 + + +@dataclass +class RuntimeIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntime: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfo + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray]) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + + self.sync_pipe.send('') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfo, + device_id: int, + model_path: str, + ) -> None: + # The sub process function + # Create a runtime + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model = torch.jit.load(model_path).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + mindietorch.set_device(device_id) + + # Tell the main function that we are ready + sync_pipe.send('') + + infer_num = 0 + preprocess_time = 0 + infer_time = 0 + forward_time = 0 + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + # Keep looping until recived a 'STOP' + while sync_pipe.recv() != 'STOP': + start = time.time() + hidden_states, encoder_hidden_states, pooled_projections, timestep = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + hidden_states_npu = hidden_states.to(torch.float32).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to(f"npu:{device_id}") + pooled_projections_npu = pooled_projections.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + + preprocess_time += time.time() - start + + start2 = time.time() + with mindietorch.npu.stream(stream): + inf_start = time.time() + output_npu = model( + hidden_states_npu, + encoder_hidden_states_npu, + pooled_projections_npu, + timestep_npu + ) + stream.synchronize() + inf_end = time.time() + + output_cpu = output_npu.to('cpu') + forward_time += inf_end - inf_start + infer_time += time.time() - start2 + + for i, _ in enumerate(output_arrays): + output = output_cpu.numpy() + output_arrays[i][:] = output[:] + + infer_num += 1 + sync_pipe.send('') + + infer_num /= NUM_LAYERS + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfo) -> 'BackgroundRuntime': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntime + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MultiModal/StableDiffusion-3/compile_model.py b/MindIE/MultiModal/StableDiffusion-3/compile_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4001de18cdc15e05587fef41ec3fde754bb9079e --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/compile_model.py @@ -0,0 +1,166 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from dataclasses import dataclass +from typing import List +import mindietorch +from mindietorch import _enums + +# Scheduler coefficient, compute coefficient manually in advance to compile scheduler npu model. For details, see: +# https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +SCHEDULER_SIGMAS = torch.tensor([1.0000, 0.9874, 0.9741, 0.9601, 0.9454, 0.9298, 0.9133, 0.8959, 0.8774, + 0.8577, 0.8367, 0.8143, 0.7904, 0.7647, 0.7371, 0.7073, 0.6751, 0.6402, + 0.6022, 0.5606, 0.5151, 0.4649, 0.4093, 0.3474, 0.2780, 0.1998, 0.1109, + 0.0089, 0.0000], dtype=torch.float32) + + +@dataclass +class CompileParam: + inputs: List[mindietorch.Input] = None + soc_version: str = "" + allow_tensor_replace_int: bool = True + require_full_compilation: bool = True + truncate_long_and_double: bool = True + min_block_size: int = 1 + + +def common_compile(model, compiled_path, compile_param): + compiled_model = ( + mindietorch.compile(model, + inputs=compile_param.inputs, + allow_tensor_replace_int=compile_param.allow_tensor_replace_int, + require_full_compilation=compile_param.require_full_compilation, + truncate_long_and_double=compile_param.truncate_long_and_double, + min_block_size=compile_param.min_block_size, + soc_version=compile_param.soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_model, compiled_path) + + +class ClipExport(torch.nn.Module): + def __init__(self, clip_model): + super().__init__() + self.clip_model = clip_model + + def forward(self, x, output_hidden_states=True, return_dict=False): + return self.clip_model(x, output_hidden_states=output_hidden_states, return_dict=return_dict) + + +def compile_clip(model, inputs, clip_compiled_path, soc_version): + clip_param = CompileParam(inputs, soc_version, True, False, False) + common_compile(model, clip_compiled_path, clip_param) + + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model, scaling_factor, shift_factor): + super().__init__() + self.vae_model = vae_model + self.scaling_factor = scaling_factor + self.shift_factor = shift_factor + + def forward(self, latents): + latents = (latents / self.scaling_factor) + self.shift_factor + image = self.vae_model.decode(latents, return_dict=False)[0] + return image + + +def compile_vae(model, inputs, vae_compiled_path, soc_version): + vae_param = CompileParam(inputs, soc_version) + common_compile(model, vae_compiled_path, vae_param) + + +class Scheduler(torch.nn.Module): + def __init__(self): + super(Scheduler, self).__init__() + self.sigmas = SCHEDULER_SIGMAS + + def forward( + self, + model_output: torch.FloatTensor, + sample: torch.FloatTensor, + step_index: torch.LongTensor + ): + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + sigma = self.sigmas[step_index] + + sigma_hat = sigma + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + denoised = sample - model_output * sigma + # 2. Convert to an ODE derivative + derivative = (sample - denoised) / sigma_hat + + dt = self.sigmas[step_index + 1] - sigma_hat + prev_sample = sample + derivative * dt + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + return prev_sample + + +def compile_scheduler(model, inputs, scheduler_compiled_path, soc_version): + scheduler_param = CompileParam(inputs, soc_version, True, True, False) + common_compile(model, scheduler_compiled_path, scheduler_param) + + +class DiTExport(torch.nn.Module): + def __init__(self, dit_model): + super().__init__() + self.dit_model = dit_model + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + ): + return self.dit_model(hidden_states, encoder_hidden_states, pooled_projections, + timestep, None, False)[0] + + +def compile_dit(model, inputs, dit_compiled_path, soc_version): + dit_param = CompileParam(inputs, soc_version) + common_compile(model, dit_compiled_path, dit_param) + + +class DiTExportCache(torch.nn.Module): + def __init__(self, dit_cache_model): + super().__init__() + self.dit_cache_model = dit_cache_model + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + cache_param, + if_skip: int = 0, + delta_cache: torch.FloatTensor = None, + delta_cache_hidden: torch.FloatTensor = None, + use_cache: bool = True, + ): + return self.dit_cache_model(hidden_states, encoder_hidden_states, pooled_projections, timestep, + cache_param, if_skip, delta_cache, delta_cache_hidden, use_cache, + joint_attention_kwargs=None, return_dict=False) + + +def compile_dit_cache(model, inputs, dit_cache_compiled_path, soc_version): + dit_cache_param = CompileParam(inputs, soc_version, True, False, True) + common_compile(model, dit_cache_compiled_path, dit_cache_param) diff --git a/MindIE/MultiModal/StableDiffusion-3/export_model.py b/MindIE/MultiModal/StableDiffusion-3/export_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ea11642ab6fb01e400a34f75de6c94d0824c24a2 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/export_model.py @@ -0,0 +1,401 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import argparse +from argparse import Namespace + +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers import StableDiffusion3Pipeline +import mindietorch +from compile_model import * + + +def check_owner(path: str): + """ + check the path owner + param: the input path + return: whether the path owner is current user or not + """ + path_stat = os.stat(path) + path_owner, path_gid = path_stat.st_uid, path_stat.st_gid + user_check = path_owner == os.getuid() and path_owner == os.geteuid() + return path_owner == 0 or path_gid in os.getgroups() or user_check + + +def path_check(path: str): + """ + check path + param: path + return: data real path after check + """ + if os.path.islink(path) or path is None: + raise RuntimeError("The path should not be None or a symbolic link file.") + path = os.path.realpath(path) + if not check_owner(path): + raise RuntimeError("The path is not owned by current user or root.") + return path + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-3-medium-diffusers", + help="Path or name of the pre-trained model.", + ) + parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("-steps", "--steps", type=int, default=28, help="steps.") + parser.add_argument("-guid", "--guidance_scale", type=float, default=7.0, help="guidance_scale") + parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") + parser.add_argument("-p", "--parallel", action="store_true", + help="Export the unet of bs=1 for parallel inferencing.") + parser.add_argument("--soc", help="soc_version.") + parser.add_argument("--device_type", choices=["A2", "Duo"], default="A2", help="device type.") + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + parser.add_argument( + "--cache_param", + type=str, + default="1,2,20,10", + help="Steps to use cache data." + ) + return parser.parse_args() + + +def trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path, t5_pt_path): + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + t5_model = sd_pipeline.text_encoder_3 + max_position_embeddings = encoder_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + if not os.path.exists(clip_pt_path): + clip_export = ClipExport(encoder_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + else: + logging.info("clip_pt_path already exists.") + + if not os.path.exists(clip2_pt_path): + clip2_export = ClipExport(encoder_2_model) + torch.jit.trace(clip2_export, dummy_input).save(clip2_pt_path) + else: + logging.info("clip2_pt_path already exists.") + + if not os.path.exists(t5_pt_path): + t5_export = ClipExport(t5_model) + torch.jit.trace(t5_export, dummy_input).save(t5_pt_path) + else: + logging.info("t5_pt_path already exists.") + + +def export_clip(sd_pipeline, args): + print("Exporting the text encoder...") + standard_path = path_check(args.output_dir) + clip_path = os.path.join(standard_path, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + batch_size = args.batch_size + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip2_pt_path = os.path.join(clip_path, f"clip2_bs{batch_size}.pt") + t5_pt_path = os.path.join(clip_path, f"t5_bs{batch_size}.pt") + clip1_compiled_path = os.path.join(clip_path, + f"clip_bs{batch_size}_compile_{args.height}x{args.width}.ts") + clip2_compiled_path = os.path.join(clip_path, + f"clip2_bs{batch_size}_compile_{args.height}x{args.width}.ts") + t5_compiled_path = os.path.join(clip_path, + f"t5_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + encoder_model = sd_pipeline.text_encoder + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path, t5_pt_path) + + # compile + if not os.path.exists(clip1_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_path, args.soc) + else: + logging.info("clip1_compiled_path already exists.") + if not os.path.exists(clip2_compiled_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_path, args.soc) + else: + logging.info("clip2_compiled_path already exists.") + if not os.path.exists(t5_compiled_path): + model = torch.jit.load(t5_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, t5_compiled_path, args.soc) + else: + logging.info("t5_compiled_path already exists.") + + +def export_dit(sd_pipeline, args): + print("Exporting the dit...") + standard_path = path_check(args.output_dir) + dit_path = os.path.join(standard_path, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + + dit_model = sd_pipeline.transformer + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + if not args.parallel: + batch_size = args.batch_size * 2 + else: + batch_size = args.batch_size + sample_size = dit_model.config.sample_size + in_channels = dit_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings * 2 + + dit_pt_path = os.path.join(dit_path, f"dit_bs{batch_size}.pt") + dit_compiled_path = os.path.join(dit_path, + f"dit_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + # trace + if not os.path.exists(dit_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64) + ) + dit = DiTExport(dit_model).eval() + torch.jit.trace(dit, dummy_input).save(dit_pt_path) + else: + logging.info("dit_pt_path already exists.") + + # compile + if not os.path.exists(dit_compiled_path): + model = torch.jit.load(dit_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size * 2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] + compile_dit(model, inputs, dit_compiled_path, args.soc) + else: + logging.info("dit_compiled_path already exists.") + + +def export_vae(sd_pipeline, args): + print("Exporting the image decoder...") + standard_path = path_check(args.output_dir) + vae_path = os.path.join(standard_path, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + batch_size = args.batch_size + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + vae_compiled_path = os.path.join(vae_path, + f"vae_bs{batch_size}_compile_{args.height}x{args.width}.ts") + + vae_model = sd_pipeline.vae + dit_model = sd_pipeline.transformer + scaling_factor = vae_model.config.scaling_factor + shift_factor = vae_model.config.shift_factor + in_channels = vae_model.config.latent_channels + sample_size = dit_model.config.sample_size + + # trace + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32) + vae_export = VaeExport(vae_model, scaling_factor, shift_factor) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + else: + logging.info("vae_pt_path already exists.") + + # compile + if not os.path.exists(vae_compiled_path): + model = torch.jit.load(vae_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_path, args.soc) + else: + logging.info("vae_compiled_path already exists.") + + +def trace_scheduler(sd_pipeline, args, scheduler_pt_path): + batch_size = args.batch_size + if not os.path.exists(scheduler_pt_path): + dummy_input = ( + torch.randn([batch_size, 16, 128, 128], dtype=torch.float32), + torch.randn([batch_size, 16, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64) + ) + scheduler = FlowMatchEulerDiscreteScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(args.steps, device="cpu") + + new_scheduler = Scheduler() + new_scheduler.eval() + torch.jit.trace(new_scheduler, dummy_input).save(scheduler_pt_path) + + +def export_scheduler(sd_pipeline, args): + print("Exporting the scheduler...") + scheduler_path = os.path.join(args.output_dir, "scheduler") + if not os.path.exists(scheduler_path): + os.makedirs(scheduler_path, mode=0o744) + batch_size = args.batch_size + height_size, width_size = args.height // 8, args.width // 8 + scheduler_pt_path = os.path.join(scheduler_path, f"scheduler_bs{batch_size}.pt") + scheduler_compiled_path = os.path.join(scheduler_path, + f"scheduler_bs{batch_size}_compile_{args.height}x{args.width}.ts") + in_channels = 16 + + # trace + trace_scheduler(sd_pipeline, args, scheduler_pt_path) + + # compile + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(scheduler_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64) + ] + compile_scheduler(model, inputs, scheduler_compiled_path, args.soc) + + +def export_dit_cache(sd_pipeline, args, if_skip, flag=""): + print("Exporting the dit_cache...") + cache_param = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_param.split(',') + cache_param[0] = int(cache_list[0]) + cache_param[1] = int(cache_list[1]) + cache_param[2] = int(cache_list[2]) + cache_param[3] = int(cache_list[3]) + dit_path = os.path.join(args.output_dir, "dit") + if not os.path.exists(dit_path): + os.makedirs(dit_path, mode=0o640) + + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + dit_model = sd_pipeline.transformer + if args.parallel or flag == "end": + batch_size = args.batch_size + else: + batch_size = args.batch_size * 2 + sample_size = dit_model.config.sample_size + in_channels = dit_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings * 2 + dit_cache_pt_path = os.path.join(dit_path, f"dit_bs{batch_size}_{if_skip}.pt") + dit_cache_compiled_path = os.path.join(dit_path, + f"dit_bs{batch_size}_{if_skip}_compile_{args.height}x{args.width}.ts") + + # trace + if not os.path.exists(dit_cache_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size * 2], dtype=torch.float32), + torch.ones([batch_size, encoder_hidden_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + cache_param, + torch.tensor([if_skip], dtype=torch.int64), + torch.ones([batch_size, 4096, 1536], dtype=torch.float32), + torch.ones([batch_size, 154, 1536], dtype=torch.float32), + ) + print("dummy_input.shape:") + for ele in dummy_input: + if isinstance(ele, torch.Tensor): + print(ele.shape) + dit = DiTExportCache(dit_model).eval() + torch.jit.trace(dit, dummy_input).save(dit_cache_pt_path) + + # compile + if not os.path.exists(dit_cache_compiled_path): + model = torch.jit.load(dit_cache_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size * 2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((4,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, 4096, 1536), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 154, 1536), + dtype=mindietorch.dtype.FLOAT), + ] + compile_dit_cache(model, inputs, dit_cache_compiled_path, args.soc) + + +def export(args) -> None: + pipeline = StableDiffusion3Pipeline.from_pretrained(args.model).to("cpu") + export_clip(pipeline, args) + if args.use_cache: + export_dit_cache(pipeline, args, 0) + export_dit_cache(pipeline, args, 1) + if args.device_type == "A2": + export_dit_cache(pipeline, args, 0, "end") + export_dit_cache(pipeline, args, 1, "end") + else: + export_dit(pipeline, args) + export_vae(pipeline, args) + export_scheduler(pipeline, args) + + +def main(args): + mindietorch.set_device(args.device) + export(args) + print("Done.") + mindietorch.finalize() + + +if __name__ == "__main__": + args = parse_arguments() + main(args) diff --git a/MindIE/MultiModal/StableDiffusion-3/prompts.txt b/MindIE/MultiModal/StableDiffusion-3/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..a375a0bb63931d0d5da6c6d91df1e14f870f47d0 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/prompts.txt @@ -0,0 +1,16 @@ +Beautiful illustration of The ocean. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Islands in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Seaports in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The waves. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Grassland. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Wheat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Hut Tong. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The boat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Pine trees. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Bamboo. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The temple. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Cloud in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Sun in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Spring. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Lotus. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Snow piles. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-3/requirements.txt b/MindIE/MultiModal/StableDiffusion-3/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..576e9300ea64658afa5fe8099e5c01b2dd63d085 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/requirements.txt @@ -0,0 +1,10 @@ +accelerate==0.31.0 +torch==2.1.0 +torchvision==0.16.0 +ftfy +diffusers==0.29.0 +transformers>=4.41.2 +tensorboard +Jinja2 +peft==0.11.1 +open_clip_torch==2.20.0 \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-3/stable_diffusion3_pipeline.py b/MindIE/MultiModal/StableDiffusion-3/stable_diffusion3_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc7e11581532272308f2690f20c2328a6e9111b --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/stable_diffusion3_pipeline.py @@ -0,0 +1,815 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import csv +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np +import torch +import mindietorch +from diffusers import StableDiffusion3Pipeline +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from background_runtime import BackgroundRuntime, RuntimeIOInfo + +clip_time = 0 +t5_time = 0 +dit_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + batch_size: int, + num_images_per_prompt: int = 1, + max_num_prompts: int = 0 + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file, max_num_prompts) + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file, max_num_prompts) + elif prompt_file_type == 'hpsv2': + self.load_prompts_hpsv2(max_num_prompts) + else: + print("This operation is not supported!") + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + + def load_prompts_hpsv2(self, max_num_prompts: int): + with open('hpsv2_benchmark_prompts.json', 'r') as file: + all_prompts = json.load(file) + count = 0 + for style, prompts in all_prompts.items(): + for prompt in prompts: + count += 1 + if max_num_prompts and count >= max_num_prompts: + break + + if style not in self.catagories: + self.catagories.append(style) + + catagory_id = self.catagories.index(style) + self.prompts.append((prompt, catagory_id)) + + +class AIEStableDiffusion3Pipeline(StableDiffusion3Pipeline): + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0, self.device_1 = args.device + else: + self.device_0 = args.device + self.data = None + + def compile_aie_model(self): + if self.is_init: + return + size = self.args.batch_size + if hasattr(self, 'device_1'): + batch_size = self.args.batch_size + else: + batch_size = self.args.batch_size * 2 + + tail = f"_{self.args.height}x{self.args.width}" + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + t5_compiled_path = os.path.join(self.args.output_dir, f"clip/t5_bs{size}_compile{tail}.ts") + self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() + + dit_compiled_path = os.path.join(self.args.output_dir, f"dit/dit_bs{batch_size}_compile{tail}.ts") + self.compiled_dit_model = torch.jit.load(dit_compiled_path).eval() + + self.use_parallel_inferencing = False + + if hasattr(self, 'device_1'): + sample_size = self.transformer.config.sample_size + in_channels = self.transformer.config.in_channels + encoder_hidden_size_2 = self.text_encoder_2.config.hidden_size + encoder_hidden_size = self.text_encoder.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = self.text_encoder.config.max_position_embeddings * 2 + + runtime_info = RuntimeIOInfo( + input_shapes=[ + (batch_size, in_channels, sample_size, sample_size), + (batch_size, max_position_embeddings, encoder_hidden_size * 2), + (batch_size, encoder_hidden_size), + (1,), + ], + input_dtypes=[np.float32, np.float32, np.float32, np.int64], + output_shapes=[(batch_size, in_channels, sample_size, sample_size)], + output_dtypes=[np.float32] + ) + self.dit_bg = BackgroundRuntime.clone(self.device_1, dit_compiled_path, runtime_info) + self.use_parallel_inferencing = True + + self.is_init = True + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + ): + device = f"npu:{self.device_0}" + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_3( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logging.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + + global t5_time + start = time.time() + prompt_embeds = self.compiled_t5_model(text_input_ids.to(device))[0].to('cpu') + t5_time += (time.time() - start) + + dtype = self.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device='cpu') + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + clip_skip: Optional[int] = None, + clip_model_index: int = 0, + ): + device = f"npu:{self.device_0}" + clip_tokenizers = [self.tokenizer, self.tokenizer_2] + clip_text_encoders = [self.compiled_clip_model, self.compiled_clip_model_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logging.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + + global clip_time + start = time.time() + prompt_embeds = text_encoder(text_input_ids.to(device)) + pooled_prompt_embeds = prompt_embeds[0].to('cpu') + clip_time += (time.time() - start) + + if clip_skip is None: + prompt_embeds = prompt_embeds[2][-2].to('cpu') + else: + prompt_embeds = prompt_embeds[2][-(clip_skip + 2)].to('cpu') + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device='cpu') + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + prompt_3: Union[str, List[str]], + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + clip_skip: Optional[int] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = self._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds( + negative_prompt, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds( + negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = self._get_t5_prompt_embeds( + prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def forward( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + global p1_time, p2_time, p3_time + start = time.time() + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + p1_time += (time.time() - start) + start1 = time.time() + + if self.do_classifier_free_guidance and not self.use_parallel_inferencing: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + else: + prompt_embeds, prompt_embeds_1 = negative_prompt_embeds, prompt_embeds + pooled_prompt_embeds, pooled_prompt_embeds_1 = negative_pooled_prompt_embeds, pooled_prompt_embeds + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + global dit_time + global vae_time + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + if not self.use_parallel_inferencing and self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_npu = t.to(torch.int64)[None].to(f"npu:{self.device_0}") + else: + latent_model_input = latents + timestep = t.to(torch.int64) + self.dit_bg.infer_asyn([ + latent_model_input.numpy(), + prompt_embeds_1.numpy(), + pooled_prompt_embeds_1.numpy(), + timestep[None].numpy().astype(np.int64) + ]) + timestep_npu = timestep[None].to(f"npu:{self.device_0}") + + latent_model_input_npu = latent_model_input.to(f"npu:{self.device_0}") + prompt_embeds_npu = prompt_embeds.to(f"npu:{self.device_0}") + pooled_prompt_embeds_npu = pooled_prompt_embeds.to(f"npu:{self.device_0}") + + start = time.time() + noise_pred = self.compiled_dit_model( + latent_model_input_npu, + prompt_embeds_npu, + pooled_prompt_embeds_npu, + timestep_npu + ).to("cpu") + dit_time += (time.time() - start) + + # perform guidance + if self.do_classifier_free_guidance: + if self.use_parallel_inferencing: + noise_pred_text = torch.from_numpy(self.dit_bg.wait_and_get_outputs()[0]) + else: + noise_pred, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred + self.guidance_scale * (noise_pred_text - noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + p2_time += (time.time() - start1) + start2 = time.time() + + if output_type == "latent": + image = latents + else: + start = time.time() + image = self.compiled_vae_model(latents.to(f"npu:{self.device_0}")).to("cpu") + vae_time += (time.time() - start) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + p3_time += (time.time() - start2) + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-3-medium-diffusers", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti", "hpsv2"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=28, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--scheduler", + choices=["FlowMatchEuler"], + default="FlowMatchEuler", + help="Type of Sampling methods. Default FlowMatchEuler", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_param", + default="1,2,20,10", + type=str, + help="steps to use cache data" + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusion3Pipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + pipe.compile_aie_model() + if isinstance(args.device, list): + mindietorch.set_device(args.device[0]) + else: + mindietorch.set_device(args.device) + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + images = pipe.forward( + prompts, + negative_prompt="", + num_inference_steps=args.steps, + guidance_scale=7.5, + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n") + + if hasattr(pipe, 'device_1'): + if (pipe.dit_bg): + pipe.dit_bg.stop() + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion-3/stable_diffusion3_pipeline_cache.py b/MindIE/MultiModal/StableDiffusion-3/stable_diffusion3_pipeline_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..5f980cb3662cb74a5f93038c2304f9fb48e441bf --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/stable_diffusion3_pipeline_cache.py @@ -0,0 +1,578 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Union +import torch +import mindietorch +from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps +from stable_diffusion3_pipeline import PromptLoader, AIEStableDiffusion3Pipeline + +tgate = 20 +dit_time = 0 +vae_time = 0 +scheduler_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class AIEStableDiffusion3CachePipeline(AIEStableDiffusion3Pipeline): + def compile_aie_model(self): + if self.is_init: + return + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + tail = f"_{self.args.height}x{self.args.width}" + + vae_compiled_path = os.path.join(self.args.output_dir, + f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, + f"scheduler/scheduler_bs{size}_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, + f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + t5_compiled_path = os.path.join(self.args.output_dir, + f"clip/t5_bs{size}_compile{tail}.ts") + self.compiled_t5_model = torch.jit.load(t5_compiled_path).eval() + + dit_cache_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_0_compile{tail}.ts") + self.compiled_dit_cache_model = torch.jit.load(dit_cache_compiled_path).eval() + + if self.args.use_cache: + dit_skip_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{batch_size}_1_compile{tail}.ts") + self.compiled_dit_skip_model = torch.jit.load(dit_skip_compiled_path).eval() + + dit_cache_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_0_compile{tail}.ts") + self.compiled_dit_cache_end_model = torch.jit.load(dit_cache_end_compiled_path).eval() + + dit_skip_end_compiled_path = os.path.join(self.args.output_dir, + f"dit/dit_bs{size}_1_compile{tail}.ts") + self.compiled_dit_skip_end_model = torch.jit.load(dit_skip_end_compiled_path).eval() + + self.is_init = True + + @torch.no_grad() + def dit_infer(self, compiled_model, latent_model_input, prompt_embeds, pooled_prompt_embeds, timestep_npu, + cache_param, skip_flag, delta_cache, delta_cache_hidden): + (noise_pred, delta_cache, delta_cache_hidden) = compiled_model( + latent_model_input.to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + pooled_prompt_embeds.to(f'npu:{self.device_0}'), + timestep_npu, + cache_param.to(f'npu:{self.device_0}'), + skip_flag, + delta_cache.to(f'npu:{self.device_0}'), + delta_cache_hidden.to(f'npu:{self.device_0}'), + ) + noise_pred = noise_pred.to("cpu") + delta_cache = delta_cache.to("cpu") + delta_cache_hidden = delta_cache_hidden.to("cpu") + return (noise_pred, delta_cache, delta_cache_hidden) + + @torch.no_grad() + def forward( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 7.0, + cache_param: torch.LongTensor = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt_3: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + global p1_time, p2_time, p3_time + start = time.time() + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + prompt_3, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_3=prompt_3, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + negative_prompt_3=negative_prompt_3, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + clip_skip=self.clip_skip, + num_images_per_prompt=num_images_per_prompt, + ) + + p1_time += (time.time() - start) + start1 = time.time() + + prompt_embeds_origin = prompt_embeds.clone() + pooled_prompt_embeds_origin = pooled_prompt_embeds.clone() + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + global dit_time + global vae_time + global scheduler_time + + skip_flag_true = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + skip_flag_false = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + + delta_cache = torch.zeros([2, 4096, 1536], dtype=torch.float32) + delta_cache_hidden = torch.zeros([2, 154, 1536], dtype=torch.float32) + + cache_interval = cache_param[1] + step_contrast = cache_param[3] % 2 + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + start = time.time() + timestep_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') + if not self.args.use_cache: + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i < tgate: + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + else: + if i == tgate: + _, delta_cache = delta_cache.chunk(2) + _, delta_cache_hidden = delta_cache_hidden.chunk(2) + latent_model_input = latents + + if i < cache_param[3]: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, delta_cache, + delta_cache_hidden) + else: + if i % cache_interval == step_contrast: + if i < tgate: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_true, + delta_cache, + delta_cache_hidden) + else: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_cache_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_param, + skip_flag_true, + delta_cache, + delta_cache_hidden) + else: + if i < tgate: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_model, + latent_model_input, + prompt_embeds, + pooled_prompt_embeds, + timestep_npu, cache_param, + skip_flag_false, + delta_cache, + delta_cache_hidden) + else: + (noise_pred, delta_cache, delta_cache_hidden) = self.dit_infer( + self.compiled_dit_skip_end_model, + latent_model_input, + prompt_embeds_origin, + pooled_prompt_embeds_origin, + timestep_npu, cache_param, + skip_flag_false, + delta_cache, + delta_cache_hidden) + + dit_time += (time.time() - start) + + # perform guidance + if self.do_classifier_free_guidance and i < tgate: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + start = time.time() + # latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + step_index = torch.tensor(i).long() + latents = self.compiled_scheduler( + noise_pred.to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + step_index[None].to(f'npu:{self.device_0}') + ).to('cpu') + scheduler_time += (time.time() - start) + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + p2_time += time.time() - start1 + start2 = time.time() + + if output_type == "latent": + image = latents + else: + start = time.time() + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to("cpu") + vae_time += time.time() - start + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + p3_time += time.time() - start2 + + if not return_dict: + return (image,) + + return StableDiffusion3PipelineOutput(images=image) + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-3-medium-diffusers", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti", "hpsv2"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=28, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--scheduler", + choices=["FlowMatchEuler"], + default="FlowMatchEuler", + help="Type of Sampling methods. Default FlowMatchEuler", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_param", + default="1,2,20,10", + type=str, + help="steps to use cache data" + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + if isinstance(args.device, list): + mindietorch.set_device(args.device[0]) + else: + mindietorch.set_device(args.device) + pipe = AIEStableDiffusion3CachePipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + pipe.compile_aie_model() + + cache_param = torch.zeros([4], dtype=torch.int64) + cache_list = args.cache_param.split(',') + cache_param[0] = int(cache_list[0]) + cache_param[1] = int(cache_list[1]) + cache_param[2] = int(cache_list[2]) + cache_param[3] = int(cache_list[3]) + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + images = pipe.forward( + prompts, + negative_prompt="", + num_inference_steps=args.steps, + guidance_scale=7.0, + cache_param=cache_param + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + print( + f"[info] infer number: {infer_num - 5}; use time: {use_time:.3f}s\n" + f"average time: {use_time / (infer_num - 5):.3f}s\n" + f"dit time: {dit_time / infer_num:.3f}s\n" + f"scheduler_time time: {scheduler_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n" + ) + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion-3/transformer_sd3.patch b/MindIE/MultiModal/StableDiffusion-3/transformer_sd3.patch new file mode 100644 index 0000000000000000000000000000000000000000..bdbbb9c8671d021ed5b4e1669ab88389d03f5ad4 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/transformer_sd3.patch @@ -0,0 +1,123 @@ +--- transformer_sd3.py 2024-09-04 09:21:58.280000000 +0000 ++++ transformer_sd3.py 2024-09-04 10:01:47.196000000 +0000 +@@ -97,6 +97,7 @@ + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.inner_dim, + context_pre_only=i == num_layers - 1, ++ layer_idx=i + ) + for i in range(self.config.num_layers) + ] +@@ -106,6 +107,8 @@ + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False ++ self.delta_cache = None ++ self.delta_cache_hidden = None + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: +@@ -245,9 +248,14 @@ + encoder_hidden_states: torch.FloatTensor = None, + pooled_projections: torch.FloatTensor = None, + timestep: torch.LongTensor = None, ++ cache_dict: torch.LongTensor = None, ++ if_skip: int = 0, ++ delta_cache: torch.FloatTensor = None, ++ delta_cache_hidden: torch.FloatTensor = None, ++ use_cache: bool = False, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ++ ): + """ + The [`SD3Transformer2DModel`] forward method. + +@@ -281,10 +289,6 @@ + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) +- else: +- logger.warning( +- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." +- ) + + height, width = hidden_states.shape[-2:] + +@@ -292,9 +296,8 @@ + temb = self.time_text_embed(timestep, pooled_projections) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + +- for block in self.transformer_blocks: +- if self.training and self.gradient_checkpointing: +- ++ if self.training and self.gradient_checkpointing: ++ for block in self.transformer_blocks: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: +@@ -312,11 +315,14 @@ + temb, + **ckpt_kwargs, + ) +- +- else: +- encoder_hidden_states, hidden_states = block( +- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb +- ) ++ else: ++ ( ++ (encoder_hidden_states, hidden_states), ++ delta_cache, ++ delta_cache_hidden ++ ) = self.forward_blocks(hidden_states, encoder_hidden_states, temb, ++ use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden) + + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) +@@ -339,6 +345,43 @@ + unscale_lora_layers(self, lora_scale) + + if not return_dict: +- return (output,) ++ return (output, delta_cache, delta_cache_hidden) + + return Transformer2DModelOutput(sample=output) ++ ++ def forward_blocks_range(self, hidden_states, encoder_hidden_states, temb, start_idx, end_idx): ++ for block_idx, block in enumerate(self.transformer_blocks[start_idx: end_idx]): ++ encoder_hidden_states, hidden_states = block( ++ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ++ ) ++ ++ return hidden_states, encoder_hidden_states ++ ++ def forward_blocks(self, hidden_states, encoder_hidden_states, temb, use_cache, if_skip, cache_dict, delta_cache, ++ delta_cache_hidden): ++ if not use_cache: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, len(self.transformer_blocks)) ++ else: ++ # infer [0, cache_start) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ 0, cache_dict[0]) ++ ++ # infer [cache_start, cache_end) ++ cache_end = cache_dict[0] + cache_dict[2] ++ hidden_states_before_cache = hidden_states.clone() ++ encoder_hidden_states_before_cache = encoder_hidden_states.clone() ++ if not if_skip: ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, ++ temb, cache_dict[0], ++ cache_end) ++ delta_cache = hidden_states - hidden_states_before_cache ++ delta_cache_hidden = encoder_hidden_states - encoder_hidden_states_before_cache ++ else: ++ hidden_states = hidden_states_before_cache + delta_cache ++ encoder_hidden_states = encoder_hidden_states_before_cache + delta_cache_hidden ++ ++ # infer [cache_end, len(self.blocks)) ++ hidden_states, encoder_hidden_states = self.forward_blocks_range(hidden_states, encoder_hidden_states, temb, ++ cache_end, len(self.transformer_blocks)) ++ return (encoder_hidden_states, hidden_states), delta_cache, delta_cache_hidden diff --git a/MindIE/MultiModal/StableDiffusion-3/transformer_sd3_patch.py b/MindIE/MultiModal/StableDiffusion-3/transformer_sd3_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..8556cbac2ba1974972fcf3965b78a9f7592b60c8 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-3/transformer_sd3_patch.py @@ -0,0 +1,33 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import logging +import diffusers + + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.29.0', "expectation diffusers==0.29.0" + result = subprocess.run(["patch", "-p0", f"{diffusers_path[0]}/models/transformers/transformer_sd3.py", + "transformer_sd3.patch"], capture_output=True, text=True) + if result.returncode != 0: + logging.error("Patch failed, error message: s%", result.stderr) + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/README.md b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fd613dfb2e76046043c5d49d738196c9446b1d2e --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/README.md @@ -0,0 +1,179 @@ +# stable-diffusionxl-controlnet模型-推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + +# 概述 + + ControlNet是一种神经网络架构,可将控制信息添加到预训练的扩散模型中。作用是通过添加额外控制条件,来引导Stable Diffusion生成图像,从而提升 AI 图像生成的可控性和精度。在使用ControlNet模型之后,Stable Diffusion模型的权重被复制出两个相同的部分,分别是“锁定”副本和“可训练”副本。ControlNet主要在“可训练”副本上施加控制条件,然后将施加控制条件之后的结果和原来SD模型的结果相加获得最终的输出结果。神经架构与“零卷积”(零初始化卷积层)连接,参数从零逐渐增长,确保微调的过程不会受到噪声影响。这样可以使用小批量数据集就能对控制条件进行学习训练,同时不会破坏Stable Diffusion模型原本的能力。 + ControlNet的应用包括:控制人物姿势、线稿上色、画质修复等。 +- 参考实现: + ```bash + # controlnet-canny-sdxl-1.0 + https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0 + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + | ------ | ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0 | - | + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + +# 快速上手 + +## 获取源码 + +1. 安装依赖。 + + ```bash + pip3 install -r requirements.txt + ``` +2. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` +3. 代码修改 + + 执行命令: + + ```bash + python3 stable_diffusion_clip_patch.py + ``` + + ```bash + python3 stable_diffusion_attention_patch.py + ``` + +## 准备数据集 + +1. 获取原始数据集。 + + ControlNet是一个控制预训练图像扩散模型的神经网络,允许输入调节图像,然后使用该调节图像来操控图像生成。调节图像可从官网下载。 + ```bash + wget https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png + ``` + +## 模型推理 + +1. 模型转换。 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 0. 获取权重(可选) + + 可提前下载权重,放到代码同级目录下,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # xl + git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 + # controlnet-canny-sdxl-1.0 + git clone https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0 + # sdxl-vae-fp16-fix + git clone https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + ``` + 1. 导出pt模型并进行编译。(可选) + + 设置模型名称或路径 + ```bash + # xl (执行时下载权重) + model_base="stabilityai/stable-diffusion-xl-base-1.0" + # controlnet-canny-sdxl-1.0 (执行时下载权重) + model_controlnet="diffusers/controlnet-canny-sdxl-1.0" + # sdxl-vae-fp16-fix (执行时下载权重) + model_vae="madebyollin/sdxl-vae-fp16-fix" + + # xl (使用上一步下载的权重) + model_base="./stable-diffusion-xl-base-1.0" + # controlnet-canny-sdxl-1.0 (使用上一步下载的权重) + model_controlnet="./controlnet-canny-sdxl-1.0" + # sdxl-vae-fp16-fix (使用上一步下载的权重) + model_vae="./sdxl-vae-fp16-fix" + ``` + + 执行命令: + ```bash + # 静态模型 + python3 export_ts_controlnet.py --model ${model_base} --controlnet_model ${model_controlnet} --vae_model ${model_vae} --output_dir ./models --batch_size 1 --flag 0 --soc A2 --device 0 + + # 动态分档模型,仅支持1024*1024、512*512两种 + python3 export_ts_controlnet.py --model ${model_base} --controlnet_model ${model_controlnet} --vae_model ${model_vae} --output_dir ./models --batch_size 1 --flag 1 --soc A2 --device 0 + + ``` + + 参数说明: + + - --model:模型权重路径 + - --controlnet_model: controlnet模型权重路径 + - --vae_model: vae模型权重路径 + - --output_dir: ONNX模型输出目录 + - --batch_size: 设置batch_size, 默认值为1,当前仅支持batch_size=1的场景 + - --falg: 设置模型编译方式。默认值为1。值为0表示静态模型,值为1表示动态分档模型。 + - --soc: 默认值为A2,当前仅支持Atlas 800I A2场景。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + + 静态编译场景: + + - ./models/clip/clip_bs{batch_size}.pt, ./models/clip/clip_bs{batch_size}_compile.ts 和 ./models/clip/clip2_bs{batch_size}.pt, ./models/clip/clip2_bs{batch_size}_compile.ts + - ./models/unet/unet_bs{batch_size x 2}.pt, ./models/unet/unet_bs{batch_size x 2}_compile_static.ts + - ./models/vae/vae_bs{batch_size}.pt, ./models/vae/vae_bs{batch_size}_compile_static.ts + - ./models/control/control_bs{batch_size}.pt, ./models/control/control_bs{batch_size}_compile_static.ts + + 动态分档场景: + + - ./models/clip/clip_bs{batch_size}.pt, ./models/clip/clip_bs{batch_size}_compile.ts 和 ./models/clip/clip2_bs{batch_size}.pt, ./models/clip/clip2_bs{batch_size}_compile.ts + - ./models/unet/unet_bs{batch_size x 2}.pt, ./models/unet/unet_bs{batch_size x 2}_compile.ts + - ./models/vae/vae_bs{batch_size}.pt, ./models/vae/vae_bs{batch_size}_compile.ts + - ./models/control/control_bs{batch_size}.pt, ./models/control/control_bs{batch_size}_compile.ts + +2. 开始推理验证。 + + 1. 执行推理脚本。 + + ```bash + python3 stable_diffusionxl_pipeline_controlnet.py \ + --model ${model_base} \ + --controlnet_model ${model_controlnet} \ + --vae_model ${model_vae} \ + --device 0 \ + --save_dir ./results \ + --output_dir ./models \ + --soc A2 \ + --flag 1 \ + --w_h 1024 + ``` + + 参数说明: + + - --model:模型名称或本地模型目录的路径。 + - --controlnet_model: controlnet模型权重路径 + - --vae_model: vae模型权重路径 + - --device:推理设备ID。 + - --save_dir:生成图片的存放目录。 + - --output_dir:存放导出模型的目录。 + - --soc: 默认值为A2,当前仅支持Atlas 800I A2场景。 + - --falg: 设置模型编译方式。默认值为1。值为0表示静态模型,值为1表示动态分档模型。 + - --w_h: image的宽高,设置为1024表示宽高均为1024,设置为512表示宽高均为512。仅支持这两种分辨率。 + + 执行完成后在 `./results`目录下生成推理图片。 diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/attention_processor.patch b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..26f296526adcfaf629f3c47a311b88bb4aa002a2 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/attention_processor.patch @@ -0,0 +1,18 @@ +--- attention_processor.py 2024-02-22 19:06:56.596000000 +0800 ++++ attention_processor.py 2024-02-22 19:07:17.232000000 +0800 +@@ -205,10 +205,11 @@ + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- if processor is None: +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # if processor is None: ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor = AttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/clip.patch b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/clip.patch new file mode 100644 index 0000000000000000000000000000000000000000..e3e4719b66f771ebb660f25151c33d140566c3f3 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/clip.patch @@ -0,0 +1,10 @@ +22a23 +> import numpy as np +760c761,762 +< mask.triu_(1) # zero out the lower diagonal +--- +> # mask.triu_(1) # zero out the lower diagonal +> mask = torch.from_numpy(np.triu(mask.numpy(), 1)) +1324a1327 +> + diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/compile_model.py b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/compile_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7caa00254c8ceb0e4e290d4d8d90e562beed3970 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/compile_model.py @@ -0,0 +1,70 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import mindietorch +from mindietorch import _enums + +def compile_clip(model, inputs, clip_compiled_path, soc_version): + compiled_clip_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + min_block_size = 1, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(compiled_clip_model, clip_compiled_path) + +def compile_vae(model, inputs, vae_compiled_path, soc_version): + compiled_vae_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_vae_model, vae_compiled_path) + + +def compile_unet(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +def compile_control(model, inputs, control_compiled_path, soc_version): + compiled_control_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_control_model, control_compiled_path) \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/export_ts_controlnet.py b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/export_ts_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0fa46dccf43f1db1e5a28bad2cb22c63720e42dc --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/export_ts_controlnet.py @@ -0,0 +1,471 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from argparse import Namespace + +import torch +import mindietorch +import torch.nn as nn + +from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL +import torch +from compile_model import compile_clip, compile_vae, compile_unet, compile_control + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "-control", + "--controlnet_model", + type=str, + default="./controlnet-canny-sdxl-1.0", + help="Path or name of the pre-trained controlnet model.", + ) + parser.add_argument( + "-vae", + "--vae_model", + type=str, + default="./sdxl-vae-fp16-fix", + help="Path or name of the pre-trained vae model.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-cond_scale", + "--conditioning_scale", + type=float, + default=0.5, + help="conditioning_scale" + ) + parser.add_argument( + "--flag", + type=int, + default=1, + choices=[0, 1], + help="0 is static; 1 is dynamic rankl.", + ) + parser.add_argument( + "--soc", + choices=["A2"], + default="A2", + help="soc_version.", + ) + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device", + ) + + return parser.parse_args() + + +class ClipExport(torch.nn.Module): + def __init__(self, clip_model): + super().__init__() + self.clip_model = clip_model + + def forward(self, x, output_hidden_states=True, return_dict=False): + return self.clip_model(x, output_hidden_states=output_hidden_states, return_dict=return_dict) + + +def trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path): + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + if not os.path.exists(clip_pt_path): + clip_export = ClipExport(encoder_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + if not os.path.exists(clip2_pt_path): + clip_export = ClipExport(encoder_2_model) + torch.jit.trace(clip_export, dummy_input).save(clip2_pt_path) + + +def export_clip(sd_pipeline: StableDiffusionXLControlNetPipeline, save_dir: str, + batch_size: int, flag: int, + soc_version: str) -> None: + print("Exporting the text encoder...") + clip_path = os.path.join(save_dir, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip2_pt_path = os.path.join(clip_path, f"clip2_bs{batch_size}.pt") + clip1_compiled_path = os.path.join(clip_path, f"clip_bs{batch_size}_compile.ts") + clip2_compiled_path = os.path.join(clip_path, f"clip2_bs{batch_size}_compile.ts") + + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + + max_position_embeddings = encoder_model.config.max_position_embeddings + print(f'max_position_embeddings: {max_position_embeddings}') + + # trace + trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path) + + # compile + if flag == 0 or flag == 1: + if not os.path.exists(clip1_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_path, soc_version) + if not os.path.exists(clip2_compiled_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class UnetExportInit(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward( + self, + latent_model_input, + timestep, + encoder_hidden_states, + down_block_res_samples0, + down_block_res_samples1, + down_block_res_samples2, + down_block_res_samples3, + down_block_res_samples4, + down_block_res_samples5, + down_block_res_samples6, + down_block_res_samples7, + down_block_res_samples8, + mid_block_res_sample, + text_embeds, + time_ids + ): + return self.unet_model(latent_model_input, timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=[down_block_res_samples0, down_block_res_samples1, + down_block_res_samples2, + down_block_res_samples3, down_block_res_samples4, + down_block_res_samples5, down_block_res_samples6, + down_block_res_samples7, down_block_res_samples8], + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids})[0] + + +def export_unet(sd_pipeline: StableDiffusionXLControlNetPipeline, save_dir: str, batch_size: int, flag, + soc_version: str) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile_static.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile.ts") + + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.float32), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([2, 320, 128, 128], dtype=torch.float32), + torch.ones([2, 320, 128, 128], dtype=torch.float32), + torch.ones([2, 320, 128, 128], dtype=torch.float32), + torch.ones([2, 320, 64, 64], dtype=torch.float32), + torch.ones([2, 640, 64, 64], dtype=torch.float32), + torch.ones([2, 640, 64, 64], dtype=torch.float32), + torch.ones([2, 640, 32, 32], dtype=torch.float32), + torch.ones([2, 1280, 32, 32], dtype=torch.float32), + torch.ones([2, 1280, 32, 32], dtype=torch.float32), + torch.ones([2, 1280, 32, 32], dtype=torch.float32), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32) + ) + + unet = UnetExportInit(unet_model) + unet.eval() + + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + # compile + if flag == 0: + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + # 静态shape + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 128, 128), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 128, 128), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 128, 128), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT)] + compile_unet(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(unet_compiled_path): + # 动态分档 + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + inputs_gear_1 = [ + mindietorch.Input((batch_size, in_channels, 1024 // 8, 1024 // 8,), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 128, 128), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 128, 128), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 128, 128), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_1) + inputs_gear_2 = [ + mindietorch.Input((batch_size, in_channels, 512 // 8, 512 // 8,), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 64, 64), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 320, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 32, 32), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 640, 16, 16), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 16, 16), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 16, 16), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((2, 1280, 16, 16), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_2) + + compile_unet(model, inputs, unet_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model): + super().__init__() + self.vae_model = vae_model + + def forward(self, latents): + return self.vae_model.decoder(latents) + + +def export_vae(sd_pipeline: StableDiffusionXLControlNetPipeline, save_dir: str, batch_size: int, flag: int, + vaepath: str, soc_version: str) -> None: + print("Exporting the image decoder...") + + vae_path = os.path.join(save_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + vae_compiled_static_path = os.path.join(vae_path, f"vae_bs{batch_size}_compile_static.ts") + vae_compiled_path = os.path.join(vae_path, f"vae_bs{batch_size}_compile.ts") + + vae_model = AutoencoderKL.from_pretrained(vaepath) + unet_model = sd_pipeline.unet + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.out_channels + # trace + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32) + vae_export = VaeExport(vae_model) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + # compile + if flag == 0: + if not os.path.exists(vae_compiled_static_path): + model = torch.jit.load(vae_pt_path).eval() + # 静态shape + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(vae_compiled_path): + # 动态分档 + model = torch.jit.load(vae_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((batch_size, in_channels, + 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((batch_size, in_channels, + 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_2) + + compile_vae(model, inputs, vae_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class ControlNetExport(torch.nn.Module): + def __init__(self, controlnet, conditioning_scale): + super().__init__() + self.controlnet = controlnet + self.conditioning_scale = conditioning_scale # 0.5 + + def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond, text_embeds, time_ids): + return self.controlnet(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_cond, conditioning_scale=self.conditioning_scale, + guess_mode=False, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + return_dict=False) + + +def export_control(model, save_path, controlnet_path, conditioning_scale, flag, soc_version: str, + batch_size: int): + print("Exporting the controlnet...") + control_path = os.path.join(save_path, "control") + if not os.path.exists(control_path): + os.makedirs(control_path, mode=0o744) + control_pt_path = os.path.join(control_path, f"control_bs{batch_size}.pt") + control_compiled_static_path = os.path.join(control_path, f"control_bs{batch_size}_compile_static.ts") + control_compiled_path = os.path.join(control_path, f"control_bs{batch_size}_compile.ts") + controlnet = ControlNetModel.from_pretrained(controlnet_path) + + # trace + if not os.path.exists(control_pt_path): + dummy_input = ( + torch.ones([2, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.float32), + torch.ones([2, 77, 2048], dtype=torch.float32), + torch.ones([2, 3, 1024, 1024], dtype=torch.float32), + torch.ones([2, 1280], dtype=torch.float32), + torch.ones([2, 6], dtype=torch.float32), + ) + model_export = ControlNetExport(controlnet, conditioning_scale).eval() + torch.jit.trace(model_export, dummy_input).save(control_pt_path) + + # compile + if flag == 0: + if not os.path.exists(control_compiled_static_path): + model = torch.jit.load(control_pt_path).eval() + # 静态shape + inputs = [mindietorch.Input(([2, 4, 128, 128]), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(([2, 77, 2048]), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 3, 1024, 1024], dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 1280], dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 6], dtype=mindietorch.dtype.FLOAT)] + compile_control(model, inputs, control_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(control_compiled_path): + # 动态分档 + model = torch.jit.load(control_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input(([2, 4, 128, 128]), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(([2, 77, 2048]), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 3, 1024, 1024], dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 1280], dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 6], dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input(([2, 4, 64, 64]), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(([2, 77, 2048]), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 3, 512, 512], dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 1280], dtype=mindietorch.dtype.FLOAT), + mindietorch.Input([2, 6], dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_2) + + compile_control(model, inputs, control_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +def export(model_path: str, controlnet_path: str, vae_path: str, save_dir: str, batch_size: int, + conditioning_scale: float, flag: int, soc: str) -> None: + if soc == "A2": + soc_version = "Ascend910B4" + else: + print("unsupport soc_version, please check!") + return + + controlnet = ControlNetModel.from_pretrained(controlnet_path) + vae = AutoencoderKL.from_pretrained(vae_path) + + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(model_path, + controlnet=controlnet, vae=vae).to("cpu") + + export_clip(pipeline, save_dir, batch_size, flag, soc_version) + export_vae(pipeline, save_dir, batch_size, flag, vae_path, soc_version) + export_unet(pipeline, save_dir, batch_size * 2, flag, soc_version) + # controlnet功能,只支持800IA2单卡不带unetcache + export_control(pipeline, save_dir, controlnet_path, conditioning_scale, flag, soc_version, batch_size) + + +def main(): + args = parse_arguments() + mindietorch.set_device(args.device) + export(args.model, args.controlnet_model, args.vae_model, args.output_dir, + args.batch_size, args.conditioning_scale, + args.flag, args.soc) + print("Done.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/requirements.txt b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ee96f049dd2ea00fcd7255fe050df602386e12c4 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/requirements.txt @@ -0,0 +1,5 @@ +setuptools==57.5.0 +torch==2.1.0 +diffusers==0.26.3 +transformers==4.26.1 +open_clip_torch==2.20.0 \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusion_attention_patch.py b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusion_attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c1b369bb23389b6abc12c710f63e5b986b836d --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusion_attention_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusion_clip_patch.py b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusion_clip_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae2a2c4dca8d774982e69323343eb48be008e43 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusion_clip_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import transformers + + +def main(): + transformers_path = transformers.__path__ + transformers_version = transformers.__version__ + + assert transformers_version is not '4.26.1', "expectation transformers==4.26.1" + os.system(f'patch -p0 {transformers_path[0]}/models/clip/modeling_clip.py clip.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusionxl_pipeline_controlnet.py b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusionxl_pipeline_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a823cd6d0479b10976a20313d11fc1f40fd9c961 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-ControlNet/stable_diffusionxl_pipeline_controlnet.py @@ -0,0 +1,1478 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np + +import torch +import mindietorch +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor + +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +import PIL.Image + +from mindietorch import _enums +import torch.nn.functional as F + +from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) + +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) + +from diffusers.utils import load_image +from PIL import Image +import cv2 + +clip_time = 0 +unet_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class AIEStableDiffusionXLPipeline(StableDiffusionXLControlNetPipeline): + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0 = self.args.device[0] + else: + self.device_0 = args.device + + def compile_aie_model(self): + if self.is_init: + return + + in_channels = self.unet.config.out_channels + sample_size = self.unet.config.sample_size + encoder_hidden_size_2 = self.text_encoder_2.config.hidden_size + encoder_hidden_size = self.text_encoder.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = self.text_encoder.config.max_position_embeddings + + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + + if self.args.flag == 0 or self.args.flag == 1: + tail = "" + if self.args.flag == 0: + tail = "_static" + elif self.args.flag == 1: + tail = "" + else: + print("This operation is not supported!") + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + control_compiled_path = os.path.join(self.args.output_dir, f"control/control_bs{size}_compile{tail}.ts") + self.compiled_control_model = torch.jit.load(control_compiled_path).eval() + + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_compile{tail}.ts") + self.compiled_unet_model = torch.jit.load(unet_compiled_path).eval() + + else: + print("This operation is not supported!") + + self.is_init = True + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. \ + Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, \ + but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. \ + Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. \ + Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, \ + but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` \ + has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} \ + elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, \ + list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. \ + image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device="cpu", dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, \ + but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. \ + Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + + text_encoders = ( + [self.compiled_clip_model, self.compiled_clip_model_2] if self.compiled_clip_model is not None + else [self.compiled_clip_model_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + global clip_time + start = time.time() + prompt_embeds_npu = text_encoder(text_input_ids.to(f'npu:{self.device_0}')) + pooled_prompt_embeds = prompt_embeds_npu[0].to('cpu') + clip_time += time.time() - start + if clip_skip is None: + prompt_embeds = prompt_embeds_npu[2][-2].to('cpu') + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds_npu.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds_npu = text_encoder(uncond_input.input_ids.to(f'npu:{self.device_0}')) + negative_pooled_prompt_embeds = negative_prompt_embeds_npu[0].to('cpu') + # We are only ALWAYS interested in the pooled output of the final text encoder + + negative_prompt_embeds = negative_prompt_embeds_npu[2][-2].to('cpu') + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + # prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device="cpu") + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device="cpu") + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def ascendie_infer( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = "cpu" + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, # [1, 77, 2048] + negative_prompt_embeds, # [1, 77, 2048] + pooled_prompt_embeds, # [1, 1280] + negative_pooled_prompt_embeds, # [1, 1280] + ) = self.encode_prompt( + prompt, + # "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" + prompt_2, # None + device, + num_images_per_prompt, # 1 + self.do_classifier_free_guidance, # True + negative_prompt, # low quality, bad quality, sketches + negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels # nv + + latents = self.prepare_latents( # [1,4,64,64] + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnet_compiled = is_compiled_module(self.controlnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples_npu, mid_block_res_sample_npu = self.compiled_control_model( + control_model_input.to(f'npu:{self.device_0}'), + t[None].to(f'npu:{self.device_0}'), controlnet_prompt_embeds.to(f'npu:{self.device_0}'), + image.to(f'npu:{self.device_0}'), + controlnet_added_cond_kwargs["text_embeds"].to(f'npu:{self.device_0}'), + controlnet_added_cond_kwargs["time_ids"].to(f'npu:{self.device_0}')) + + down_block_res_samples_0 = down_block_res_samples_npu[0].to('cpu') + down_block_res_samples_1 = down_block_res_samples_npu[1].to('cpu') + down_block_res_samples_2 = down_block_res_samples_npu[2].to('cpu') + down_block_res_samples_3 = down_block_res_samples_npu[3].to('cpu') + down_block_res_samples_4 = down_block_res_samples_npu[4].to('cpu') + down_block_res_samples_5 = down_block_res_samples_npu[5].to('cpu') + down_block_res_samples_6 = down_block_res_samples_npu[6].to('cpu') + down_block_res_samples_7 = down_block_res_samples_npu[7].to('cpu') + down_block_res_samples_8 = down_block_res_samples_npu[8].to('cpu') + mid_block_res_sample = mid_block_res_sample_npu.to('cpu') + + if guess_mode and self.do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + # predict the noise residual + noise_pred = self.compiled_unet_model( + latent_model_input.to(f'npu:{self.device_0}'), # [2, 4, 128, 128] + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), # [1] + prompt_embeds.to(f'npu:{self.device_0}'), # [2, 77, 2048] + down_block_res_samples_0.to(f'npu:{self.device_0}'), # [2,320,128,128] + down_block_res_samples_1.to(f'npu:{self.device_0}'), # [2,320,128,128] + down_block_res_samples_2.to(f'npu:{self.device_0}'), # [2,320,128,128] + down_block_res_samples_3.to(f'npu:{self.device_0}'), # [2,320,64,64] + down_block_res_samples_4.to(f'npu:{self.device_0}'), # [2,640,64,64] + down_block_res_samples_5.to(f'npu:{self.device_0}'), # [2,640,64,64] + down_block_res_samples_6.to(f'npu:{self.device_0}'), # [2,640,32,32] + down_block_res_samples_7.to(f'npu:{self.device_0}'), # [2,1280,32,32] + down_block_res_samples_8.to(f'npu:{self.device_0}'), # [2,1280,32,32] + mid_block_res_sample.to(f'npu:{self.device_0}'), # [2, 1280, 32, 32] + added_cond_kwargs["text_embeds"].to(f'npu:{self.device_0}'), # [2, 1280] + added_cond_kwargs["time_ids"].to(f'npu:{self.device_0}')) # [2, 6] + + noise_pred = noise_pred.to('cpu') + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + latents = latents / self.vae.config.scaling_factor + + start = time.time() + latents = self.vae.post_quant_conv(latents) + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to('cpu') + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + return (image, None) + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "-control", + "--controlnet_model", + type=str, + default="./controlnet-canny-sdxl-1.0", + help="Path or name of the pre-trained controlnet model.", + ) + parser.add_argument( + "-vae", + "--vae_model", + type=str, + default="./sdxl-vae-fp16-fix", + help="Path or name of the pre-trained vae model.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--soc", + choices=["A2"], + default="A2", + help="soc_version.", + ) + parser.add_argument( + "-cond_scale", + "--conditioning_scale", + type=float, + default=0.5, + help="conditioning_scale" + ) + parser.add_argument( + "--flag", + type=int, + default=1, + choices=[0, 1], + help="0 is static; 1 is dynamic rankl.", + ) + parser.add_argument( + "--w_h", + type=int, + default=1024, + choices=[512, 1024], + help="512: witdh=height=512, 1024:width=height=1024", + ) + parser.add_argument( + "--img_path", + type=str, + default="./hf-logo.png", + help="images path", + ) + parser.add_argument( + "--prompt", + type=str, + default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting", + help="prompt", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="low quality, bad quality, sketches", + help="negative_prompt", + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + prompt = args.prompt + negative_prompt = args.negative_prompt + + image = load_image(args.img_path) + + controlnet_conditioning_scale = 0.5 # recommended for good generalization + + controlnet = ControlNetModel.from_pretrained(args.controlnet_model) + vae = AutoencoderKL.from_pretrained(args.vae_model) + pipe = AIEStableDiffusionXLPipeline.from_pretrained(args.model, controlnet=controlnet, vae=vae).to("cpu") + pipe.enable_model_cpu_offload() + + pipe.parser_args(args) + pipe.compile_aie_model() + mindietorch.set_device(args.device) + + image = np.array(image) + image = cv2.Canny(image, 100, 200) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + image = Image.fromarray(image) + + stream = mindietorch.npu.Stream("npu:" + str(args.device)) + height = 1024 + width = 1024 + if args.w_h == 1024: + height = 1024 + width = 1024 + elif args.w_h == 512: + height = 512 + width = 512 + else: + print("This operation is not supported!") + + with mindietorch.npu.stream(stream): + images = pipe.ascendie_infer( + prompt, negative_prompt=negative_prompt, height=height, width=width, image=image, + controlnet_conditioning_scale=controlnet_conditioning_scale + ) + images[0][0].save(os.path.join(save_dir, f"hug_lab.png")) + + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/README.md b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/README.md new file mode 100644 index 0000000000000000000000000000000000000000..877367e3cd95588b71f1cf4bad399b1c3346ab09 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/README.md @@ -0,0 +1,196 @@ +# stable-diffusionxl-inpainting模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + + +# 概述 + + stable diffusion xl inpainting,图像重绘。是指对图像进行修改、调整和优化的过程。可以包括对图像的颜色、对比度、亮度、饱和度等进行调整,以及修复图像中的缺陷、删除不需要的元素、添加新的图像内容等操作。主要是通过给定一个想要编辑的区域mask,并在这个区域mask圈定的范围内进行文本生成图像的操作,从而编辑mask区域的图像内容。图像inpainting整体上和图生图流程一致,不过为了保证mask以外的图像区域不发生改变,在去噪过程的每一步,我们利用mask将Latent特征中不需要重建的部分都替换成原图最初的特征,只在mask部分进行特征的重建与优化。 + +- 参考实现: + ```bash + # stable-diffusion-xl-1.0-inpainting-0.1 + https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1 + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + | 配套 | 版本 | 环境准备指导 | + | ------ | ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0 | - | | + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + + +# 快速上手 + +## 获取源码 + +1. 安装依赖。 + ```bash + pip3 install -r requirements.txt + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + + 执行命令: + + ```bash + python3 stable_diffusion_clip_patch.py + ``` + + ```bash + python3 stable_diffusion_attention_patch.py + ``` + + ```bash + # 若使用unetCache + python3 stable_diffusionxl_unet_patch.py + ``` + +## 准备数据集 + +1. 获取原始数据集。 + + Inpainting图像重绘。图像编辑是指对图像进行修改、调整和优化的过程。可以包括对图像的颜色、对比度、亮度、饱和度等进行调整,以及修复图像中的缺陷、删除不需要的元素、添加新的图像内容等操作。 + ```bash + # img + wget https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png + + #mask img + wget https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png + ``` + +## 模型推理 + +1. 模型转换。 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 0. 获取权重(可选) + + 可提前下载权重,放到代码同级目录下,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # xl + git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 + ``` + + 1. 导出pt模型并进行编译。(可选) + + ```bash + # xl (执行时下载权重) + model_base="stabilityai/stable-diffusion-xl-base-1.0" + + # xl (使用上一步下载的权重) + model_base="./stable-diffusion-xl-base-1.0" + ``` + + 执行命令: + + ```bash + python3 export_ts_inpainting.py --model ${model_base} --output_dir ./models --batch_size 1 --flag 1 --soc A2 --device 0 + ``` + 参数说明: + - --model:模型权重路径 + - --output_dir: ONNX模型输出目录 + - --batch_size: 设置batch_size, 默认值为1,当前仅支持batch_size=1的场景 + - --flag:默认为1。0代表静态,只支持分辨率为1024x1024;1代表动态分档,支持的分辨率为1024x1024和512x512。 + - --soc:当前仅支持A2。 + - --device:推理设备ID + - --use_cache: 【可选】在推理过程中使用cache + + 静态编译场景: + + - ./models/clip/clip_bs{batch_size}.pt, ./models/clip/clip_bs{batch_size}_compile.ts 和 ./models/clip/clip2_bs{batch_size}.pt, ./models/clip/clip2_bs{batch_size}_compile.ts + - ./models/unet/unet_bs{batch_size x 2}.pt, ./models/unet/unet_bs{batch_size x 2}_compile_static.ts + - ./models/vae/vae_bs{batch_size}.pt, ./models/vae/vae_bs{batch_size}_compile_static.ts + - ./models/image_encode/image_encode_bs{batch_size}.pt, ./models/image_encode/image_encode_bs{batch_size}_compile_static.ts + + 动态分档场景: + + - ./models/clip/clip_bs{batch_size}.pt, ./models/clip/clip_bs{batch_size}_compile.ts 和 ./models/clip/clip2_bs{batch_size}.pt, ./models/clip/clip2_bs{batch_size}_compile.ts + - ./models/unet/unet_bs{batch_size x 2}.pt, ./models/unet/unet_bs{batch_size x 2}_compile.ts + - ./models/vae/vae_bs{batch_size}.pt, ./models/vae/vae_bs{batch_size}_compile.ts + - ./models/image_encode/image_encode_bs{batch_size}.pt, ./models/image_encode/image_encode_bs{batch_size}_compile.ts + + +2. 开始推理验证。 + + 1. 执行推理脚本。 + ```bash + # 不使用unetCache策略 + python3 stable_diffusionxl_pipeline_inpainting.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --save_dir ./results \ + --steps 30 \ + --device 0 \ + --output_dir ./models \ + --soc A2 \ + --flag 1 \ + --w_h 1024 \ + --strength 0.99 \ + --img_url ./imgs \ + --mask_url ./mask_imgs + + # 使用UnetCache策略 + python3 stable_diffusionxl_pipeline_inpainting.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --save_dir ./results \ + --steps 30 \ + --device 0 \ + --output_dir ./models \ + --soc A2 \ + --flag 1 \ + --w_h 1024 \ + --strength 0.99 \ + --img_url ./imgs \ + --mask_url ./mask_imgs \ + --use_cache + + ``` + + 参数说明: + - --model:模型名称或本地模型目录的路径。 + - --prompt_file:提示词文件。 + - --save_dir:生成图片的存放目录。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + - --output_dir:存放导出模型的目录。 + - --soc:当前仅支持A2。 + - --flag:默认为1。0代表静态,只支持分辨率为1024x1024;1代表动态分档,支持的分辨率为1024x1024和512x512。**注意**:请与导出模型时设置的flag保持一致 + - --w_h: image的宽高,设置为1024表示宽高均为1024,设置为512表示宽高均为512。仅支持这两种分辨率。 + - --strength:当w_h=1024时,设置该值为0.99。当w_h=512时,设置该值为0.6。 + - --img_url: img图片路径 + - --mask_url: mask_imgs图片路径 + - --use_cache:【可选】在推理过程中使用cache。 + + 执行完成后在 `./results`目录下生成推理图片。 diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/attention_processor.patch b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..26f296526adcfaf629f3c47a311b88bb4aa002a2 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/attention_processor.patch @@ -0,0 +1,18 @@ +--- attention_processor.py 2024-02-22 19:06:56.596000000 +0800 ++++ attention_processor.py 2024-02-22 19:07:17.232000000 +0800 +@@ -205,10 +205,11 @@ + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- if processor is None: +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # if processor is None: ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor = AttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/clip.patch b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/clip.patch new file mode 100644 index 0000000000000000000000000000000000000000..e3e4719b66f771ebb660f25151c33d140566c3f3 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/clip.patch @@ -0,0 +1,10 @@ +22a23 +> import numpy as np +760c761,762 +< mask.triu_(1) # zero out the lower diagonal +--- +> # mask.triu_(1) # zero out the lower diagonal +> mask = torch.from_numpy(np.triu(mask.numpy(), 1)) +1324a1327 +> + diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/compile_model.py b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/compile_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e881bda79bec649a885ac931e03d4889eab2cf52 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/compile_model.py @@ -0,0 +1,96 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import mindietorch +from mindietorch import _enums + +def compile_clip(model, inputs, clip_compiled_path, soc_version): + compiled_clip_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + min_block_size = 1, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(compiled_clip_model, clip_compiled_path) + +def compile_vae(model, inputs, vae_compiled_path, soc_version): + compiled_vae_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_vae_model, vae_compiled_path) + + +def compile_unet_cache(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0, + )) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +def compile_unet_skip(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +def compile_unet_init(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +def compile_img_encode(model, inputs, image_encode_compiled_path, soc_version): + compiled_image_encode_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP32, + optimization_level=0 + )) + torch.jit.save(compiled_image_encode_model, image_encode_compiled_path) \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/export_ts_inpainting.py b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/export_ts_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..6792e2558fafe363fd49f4919ac3911c7d334718 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/export_ts_inpainting.py @@ -0,0 +1,659 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from argparse import Namespace + +import torch +import mindietorch +import torch.nn as nn +from diffusers import ControlNetModel, StableDiffusionXLInpaintPipeline, AutoencoderKL +from compile_model import compile_clip, compile_vae, compile_unet_cache,\ + compile_unet_skip, compile_unet_init, compile_img_encode + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "--flag", + type=int, + default=1, + choices=[0, 1], + help="0 is static; 1 is dynamic rankl.", + ) + parser.add_argument( + "--soc", + choices=["A2"], + default="A2", + help="soc_version.", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device", + ) + + return parser.parse_args() + + +class ClipExport(torch.nn.Module): + def __init__(self, clip_model): + super().__init__() + self.clip_model = clip_model + + def forward(self, x, output_hidden_states=True, return_dict=False): + return self.clip_model(x, output_hidden_states=output_hidden_states, + return_dict=return_dict) + + +def trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path): + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + if not os.path.exists(clip_pt_path): + clip_export = ClipExport(encoder_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + if not os.path.exists(clip2_pt_path): + clip_export = ClipExport(encoder_2_model) + torch.jit.trace(clip_export, dummy_input).save(clip2_pt_path) + + +def export_clip(sd_pipeline: StableDiffusionXLInpaintPipeline, save_dir: str, + batch_size: int, flag: int, soc_version: str) -> None: + print("Exporting the text encoder...") + clip_path = os.path.join(save_dir, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip2_pt_path = os.path.join(clip_path, f"clip2_bs{batch_size}.pt") + clip1_compiled_path = os.path.join(clip_path, f"clip_bs{batch_size}_compile.ts") + clip2_compiled_path = os.path.join(clip_path, f"clip2_bs{batch_size}_compile.ts") + + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path) + + # compile + if flag == 0 or flag == 1: + if not os.path.exists(clip1_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), + dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_path, soc_version) + if not os.path.exists(clip2_compiled_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), + dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class UnetExportInit(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + text_embeds, + time_ids + ): + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids})[0] + + +def trace_unet_init(sd_pipeline, batch_size, unet_pt_path): + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + max_position_embeddings = encoder_model.config.max_position_embeddings + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32) + ) + + unet = UnetExportInit(unet_model) + unet.eval() + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + +def export_unet_init(sd_pipeline: StableDiffusionXLInpaintPipeline, + save_dir: str, batch_size: int, flag: int, + soc_version: str) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile_static.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile.ts") + unet_compiled_dynamic_path = os.path.join(unet_path, f"unet_compile_dynamic.ts") + + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + trace_unet_init(sd_pipeline, batch_size, unet_pt_path) + + # compile + if flag == 0: + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT)] + compile_unet_init(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((batch_size, in_channels, 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT) + ] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((batch_size, in_channels, 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT) + ] + inputs.append(inputs_gear_2) + compile_unet_init(model, inputs, unet_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class UnetExport(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + text_embeds, + time_ids, + if_skip, + inputCache=None + ): + if if_skip: + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + if_skip=if_skip, inputCache=inputCache)[0] + else: + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + if_skip=if_skip) + + +def export_unet_skip(sd_pipeline: StableDiffusionXLInpaintPipeline, save_dir: str, + batch_size: int, flag: int, + soc_version: str) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_1.pt") + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile_1_static.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile_1_haveshape.ts") + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + # trace + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, 1280, + math.ceil(sample_size / 2), math.ceil(sample_size / 2)], + dtype=torch.float32), + ) + unet = UnetExport(unet_model) + unet.eval() + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + # compile + if flag == 0: + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input( + (batch_size, 1280, math.ceil(sample_size / 2), + math.ceil(sample_size / 2)), + dtype=mindietorch.dtype.FLOAT)] + compile_unet_skip(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((batch_size, in_channels, 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input( + (batch_size, 1280, math.ceil(1024 // 8 / 2), + math.ceil(1024 // 8 / 2)), + dtype=mindietorch.dtype.FLOAT), + ] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((batch_size, in_channels, 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input( + (batch_size, 1280, math.ceil(512 // 8 / 2), + math.ceil(512 // 8 / 2)), + dtype=mindietorch.dtype.FLOAT), + ] + inputs.append(inputs_gear_2) + compile_unet_skip(model, inputs, unet_compiled_path, soc_version) + + else: + print("This operation is not supported!") + + +def export_unet_cache(sd_pipeline: StableDiffusionXLInpaintPipeline, save_dir: str, + batch_size: int, flag: int, + soc_version: str) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_0.pt") + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile_0_static.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile_0.ts") + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32), + torch.zeros([1], dtype=torch.int64), + ) + unet = UnetExport(unet_model) + unet.eval() + + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + # compile + if flag == 0: + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)] + compile_unet_cache(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((batch_size, + in_channels, 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + ] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((batch_size, + in_channels, 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + ] + inputs.append(inputs_gear_2) + compile_unet_cache(model, inputs, unet_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model): + super().__init__() + self.vae_model = vae_model + + def forward(self, latents): + return self.vae_model.decoder(latents) + + +def export_vae(sd_pipeline: StableDiffusionXLInpaintPipeline, save_dir: str, + batch_size: int, flag: int, + soc_version: str) -> None: + print("Exporting the image decoder...") + + vae_path = os.path.join(save_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + vae_compiled_static_path = os.path.join(vae_path, f"vae_bs{batch_size}_compile_static.ts") + vae_compiled_path = os.path.join(vae_path, f"vae_bs{batch_size}_compile.ts") + + vae_model = sd_pipeline.vae + unet_model = sd_pipeline.unet + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.out_channels + # trace + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32) + vae_export = VaeExport(vae_model) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + # compile + if flag == 0: + if not os.path.exists(vae_compiled_static_path): + model = torch.jit.load(vae_pt_path).eval() + # 静态shape + inputs = [ + mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(vae_compiled_path): + # 动态分档 + model = torch.jit.load(vae_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((batch_size, in_channels, + 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((batch_size, in_channels, + 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_2) + + compile_vae(model, inputs, vae_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class ImageEncodeExport(torch.nn.Module): + def __init__(self, vae): + super().__init__() + self.tile_sample_min_size = vae.tile_sample_min_size + self.use_tiling = vae.use_tiling + self.encoder = vae.encoder + self.quant_conv = vae.quant_conv + self.tiled_encode = vae.tiled_encode + self.use_slicing = vae.use_slicing + + def forward(self, x, return_dict=True): + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + return moments + + +def export_image_encode(sd_pipeline: StableDiffusionXLInpaintPipeline, save_dir: str, + batch_size: int, flag: int, + soc_version: str) -> None: + print("Exporting the image decoder...") + + img_encode_path = os.path.join(save_dir, "image_encode") + if not os.path.exists(img_encode_path): + os.makedirs(img_encode_path, mode=0o640) + img_encode_pt_path = os.path.join(img_encode_path, f"image_encode_bs{batch_size}.pt") + img_encode_compiled_static_path = os.path.join(img_encode_path, + f"image_encode_bs{batch_size}_compile_static.ts") + img_encode_compiled_path = os.path.join(img_encode_path, f"image_encode_bs{batch_size}_compile.ts") + + # trace + vae_model = sd_pipeline.vae + if not os.path.exists(img_encode_pt_path): + dummy_input = torch.ones([1, 3, 1024, 1024], dtype=torch.float32) + vae_export = ImageEncodeExport(vae_model) + torch.jit.trace(vae_export, dummy_input).save(img_encode_pt_path) + + # compile + if flag == 0: + if not os.path.exists(img_encode_compiled_static_path): + model = torch.jit.load(img_encode_pt_path).eval() + # 静态shape + inputs = [mindietorch.Input((1, 3, 1024, 1024), dtype=mindietorch.dtype.FLOAT)] + compile_img_encode(model, inputs, img_encode_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(img_encode_compiled_path): + # 动态分档 + model = torch.jit.load(img_encode_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((1, 3, 1024, 1024), dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((1, 3, 512, 512), dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_2) + + compile_img_encode(model, inputs, img_encode_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +def export(model_path: str, save_dir: str, batch_size: int, flag: int, soc: str, use_cache: bool) -> None: + if soc == "A2": + soc_version = "Ascend910B4" + else: + print("unsupport soc_version, please check!") + return + + pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(model_path).to("cpu") + + export_clip(pipeline, save_dir, batch_size, flag, soc_version) + export_vae(pipeline, save_dir, batch_size, flag, soc_version) + export_image_encode(pipeline, save_dir, batch_size, flag, soc_version) + + if use_cache: + # 单卡带unetcache + export_unet_cache(pipeline, save_dir, batch_size * 2, flag, soc_version) + export_unet_skip(pipeline, save_dir, batch_size * 2, flag, soc_version) + else: + # 单卡不带unetcache + export_unet_init(pipeline, save_dir, batch_size * 2, flag, soc_version) + + +def main(): + args = parse_arguments() + mindietorch.set_device(args.device) + export(args.model, args.output_dir, args.batch_size, args.flag, args.soc, args.use_cache) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/prompts.txt b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..f919b6b0b51b0a71e7f557e38032f535c81de8d8 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/prompts.txt @@ -0,0 +1 @@ +a tiger sitting on a park bench \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/requirements.txt b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ee96f049dd2ea00fcd7255fe050df602386e12c4 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/requirements.txt @@ -0,0 +1,5 @@ +setuptools==57.5.0 +torch==2.1.0 +diffusers==0.26.3 +transformers==4.26.1 +open_clip_torch==2.20.0 \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusion_attention_patch.py b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusion_attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c1b369bb23389b6abc12c710f63e5b986b836d --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusion_attention_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusion_clip_patch.py b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusion_clip_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae2a2c4dca8d774982e69323343eb48be008e43 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusion_clip_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import transformers + + +def main(): + transformers_path = transformers.__path__ + transformers_version = transformers.__version__ + + assert transformers_version is not '4.26.1', "expectation transformers==4.26.1" + os.system(f'patch -p0 {transformers_path[0]}/models/clip/modeling_clip.py clip.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusionxl_pipeline_inpainting.py b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusionxl_pipeline_inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..926fff53f855114253bfefe165f0321ca2c6d6e1 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusionxl_pipeline_inpainting.py @@ -0,0 +1,1543 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import mindietorch +from diffusers import StableDiffusionXLInpaintPipeline +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionXLLoraLoaderMixin, \ + TextualInversionLoaderMixin +from diffusers.utils import USE_PEFT_BACKEND, deprecate, is_invisible_watermark_available, is_torch_xla_available, \ + logging, \ + replace_example_docstring, scale_lora_layers, unscale_lora_layers + +from diffusers.utils import load_image +from diffusers.utils.torch_utils import randn_tensor +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution + +from mindietorch import _enums + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +clip_time = 0 +unet_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + batch_size: int, + num_images_per_prompt: int = 1, + max_num_prompts: int = 0 + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file, max_num_prompts) + + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file, max_num_prompts) + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + + +class AIEStableDiffusionXLInpaintPipeline(StableDiffusionXLInpaintPipeline): + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0 = self.args.device[0] + else: + self.device_0 = args.device + + def compile_aie_model(self): + if self.is_init: + return + + in_channels = self.unet.config.in_channels + out_channels = self.unet.config.out_channels + sample_size = self.unet.config.sample_size + encoder_hidden_size_2 = self.text_encoder_2.config.hidden_size + encoder_hidden_size = self.text_encoder.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = self.text_encoder.config.max_position_embeddings + + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + + if self.args.flag == 0 or self.args.flag == 1: + if self.args.flag == 0: + tail = "_static" + elif self.args.flag == 1: + tail = "" + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + img_encode_compiled_path = os.path.join(self.args.output_dir, + f"image_encode/image_encode_bs{size}_compile{tail}.ts") + self.compiled_image_encode_model = torch.jit.load(img_encode_compiled_path).eval() + + if not self.args.use_cache: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_compile{tail}.ts") + self.compiled_unet_model = torch.jit.load(unet_compiled_path).eval() + if self.args.use_cache: + unet_skip_compiled_path = os.path.join(self.args.output_dir, + f"unet/unet_bs{batch_size}_compile_1{tail}.ts") + self.compiled_unet_model_skip = torch.jit.load(unet_skip_compiled_path).eval() + + unet_cache_compiled_path = os.path.join(self.args.output_dir, + f"unet/unet_bs{batch_size}_compile_0{tail}.ts") + self.compiled_unet_model_cache = torch.jit.load(unet_cache_compiled_path).eval() + else: + print("This operation is not supported!") + + self.is_init = True + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + + text_encoders = ( + [self.compiled_clip_model, self.compiled_clip_model_2] if self.compiled_clip_model is not None + else [self.compiled_clip_model_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + global clip_time + start = time.time() + prompt_embeds_npu = text_encoder(text_input_ids.to(f'npu:{self.device_0}')) + pooled_prompt_embeds = prompt_embeds_npu[0].to('cpu') + clip_time += time.time() - start + if clip_skip is None: + prompt_embeds = prompt_embeds_npu[2][-2].to('cpu') + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds_npu.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(f'npu:{self.device_0}'))[0].to('cpu') + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_prompt_embeds = [torch.from_numpy(text) for text in negative_prompt_embeds] + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device="cpu") + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + add_noise=True, + return_noise=False, + return_image_latents=False, + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if image.shape[1] == 4: + image_latents = image.to(device=device, dtype=dtype) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + elif return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + start = time.time() + image_latents = self._encode_vae_image(image=image, generator=generator) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + if latents is None and add_noise: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + elif add_noise: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = image_latents.to(device) + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + dtype = image.dtype + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list): + + image_latents = [ + retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + + if image.is_cpu: + image = image.contiguous().to(f'npu:{self.device_0}') + + moments = self.compiled_image_encode_model(image).to('cpu') + image_latents = DiagonalGaussianDistribution(moments).sample(generator) + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + image_latents = image_latents.to(dtype) + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + + if masked_image is not None and masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + masked_image_latents = None + + if masked_image is not None: + if masked_image_latents is None: + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @torch.no_grad() + def ascendie_infer( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image=None, + mask_image=None, + masked_image_latents: torch.FloatTensor = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.9999, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_start: Optional[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 8.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + skip_steps=None, + flag_cache=0, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If + `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and + contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on + the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large + and contain information inreleant for inpainging, such as background. + strength (`float`, *optional*, defaults to 0.9999): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. Note that in the case of `denoising_start` being declared as an + integer, the value of `strength` will be ignored. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_start (`float`, *optional*): + When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be + bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and + it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, + strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline + is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be + denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the + final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline + forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + start = time.time() + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + prompt_2, + image, + mask_image, + height, + width, + strength, + callback_steps, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._denoising_start = denoising_start + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. set timesteps + def denoising_value_valid(dnv): + return isinstance(self.denoising_end, float) and (0 < dnv < 1) + + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, + strength, + device, + denoising_start=self.denoising_start if denoising_value_valid else None, + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is not None: + masked_image = masked_image_latents + elif init_image.shape[1] == 4: + # if images are in latent space, we can't mask it + masked_image = None + else: + masked_image = init_image * (mask < 0.5) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + add_noise = True if self.denoising_start is None else False + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + add_noise=add_noise, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + # 8.1 Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Prepare extra step kwargs. + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 10. Prepare added time ids & embeddings + if negative_original_size is None: + negative_original_size = original_size + if negative_target_size is None: + negative_target_size = target_size + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + if ip_adapter_image is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt + ) + + # 11. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + if ( + self.denoising_end is not None + and self.denoising_start is not None + and denoising_value_valid(self.denoising_end) + and denoising_value_valid(self.denoising_start) + and self.denoising_start >= self.denoising_end + ): + raise ValueError( + f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: " + + f" {self.denoising_end} when using type float." + ) + elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 11.1 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + skip_flag = torch.ones([1], dtype=torch.long) + cache_flag = torch.zeros([1], dtype=torch.long) + global unet_time + global vae_time + with self.progress_bar(total=num_inference_steps) as progress_bar: + + for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + latent_model_input = latent_model_input.to(f'npu:{self.device_0}') + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + start = time.time() + if flag_cache: + if skip_steps[i]: + noise_pred = self.compiled_unet_model_skip(latent_model_input, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + add_text_embeds.to(f'npu:{self.device_0}'), + add_time_ids.to(f'npu:{self.device_0}'), + skip_flag.to(f'npu:{self.device_0}'), + cache, ).to('cpu') + else: + outputs = self.compiled_unet_model_cache(latent_model_input, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + add_text_embeds.to(f'npu:{self.device_0}'), + add_time_ids.to(f'npu:{self.device_0}'), + cache_flag.to(f'npu:{self.device_0}'), + ) + noise_pred = outputs[0].to('cpu') + cache = outputs[1] + else: + noise_pred = self.compiled_unet_model(latent_model_input, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + add_text_embeds.to(f'npu:{self.device_0}'), + add_time_ids.to(f'npu:{self.device_0}')).to('cpu') + unet_time += time.time() - start + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents + if self.do_classifier_free_guidance: + init_mask, _ = mask.chunk(2) + else: + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) + mask = callback_outputs.pop("mask", mask) + masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + latents = latents / self.vae.config.scaling_factor + + start = time.time() + latents = self.vae.post_quant_conv(latents) + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to('cpu') + vae_time += time.time() - start + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + + # Offload all models + self.maybe_free_model_hooks() + + return (image, None) + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=30, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--soc", + choices=["A2"], + default="A2", + help="soc_version.", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_steps", + type=str, + default="1,2,4,6,7,9,10,12,13,14,16,18,19,21,23,24,26,27,29,\ + 30,31,33,34,36,37,39,40,42,43,45,47,48,49", # 17+33 + help="Steps to use cache data." + ) + parser.add_argument( + "--flag", + type=int, + default=1, + choices=[0, 1], + help="0 is static; 1 is dynamic rankl.", + ) + parser.add_argument( + "--w_h", + type=int, + default=512, + choices=[512, 1024], + help="0 is static; 1 is dynamic rankl.", + ) + parser.add_argument( + "--strength", + type=float, + default=0.6, + choices=[0.6, 0.99], + help="512 use 0.6, 1024 use 0.99", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=8.0, + help="guidance_scale", + ) + parser.add_argument( + "--img_url", + type=str, + default="./imgs", + help="img_url", + ) + parser.add_argument( + "--mask_url", + type=str, + default="./mask_imgs", + help="mask_url", + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusionXLInpaintPipeline.from_pretrained(args.model).to("cpu") + + pipe.parser_args(args) + pipe.compile_aie_model() + mindietorch.set_device(args.device) + skip_steps = [0] * args.steps + flag_cache = 0 + if args.use_cache: + flag_cache = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + height = 1024 + width = 1024 + if args.w_h == 1024: + height = 1024 + width = 1024 + elif args.w_h == 512: + height = 512 + width = 512 + else: + print("This operation is not supported!") + + prompts_2 = "" + infer_num = 0 + image_info = [] + + img_urls = [os.path.join(args.img_url, filename) for filename in os.listdir(args.img_url)] + mask_urls = [os.path.join(args.mask_url, filename) for filename in os.listdir(args.mask_url)] + + current_prompt = None + + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + + image = load_image(img_urls[i]).resize((height, width)) + mask_image = load_image(mask_urls[i]).resize((height, width)) + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + stream = mindietorch.npu.Stream("npu:" + str(args.device)) + start_time = time.time() + with mindietorch.npu.stream(stream): + + images = pipe.ascendie_infer( + prompt=prompts, + image=image, + mask_image=mask_image, + num_inference_steps=args.steps, + guidance_scale=args.guidance_scale, + strength=args.strength, + skip_steps=skip_steps, + flag_cache=flag_cache, + width=width, + height=height, + ) + + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n" + f"clip time: {clip_time:.3f}s\n" + f"average clip time: {clip_time / infer_num:.3f}s\n" + f"unet time: {unet_time:.3f}s\n" + f"average unet time: {unet_time / infer_num:.3f}s\n" + f"vae time: {vae_time:.3f}s\n" + f"average vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n") + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusionxl_unet_patch.py b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusionxl_unet_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..112dd0c88e73617f4f239235afc3a704f9b9052f --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/stable_diffusionxl_unet_patch.py @@ -0,0 +1,28 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/unets/unet_2d_condition.py unet_2d_condition.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-Inpainting/unet_2d_condition.patch b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/unet_2d_condition.patch new file mode 100644 index 0000000000000000000000000000000000000000..88c9cb09510b12e4bf796c7ff3ccaf712b411b28 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-Inpainting/unet_2d_condition.patch @@ -0,0 +1,244 @@ +--- ../mindie-sdxl/unet_2d_condition_bak.py 2024-04-17 07:41:03.000000000 +0000 ++++ ../mindie-sdxl/unet_2d_condition.py 2024-04-17 07:41:05.000000000 +0000 +@@ -855,6 +855,8 @@ + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, ++ if_skip: int = 0, ++ inputCache: torch.FloatTensor = None, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. +@@ -1110,29 +1112,56 @@ + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + +- down_block_res_samples = (sample,) +- for downsample_block in self.down_blocks: +- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: +- # For t2i-adapter CrossAttnDownBlock2D +- additional_residuals = {} +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) +- +- sample, res_samples = downsample_block( +- hidden_states=sample, +- temb=emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- **additional_residuals, +- ) +- else: +- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- sample += down_intrablock_additional_residuals.pop(0) +- +- down_block_res_samples += res_samples ++ if not if_skip: ++ down_block_res_samples = (sample,) ++ for downsample_block in self.down_blocks: ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) ++ ++ down_block_res_samples += res_samples ++ else: ++ down_block_res_samples = (sample,) ++ for tmp, downsample_block in enumerate(self.down_blocks): ++ if tmp >= 2: ++ break ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) ++ ++ down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () +@@ -1146,61 +1175,93 @@ + down_block_res_samples = new_down_block_res_samples + + # 4. mid +- if self.mid_block is not None: +- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: +- sample = self.mid_block( +- sample, +- emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- ) +- else: +- sample = self.mid_block(sample, emb) +- +- # To support T2I-Adapter-XL +- if ( +- is_adapter +- and len(down_intrablock_additional_residuals) > 0 +- and sample.shape == down_intrablock_additional_residuals[0].shape +- ): +- sample += down_intrablock_additional_residuals.pop(0) ++ if not if_skip: ++ if self.mid_block is not None: ++ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: ++ sample = self.mid_block( ++ sample, ++ emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ else: ++ sample = self.mid_block(sample, emb) ++ ++ # To support T2I-Adapter-XL ++ if ( ++ is_adapter ++ and len(down_intrablock_additional_residuals) > 0 ++ and sample.shape == down_intrablock_additional_residuals[0].shape ++ ): ++ sample += down_intrablock_additional_residuals.pop(0) ++ else: ++ sample = inputCache + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up +- for i, upsample_block in enumerate(self.up_blocks): +- is_final_block = i == len(self.up_blocks) - 1 +- +- res_samples = down_block_res_samples[-len(upsample_block.resnets) :] +- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] +- +- # if we have not reached the final block and need to forward the +- # upsample size, we do it here +- if not is_final_block and forward_upsample_size: +- upsample_size = down_block_res_samples[-1].shape[2:] +- +- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: +- sample = upsample_block( +- hidden_states=sample, +- temb=emb, +- res_hidden_states_tuple=res_samples, +- encoder_hidden_states=encoder_hidden_states, +- cross_attention_kwargs=cross_attention_kwargs, +- upsample_size=upsample_size, +- attention_mask=attention_mask, +- encoder_attention_mask=encoder_attention_mask, +- ) +- else: +- sample = upsample_block( +- hidden_states=sample, +- temb=emb, +- res_hidden_states_tuple=res_samples, +- upsample_size=upsample_size, +- scale=lora_scale, +- ) ++ if not if_skip: ++ for i, upsample_block in enumerate(self.up_blocks): ++ is_final_block = i == len(self.up_blocks) - 1 ++ ++ res_samples = down_block_res_samples[-len(upsample_block.resnets) :] ++ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] ++ ++ # if we have not reached the final block and need to forward the ++ # upsample size, we do it here ++ if not is_final_block and forward_upsample_size: ++ upsample_size = down_block_res_samples[-1].shape[2:] ++ ++ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ encoder_hidden_states=encoder_hidden_states, ++ cross_attention_kwargs=cross_attention_kwargs, ++ upsample_size=upsample_size, ++ attention_mask=attention_mask, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ else: ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ upsample_size=upsample_size, ++ scale=lora_scale, ++ ) ++ ++ if (not if_skip) and (i == 0): ++ inputCache = sample ++ ++ else: ++ ++ for i, upsample_block in enumerate(self.up_blocks): ++ if i==1: ++ res_samples = down_block_res_samples[-4:-1] ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ encoder_hidden_states=encoder_hidden_states, ++ cross_attention_kwargs=cross_attention_kwargs, ++ upsample_size=upsample_size, ++ attention_mask=attention_mask, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ if i==2: ++ res_samples = down_block_res_samples[:3] ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ upsample_size=upsample_size, ++ scale=lora_scale, ++ ) + + # 6. post-process + if self.conv_norm_out: +@@ -1215,4 +1276,7 @@ + if not return_dict: + return (sample,) + +- return UNet2DConditionOutput(sample=sample) ++ if (not if_skip): ++ return (sample, inputCache) ++ else: ++ return UNet2DConditionOutput(sample=sample) diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/README.md b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e67173076f3b0029550c0379911f909dddaf88ad --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/README.md @@ -0,0 +1,179 @@ +# stable-diffusionxl-prompt-weighting模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + + +# 概述 + + stable-diffusionxl-prompt-weighting描述增强(类似++的操作):通过“提示权重(prompt weighting)”来精细调控模型对输入文本提示中不同概念的关注程度,从而影响最终生成图像的内容和焦点。 + +- 参考实现: + ```bash + # stable-diffusionxl-prompt-weighting + https://huggingface.co/docs/diffusers/using-diffusers/weighted_prompts#stable-diffusion-xl + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1 +Atlas 300I Duo推理卡:支持的卡数为1 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + | 配套 | 版本 | 环境准备指导 | + | ------ | ------- | ------------ | + | Python | 3.10.13 | - | + | torch | 2.1.0 | - | +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + + +# 快速上手 + +## 获取源码 + +1. 安装依赖。 + ```bash + pip3 install -r requirements.txt + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + + 执行命令: + + ```bash + python3 stable_diffusion_clip_patch.py + ``` + + ```bash + python3 stable_diffusion_attention_patch.py + ``` + + ```bash + # 若使用unetCache + python3 stable_diffusionxl_unet_patch.py + ``` + +## 准备数据集 + +1. 获取原始数据集。 + + 本模型输入文本信息生成图片,无需数据集。 + + +## 模型推理 + +1. 模型转换。 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 0. 获取权重(可选) + + 可提前下载权重,放到代码同级目录下,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # xl + git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 + ``` + + 1. 导出pt模型并进行编译。(可选) + + ```bash + # xl (执行时下载权重) + model_base="stabilityai/stable-diffusion-xl-base-1.0" + + # xl (使用上一步下载的权重) + model_base="./stable-diffusion-xl-base-1.0" + ``` + + 执行命令: + + ```bash + python3 export_ts_prompt_weight.py --model ${model_base} --output_dir ./models --batch_size 1 --flag 1 --soc A2 --device 0 + + ``` + 参数说明: + - --model:模型权重路径 + - --output_dir: ONNX模型输出目录 + - --batch_size: 设置batch_size, 默认值为1,当前仅支持batch_size=1的场景 + - --flag:默认为1。0代表静态,只支持分辨率为1024x1024;1代表动态分档,支持的分辨率为1024x1024和512x512。 + - --soc:只支持Duo和A2。默认为A2。 + - --device:推理设备ID + - --use_cache: 【可选】在推理过程中使用cache + + 静态编译场景: + + - ./models/clip/clip_bs{batch_size}.pt, ./models/clip/clip_bs{batch_size}_compile.ts 和 ./models/clip/clip2_bs{batch_size}.pt, ./models/clip/clip2_bs{batch_size}_compile.ts + - ./models/unet/unet_bs{batch_size x 2}.pt, ./models/unet/unet_bs{batch_size x 2}_compile_static.ts + - ./models/vae/vae_bs{batch_size}.pt, ./models/vae/vae_bs{batch_size}_compile_static.ts + - ./models/ddim/ddim_bs{batch_size}.pt, ./models/ddim/ddim_bs{batch_size}_compile_static.ts + + 动态分档场景: + + - ./models/clip/clip_bs{batch_size}.pt, ./models/clip/clip_bs{batch_size}_compile.ts 和 ./models/clip/clip2_bs{batch_size}.pt, ./models/clip/clip2_bs{batch_size}_compile.ts + - ./models/unet/unet_bs{batch_size x 2}.pt, ./models/unet/unet_bs{batch_size x 2}_compile.ts + - ./models/vae/vae_bs{batch_size}.pt, ./models/vae/vae_bs{batch_size}_compile.ts + - ./models/ddim/ddim_bs{batch_size}.pt, ./models/ddim/ddim_bs{batch_size}_compile.ts + + + +2. 开始推理验证。 + + 1. 执行推理脚本。 + ```bash + # 不使用unetCache策略 + python3 stable_diffusionxl_pipeline_prompt_weight.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --output_dir ./models \ + --flag 1 \ + --w_h 1024 + + # 使用UnetCache策略 + python3 stable_diffusionxl_pipeline_prompt_weight.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results_unetCache \ + --steps 50 \ + --output_dir ./models \ + --flag 1 \ + --w_h 1024 \ + --use_cache + ``` + + 参数说明: + - --model: 模型名称或本地模型目录的路径。 + - --prompt_file: 提示词文件。 + - --device: 推理设备ID。 + - --save_dir: 生成图片的存放目录。 + - --steps: 生成图片迭代次数。 + - --output_dir: 存放导出模型的目录。 + - --flag:默认为1。0代表静态,只支持分辨率为1024x1024;1代表动态分档,支持的分辨率为1024x1024和512x512。**注意**:请与导出模型时设置的flag保持一致 + - --w_h: image的宽高,设置为1024表示宽高均为1024,设置为512表示宽高均为512。仅支持这两种分辨率。 + - --use_cache: 【可选】在推理过程中使用cache。 + + 执行完成后在 `./results`目录下生成推理图片。 \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/attention_processor.patch b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..26f296526adcfaf629f3c47a311b88bb4aa002a2 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/attention_processor.patch @@ -0,0 +1,18 @@ +--- attention_processor.py 2024-02-22 19:06:56.596000000 +0800 ++++ attention_processor.py 2024-02-22 19:07:17.232000000 +0800 +@@ -205,10 +205,11 @@ + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- if processor is None: +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # if processor is None: ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor = AttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/clip.patch b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/clip.patch new file mode 100644 index 0000000000000000000000000000000000000000..e3e4719b66f771ebb660f25151c33d140566c3f3 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/clip.patch @@ -0,0 +1,10 @@ +22a23 +> import numpy as np +760c761,762 +< mask.triu_(1) # zero out the lower diagonal +--- +> # mask.triu_(1) # zero out the lower diagonal +> mask = torch.from_numpy(np.triu(mask.numpy(), 1)) +1324a1327 +> + diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/compile_model.py b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/compile_model.py new file mode 100644 index 0000000000000000000000000000000000000000..374198f162565722bdbd43c7841f153ada3c948d --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/compile_model.py @@ -0,0 +1,94 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import mindietorch +from mindietorch import _enums + +def compile_clip(model, inputs, clip_compiled_path, soc_version): + compiled_clip_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + min_block_size = 1, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(compiled_clip_model, clip_compiled_path) + +def compile_vae(model, inputs, vae_compiled_path, soc_version): + compiled_vae_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_vae_model, vae_compiled_path) + +def compile_ddim(model, inputs, scheduler_compiled_path, soc_version): + compiled_scheduler = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=False, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(compiled_scheduler, scheduler_compiled_path) + +def compile_unet_cache(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0, + )) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +def compile_unet_skip(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +def compile_unet_init(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(compiled_unet_model, unet_compiled_path) \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/export_ts_prompt_weight.py b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/export_ts_prompt_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..7c82b3090b18d97c5250689a1fcc6b6b98a29070 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/export_ts_prompt_weight.py @@ -0,0 +1,709 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from argparse import Namespace + +import torch +import torch.nn as nn +from diffusers import DDIMScheduler +from diffusers import StableDiffusionXLPipeline +import mindietorch +from mindietorch import _enums +import math +from compile_model import compile_clip, compile_vae, compile_ddim,\ + compile_unet_cache, compile_unet_skip, compile_unet_init + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("-steps", "--steps", type=int, default=50, help="steps.") + parser.add_argument("-guid", "--guidance_scale", type=float, default=5.0, help="guidance_scale") + parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") + parser.add_argument("--soc", choices=["A2"], default="A2", help="soc_version.", ) + parser.add_argument( + "--flag", + type=int, + default=1, + choices=[0, 1], + help="0 is static; 1 is dynamic rankl.", + ) + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device", + ) + + return parser.parse_args() + + +class NewScheduler(torch.nn.Module): + def __init__(self, num_train_timesteps=1000, num_inference_steps=50, alphas_cumprod=None, + guidance_scale=5.0, alpha_prod_t_prev_cache=None): + super(NewScheduler, self).__init__() + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = num_inference_steps + self.alphas_cumprod = alphas_cumprod + self.guidance_scale = guidance_scale + self.alpha_prod_t_prev_cache = alpha_prod_t_prev_cache + + def forward(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, step_index: int): + divide_batch = (model_output.shape[0]) // 2 + noise_pred_uncond = model_output[:divide_batch, ..., ..., ...] + noise_pred_text = model_output[divide_batch:, ..., ..., ...] + model_output = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alpha_prod_t_prev_cache[step_index] + beta_prod_t = 1 - alpha_prod_t + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + return prev_sample + + +def trace_ddim(sd_pipeline, steps, guidance_scale, batch_size, ddim_pt_path): + if not os.path.exists(ddim_pt_path): + dummy_input = ( + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.randn([batch_size // 2, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + ) + scheduler = DDIMScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(steps, device="cpu") + + timesteps = scheduler.timesteps[:steps] + alpha_prod_t_prev_cache = [] + for timestep in timesteps: + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + alpha_prod_t_prev = scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + alpha_prod_t_prev_cache.append(alpha_prod_t_prev) + + new_ddim = NewScheduler( + num_train_timesteps=scheduler.config.num_train_timesteps, + num_inference_steps=scheduler.num_inference_steps, + alphas_cumprod=scheduler.alphas_cumprod, + guidance_scale=guidance_scale, + alpha_prod_t_prev_cache=torch.tensor(alpha_prod_t_prev_cache) + ) + + new_ddim.eval() + torch.jit.trace(new_ddim, dummy_input).save(ddim_pt_path) + + +def export_ddim(sd_pipeline: StableDiffusionXLPipeline, save_dir: str, steps: int, guidance_scale: float, + batch_size: int, flag: int) -> None: + print("Exporting the ddim...") + ddim_path = os.path.join(save_dir, "ddim") + if not os.path.exists(ddim_path): + os.makedirs(ddim_path, mode=0o744) + + ddim_pt_path = os.path.join(ddim_path, f"ddim_bs{batch_size}.pt") + scheduler_compiled_static_path = os.path.join(ddim_path, f"ddim_bs{batch_size}_compile_static.ts") + scheduler_compiled_path = os.path.join(ddim_path, f"ddim_bs{batch_size}_compile.ts") + + unet_model = sd_pipeline.unet + ddim_model = sd_pipeline.scheduler + sample_size = unet_model.config.sample_size + + in_channels = 4 + # trace + trace_ddim(sd_pipeline, steps, guidance_scale, batch_size, ddim_pt_path) + + # compile + if flag == 0: + if not os.path.exists(scheduler_compiled_static_path): + model = torch.jit.load(ddim_pt_path).eval() + inputs = [mindietorch.Input((batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size // 2, + in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)] + compile_ddim(model, inputs, scheduler_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(ddim_pt_path).eval() + # 动态分档 + inputs = [] + inputs_gear_1 = [mindietorch.Input((batch_size, in_channels, 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size // 2, + in_channels, 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((batch_size, in_channels, 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size // 2, + in_channels, 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)] + inputs.append(inputs_gear_2) + compile_ddim(model, inputs, scheduler_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class ClipExport(torch.nn.Module): + def __init__(self, clip_model): + super().__init__() + self.clip_model = clip_model + + def forward(self, x, output_hidden_states=True, return_dict=False): + return self.clip_model(x, output_hidden_states=output_hidden_states, return_dict=return_dict) + + +def trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path): + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + if not os.path.exists(clip_pt_path): + clip_export = ClipExport(encoder_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + if not os.path.exists(clip2_pt_path): + clip_export = ClipExport(encoder_2_model) + torch.jit.trace(clip_export, dummy_input).save(clip2_pt_path) + + +def export_clip(sd_pipeline: StableDiffusionXLPipeline, save_dir: str, batch_size: int, flag: int) -> None: + print("Exporting the text encoder...") + clip_path = os.path.join(save_dir, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip2_pt_path = os.path.join(clip_path, f"clip2_bs{batch_size}.pt") + clip1_compiled_path = os.path.join(clip_path, f"clip_bs{batch_size}_compile.ts") + clip2_compiled_path = os.path.join(clip_path, f"clip2_bs{batch_size}_compile.ts") + + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + + max_position_embeddings = encoder_model.config.max_position_embeddings + print(f'max_position_embeddings: {max_position_embeddings}') + + # trace + trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path) + + # compile + if flag == 0 or flag == 1: + if not os.path.exists(clip1_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_path, soc_version) + if not os.path.exists(clip2_compiled_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class UnetExportInit(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + text_embeds, + time_ids + ): + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids})[0] + + +def trace_unet_init(sd_pipeline, batch_size, unet_pt_path): + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + max_position_embeddings = encoder_model.config.max_position_embeddings + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32) + ) + + unet = UnetExportInit(unet_model) + unet.eval() + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + +def export_unet_init(sd_pipeline: StableDiffusionXLPipeline, save_dir: str, batch_size: int, flag: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + compile_batch_size = batch_size + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{compile_batch_size}_compile_static.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{compile_batch_size}_compile.ts") + + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + trace_unet_init(sd_pipeline, batch_size, unet_pt_path) + + # compile + if flag == 0: + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [mindietorch.Input((compile_batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT)] + compile_unet_init(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((compile_batch_size, in_channels, 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT) + ] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((compile_batch_size, in_channels, 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT) + ] + inputs.append(inputs_gear_2) + compile_unet_init(model, inputs, unet_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class UnetExport(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + text_embeds, + time_ids, + if_skip, + inputCache=None + ): + if if_skip: + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + if_skip=if_skip, inputCache=inputCache)[0] + else: + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + if_skip=if_skip) + + +def export_unet_skip(sd_pipeline: StableDiffusionXLPipeline, save_dir: str, batch_size: int, flag: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + compile_batch_size = batch_size + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_1.pt") + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{compile_batch_size}_compile_1_static.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{compile_batch_size}_compile_1_haveshape.ts") + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + # trace + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, 1280, math.ceil(sample_size / 2), math.ceil(sample_size / 2)], dtype=torch.float32), + ) + unet = UnetExport(unet_model) + unet.eval() + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + # compile + if flag == 0: + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [mindietorch.Input((compile_batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input( + (compile_batch_size, 1280, math.ceil(sample_size / 2), + math.ceil(sample_size / 2)), + dtype=mindietorch.dtype.FLOAT)] + compile_unet_skip(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((compile_batch_size, in_channels, 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input( + (compile_batch_size, 1280, math.ceil(1024 // 8 / 2), + math.ceil(1024 // 8 / 2)), + dtype=mindietorch.dtype.FLOAT), + ] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((compile_batch_size, in_channels, 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input( + (compile_batch_size, 1280, math.ceil(512 // 8 / 2), + math.ceil(512 // 8 / 2)), + dtype=mindietorch.dtype.FLOAT), + ] + inputs.append(inputs_gear_2) + compile_unet_skip(model, inputs, unet_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +def export_unet_cache(sd_pipeline: StableDiffusionXLPipeline, save_dir: str, + batch_size: int, flag: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_0.pt") + compile_batch_size = batch_size + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{compile_batch_size}_compile_0_static.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{compile_batch_size}_compile_0.ts") + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32), + torch.zeros([1], dtype=torch.int64), + ) + unet = UnetExport(unet_model) + unet.eval() + + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + # compile + if flag == 0: + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [mindietorch.Input((compile_batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)] + compile_unet_cache(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((compile_batch_size, + in_channels, 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + ] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((compile_batch_size, + in_channels, 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((compile_batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, + encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((compile_batch_size, 6), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + ] + inputs.append(inputs_gear_2) + compile_unet_cache(model, inputs, unet_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model): + super().__init__() + self.vae_model = vae_model + + def forward(self, latents): + return self.vae_model.decoder(latents) + + +def export_vae(sd_pipeline: StableDiffusionXLPipeline, save_dir: str, batch_size: int, flag: int) -> None: + print("Exporting the image decoder...") + + compile_batch_size = batch_size + + vae_path = os.path.join(save_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + vae_compiled_static_path = os.path.join(vae_path, f"vae_bs{compile_batch_size}_compile_static.ts") + vae_compiled_path = os.path.join(vae_path, f"vae_bs{compile_batch_size}_compile.ts") + + vae_model = sd_pipeline.vae + unet_model = sd_pipeline.unet + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.out_channels + # trace + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32) + vae_export = VaeExport(vae_model) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + # compile + if flag == 0: + if not os.path.exists(vae_compiled_static_path): + model = torch.jit.load(vae_pt_path).eval() + # 静态shape + inputs = [ + mindietorch.Input((compile_batch_size, in_channels, sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(vae_compiled_path): + # 动态分档 + model = torch.jit.load(vae_pt_path).eval() + inputs = [] + inputs_gear_1 = [mindietorch.Input((compile_batch_size, in_channels, + 1024 // 8, 1024 // 8), + dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_1) + inputs_gear_2 = [mindietorch.Input((compile_batch_size, in_channels, + 512 // 8, 512 // 8), + dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear_2) + + compile_vae(model, inputs, vae_compiled_path, soc_version) + else: + print("This operation is not supported!") + + +def export(model_path: str, save_dir: str, batch_size: int, steps: int, guidance_scale: float, use_cache: bool, + flag: int) -> None: + pipeline = StableDiffusionXLPipeline.from_pretrained(model_path).to("cpu") + + export_clip(pipeline, save_dir, batch_size, flag) + export_vae(pipeline, save_dir, batch_size, flag) + + if use_cache: + # 单卡带unetcache + export_unet_cache(pipeline, save_dir, batch_size * 2, flag) + export_unet_skip(pipeline, save_dir, batch_size * 2, flag) + else: + # 单卡不带unetcache + export_unet_init(pipeline, save_dir, batch_size * 2, flag) + + # 单卡 + export_ddim(pipeline, save_dir, steps, guidance_scale, batch_size * 2, flag) + + +def main(): + args = parse_arguments() + export(args.model, + args.output_dir, + args.batch_size, + args.steps, + args.guidance_scale, + args.use_cache, + args.flag) + print("Done.") + mindietorch.finalize() + + +if __name__ == "__main__": + min_batch, max_batch = 1, 32 + min_height, max_height = 512 // 8, 1024 // 8 + min_width, max_width = 512 // 8, 1664 // 8 + args = parse_arguments() + mindietorch.set_device(args.device) + if args.soc == "Duo": + soc_version = "Ascend310P3" + elif args.soc == "A2": + soc_version = "Ascend910B4" + else: + print("unsupport soc_version, please check!") + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/prompts.txt b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..f4e68659a19f1e0cc51f258294037d400fbc2d17 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/prompts.txt @@ -0,0 +1,2 @@ +a red cat playing with a (ball)1.5 +a red cat playing with a (ball)0.1 \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/requirements.txt b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3fe7a7de5e96ff7d66bdfca8e77a98adf5eaef73 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/requirements.txt @@ -0,0 +1,6 @@ +setuptools==57.5.0 +torch==2.1.0 +diffusers==0.26.3 +transformers==4.26.1 +open_clip_torch==2.20.0 +compel==2.0.2 \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusion_attention_patch.py b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusion_attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c1b369bb23389b6abc12c710f63e5b986b836d --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusion_attention_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusion_clip_patch.py b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusion_clip_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae2a2c4dca8d774982e69323343eb48be008e43 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusion_clip_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import transformers + + +def main(): + transformers_path = transformers.__path__ + transformers_version = transformers.__version__ + + assert transformers_version is not '4.26.1', "expectation transformers==4.26.1" + os.system(f'patch -p0 {transformers_path[0]}/models/clip/modeling_clip.py clip.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusionxl_pipeline_prompt_weight.py b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusionxl_pipeline_prompt_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..e8412ff839033cb6fbf236a6420bedd6034ac688 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusionxl_pipeline_prompt_weight.py @@ -0,0 +1,1003 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +from typing import Callable, List, Optional, Union +import numpy as np + +import torch +import mindietorch +from diffusers import StableDiffusionXLPipeline +from diffusers.loaders import TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin +from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, SASolverScheduler +from diffusers.image_processor import PipelineImageInput + +from compel import Compel, ReturnedEmbeddingsType +from diffusers.utils import make_image_grid, USE_PEFT_BACKEND + +from mindietorch import _enums + +clip_time = 0 +unet_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class PromptLoader: + def __init__(self, prompt_file: str, prompt_file_type: str, batch_size: int, + num_images_per_prompt: int = 1, max_num_prompts: int = 0 + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file, max_num_prompts) + + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file, max_num_prompts) + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = {'prompts': [], 'catagories': [], 'save_names': [], 'n_prompts': self.batch_size,} + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + + +class AIEStableDiffusionXLPipeline(StableDiffusionXLPipeline): + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0 = self.args.device[0] + else: + self.device_0 = args.device + + def compile_aie_model(self): + if self.is_init: + return + + in_channels = self.unet.config.out_channels + sample_size = self.unet.config.sample_size + encoder_hidden_size_2 = self.text_encoder_2.config.hidden_size + encoder_hidden_size = self.text_encoder.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = self.text_encoder.config.max_position_embeddings + + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + + if self.args.flag == 0 or self.args.flag == 1: + if self.args.flag == 0: + tail = "_static" + elif self.args.flag == 1: + tail = "" + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + if not self.args.use_cache: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_compile{tail}.ts") + self.compiled_unet_model = torch.jit.load(unet_compiled_path).eval() + if self.args.use_cache: + unet_skip_compiled_path = os.path.join(self.args.output_dir, + f"unet/unet_bs{batch_size}_compile_1{tail}.ts") + self.compiled_unet_model_skip = torch.jit.load(unet_skip_compiled_path).eval() + + unet_cache_compiled_path = os.path.join(self.args.output_dir, + f"unet/unet_bs{batch_size}_compile_0{tail}.ts") + self.compiled_unet_model_cache = torch.jit.load(unet_cache_compiled_path).eval() + else: + print("This operation is not supported!") + + self.is_init = True + + def encode_prompt(self, prompt: str, prompt_2: Optional[str] = None, num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.compiled_clip_model, self.compiled_clip_model_2] if self.compiled_clip_model is not None + else [self.compiled_clip_model_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + text_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, + truncation=True, return_tensors="pt",) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1]) + + # We are only ALWAYS interested in the pooled output of the final text encoder + global clip_time + start = time.time() + prompt_embeds_npu = text_encoder(text_input_ids.to(f'npu:{self.device_0}')) + + pooled_prompt_embeds = prompt_embeds_npu[0].to('cpu') + clip_time += time.time() - start + + if clip_skip is None: + prompt_embeds = prompt_embeds_npu[2][-2].to('cpu') + + else: + # "2" because SDXL always indexes from the penultimate layer.????待定 + prompt_embeds = prompt_embeds_npu.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer(negative_prompt, padding="max_length", max_length=max_length, + truncation=True, return_tensors="pt",) + + negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(f'npu:{self.device_0}'))[0].to('cpu') + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_prompt_embeds = [torch.from_numpy(text) for text in negative_prompt_embeds] + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def ascendie_infer(self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[dict[str, any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[tuple[int, int]] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: Optional[tuple[int, int]] = None, + negative_original_size: Optional[tuple[int, int]] = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: Optional[tuple[int, int]] = None, + clip_skip: Optional[int] = None, + skip_steps=None, + flag_ddim: int = None, + flag_cache: int = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + """ + global p1_time, p2_time, p3_time + start = time.time() + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + # 3. Encode input prompt + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + + ( + prompt_embeds, # [2,77,2048] + negative_prompt_embeds, # [2,77,2048] + pooled_prompt_embeds, # [2,1280] + negative_pooled_prompt_embeds, # [2,1280] + ) = self.encode_prompt( # nv + prompt=prompt, # None + prompt_2=prompt_2, # None + num_images_per_prompt=num_images_per_prompt, # 1 + do_classifier_free_guidance=do_classifier_free_guidance, # True + negative_prompt=negative_prompt, # None + negative_prompt_2=negative_prompt_2, # None + prompt_embeds=prompt_embeds, # [2,77,2048] + negative_prompt_embeds=negative_prompt_embeds, # None + pooled_prompt_embeds=pooled_prompt_embeds, # [2, 1280] + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, # None + lora_scale=lora_scale, # None + clip_skip=clip_skip, # None + ) + + p1_time += time.time() - start + start1 = time.time() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=self.text_encoder_2.config.projection_dim + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) # [4,1280] + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # [4,6] + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + # 8.1 Apply denoising_end + if ( + denoising_end is not None + and isinstance(denoising_end, float) + and denoising_end > 0 + and denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + cache = None + global unet_time + global vae_time + + skip_flag = torch.ones([1], dtype=torch.long) + cache_flag = torch.zeros([1], dtype=torch.long) + + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + if latents.is_cpu: + latent_model_input = torch.cat([latents] * 2) # latent_model_input:[4,4,128,128] + + # latent_model_input:[4,4,128,128] + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t).to(f'npu:{self.device_0}') + + start = time.time() + if flag_cache: + if skip_steps[i]: + noise_pred = self.compiled_unet_model_skip(latent_model_input, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + add_text_embeds.to(f'npu:{self.device_0}'), + add_time_ids.to(f'npu:{self.device_0}'), + skip_flag.to(f'npu:{self.device_0}'), + cache, ) + else: + outputs = self.compiled_unet_model_cache(latent_model_input, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + add_text_embeds.to(f'npu:{self.device_0}'), + add_time_ids.to(f'npu:{self.device_0}'), + cache_flag.to(f'npu:{self.device_0}'), + ) + noise_pred = outputs[0] + cache = outputs[1] + else: + + noise_pred = self.compiled_unet_model(latent_model_input, # [4,4,128,128] + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), # [1] + prompt_embeds.to(f'npu:{self.device_0}'), # [4,77,2048] + add_text_embeds.to(f'npu:{self.device_0}'), # [4,1280] + add_time_ids.to(f'npu:{self.device_0}')) # [4,6] + unet_time += time.time() - start + + # perform guidance + if do_classifier_free_guidance: + if flag_ddim: + if not noise_pred.is_cpu: + noise_pred = noise_pred.to('cpu') + x = np.array(i, dtype=np.int64) + y = torch.from_numpy(x).long() + + latents = self.compiled_scheduler( + noise_pred.to(f'npu:{self.device_0}'), + t[None].to(torch.int64).to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + y[None].to(f'npu:{self.device_0}')).to('cpu') + else: + noise_pred = noise_pred.to('cpu') + noise_pred, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred + guidance_scale * (noise_pred_text - + noise_pred) + + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + p2_time = time.time() - start1 + start2 = time.time() + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + latents = latents / self.vae.config.scaling_factor + + start = time.time() + latents = self.vae.post_quant_conv(latents) + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to('cpu') + vae_time += time.time() - start + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if output_type == "pil": + image = self.numpy_to_pil(image) + + p3_time += time.time() - start2 + return (image, None) + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--soc", + choices=["A2"], + default="A2", + help="soc_version.", + ) + parser.add_argument( + "--scheduler", + choices=["DDIM", "Euler", "DPM", "SA-Solver"], + default="DDIM", + help="Type of Sampling methods. Can choose from DDIM, Euler, DPM, SA-Solver", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_steps", + type=str, + default="1,2,4,6,7,9,10,12,13,14,16,18,19,21,23,24,26,27,29,\ + 30,31,33,34,36,37,39,40,42,43,45,47,48,49", # 17+33 + help="Steps to use cache data." + ) + parser.add_argument( + "--flag", + type=int, + default=1, + choices=[0, 1], + help="0 is static; 1 is dynamic rankl.", + ) + parser.add_argument( + "--w_h", + type=int, + default=1024, + choices=[512, 1024], + help="0 is static; 1 is dynamic rankl.", + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusionXLPipeline.from_pretrained(args.model).to("cpu") + + flag_ddim = 0 + if args.scheduler == "DDIM": + flag_ddim = 1 + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "Euler": + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "DPM": + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "SA-Solver": + pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config) + + pipe.parser_args(args) + pipe.compile_aie_model() + mindietorch.set_device(args.device) + skip_steps = [0] * args.steps + flag_cache = 0 + if args.use_cache: + flag_cache = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + height = 1024 + width = 1024 + if args.w_h == 1024: + height = 1024 + width = 1024 + elif args.w_h == 512: + height = 512 + width = 512 + else: + print("This operation is not supported!") + + prompts_2 = "" + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompt = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompt}") + infer_num += args.batch_size + + start_time = time.time() + + stream = mindietorch.npu.Stream("npu:" + str(args.device)) + compel = Compel( + tokenizer=[pipe.tokenizer, pipe.tokenizer_2], + text_encoder=[pipe.text_encoder, pipe.text_encoder_2], + returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, + requires_pooled=[False, True] + ) + + # apply weights + conditioning, pooled = compel(prompt) # 耗时 + + # generate image + generator = [torch.Generator().manual_seed(33) for _ in range(len(prompt))] + + with mindietorch.npu.stream(stream): + images = pipe.ascendie_infer( + prompt_embeds=conditioning, + pooled_prompt_embeds=pooled, + generator=generator, + num_inference_steps=30, + height=height, + width=width, + ) + make_image_grid(images[0], rows=1, cols=1) + + images[0][0].save(os.path.join(save_dir, f"{i}.jpg")) + + use_time += time.time() - start_time + + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n" + f"clip time: {clip_time / infer_num:.3f}s\n" + f"unet time: {unet_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n") + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusionxl_unet_patch.py b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusionxl_unet_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..112dd0c88e73617f4f239235afc3a704f9b9052f --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/stable_diffusionxl_unet_patch.py @@ -0,0 +1,28 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/unets/unet_2d_condition.py unet_2d_condition.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/unet_2d_condition.patch b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/unet_2d_condition.patch new file mode 100644 index 0000000000000000000000000000000000000000..88c9cb09510b12e4bf796c7ff3ccaf712b411b28 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL-PromptWeight/unet_2d_condition.patch @@ -0,0 +1,244 @@ +--- ../mindie-sdxl/unet_2d_condition_bak.py 2024-04-17 07:41:03.000000000 +0000 ++++ ../mindie-sdxl/unet_2d_condition.py 2024-04-17 07:41:05.000000000 +0000 +@@ -855,6 +855,8 @@ + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, ++ if_skip: int = 0, ++ inputCache: torch.FloatTensor = None, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. +@@ -1110,29 +1112,56 @@ + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + +- down_block_res_samples = (sample,) +- for downsample_block in self.down_blocks: +- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: +- # For t2i-adapter CrossAttnDownBlock2D +- additional_residuals = {} +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) +- +- sample, res_samples = downsample_block( +- hidden_states=sample, +- temb=emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- **additional_residuals, +- ) +- else: +- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- sample += down_intrablock_additional_residuals.pop(0) +- +- down_block_res_samples += res_samples ++ if not if_skip: ++ down_block_res_samples = (sample,) ++ for downsample_block in self.down_blocks: ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) ++ ++ down_block_res_samples += res_samples ++ else: ++ down_block_res_samples = (sample,) ++ for tmp, downsample_block in enumerate(self.down_blocks): ++ if tmp >= 2: ++ break ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) ++ ++ down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () +@@ -1146,61 +1175,93 @@ + down_block_res_samples = new_down_block_res_samples + + # 4. mid +- if self.mid_block is not None: +- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: +- sample = self.mid_block( +- sample, +- emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- ) +- else: +- sample = self.mid_block(sample, emb) +- +- # To support T2I-Adapter-XL +- if ( +- is_adapter +- and len(down_intrablock_additional_residuals) > 0 +- and sample.shape == down_intrablock_additional_residuals[0].shape +- ): +- sample += down_intrablock_additional_residuals.pop(0) ++ if not if_skip: ++ if self.mid_block is not None: ++ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: ++ sample = self.mid_block( ++ sample, ++ emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ else: ++ sample = self.mid_block(sample, emb) ++ ++ # To support T2I-Adapter-XL ++ if ( ++ is_adapter ++ and len(down_intrablock_additional_residuals) > 0 ++ and sample.shape == down_intrablock_additional_residuals[0].shape ++ ): ++ sample += down_intrablock_additional_residuals.pop(0) ++ else: ++ sample = inputCache + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up +- for i, upsample_block in enumerate(self.up_blocks): +- is_final_block = i == len(self.up_blocks) - 1 +- +- res_samples = down_block_res_samples[-len(upsample_block.resnets) :] +- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] +- +- # if we have not reached the final block and need to forward the +- # upsample size, we do it here +- if not is_final_block and forward_upsample_size: +- upsample_size = down_block_res_samples[-1].shape[2:] +- +- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: +- sample = upsample_block( +- hidden_states=sample, +- temb=emb, +- res_hidden_states_tuple=res_samples, +- encoder_hidden_states=encoder_hidden_states, +- cross_attention_kwargs=cross_attention_kwargs, +- upsample_size=upsample_size, +- attention_mask=attention_mask, +- encoder_attention_mask=encoder_attention_mask, +- ) +- else: +- sample = upsample_block( +- hidden_states=sample, +- temb=emb, +- res_hidden_states_tuple=res_samples, +- upsample_size=upsample_size, +- scale=lora_scale, +- ) ++ if not if_skip: ++ for i, upsample_block in enumerate(self.up_blocks): ++ is_final_block = i == len(self.up_blocks) - 1 ++ ++ res_samples = down_block_res_samples[-len(upsample_block.resnets) :] ++ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] ++ ++ # if we have not reached the final block and need to forward the ++ # upsample size, we do it here ++ if not is_final_block and forward_upsample_size: ++ upsample_size = down_block_res_samples[-1].shape[2:] ++ ++ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ encoder_hidden_states=encoder_hidden_states, ++ cross_attention_kwargs=cross_attention_kwargs, ++ upsample_size=upsample_size, ++ attention_mask=attention_mask, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ else: ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ upsample_size=upsample_size, ++ scale=lora_scale, ++ ) ++ ++ if (not if_skip) and (i == 0): ++ inputCache = sample ++ ++ else: ++ ++ for i, upsample_block in enumerate(self.up_blocks): ++ if i==1: ++ res_samples = down_block_res_samples[-4:-1] ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ encoder_hidden_states=encoder_hidden_states, ++ cross_attention_kwargs=cross_attention_kwargs, ++ upsample_size=upsample_size, ++ attention_mask=attention_mask, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ if i==2: ++ res_samples = down_block_res_samples[:3] ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ upsample_size=upsample_size, ++ scale=lora_scale, ++ ) + + # 6. post-process + if self.conv_norm_out: +@@ -1215,4 +1276,7 @@ + if not return_dict: + return (sample,) + +- return UNet2DConditionOutput(sample=sample) ++ if (not if_skip): ++ return (sample, inputCache) ++ else: ++ return UNet2DConditionOutput(sample=sample) diff --git a/MindIE/MultiModal/StableDiffusion-XL/README.md b/MindIE/MultiModal/StableDiffusion-XL/README.md new file mode 100644 index 0000000000000000000000000000000000000000..952279c3a06c1a4aeb26b12dbbd89cc77170db36 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/README.md @@ -0,0 +1,840 @@ +# stable-diffusionxl模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + + SDXL 由一组用于潜在扩散的专家管道组成: 在第一步中,使用基础模型生成(噪声)潜伏, 然后使用专门用于最终降噪步骤的细化模型[此处获得](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/) + +- 参考实现: + ```bash + # StableDiffusionxl + https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1或2 +Atlas 300I Duo推理卡:支持的卡数为1,可双芯并行 + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | prompt | 1 x 77 | INT64| ND| + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | output1 | 1 x 3 x 1024 x 1024 | FLOAT32 | NCHW | + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ |--------| ------------------------------------------------------------ | + | Python | 3.10.13 | - | + | torch| 2.1.0 | - | + +该模型性能受CPU规格影响,建议使用64核CPU(arm)以复现性能 + + +# 快速上手 + +## 获取源码 + +1. 安装依赖。 + ```bash + pip3 install -r requirements.txt + + # 若要使用hpsv2验证精度,则还需要按照以下步骤安装hpsv2 + git clone https://github.com/tgxs002/HPSv2.git + cd HPSv2 + pip3 install -e . + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + + 执行命令: + + ```bash + # 若环境没有patch工具,请自行安装 + ``` + + ```bash + python3 stable_diffusion_attention_patch.py + ``` + + ```bash + # 若使用unetCache + python3 stable_diffusionxl_unet_patch.py + ``` + +## 准备数据集 + +1. 获取原始数据集。 + + 本模型输入文本信息生成图片,无需数据集。 + + +## 模型推理 + +1. 模型转换。 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 0. 获取权重(可选) + + 可提前下载权重,放到代码同级目录下,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # xl + git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 + ``` + + 1. 导出pt模型并进行编译。(可选) + + ```bash + # xl (执行时下载权重) + model_base="stabilityai/stable-diffusion-xl-base-1.0" + + xl (使用上一步下载的权重) + model_base="./stable-diffusion-xl-base-1.0" + ``` + 执行命令: + + ```bash + # 使用unetCache, 非并行 + python3 export_ts.py --model ${model_base} --output_dir ./models --use_cache --batch_size 1 --flag 0 --soc A2 --device 0 + + # 使用unetCache, 并行 + python3 export_ts.py --model ${model_base} --output_dir ./models --use_cache --parallel --batch_size 1 --flag 0 --soc Duo --device 0 + + ``` + 参数说明: + - --model:模型权重路径 + - --output_dir: 存放导出模型的路径 + - --use_cache: 【可选】推荐在推理过程中使用unetCache策略 + - --parallel: 【可选】导出适用于并行方案的模型, 当前仅带unetCache优化时,支持并行 + - --batch_size: 设置batch_size, 默认值为1, 当前最大支持batch_size=2 + - --flag:默认为0。0代表静态,只支持分辨率为1024x1024;1代表动态分档,支持的分辨率为1024x1024和512x512;2代表动态shape,height的范围为[512, 1024],width的范围是[512, 1664]。 + - --soc:只支持Duo和A2。 + - --device:推理设备ID + +2. 开始推理验证。 + + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + # 不使用unetCache策略 + numactl -C 0-23 python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用UnetCache策略 + numactl -C 0-23 python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results_unetCache \ + --steps 50 \ + --output_dir ./models \ + --use_cache \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用UnetCache策略,同时使用双卡并行策略 + numactl -C 0-23 python3 stable_diffusionxl_pipeline_cache_parallel.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0,1 \ + --save_dir ./results_unetCache_parallel \ + --steps 50 \ + --output_dir ./models \ + --use_cache \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + ``` + + 参数说明: + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。 + - --save_dir:生成图片的存放目录。 + - --batch_size:模型batch size。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + - --use_cache: 【可选】推荐在推理过程中使用unetCache策略。 + - --flag:默认为0。0代表静态,只支持分辨率为1024x1024;1代表动态分档,支持的分辨率为1024x1024和512x512;2代表动态shape,height的范围为[512, 1024],width的范围是[512, 1664]。**注意**:请与导出模型时设置的flag保持一致 + - --height:与flag标志位对应的height一致 + - --width:与flag标志位对应的width一致 + + 不带unetCache策略,执行完成后在`./results`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 带unetCache策略,执行完成后在`./results_unetCache`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + 带unetCache策略,同时使用双卡并行策略,执行完成后在`./results_unetCache_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + +## 精度验证 + + 由于生成的图片存在随机性,提供两种精度验证方法: + 1. CLIP-score(文图匹配度量):评估图片和输入文本的相关性,分数的取值范围为[-1, 1],越高越好。使用Parti数据集进行验证。 + 2. HPSv2(图片美学度量):评估生成图片的人类偏好评分,分数的取值范围为[0, 1],越高越好。使用HPSv2数据集进行验证 + + 注意,由于要生成的图片数量较多,进行完整的精度验证需要耗费很长的时间。 + + 1. 下载Parti数据集 + + ```bash + wget https://raw.githubusercontent.com/google-research/parti/main/PartiPrompts.tsv --no-check-certificate + ``` + + 2. 下载模型权重 + + ```bash + # Clip Score和HPSv2均需要使用的权重 + # 安装git-lfs + apt install git-lfs + git lfs install + + # Clip Score权重 + git clone https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K + + # HPSv2权重 + wget https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt --no-check-certificate + ``` + 也可手动下载[权重](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/open_clip_pytorch_model.bin) + 将权重放到`CLIP-ViT-H-14-laion2B-s32B-b79K`目录下,手动下载[HPSv2权重](https://huggingface.co/spaces/xswu/HPSv2/resolve/main/HPS_v2_compressed.pt)放到当前路径 + + 3. 使用推理脚本读取Parti数据集,生成图片 + + ```bash + # 不使用unetCache策略 + python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用UnetCache策略 + python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts_unetCache \ + --steps 50 \ + --output_dir ./models \ + --use_cache \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + + # 使用UnetCache策略,同时使用双卡并行策略 + python3 stable_diffusionxl_pipeline_cache_parallel.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0,1 \ + --save_dir ./results_PartiPrompts_unetCache_parallel \ + --steps 50 \ + --output_dir ./models \ + --use_cache \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + ``` + + 参数说明: + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - --num_images_per_prompt: 每个prompt生成的图片数量。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - --max_num_prompts:限制prompt数量为前X个,0表示不限制。 + - --save_dir:生成图片的存放目录。 + - --batch_size:模型batch size。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + + 不带unetCache策略,执行完成后在`./results_PartiPrompts`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 带unetCache策略,执行完成后在`./results_PartiPrompts_unetCache`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + 带unetCache策略,同时使用双卡并行策略,执行完成后在`./results_PartiPrompts_unetCache_parallel`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + + 4. 计算精度指标 + + 1. CLIP-score + + ```bash + python3 clip_score.py \ + --device=cpu \ + --image_info="image_info.json" \ + --model_name="ViT-H-14" \ + --model_weights_path="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --device: 推理设备。 + - --image_info: 上一步生成的`image_info.json`文件。 + - --model_name: Clip模型名称。 + - --model_weights_path: Clip模型权重文件路径。 + + 执行完成后会在屏幕打印出精度计算结果。 + + 2. HPSv2 + + ```bash + python3 hpsv2_score.py \ + --image_info="image_info.json" \ + --HPSv2_checkpoint="./HPS_v2_compressed.pt" \ + --clip_checkpoint="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --image_info: 上一步生成的`image_info.json`文件。 + - --HPSv2_checkpoint: HPSv2模型权重文件路径。 + - --clip_checkpointh: Clip模型权重文件路径。 + + 执行完成后会在屏幕打印出精度计算结果。 + + +## 量化功能【可选】 + +可使用W8A8量化功能提升性能,但可能导致精度下降。默认batch_size为1,默认分辨率为1024x1024,可支持batch_size为2、分辨率为512x512的场景(修改第4. 5.步参数即可) + + 1. 导出浮点pt模型并进行编译。 + + ```bash + # 使用unetCache, 非并行 + python3 export_ts.py --model ${model_base} --output_dir ./models --use_cache --flag 0 --soc A2 --device 0 + + # 不使用unetCache, 非并行 + python3 export_ts.py --model ${model_base} --output_dir ./models --flag 0 --soc A2 --device 0 + ``` + + 2. 量化编译。./quant/build.sh中的TorchPath需要指定为python安装torch的路径。 + + 执行命令: + + ```bash + cd quant + bash build.sh + ``` + + 3. 导出浮点unet模型的输入。执行完毕后会在当前路径下生成unet_data.npy文件。 + + 执行命令: + + ```bash + # 若使用UnetCache策略 + python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results_temp \ + --steps 50 \ + --output_dir ./models \ + --use_cache \ + --flag 0 \ + --save_unet_input + # 若不使用UnetCache策略 + python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results_temp \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --save_unet_input + ``` + + 4. 导出量化pt模型并进行编译。 + + 执行命令: + + ```bash + # 若使用unetCache, 且非并行 + python3 export_ts_quant.py --model ${model_base} --output_dir ./models_quant --use_cache --batch_size 1 --soc A2 --device 0 --height 1024 --width 1024 + + # 若不使用unetCache, 且非并行 + python3 export_ts_quant.py --model ${model_base} --output_dir ./models_quant --batch_size 1 --soc A2 --device 0 --height 1024 --width 1024 + ``` + + 参数说明: + - --model:模型权重路径 + - --output_dir:存放导出模型的目录,执行完成后在`./models_quant`目录下生成量化模型。 + - --batch_size:默认batch_size为1(可支持batch_size=2的场景, 性能受影响) + - --height:默认分辨率为1024x1024(可支持512x512的场景, 性能受影响) + - --width:默认分辨率为1024x1024(可支持512x512的场景, 性能受影响) + + 5. 开始推理验证。 + + 执行命令: + + ```bash + # 使用UnetCache策略,且非并行 + numactl -C 0-23 python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results_quant \ + --steps 50 \ + --output_dir ./models_quant \ + --flag 3 \ + --use_cache \ + --batch_size 1 \ + --height 1024 \ + --width 1024 \ + --quant + + # 不使用UnetCache策略,且非并行 + numactl -C 0-23 python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results_quant \ + --steps 50 \ + --output_dir ./models_quant \ + --flag 3 \ + --batch_size 1 \ + --height 1024 \ + --width 1024 \ + --quant + ``` + + 执行完成后在`./results_quant`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + + 6. 使用推理脚本读取Parti数据集,生成图片。 + + 执行命令: + + ```bash + # 使用UnetCache策略,且非并行 + numactl -C 0-23 python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts_quant_unetCache \ + --steps 50 \ + --output_dir ./models_quant \ + --flag 3 \ + --use_cache \ + --batch_size 1 \ + --height 1024 \ + --width 1024 \ + --quant + + # 不使用UnetCache策略,且非并行 + numactl -C 0-23 python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts_quant \ + --steps 50 \ + --output_dir ./models_quant \ + --flag 3 \ + --batch_size 1 \ + --height 1024 \ + --width 1024 \ + --quant + ``` + + 参数说明: + - --prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + - --num_images_per_prompt: 每个prompt生成的图片数量。注意使用hpsv2时,设置num_images_per_prompt=1即可。 + + 不带unetCache策略,执行完成后在`./results_PartiPrompts_quant`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系,并在终端显示推理时间。 + 带unetCache策略,执行完成后在`./results_PartiPrompts_quant_unetCache`目录下生成推理图片,在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。并在终端显示推理时间。 + + 计算精度指标CLIP-score和HPSv2同浮点。 + +## Lora热切换 +### Lora热切换功能使用 + 1. lora热切使用准备 + + 1. 获取权重 + + sdxl权重地址: + ```bash + https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 + ``` + sdxl lora权重地址: + ```bash + https://huggingface.co/latent-consistency/lcm-lora-sdxl + ``` + 2. 代码修改 + + 基础补丁 + + ```bash + python3 stable_diffusion_attention_patch.py + ``` + + ```bash + # 若使用unetCache + python3 stable_diffusionxl_unet_patch.py + ``` + + lora补丁 + + ```bash + python3 stable_diffusionxl_lora_patch.py + ``` + 2. 模型转换 + + 设置基础模型路径以及unet基础权重保存路径: + ```bash + # 上一步下载的模型路径 + export model_base="./stable-diffusion-xl-base-1.0" + # unet基础权重保存路径 + export baselora_path="./baselora" + ``` + + 导入环境变量: + ```bash + export MINDIE_TORCH_ENABLE_RUNTIME_BUFFER_MUTATION=true + ``` + + 执行模型转换: + ```bash + #基础模型lora热切特性转换: + python3 export_ts.py --model ${model_base} --output_dir ./models --batch_size 1 --flag 0 --soc A2 --device 0 --lorahot_support --baselora_path ${baselora_path} + #unetcahche版模型转换: + python3 export_ts.py --model ${model_base} --output_dir ./models --use_cache --batch_size 1 --flag 0 --soc A2 --device 0 --lorahot_support --baselora_path ${baselora_path} + ``` + 参数说明: + - --model 下载的模型权重路径 + - --output_dir 转换后的模型输出路径 + - --batch_size 设置batch_size, 默认值为1, 当前最大支持batch_size=2 + - --flag:默认为0。0代表静态,只支持分辨率为1024x1024;1代表动态分档,支持的分辨率为1024x1024和512x512;2代表动态shape,height的范围为[512, 1024],width的范围是[512, 1664]。 + - --soc:只支持Duo和A2。 + - --device:推理设备ID + - --lorahot_support:生成模型支持Lora热切换功能 + - --baselora_path:仅指定lorahot_support时生效,代表Unet基础权重的保存路径,用于后续Lora权重热切换 + 3. 推理验证 + + 设置lora权重路径: + ```bash + # 第一步下载的lora权重路径 + export newlora_path="./lora_weight" + ``` + 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 执行推理 + ```bash + #基础模型使用lora热切换功能推理 + numactl -C 0-23 python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --use_loraHotswitch \ + --lorabase_weight ${baselora_path} \ + --loranew_weight ${newlora_path} + #Unetcache优化接入后lora热切换功能推理 + numactl -C 0-23 python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --prompt_file_type plain \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --use_cache \ + --use_loraHotswitch \ + --lorabase_weight ${baselora_path} \ + --loranew_weight ${newlora_path} + ``` + 参数说明 + - --model:模型权重路径。 + - --output_dir:存放导出模型的目录。 + - --prompt_file:提示词文件。 + - --prompt_file_type: prompt文件类型,用于指定读取方式,可选plain,parti,hpsv2。 + - --save_dir:生成图片的存放目录。 + - --batch_size:模型batch size。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID + - --use_cache: 推理过程中使用unetCache策略。 + - --flag:默认为0。0代表静态,只支持分辨率为1024x1024;1代表动态分档,支持的分辨率为1024x1024和512x512;2代表动态shape,height的范围为[512, 1024],width的范围是[512, 1664]。**注意**:请与导出模型时设置的flag保持一致 + - --height:与flag标志位对应的height一致 + - --width:与flag标志位对应的width一致 + - --use_loraHotswitch: 代表是否有Lora热切换功能启用 + - --lorabase_weight: 基础的Unet权重存储路径 + - --loranew_weight:第一步下载的lora权重路径 + ### Lora热切换功能精度验证 + + 通过比较冷切模型与热切模型对于同一prompt的出图余弦相似度来衡量热切方法的精度 + + 1. 模型准备: + 1. 冷切模型准备 + ```bash + #导入融合lora模型权重环境便令 + export model_new="融合后模型保存路径" + #执行如下命令进行权重融合 + python3 convert_lora_safetensors_to_diffusers.py --base_model_path ${model_base} \ + --checkpoint_path ${newlora_path} --dump_path ${model_new} + ``` + 参数说明: + - --base_model_path:基础sdxl模型权重路径 + - --checkpoint_path:下载的lora权重路径 + - --dump_path:权重融合后的模型输出路径 + + 此后运行如下命令生成新权重的pt模型: + ```bash + #不使用unetcache + python3 export_ts.py --model ${model_new} --output_dir ./models --batch_size 1 --flag 0 --soc A2 --device 0 + #使用unetcache + python3 export_ts.py --model ${model_new} --output_dir ./models --batch_size 1 --flag 0 --soc A2 --device 0 --use_cache + ``` + 2. 热切模型准备 + + 参照"Lora热切换功能使用"章节进行模型导出 + 3. 精度验证用clip网络: + ```bash + https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K + ``` + 2. 准备精度衡量数据集: + + 下载parti数据集: + ```bash + wget https://raw.githubusercontent.com/google-research/parti/main/PartiPrompts.tsv --no-check-certificate + ``` + 3. 执行推理: + + 不使用Unetcache + ```bash + #冷切模型推理: + python3 stable_diffusionxl_pipeline.py \ + --model ${model_new} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 1 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts_wolorahot \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --info_file_save_path ./coldModel.json + #热切模型推理: + python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 1 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts_lorahot \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + --use_loraHotswitch \ + --lorabase_weight ${baselora_path} \ + --loranew_weight ${newlora_path} \ + --info_file_save_path ./hotModel.json + ``` + 使用Unetcache + ```bash + #冷切模型推理: + python3 stable_diffusionxl_pipeline.py \ + --model ${model_new} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 1 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts_wolorahot \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 \ + --use_cache \ + --info_file_save_path ./coldModel.json + #热切模型推理: + python3 stable_diffusionxl_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 1 \ + --max_num_prompts 0 \ + --device 0 \ + --save_dir ./results_PartiPrompts_lorahot \ + --steps 50 \ + --output_dir ./models \ + --flag 0 \ + --height 1024 \ + --width 1024 \ + --batch_size 1 + --use_cache \ + --use_loraHotswitch \ + --lorabase_weight ${baselora_path} \ + --loranew_weight ${newlora_path} \ + --info_file_save_path ./hotModel.json + ``` + 新增参数说明: + + - --info_file_save_path:推理任务完成后图像与promt的对应关系会以json文件存储,此参数指定json文件存储路径与存储名称 + 4. 执行精度验证脚本: + ```bash + python3 lorahot_score.py \ + --device=cpu \ + --image_info_wo_lorahot = "image_info_wo_lorahot.json" \ + --image_info_lorahot = "image_info_lorahot.json" \ + --model_name="ViT-H-14" \ + --model_weights_path="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --device:Clip网络推理设备 + - --image_info_wo_lorahot:上一步生成的无离线融合模型推理结果 + - --image_info_lorahot:上一步Lora热切换后模型推理结果 + - --model_name:Clip模型结果 + - --model_weights_path:Clip模型权重 + +# 模型推理性能&精度 + +调用ACL接口推理计算,性能参考下列数据。 + +### StableDiffusionxl + +| 硬件形态 | cpu规格 | batch size | 迭代次数 | 优化手段 | 平均耗时 | 精度 | 采样器 | +| :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | +| Atlas 800I A2(8*32G) | 64核(arm) | 1 | 50 | with UnetCache, w/o 量化 | 4s | clip score 0.376 | ddim | +| Atlas 800I A2(8*32G) | 64核(arm) | 1 | 50 | with UnetCache, with 量化 | 3.6s | clip score 0.371 | ddim | + +性能测试需要独占npu和cpu \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/attention_lora.patch b/MindIE/MultiModal/StableDiffusion-XL/attention_lora.patch new file mode 100644 index 0000000000000000000000000000000000000000..ae2fa8a1f6cd9ea7cfe7fa5151a34889c372e493 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/attention_lora.patch @@ -0,0 +1,12 @@ +--- attention_processor.py 2024-09-11 19:51:35.404742200 +0800 ++++ attention_processor.py 2024-09-11 19:56:20.519589500 +0800 +@@ -761,6 +761,9 @@ + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + ++ hidden_states = torch.unsqueeze(hidden_states, 0) ++ hidden_states = torch.squeeze(hidden_states, 0) ++ + # linear proj + hidden_states = attn.to_out[0](hidden_states, *args) + # dropout diff --git a/MindIE/MultiModal/StableDiffusion-XL/attention_processor.patch b/MindIE/MultiModal/StableDiffusion-XL/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..bd15281c5a3acf9752eec8a239323f66f1beadb7 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/attention_processor.patch @@ -0,0 +1,18 @@ +--- attention_processor.py 2024-07-02 07:42:32.312000000 +0000 ++++ attention_processor.py 2024-07-02 07:44:55.100000000 +0000 +@@ -205,10 +205,11 @@ + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- if processor is None: +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # if processor is None: ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor = AttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( diff --git a/MindIE/MultiModal/StableDiffusion-XL/background_runtime_cache.py b/MindIE/MultiModal/StableDiffusion-XL/background_runtime_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..a315d13cc1e3f88992154c465affec22747232f7 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/background_runtime_cache.py @@ -0,0 +1,194 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +@dataclass +class RuntimeIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntime: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfo + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray], skip) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + if skip: + self.sync_pipe.send('skip') + else: + self.sync_pipe.send('cache') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfo, + device_id: int, + model_path: list, + ) -> None: + # The sub process function + # Create a runtime + + # Tell the main function that we are ready + model_cache = torch.jit.load(model_path[0]).eval() + model_skip = torch.jit.load(model_path[1]).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + mindietorch.set_device(device_id) + # Tell the main function that we are ready + sync_pipe.send('') + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + return_cache = None + + # Keep looping until recived a 'STOP' + while True: + flag = sync_pipe.recv() + start = time.time() + if flag == 'STOP': + break + + if flag == 'cache': + sample, timestep, prompt_embeds, add_text_embeds, add_time_ids, return_flag = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + else: + sample, timestep, prompt_embeds, add_text_embeds, add_time_ids, return_flag = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + + sample_npu = sample.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + prompt_embeds_npu = prompt_embeds.to(torch.float32).to(f"npu:{device_id}") + add_text_embeds_npu = add_text_embeds.to(torch.float32).to(f"npu:{device_id}") + add_time_ids_npu = add_time_ids.to(torch.float32).to(f"npu:{device_id}") + flag_npu = return_flag.to(torch.int64).to(f"npu:{device_id}") + + if flag == 'cache': + with mindietorch.npu.stream(stream): + output_npu = model_cache(sample_npu, timestep_npu, prompt_embeds_npu, add_text_embeds_npu, + add_time_ids_npu, flag_npu) + stream.synchronize() + + output_cpu0 = output_npu[0].to('cpu') + output0 = output_cpu0.numpy() + output_arrays[0][:] = output0 + + return_cache = output_npu[1] + else: + with mindietorch.npu.stream(stream): + output_npu = model_skip(sample_npu, timestep_npu, prompt_embeds_npu, add_text_embeds_npu, + add_time_ids_npu, flag_npu, return_cache) + stream.synchronize() + + output_cpu0 = output_npu.to('cpu') + output0 = output_cpu0.numpy() + output_arrays[0][:] = output0 + + sync_pipe.send('') + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfo) -> 'BackgroundRuntime': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntime + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MultiModal/StableDiffusion-XL/clip_score.py b/MindIE/MultiModal/StableDiffusion-XL/clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..069f5d6e9a9baaa61b9a3537bcab6f637605858e --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/clip_score.py @@ -0,0 +1,140 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import time +import argparse + +import open_clip +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F + + +def clip_score(model_clip, tokenizer, preprocess, prompt, image_files, device): + imgs = [] + texts = [] + for image_file in image_files: + img = preprocess(Image.open(image_file)).unsqueeze(0).to(device) + imgs.append(img) + text = tokenizer([prompt]).to(device) + texts.append(text) + + img = torch.cat(imgs) # [bs, 3, 224, 224] + text = torch.cat(texts) # [bs, 77] + + with torch.no_grad(): + text_ft = model_clip.encode_text(text).float() + img_ft = model_clip.encode_image(img).float() + score = F.cosine_similarity(img_ft, text_ft).squeeze() + + return score.cpu() + + +def main(): + args = parse_arguments() + + if args.device is None: + device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + else: + device = torch.device(args.device) + + t_b = time.time() + print(f"Load clip model...") + model_clip, _, preprocess = open_clip.create_model_and_transforms( + args.model_name, pretrained=args.model_weights_path, device=device) + model_clip.eval() + print(f">done. elapsed time: {(time.time() - t_b):.3f} s") + + tokenizer = open_clip.get_tokenizer(args.model_name) + + with os.fdopen(os.open(args.image_info, os.O_RDONLY), "r") as f: + image_info = json.load(f) + + t_b = time.time() + print(f"Calc clip score...") + all_scores = [] + cat_scores = {} + + for i, info in enumerate(image_info): + image_files = info['images'] + category = info['category'] + prompt = info['prompt'] + + print(f"[{i + 1}/{len(image_info)}] {prompt}") + + image_scores = clip_score(model_clip, + tokenizer, + preprocess, + prompt, + image_files, + device) + if len(image_files) > 1: + best_score = max(image_scores) + else: + best_score = image_scores + + print(f"image scores: {image_scores}") + print(f"best score: {best_score}") + + all_scores.append(best_score) + if category not in cat_scores: + cat_scores[category] = [] + cat_scores[category].append(best_score) + print(f">done. elapsed time: {(time.time() - t_b):.3f} s") + + average_score = np.average(all_scores) + print(f"====================================") + print(f"average score: {average_score:.3f}") + print(f"category average scores:") + cat_average_scores = {} + for category, scores in cat_scores.items(): + cat_average_scores[category] = np.average(scores) + print(f"[{category}], average score: {cat_average_scores[category]:.3f}") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="device for torch.", + ) + parser.add_argument( + "--image_info", + type=str, + default="./image_info.json", + help="Image_info.json file.", + ) + parser.add_argument( + "--model_name", + type=str, + default="ViT-H-14", + help="open clip model name", + ) + parser.add_argument( + "--model_weights_path", + type=str, + default="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", + help="open clip model weights", + ) + return parser.parse_args() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/compile_model.py b/MindIE/MultiModal/StableDiffusion-XL/compile_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5ecc8080c24f0a19284a25113b4183db3a724989 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/compile_model.py @@ -0,0 +1,192 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import mindietorch +from mindietorch import _enums + +class ClipExport(torch.nn.Module): + def __init__(self, clip_model): + super().__init__() + self.clip_model = clip_model + + def forward(self, x, output_hidden_states=True, return_dict=False): + return self.clip_model(x, output_hidden_states=output_hidden_states, return_dict=return_dict) + +def compile_clip(model, inputs, clip_compiled_path, soc_version): + compiled_clip_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + min_block_size=1, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_clip_model, clip_compiled_path) + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model): + super().__init__() + self.vae_model = vae_model + + def forward(self, latents): + return self.vae_model.decoder(latents) + +def compile_vae(model, inputs, vae_compiled_path, soc_version): + compiled_vae_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_vae_model, vae_compiled_path) + +class NewScheduler(torch.nn.Module): + def __init__(self, num_train_timesteps=1000, num_inference_steps=50, alphas_cumprod=None, + guidance_scale=5.0, alpha_prod_t_prev_cache=None): + super(NewScheduler, self).__init__() + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = num_inference_steps + self.alphas_cumprod = alphas_cumprod + self.guidance_scale = guidance_scale + self.alpha_prod_t_prev_cache = alpha_prod_t_prev_cache + + def forward(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, step_index: int): + divide_batch = (model_output.shape[0]) // 2 + noise_pred_uncond = model_output[:divide_batch, ..., ..., ...] + noise_pred_text = model_output[divide_batch:, ..., ..., ...] + model_output = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alpha_prod_t_prev_cache[step_index] + beta_prod_t = 1 - alpha_prod_t + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + return prev_sample + +class Scheduler(torch.nn.Module): + def __init__(self, num_train_timesteps=1000, num_inference_steps=50, alphas_cumprod=None, + guidance_scale=5.0, alpha_prod_t_prev_cache=None): + super(Scheduler, self).__init__() + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = num_inference_steps + self.alphas_cumprod = alphas_cumprod + self.guidance_scale = guidance_scale + self.alpha_prod_t_prev_cache = alpha_prod_t_prev_cache + + def forward(self, noise_pred_uncond: torch.FloatTensor, noise_pred_text: torch.FloatTensor, timestep: int, + sample: torch.FloatTensor, step_index: int): + model_output = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alpha_prod_t_prev_cache[step_index] + beta_prod_t = 1 - alpha_prod_t + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + return prev_sample + +def compile_ddim(model, inputs, scheduler_compiled_path, soc_version): + compiled_scheduler_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=False, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_scheduler_model, scheduler_compiled_path) + +class UnetExport(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + text_embeds, + time_ids, + if_skip, + inputCache=None + ): + if if_skip: + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + if_skip=if_skip, inputCache=inputCache)[0] + else: + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + if_skip=if_skip) + +def compile_unet_cache(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +def compile_unet_skip(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_unet_model, unet_compiled_path) + +class UnetExportInit(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + text_embeds, + time_ids + ): + return self.unet_model(sample, timestep, encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids})[0] + +def compile_unet_init(model, inputs, unet_compiled_path, soc_version): + compiled_unet_model = ( + mindietorch.compile(model, + inputs=inputs, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0)) + torch.jit.save(compiled_unet_model, unet_compiled_path) \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/convert_lora_safetensors_to_diffusers.py b/MindIE/MultiModal/StableDiffusion-XL/convert_lora_safetensors_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd9d4eac95080dacede81853e87542cf7a78ebc --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/convert_lora_safetensors_to_diffusers.py @@ -0,0 +1,144 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Conversion script for the LoRA's safetensors checkpoints. """ + +import argparse + +import torch +from safetensors.torch import load_file + +from diffusers import StableDiffusionXLPipeline +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear + + +def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT_ENCODER, alpha): + # load base model + pipeline = StableDiffusionXLPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) + + # load LoRA weight from .safetensors + state_dict = load_file(checkpoint_path) + + visited = [] + shape4failed = 0 + shapeno4failed = 0 + + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + not_found = False + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + not_found = True + break + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + if not_found: + continue + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + pair_keys.append(key.replace("lora_down.weight", "alpha")) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + pair_keys.append(key.replace("lora_up.weight", "alpha")) + + if state_dict[pair_keys[2]] == None: + ratio = 1.0 + else: + alpha = state_dict[pair_keys[2]].item() + ratio = alpha / min(state_dict[pair_keys[0]].shape[0], state_dict[pair_keys[1]].shape[1]) + # update weight + if isinstance(curr_layer, LoRACompatibleConv): + upmat = state_dict[pair_keys[0]].to(torch.float32).flatten(start_dim=1) + downmat = state_dict[pair_keys[1]].to(torch.float32).flatten(start_dim=1) + fusionupdown = torch.mm(upmat, downmat) + fusionupdown = fusionupdown.reshape(curr_layer.weight.data.shape) + curr_layer.weight.data += ratio * fusionupdown + elif isinstance(curr_layer,LoRACompatibleLinear): + upmat = state_dict[pair_keys[0]].to(torch.float32)[None, :] + downmat = state_dict[pair_keys[1]].to(torch.float32)[None, :] + fusion = torch.bmm(upmat, downmat)[0] + curr_layer.weight.data += ratio * fusion + # update visited list + for item in pair_keys: + visited.append(item) + + return pipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" + ) + parser.add_argument( + "--lora_prefix_text_encoder", + default="lora_te", + type=str, + help="The prefix of text encoder weight in safetensors", + ) + parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") + parser.add_argument( + "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." + ) + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + args = parser.parse_args() + + base_model_path = args.base_model_path + checkpoint_path = args.checkpoint_path + dump_path = args.dump_path + lora_prefix_unet = args.lora_prefix_unet + lora_prefix_text_encoder = args.lora_prefix_text_encoder + alpha = args.alpha + + pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) + + pipe = pipe.to(args.device) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/MindIE/MultiModal/StableDiffusion-XL/export_ts.py b/MindIE/MultiModal/StableDiffusion-XL/export_ts.py new file mode 100644 index 0000000000000000000000000000000000000000..f933eca1e44369d5832bcad1f349c40ca659fdb8 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/export_ts.py @@ -0,0 +1,770 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from argparse import Namespace + +import torch +import torch.nn as nn +from diffusers import DDIMScheduler +from diffusers import StableDiffusionXLPipeline +import math +from compile_model import * +import mindietorch +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("-steps", "--steps", type=int, default=50, help="steps.") + parser.add_argument("-guid", "--guidance_scale", type=float, default=5.0, help="guidance_scale") + parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") + parser.add_argument("-p", "--parallel", action="store_true", + help="Export the unet of bs=1 for parallel inferencing.") + parser.add_argument("--soc", choices=["Duo", "A2"], default="A2", help="soc_version.") + parser.add_argument("--lorahot_support", action="store_true", help="compiled model support hot lora weight switch") + parser.add_argument( + "--baselora_path", + type=str, + default="./baseLoraPath/", + help="this para takes effect only when --lorahot_support is specified" + ) + parser.add_argument( + "--flag", + choices=[0, 1, 2], + default=0, + type=int, + help="0 is static; 1 is dynami dims; 2 is dynamic range.", + ) + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + return parser.parse_args() + +def trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path): + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + if not os.path.exists(clip_pt_path): + clip_export = ClipExport(encoder_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + if not os.path.exists(clip2_pt_path): + clip_export = ClipExport(encoder_2_model) + torch.jit.trace(clip_export, dummy_input).save(clip2_pt_path) + +def export_clip(sd_pipeline, args): + print("Exporting the text encoder...") + clip_path = os.path.join(args.output_dir, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + flag, batch_size = args.flag, args.batch_size + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip2_pt_path = os.path.join(clip_path, f"clip2_bs{batch_size}.pt") + clip1_compiled_static_path = os.path.join(clip_path, f"clip_bs{batch_size}_compile_static_{args.height}x{args.width}.ts") + clip2_compiled_static_path = os.path.join(clip_path, f"clip2_bs{batch_size}_compile_static_{args.height}x{args.width}.ts") + clip1_compiled_path = os.path.join(clip_path, f"clip_bs{batch_size}_compile.ts") + clip2_compiled_path = os.path.join(clip_path, f"clip2_bs{batch_size}_compile.ts") + clip1_compiled_dynamic_path = os.path.join(clip_path, f"clip_compile_dynamic.ts") + clip2_compiled_dynamic_path = os.path.join(clip_path, f"clip2_compile_dynamic.ts") + + encoder_model = sd_pipeline.text_encoder + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path) + + # compile + if flag == 0: + if not os.path.exists(clip1_compiled_static_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_static_path, soc_version) + if not os.path.exists(clip2_compiled_static_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(clip1_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_path, soc_version) + if not os.path.exists(clip2_compiled_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_path, soc_version) + elif flag == 2: + min_shape = (min_batch, max_position_embeddings) + max_shape = (max_batch, max_position_embeddings) + if not os.path.exists(clip1_compiled_dynamic_path): + inputs = [] + inputs.append(mindietorch.Input(min_shape=min_shape, max_shape=max_shape, dtype=mindietorch.dtype.INT64)) + model = torch.jit.load(clip_pt_path).eval() + compile_clip(model, inputs, clip1_compiled_dynamic_path, soc_version) + if not os.path.exists(clip2_compiled_dynamic_path): + inputs = [] + inputs.append(mindietorch.Input(min_shape=min_shape, max_shape=max_shape, dtype=mindietorch.dtype.INT64)) + model = torch.jit.load(clip2_pt_path).eval() + compile_clip(model, inputs, clip2_compiled_dynamic_path, soc_version) + +def export_vae(sd_pipeline, args): + print("Exporting the image decoder...") + vae_path = os.path.join(args.output_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + flag, batch_size = args.flag, args.batch_size + height_size, width_size = args.height // 8, args.width // 8 + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + vae_compiled_static_path = os.path.join(vae_path, f"vae_bs{batch_size}_compile_static_{args.height}x{args.width}.ts") + vae_compiled_path = os.path.join(vae_path, f"vae_bs{batch_size}_compile.ts") + vae_compiled_dynamic_path = os.path.join(vae_path, f"vae_compile_dynamic.ts") + + vae_model = sd_pipeline.vae + unet_model = sd_pipeline.unet + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.out_channels + + # trace + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32) + vae_export = VaeExport(vae_model) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + + # compile + if flag == 0: + # 静态 + if not os.path.exists(vae_compiled_static_path): + model = torch.jit.load(vae_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_static_path, soc_version) + elif flag == 1: + # 动态dims + if not os.path.exists(vae_compiled_path): + model = torch.jit.load(vae_pt_path).eval() + inputs = [] + for i in range(len(heights)): + inputs_gear = [ + mindietorch.Input((batch_size, in_channels, heights[i] // 8, widths[i] // 8), dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear) + compile_vae(model, inputs, vae_compiled_path, soc_version) + elif flag == 2: + # 动态shape + if not os.path.exists(vae_compiled_dynamic_path): + model = torch.jit.load(vae_pt_path).eval() + min_shape = (min_batch, in_channels, min_height, min_width) + max_shape = (max_batch, in_channels, max_height, max_width) + inputs = [mindietorch.Input(min_shape=min_shape, max_shape=max_shape, dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_dynamic_path, soc_version) + +def export_unet_init(sd_pipeline, args): + print("Exporting the image information creater...") + unet_path = os.path.join(args.output_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + flag, batch_size = args.flag, args.batch_size * 2 + height_size, width_size = args.height // 8, args.width // 8 + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile_static_{args.height}x{args.width}.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile.ts") + unet_compiled_dynamic_path = os.path.join(unet_path, f"unet_compile_dynamic.ts") + + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32) + ) + unet = UnetExportInit(unet_model).eval() + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + # compile + if flag == 0: + # 静态 + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT)] + compile_unet_init(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + for i in range(len(heights)): + inputs_gear = [ + mindietorch.Input((batch_size, in_channels, heights[i] // 8, widths[i] // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear) + compile_unet_init(model, inputs, unet_compiled_path, soc_version) + elif flag == 2: + if not os.path.exists(unet_compiled_dynamic_path): + model = torch.jit.load(unet_pt_path).eval() + min_shape_1 = (min_batch * 2, in_channels, min_height, min_width) + max_shape_1 = (max_batch * 2, in_channels, max_height, max_width) + min_shape_2, max_shape_2 = (1,), (1,) + min_shape_3 = (min_batch * 2, max_position_embeddings, encoder_hidden_size) + max_shape_3 = (max_batch * 2, max_position_embeddings, encoder_hidden_size) + min_shape_4 = (min_batch * 2, encoder_hidden_size_2) + max_shape_4 = (max_batch * 2, encoder_hidden_size_2) + min_shape_5 = (min_batch * 2, 6) + max_shape_5 = (max_batch * 2, 6) + inputs = [ + mindietorch.Input(min_shape=min_shape_1, max_shape=max_shape_1, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_2, max_shape=max_shape_2, dtype=mindietorch.dtype.INT64), + mindietorch.Input(min_shape=min_shape_3, max_shape=max_shape_3, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_4, max_shape=max_shape_4, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_5, max_shape=max_shape_5, dtype=mindietorch.dtype.FLOAT), + ] + compile_unet_init(model, inputs, unet_compiled_dynamic_path, soc_version) + +def export_unet_cache(sd_pipeline, args): + print("Exporting the image information creater...") + unet_path = os.path.join(args.output_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + if args.parallel: + parallel = "parallel_" + batch_size = args.batch_size + else: + parallel = "" + batch_size = args.batch_size * 2 + flag = args.flag + height_size, width_size = args.height // 8, args.width // 8 + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_0.pt") + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{batch_size}_{parallel}compile_0_static_{args.height}x{args.width}.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_{parallel}compile_0.ts") + unet_compiled_dynamic_path = os.path.join(unet_path, f"unet_{parallel}compile_0_dynamic.ts") + + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32), + torch.zeros([1], dtype=torch.int64), + ) + unet = UnetExport(unet_model).eval() + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + # compile + if flag == 0: + # 静态 + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] + compile_unet_cache(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + # 动态dims + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + for i in range(len(heights)): + inputs_gear = [ + mindietorch.Input((batch_size, in_channels, heights[i] // 8, widths[i] // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] + inputs.append(inputs_gear) + compile_unet_cache(model, inputs, unet_compiled_path, soc_version) + elif flag == 2: + if not os.path.exists(unet_compiled_dynamic_path): + model = torch.jit.load(unet_pt_path).eval() + if args.parallel: + min_batch_temp, max_batch_temp = min_batch, max_batch + else: + min_batch_temp, max_batch_temp = min_batch * 2, max_batch * 2 + min_shape_1 = (min_batch_temp, in_channels, min_height, min_width) + max_shape_1 = (max_batch_temp, in_channels, max_height, max_width) + min_shape_2, max_shape_2 = (1,), (1,) + min_shape_3 = (min_batch_temp, max_position_embeddings, encoder_hidden_size) + max_shape_3 = (max_batch_temp, max_position_embeddings, encoder_hidden_size) + min_shape_4 = (min_batch_temp, encoder_hidden_size_2) + max_shape_4 = (max_batch_temp, encoder_hidden_size_2) + min_shape_5 = (min_batch_temp, 6) + max_shape_5 = (max_batch_temp, 6) + min_shape_6, max_shape_6 = (1,), (1,) + inputs = [ + mindietorch.Input(min_shape=min_shape_1, max_shape=max_shape_1, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_2, max_shape=max_shape_2, dtype=mindietorch.dtype.INT64), + mindietorch.Input(min_shape=min_shape_3, max_shape=max_shape_3, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_4, max_shape=max_shape_4, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_5, max_shape=max_shape_5, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_6, max_shape=max_shape_6, dtype=mindietorch.dtype.INT64) + ] + compile_unet_cache(model, inputs, unet_compiled_dynamic_path, soc_version) + +def export_unet_skip(sd_pipeline, args): + print("Exporting the image information creater...") + unet_path = os.path.join(args.output_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + if args.parallel: + parallel = "parallel_" + batch_size = args.batch_size + else: + parallel = "" + batch_size = args.batch_size * 2 + flag = args.flag + height_size, width_size = args.height // 8, args.width // 8 + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_1.pt") + unet_compiled_static_path = os.path.join(unet_path, f"unet_bs{batch_size}_{parallel}compile_1_static_{args.height}x{args.width}.ts") + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_{parallel}compile_1.ts") + unet_compiled_dynamic_path = os.path.join(unet_path, f"unet_{parallel}compile_1_dynamic.ts") + + unet_model = sd_pipeline.unet + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones( + [batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32 + ), + torch.ones([batch_size, encoder_hidden_size_2], dtype=torch.float32), + torch.ones([batch_size, 6], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, 1280, math.ceil(sample_size / 2), math.ceil(sample_size / 2)], + dtype=torch.float32) + ) + unet = UnetExport(unet_model).eval() + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + # compile + if flag == 0: + # 静态 + if not os.path.exists(unet_compiled_static_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, 1280, math.ceil(height_size / 2), math.ceil(width_size / 2)), + dtype=mindietorch.dtype.FLOAT)] + compile_unet_skip(model, inputs, unet_compiled_static_path, soc_version) + elif flag == 1: + # 动态dims + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [] + for i in range(len(heights)): + inputs_gear = [ + mindietorch.Input((batch_size, in_channels, heights[i] // 8, widths[i] // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, 1280, math.ceil(heights[i] // 8 / 2), math.ceil(widths[i] // 8 / 2)), + dtype=mindietorch.dtype.FLOAT)] + inputs.append(inputs_gear) + compile_unet_skip(model, inputs, unet_compiled_path, soc_version) + elif flag == 2: + if not os.path.exists(unet_compiled_dynamic_path): + model = torch.jit.load(unet_pt_path).eval() + if args.parallel: + min_batch_temp, max_batch_temp = min_batch, max_batch + else: + min_batch_temp, max_batch_temp = min_batch * 2, max_batch * 2 + min_shape_1 = (min_batch_temp, in_channels, min_height, min_width) + max_shape_1 = (max_batch_temp, in_channels, max_height, max_width) + min_shape_2, max_shape_2 = (1,), (1,) + min_shape_3 = (min_batch_temp, max_position_embeddings, encoder_hidden_size) + max_shape_3 = (max_batch_temp, max_position_embeddings, encoder_hidden_size) + min_shape_4 = (min_batch_temp, encoder_hidden_size_2) + max_shape_4 = (max_batch_temp, encoder_hidden_size_2) + min_shape_5 = (min_batch_temp, 6) + max_shape_5 = (max_batch_temp, 6) + min_shape_6, max_shape_6 = (1,), (1,) + min_shape_7 = (min_batch_temp, 1280, math.ceil(min_height / 2), math.ceil(min_width / 2)) + max_shape_7 = (max_batch_temp, 1280, math.ceil(max_height / 2), math.ceil(max_width / 2)) + inputs = [ + mindietorch.Input(min_shape=min_shape_1, max_shape=max_shape_1, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_2, max_shape=max_shape_2, dtype=mindietorch.dtype.INT64), + mindietorch.Input(min_shape=min_shape_3, max_shape=max_shape_3, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_4, max_shape=max_shape_4, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_5, max_shape=max_shape_5, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_6, max_shape=max_shape_6, dtype=mindietorch.dtype.INT64), + mindietorch.Input(min_shape=min_shape_7, max_shape=max_shape_7, dtype=mindietorch.dtype.FLOAT) + ] + compile_unet_skip(model, inputs, unet_compiled_dynamic_path, soc_version) + +def trace_ddim(sd_pipeline, args, ddim_pt_path): + batch_size = args.batch_size * 2 + if not os.path.exists(ddim_pt_path): + dummy_input = ( + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.randn([batch_size // 2, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + ) + scheduler = DDIMScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(args.steps, device="cpu") + + timesteps = scheduler.timesteps[:args.steps] + alpha_prod_t_prev_cache = [] + for timestep in timesteps: + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + alpha_prod_t_prev = scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + alpha_prod_t_prev_cache.append(alpha_prod_t_prev) + + new_ddim = NewScheduler( + num_train_timesteps=scheduler.config.num_train_timesteps, + num_inference_steps=scheduler.num_inference_steps, + alphas_cumprod=scheduler.alphas_cumprod, + guidance_scale=args.guidance_scale, + alpha_prod_t_prev_cache=torch.tensor(alpha_prod_t_prev_cache) + ) + new_ddim.eval() + torch.jit.trace(new_ddim, dummy_input).save(ddim_pt_path) + +def export_ddim(sd_pipeline, args): + print("Exporting the ddim...") + ddim_path = os.path.join(args.output_dir, "ddim") + if not os.path.exists(ddim_path): + os.makedirs(ddim_path, mode=0o744) + flag, batch_size = args.flag, args.batch_size * 2 + height_size, width_size = args.height // 8, args.width // 8 + ddim_pt_path = os.path.join(ddim_path, f"ddim_bs{batch_size}.pt") + scheduler_compiled_static_path = os.path.join(ddim_path, f"ddim_bs{batch_size}_compile_static_{args.height}x{args.width}.ts") + scheduler_compiled_path = os.path.join(ddim_path, f"ddim_bs{batch_size}_compile.ts") + scheduler_compiled_dynamic_path = os.path.join(ddim_path, f"ddim_compile_dynamic.ts") + + in_channels = 4 + + # trace + trace_ddim(sd_pipeline, args, ddim_pt_path) + + # compile + if flag == 0: + # 静态 + if not os.path.exists(scheduler_compiled_static_path): + model = torch.jit.load(ddim_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size // 2, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64) + ] + compile_ddim(model, inputs, scheduler_compiled_static_path, soc_version) + elif flag == 1: + # 动态dims + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(ddim_pt_path).eval() + inputs = [] + for i in range(len(heights)): + inputs_gear = [ + mindietorch.Input((batch_size, in_channels, heights[i] // 8, widths[i] // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size // 2, in_channels, heights[i] // 8, widths[i] // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] + inputs.append(inputs_gear) + compile_ddim(model, inputs, scheduler_compiled_path, soc_version) + elif flag == 2: + if not os.path.exists(scheduler_compiled_dynamic_path): + model = torch.jit.load(ddim_pt_path).eval() + min_shape_1 = (min_batch * 2, in_channels, min_height, min_width) + max_shape_1 = (max_batch * 2, in_channels, max_height, max_width) + min_shape_2, max_shape_2 = (1,), (1,) + min_shape_3 = (min_batch, in_channels, min_height, min_width) + max_shape_3 = (max_batch, in_channels, max_height, max_width) + min_shape_4, max_shape_4 = (1,), (1,) + inputs = [ + mindietorch.Input(min_shape=min_shape_1, max_shape=max_shape_1, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_2, max_shape=max_shape_2, dtype=mindietorch.dtype.INT64), + mindietorch.Input(min_shape=min_shape_3, max_shape=max_shape_3, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_4, max_shape=max_shape_4, dtype=mindietorch.dtype.INT64)] + compile_ddim(model, inputs, scheduler_compiled_dynamic_path, soc_version) + +def trace_ddim_parallel(sd_pipeline, args, ddim_pt_path): + batch_size = args.batch_size + if not os.path.exists(ddim_pt_path): + dummy_input = ( + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + ) + scheduler = DDIMScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(args.steps, device="cpu") + + timesteps = scheduler.timesteps[:args.steps] + alpha_prod_t_prev_cache = [] + for timestep in timesteps: + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + alpha_prod_t_prev = scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + alpha_prod_t_prev_cache.append(alpha_prod_t_prev) + + new_ddim = Scheduler( + num_train_timesteps=scheduler.config.num_train_timesteps, + num_inference_steps=scheduler.num_inference_steps, + alphas_cumprod=scheduler.alphas_cumprod, + guidance_scale=args.guidance_scale, + alpha_prod_t_prev_cache=torch.tensor(alpha_prod_t_prev_cache) + ) + new_ddim.eval() + torch.jit.trace(new_ddim, dummy_input).save(ddim_pt_path) + +def export_ddim_parallel(sd_pipeline, args): + print("Exporting the ddim...") + ddim_path = os.path.join(args.output_dir, "ddim") + if not os.path.exists(ddim_path): + os.makedirs(ddim_path, mode=0o744) + flag, batch_size = args.flag, args.batch_size + height_size, width_size = args.height // 8, args.width // 8 + ddim_pt_path = os.path.join(ddim_path, f"ddim_bs{batch_size}.pt") + scheduler_compiled_static_path = os.path.join(ddim_path, f"ddim_bs{batch_size}_parallel_compile_static_{args.height}x{args.width}.ts") + scheduler_compiled_path = os.path.join(ddim_path, f"ddim_bs{batch_size}_parallel_compile.ts") + scheduler_compiled_dynamic_path = os.path.join(ddim_path, f"ddim_parallel_compile_dynamic.ts") + + in_channels = 4 + + # trace + trace_ddim_parallel(sd_pipeline, args, ddim_pt_path) + + # compile + if flag == 0: + # 静态 + if not os.path.exists(scheduler_compiled_static_path): + model = torch.jit.load(ddim_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] + compile_ddim(model, inputs, scheduler_compiled_static_path, soc_version) + elif flag == 1: + # 动态dims + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(ddim_pt_path).eval() + inputs = [] + for i in range(len(heights)): + inputs_gear = [ + mindietorch.Input((batch_size, in_channels, heights[i] // 8, widths[i] // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, in_channels, heights[i] // 8, widths[i] // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, in_channels, heights[i] // 8, widths[i] // 8), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] + inputs.append(inputs_gear) + compile_ddim(model, inputs, scheduler_compiled_path, soc_version) + elif flag == 2: + if not os.path.exists(scheduler_compiled_dynamic_path): + model = torch.jit.load(ddim_pt_path).eval() + min_shape_1 = (min_batch, in_channels, min_height, min_width) + max_shape_1 = (max_batch, in_channels, max_height, max_width) + min_shape_2 = (min_batch, in_channels, min_height, min_width) + max_shape_2 = (max_batch, in_channels, max_height, max_width) + min_shape_3, max_shape_3 = (1,), (1,) + min_shape_4 = (min_batch, in_channels, min_height, min_width) + max_shape_4 = (max_batch, in_channels, max_height, max_width) + min_shape_5, max_shape_5 = (1,), (1,) + inputs = [ + mindietorch.Input(min_shape=min_shape_1, max_shape=max_shape_1, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_2, max_shape=max_shape_2, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_3, max_shape=max_shape_3, dtype=mindietorch.dtype.INT64), + mindietorch.Input(min_shape=min_shape_4, max_shape=max_shape_4, dtype=mindietorch.dtype.FLOAT), + mindietorch.Input(min_shape=min_shape_5, max_shape=max_shape_5, dtype=mindietorch.dtype.INT64)] + compile_ddim(model, inputs, scheduler_compiled_dynamic_path, soc_version) + + +def concat_string(string_array): + length = len(string_array) + strres ='' + for i in range(length - 2): + strres = strres + string_array[i] + '.' + strres = strres + string_array[length - 2] + return strres + + +def register_unet_buffer(sd_pipeline, baselora_path): + if not os.path.exists(baselora_path): + os.makedirs(baselora_path, mode=0o640) + unet_model = sd_pipeline.unet + save_tensor = dict() + + for name in list(unet_model.state_dict().keys()): + name_array = name.split('.') + if(name_array[0] == "time_embedding" or name_array[0] == "add_embedding"): + continue + curlayer = unet_model + for i in range(len(name_array)-1): + curlayer = curlayer.__getattr__(name_array[i]) + + if isinstance(curlayer, LoRACompatibleLinear) or isinstance(curlayer, LoRACompatibleConv): + strres = concat_string(name_array) + save_tensor[strres] = curlayer.weight.data + curlayer.register_buffer("mindie_buffer", curlayer.weight.data) + curlayer.status = True + + save_path = os.path.join(baselora_path, "saveTensor.pt") + torch.save(save_tensor, save_path) + +def export(args): + pipeline = StableDiffusionXLPipeline.from_pretrained(args.model).to('cpu') + + export_clip(pipeline, args) + export_vae(pipeline, args) + #before export, we need to register buffer first + if args.lorahot_support: + register_unet_buffer(pipeline, args.baselora_path) + if args.use_cache: + export_unet_cache(pipeline, args) + export_unet_skip(pipeline, args) + else: + export_unet_init(pipeline, args) + if args.parallel: + export_ddim_parallel(pipeline, args) + else: + export_ddim(pipeline, args) + +def main(): + args = parse_arguments() + mindietorch.set_device(args.device) + export(args) + print("Done.") + mindietorch.finalize() + +if __name__ == "__main__": + # 动态shape支持的分辨率 + min_batch, max_batch = 1, 32 + min_height, max_height = 512 // 8, 1024 // 8 + min_width, max_width = 512 // 8, 1664 // 8 + # 动态分档支持的分辨率 + heights = [1024, 512, 936, 768, 576] + widths = [1024, 512, 1664, 1360, 1024] + + args = parse_arguments() + if args.soc == "Duo": + soc_version = "Ascend310P3" + elif args.soc == "A2": + soc_version = "Ascend910B4" + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/export_ts_quant.py b/MindIE/MultiModal/StableDiffusion-XL/export_ts_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..c600937c7fd08f63bdea47bef1f854eb039af841 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/export_ts_quant.py @@ -0,0 +1,460 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import numpy as np +from modelslim.pytorch.quant.ptq_tools import Calibrator, QuantConfig +from quant_utils import modify_model +import argparse +from argparse import Namespace +import math +import torch +import torch.nn as nn +from diffusers import DDIMScheduler +from diffusers import StableDiffusionXLPipeline +from compile_model import * +import mindietorch + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models_quant", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument("-bs", "--batch_size", type=int, default=1, help="Batch size.") + parser.add_argument("-steps", "--steps", type=int, default=50, help="steps.") + parser.add_argument("-guid", "--guidance_scale", type=float, default=5.0, help="guidance_scale") + parser.add_argument("--use_cache", action="store_true", help="Use cache during inference.") + parser.add_argument("-p", "--parallel", action="store_true", + help="Export the unet of bs=1 for parallel inferencing.") + parser.add_argument("--soc", choices=["Duo", "A2"], default="A2", help="soc_version.") + parser.add_argument( + "--unet_data_dir", + type=str, + default='./unet_data.npy', + help="save unet input for quant." + ) + parser.add_argument( + "--device", + default=0, + type=int, + help="NPU device", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + return parser.parse_args() + +def trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path): + encoder_model = sd_pipeline.text_encoder + encoder_2_model = sd_pipeline.text_encoder_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + if not os.path.exists(clip_pt_path): + clip_export = ClipExport(encoder_model) + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + if not os.path.exists(clip2_pt_path): + clip_export = ClipExport(encoder_2_model) + torch.jit.trace(clip_export, dummy_input).save(clip2_pt_path) + +def export_clip(sd_pipeline, args): + print("Exporting the text encoder...") + clip_path = os.path.join(args.output_dir, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + batch_size = 1 + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + clip2_pt_path = os.path.join(clip_path, f"clip2_bs{batch_size}.pt") + + encoder_model = sd_pipeline.text_encoder + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + trace_clip(sd_pipeline, batch_size, clip_pt_path, clip2_pt_path) + + # compile + batch_size = args.batch_size + clip1_compiled_path = os.path.join(clip_path, f"clip_bs{batch_size}_compile_quant_{args.height}x{args.width}.ts") + clip2_compiled_path = os.path.join(clip_path, f"clip2_bs{batch_size}_compile_quant_{args.height}x{args.width}.ts") + if not os.path.exists(clip1_compiled_path): + model = torch.jit.load(clip_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip1_compiled_path, soc_version) + if not os.path.exists(clip2_compiled_path): + model = torch.jit.load(clip2_pt_path).eval() + inputs = [mindietorch.Input((batch_size, max_position_embeddings), dtype=mindietorch.dtype.INT64)] + compile_clip(model, inputs, clip2_compiled_path, soc_version) + +def export_vae(sd_pipeline, args): + print("Exporting the image decoder...") + vae_path = os.path.join(args.output_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + batch_size = 1 + height_size, width_size = args.height // 8, args.width // 8 + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + + vae_model = sd_pipeline.vae + unet_model = sd_pipeline.unet + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.out_channels + + # trace + if not os.path.exists(vae_pt_path): + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32) + vae_export = VaeExport(vae_model) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + + # compile + batch_size = args.batch_size + vae_compiled_path = os.path.join(vae_path, f"vae_bs{batch_size}_compile_quant_{args.height}x{args.width}.ts") + if not os.path.exists(vae_compiled_path): + model = torch.jit.load(vae_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), dtype=mindietorch.dtype.FLOAT)] + compile_vae(model, inputs, vae_compiled_path, soc_version) + +def trace_ddim(sd_pipeline, args, ddim_pt_path): + batch_size = 2 + if not os.path.exists(ddim_pt_path): + dummy_input = ( + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.randn([batch_size // 2, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + ) + scheduler = DDIMScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(args.steps, device="cpu") + + timesteps = scheduler.timesteps[:args.steps] + alpha_prod_t_prev_cache = [] + for timestep in timesteps: + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + alpha_prod_t_prev = scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + alpha_prod_t_prev_cache.append(alpha_prod_t_prev) + + new_ddim = NewScheduler( + num_train_timesteps=scheduler.config.num_train_timesteps, + num_inference_steps=scheduler.num_inference_steps, + alphas_cumprod=scheduler.alphas_cumprod, + guidance_scale=args.guidance_scale, + alpha_prod_t_prev_cache=torch.tensor(alpha_prod_t_prev_cache) + ) + new_ddim.eval() + torch.jit.trace(new_ddim, dummy_input).save(ddim_pt_path) + +def export_ddim(sd_pipeline, args): + print("Exporting the ddim...") + ddim_path = os.path.join(args.output_dir, "ddim") + if not os.path.exists(ddim_path): + os.makedirs(ddim_path, mode=0o744) + batch_size = 2 + height_size, width_size = args.height // 8, args.width // 8 + ddim_pt_path = os.path.join(ddim_path, f"ddim_bs{batch_size}.pt") + + unet_model = sd_pipeline.unet + ddim_model = sd_pipeline.scheduler + sample_size = unet_model.config.sample_size + in_channels = 4 + + # trace + trace_ddim(sd_pipeline, args, ddim_pt_path) + # compile + batch_size = args.batch_size * 2 + scheduler_compiled_path = os.path.join(ddim_path, f"ddim_bs{batch_size}_compile_quant_{args.height}x{args.width}.ts") + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(ddim_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size // 2, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64) + ] + compile_ddim(model, inputs, scheduler_compiled_path, soc_version) + +def trace_ddim_parallel(sd_pipeline, args, ddim_pt_path): + batch_size = 1 + if not os.path.exists(ddim_pt_path): + dummy_input = ( + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.randn([batch_size, 4, 128, 128], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + ) + scheduler = DDIMScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(args.steps, device="cpu") + + timesteps = scheduler.timesteps[:args.steps] + alpha_prod_t_prev_cache = [] + for timestep in timesteps: + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + alpha_prod_t_prev = scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + alpha_prod_t_prev_cache.append(alpha_prod_t_prev) + + new_ddim = Scheduler( + num_train_timesteps=scheduler.config.num_train_timesteps, + num_inference_steps=scheduler.num_inference_steps, + alphas_cumprod=scheduler.alphas_cumprod, + guidance_scale=args.guidance_scale, + alpha_prod_t_prev_cache=torch.tensor(alpha_prod_t_prev_cache) + ) + new_ddim.eval() + torch.jit.trace(new_ddim, dummy_input).save(ddim_pt_path) + +def export_ddim_parallel(sd_pipeline, args): + print("Exporting the ddim...") + ddim_path = os.path.join(args.output_dir, "ddim") + if not os.path.exists(ddim_path): + os.makedirs(ddim_path, mode=0o640) + batch_size = 1 + height_size, width_size = args.height // 8, args.width // 8 + ddim_pt_path = os.path.join(ddim_path, f"ddim_bs{batch_size}.pt") + + in_channels = 4 + + # trace + trace_ddim_parallel(sd_pipeline, args, ddim_pt_path) + # compile + batch_size = args.batch_size + scheduler_compiled_path = os.path.join(ddim_path, f"ddim_bs{batch_size}_parallel_compile_quant_{args.height}x{args.width}.ts") + if not os.path.exists(scheduler_compiled_path): + model = torch.jit.load(ddim_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] + compile_ddim(model, inputs, scheduler_compiled_path, soc_version) + +def trace_quant_model(model, calib_datas, input_shape, pt_path, need_calib=True): + save_path = os.path.dirname(os.path.split(pt_path)[0]) + quant_model = copy.deepcopy(model) + export_model = copy.deepcopy(model) + if need_calib: + quant_config = QuantConfig(disable_names=[], + amp_num=0, input_shape=input_shape, + act_method=0, quant_mode=0, a_signed=True, sigma=40) + calibrator = Calibrator(quant_model, quant_config, calib_data=calib_datas) + calibrator.run() + calibrator.export_param(os.path.join(save_path, 'quant_weights')) + input_scale = np.load(os.path.join(save_path, 'quant_weights', 'input_scale.npy'), allow_pickle=True).item() + input_offset = np.load(os.path.join(save_path, 'quant_weights', 'input_offset.npy'), allow_pickle=True).item() + weight_scale = np.load(os.path.join(save_path, 'quant_weights', 'weight_scale.npy'), allow_pickle=True).item() + weight_offset = np.load(os.path.join(save_path, 'quant_weights', 'weight_offset.npy'), allow_pickle=True).item() + quant_weight = np.load(os.path.join(save_path, 'quant_weights', 'quant_weight.npy'), allow_pickle=True).item() + + export_model = modify_model(export_model, input_scale, input_offset, weight_scale, weight_offset, quant_weight) + torch.jit.trace(export_model, calib_datas[0]).save(pt_path) + +def export_unet_cache(sd_pipeline, args, input_data): + print("Exporting the image information creater...") + unet_path = os.path.join(args.output_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + if input_data['parallel']: + parallel = "parallel_" + batch_size = 1 + else: + parallel = "" + batch_size = 2 + height_size, width_size = args.height // 8, args.width // 8 + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_0.pt") + + unet_model = copy.deepcopy(sd_pipeline.unet) + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + calib_datas = [list(input_data['cache'])] + unet = UnetExport(unet_model) + unet.eval() + trace_quant_model(unet, calib_datas, [batch_size, in_channels, sample_size, sample_size], unet_pt_path, need_calib=True) + # compile + batch_size = args.batch_size * 2 + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_{parallel}compile_0_quant_{args.height}x{args.width}.ts") + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64)] + compile_unet_cache(model, inputs, unet_compiled_path, soc_version) + +def export_unet_skip(sd_pipeline, args, input_data): + print("Exporting the image information creater...") + unet_path = os.path.join(args.output_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + if input_data['parallel']: + parallel = "parallel_" + batch_size = 1 + else: + parallel = "" + batch_size = 2 + height_size, width_size = args.height // 8, args.width // 8 + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_1.pt") + + unet_model = copy.deepcopy(sd_pipeline.unet) + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + calib_datas = [list(input_data['skip'])] + unet = UnetExport(unet_model) + unet.eval() + trace_quant_model(unet, calib_datas, [batch_size, in_channels, sample_size, sample_size], unet_pt_path, need_calib=False) + # compile + batch_size = args.batch_size * 2 + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_{parallel}compile_1_quant_{args.height}x{args.width}.ts") + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, 1280, math.ceil(height_size / 2), math.ceil(width_size / 2)), + dtype=mindietorch.dtype.FLOAT)] + compile_unet_skip(model, inputs, unet_compiled_path, soc_version) + + +def export_unet_init(sd_pipeline, args, input_data): + print("Exporting the image information creater...") + unet_path = os.path.join(args.output_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + batch_size = 2 + height_size, width_size = args.height // 8, args.width // 8 + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + + unet_model = copy.deepcopy(sd_pipeline.unet) + encoder_model = sd_pipeline.text_encoder + encoder_model_2 = sd_pipeline.text_encoder_2 + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size_2 = encoder_model_2.config.hidden_size + encoder_hidden_size = encoder_model.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = encoder_model.config.max_position_embeddings + + # trace + if not os.path.exists(unet_pt_path): + calib_datas = [list(input_data['no_cache'])] + unet = UnetExportInit(unet_model) + unet.eval() + trace_quant_model(unet, calib_datas, [batch_size, in_channels, sample_size, sample_size], unet_pt_path, need_calib=True) + # compile + batch_size = args.batch_size * 2 + unet_compiled_path = os.path.join(unet_path, f"unet_bs{batch_size}_compile_quant_{args.height}x{args.width}.ts") + if not os.path.exists(unet_compiled_path): + model = torch.jit.load(unet_pt_path).eval() + inputs = [ + mindietorch.Input((batch_size, in_channels, height_size, width_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, max_position_embeddings, encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, encoder_hidden_size_2), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, 6), dtype=mindietorch.dtype.FLOAT)] + compile_unet_init(model, inputs, unet_compiled_path, soc_version) + +def export(args): + pipeline = StableDiffusionXLPipeline.from_pretrained(args.model).to('cpu') + data = np.load(args.unet_data_dir, allow_pickle=True).item() + print(data.keys()) + print(data['use_cache']) + if 'use_cache' not in data or 'parallel' not in data: + raise RuntimeError(f'invalid unet data file.') + + export_clip(pipeline, args) + export_vae(pipeline, args) + if data['use_cache']: + export_unet_cache(pipeline, args, data) + export_unet_skip(pipeline, args, data) + else: + export_unet_init(pipeline, args, data) + if args.parallel: + export_ddim_parallel(pipeline, args) + else: + export_ddim(pipeline, args) + +def main(): + args = parse_arguments() + mindietorch.set_device(args.device) + torch.ops.load_library("./quant/build/libquant_ops.so") + export(args) + print("Done.") + mindietorch.finalize() + +if __name__ == '__main__': + args = parse_arguments() + if args.soc == "Duo": + soc_version = "Ascend310P3" + elif args.soc == "A2": + soc_version = "Ascend910B4" + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/hpsv2_benchmark_prompts.json b/MindIE/MultiModal/StableDiffusion-XL/hpsv2_benchmark_prompts.json new file mode 100644 index 0000000000000000000000000000000000000000..d73bb647aca744dcc781a797247cb1a4b9345990 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/hpsv2_benchmark_prompts.json @@ -0,0 +1,3210 @@ +{ + "anime": [ + "Spongebob depicted in the style of Dragon Ball Z.", + "Lionel Messi portrayed as a sitcom character.", + "A digital artwork depicting a cartoon illustration of a warehouse environment.", + "Two young Japanese goth cosplay girls in fishnets, corsets, chokers, and black and white makeup with full body tattoos and intricate painted details.", + "A white-haired girl in a pink sweater looks out a window in her bedroom.", + "A girl gazes at a city from a mountain at night in a colored manga illustration by Diego Facio.", + "A hamster resembling a horse.", + "The president being abducted by aliens.", + "Anime-style fighter pilot in cockpit engaged in a night air battle with explosions.", + "A hyper-realistic representation of the hypnotoad from Futurama.", + "A detailed movie poster for Invader Zim with a horror theme.", + "The image depicts Rika Furude from the video game Yume Nikki.", + "A lemon character with sunglasses on the beach.", + "A lemon wearing sunglasses on the beach.", + "A tall chicken standing next to a farmer.", + "A girl in school uniform standing in the city.", + "Image of xqc with a distinctive underbite and big, long nose.", + "A goth anime woman with a symmetrical and attractive face in a black and white watercolor headshot art on ArtStation.", + "Peter Griffin taking a selfie in a park on an iPhone 12 Pro.", + "A helmet-wearing monkey skating.", + "Anime oil painting of Rem from Re Zero.", + "Pinup girl in the style of Jab Comix.", + "A photo of Big Chungus from Looney Tunes.", + "Frontal portrait of anime girl with pink hair wearing white t-shirt and smiling.", + "Tom and Jerry are featured in Iron Maiden's album Live After Death in place of their usual mascot, Eddie.", + "A kangaroo in an orange hoodie and blue sunglasses stands on the grass in front of the Sydney Opera House holding a \"Welcome Friends\" sign.", + "A whale is reading a book about avoiding Japanese spears in an underwater library.", + "Gordon Ramsay frying minions on a pan.", + "The image depicts Peter Griffin, a cartoon character from the television show \"Family Guy.\"", + "A cartoon-style illustration of a fantasy village environment.", + "Danny DeVito appears in Jojo's Bizarre Adventure.", + "A Pixar lemon wearing sunglasses on a beach.", + "A photograph of a woman from Steven Universe with gigantic pink ringlets and a white dress.", + "An image of iCarly and Hannah Montana in a mosh pit.", + "A close-up anime portrait of Sailor Moon against a grey background with Russian panel housing in bokeh.", + "Gus Fring working as a KFC waiter.", + "Anime art featuring Hatsune Miku with symmetrical shoulders.", + "Asuka Langley dressed as a slavic priestess in a birch forest during spring.", + "Spongebob is wearing avant-garde Rick Owens clothing in a fashion-forward look.", + "\"A man drinking a cup of cosmic energy in a surreal anime style artwork.\"", + "A syrup bottle crashes into twin towers made of pancakes.", + "A capybara wearing sunglasses.", + "A still image from the CGI animated movie Akira.", + "Close-up of a manga still depicting the interior of a Shinkansen train with a view from a leather seat next to the side window, and a hyper-realistic film still projection of a promenade scene in the background.", + "Image of Katara from Avatar.", + "Superman with Danny DeVito's face.", + "A stylized Mr. Clean dressed in leather jacket and shades reminiscent of the Fonz from \"Happy Days.\"", + "A colorful anime painting of a sugar glider with a hiphop graffiti theme, by several artists, currently trending on Artstation.", + "The image is a music album cover featuring Vaporize Pikachu.", + "Madrid cityscape with a distinct Studio Ghibli-inspired aesthetic.", + "A beaver in formal attire stands next to a stack of books in a library.", + "The image depicts a colorful female pirate with dreadlocks, holding a sword.", + "A dog wearing a business suit and smoking a cigar.", + "An anime-style painting of a sugar glider royalty enjoying hiphop music amidst a graffiti background.", + "Male furry anthro mountain goat in a pinstripe suit and waistcoat, smoking a cigar.", + "\"A comic portrait of an Indian goddess with realistic shading and fine details set in a nighttime anime style.\"", + "A colorful illustration of a suburban neighborhood on an ancient post-apocalyptic planet featuring creatures made by Jim Henson's workshop.", + "A hyena fursona sitting in the grass in a savannah at sunset.", + "A plushy tired owl sits on a pile of antique books in a humorous illustration.", + "A dolphin swimming in front of a Studio Ghibli logo backdrop.", + "An image of a person dressed as an avocado queen.", + "Portrait of girl in black riot gear holding a whip, drawn in anime/manga style by Makoto Shinkai.", + "A lemon wearing sunglasses on the beach.", + "A puppy is driving a car in a film still.", + "Portrait of an ape wearing an astronaut helmet.", + "Baby Yoda depicted in the style of Assassination Classroom anime.", + "Pop art illustration of a surprised girl in a comic style, suitable for advertising posters, party invitations, and birthday greeting cards.", + "An illustration featuring characters from a Dungeons and Dragons game.", + "A digital painting of an anthropomorphic corgi lifting weights in a dim gym with intricate details and a dynamic pose.", + "Danny Devito dressed up as Wolverine.", + "A manga-style railgun illustration.", + "Two Sesame Street characters, Ernie and Bert, standing close to each other and looking towards the camera.", + "A close-up portrait of Rapunzel with a smile.", + "\"Fuzzy clay creatures playing a game in a heavily tilted green pasture.\"", + "Closeup of a seinen manga film still showing the interior of a shinkansen train with a leather seat and a window view, with a hyperrealistic film still from a Nepali movie projecting in the background.", + "Cartoon astronaut floating around the moon with a rocket ship and shooting stars in the background.", + "A tiger wearing a train conductor's hat and holding a skateboard decorated with a yin-yang symbol.", + "A lifelike anime girl in steel plate armor takes a selfie in a castle courtyard.", + "Android 18 is talking to Danny DeVito.", + "An attractive anime-style character of a girl with short brown hair and wearing a white blouse, drawn by WLOP.", + "The image depicts Karl Marx bodyslamming Friedrich Nietzsche in a WWE championship match.", + "Simplified anime-style logo featuring a railgun.", + "of a pug wearing a cowboy hat and bandana, sitting on a hay bale.", + "Peter Griffin is featured in a scene from the TV show Family Guy.", + "Patrick Bateman beating an anthropomorphic wolf cosplay.", + "The image is a mash-up of \"Stranger Things\" and \"Ghostbusters\" themes.", + "A cyan silver the hedgehog with black tipped quills wearing green-tinted sunglasses, a purple and green cape, and shoes.", + "A flat design illustration of a cheese sandwich with minimalistic line elements.", + "The image depicts Spongebob and friends at a wedding ceremony, created through AI and with a high level of detail.", + "A forest scene depicted in the morning light by Rumiko Takahashi.", + "Two gundams in an intimate and potentially romantic scene, illustrated by Yoshitaka Amano.", + "An advanced digital anime art of Lucy standing on an asteroid with a futuristic utopian city in the background and sunlight backlighting the landscape.", + "Magical girl manga cover featuring a highly detailed, fancy design.", + "A boy flies through candy planets eating pizza in a comic-style drawing by Steve Ditko and Frank Miller.", + "A red-haired girl with an eyepatch scaring little children.", + "A figurine of Walter White from Breaking Bad depicted as an anime character.", + "A fox wearing a yellow dress.", + "Hatsune Miku, a giant anime-styled woman, walking through New York.", + "Anime portrait of Cristiano Ronaldo as a shaman yedi using dark force to eliminate Messi as an anime antagonist.", + "Medium shot manga pencil drawing of Alita by Yukito Kishiro in black and white.", + "A photograph of a woman from Steven Universe with gigantic pink ringlets and a white dress.", + "A beaver in formal attire stands next to a stack of books in a library.", + "An image of banana ducks, created by adding duck beaks to peeled bananas with googly eyes.", + "A painting of a pokemon resembling a phone booth wearing clothes walking on two legs in a cobblestone street in a magical city.", + "Anime illustration of Gundam mech suit on Pixiv.", + "A girl sneaking behind a giant wooden door with archaic symbols embedded onto it, in a cave with the waterfall, illustrated in comics style.", + "Cat wearing cowboy hat rides on corgi during sunset in the Wild West.", + "An anime-style depiction of a boy that showcases impressive artistic skill.", + "Mickey Mouse performing at Woodstock.", + "A comic book illustration by John Kirby depicting a jungle fortress surrounded by dirt walls in a marketplace setting with cinematic rays of sunlight.", + "A kangaroo in an orange hoodie and blue sunglasses stands on the grass in front of the Sydney Opera House holding a sign that says Welcome Friends.", + "A girl with white hair and a school uniform, depicted in an illustration with warm clothes and a cold background.", + "An angry weeb fan destroys his monitor and smashes his keyboard after Genshin Impact shut down.", + "A cat sitting besides a rocket on a planet with a lot of cactuses.", + "A cat floating inside the International Space Station.", + "Portrait of a monkey wearing an astronaut helmet.", + "A young adult Japanese delinquent with a pompadour hairstyle named Jotaro Kujo stands in an action pose ready for a fight on the cover of a manga issue.", + "A banana duck and his pet dog are depicted in a minimalist stock illustration.", + "A yarn monster attacking a city.", + "The image is of a cute young owl sleeping in a tea cup.", + "Comic art featuring Walter White and Saul Goodman at sunset.", + "A squirrel is riding a skateboard.", + "A dog wearing a Burger King crown is eating a Big Mac.", + "Cartoon-style rock texture with a seamless design, created using a 4k resolution Substance material, resembling an Anime aesthetic.", + "Up close portrait of Mr. Bean's face.", + "A banana spaceship reminiscent of Homeworld.", + "Funko Pop toy of a capybara with sunglasses.", + "Photo of Violet Parr from The Incredibles in a two-piece dress at the beach.", + "A bird is speaking into a high-end microphone, wearing headphones in a recording studio.", + "A cartoon satanic priest depicted as an anthropomorphic lamb in a highly detailed 3D render.", + "Goku and Vegeta battling as Super Saiyans.", + "A digital painting of a cyberpunk anime woman with intricate and highly detailed features.", + "A hamster is eating spaghetti in a candlelit restaurant table with different poses.", + "A beaver in attire stands near books in a library.", + "A manga-style illustration of Harry Potter as a Gundam mech.", + "The image depicts John Egbert from Homestuck.", + "A beaver in formal attire stands beside books in a library.", + "Lady with an Ermine is depicted in comic art by Carl Barks.", + "A portrait of Pikachu with an army of minions, surrounded by dramatic lightning and electricity.", + "Anime-style depiction of a moment symbolizing the start of a conclusion.", + "A pencil sketch of Victoria Justice drawn in the Disney style by Milt Kahl.", + "A Rika Furude plush toy.", + "A portrait of a stylized business cat in sharp focus with a medium shot perspective, resembling boxart.", + "Hank Hill arm-wrestling One Punch Man.", + "A dog wearing a business suit smoking a cigar in a cinematic style.", + "A toast with black sunglasses and a blue flower on the top right corner.", + "Rabbit minions of snail soldiers.", + "Portrait of Bugs Bunny.", + "A cat in space with a comic style reminiscent of Moebius and Laurie Greasley.", + "Illustration of an anime maid with a pretty face and eyes, shown in a full-body upper shot.", + "King Bradley from Fullmetal Alchemist Brotherhood with a serious look, in vibrant colors on a plain background, styled like comic cover art.", + "A hamster dressed as C-3PO from Star Wars in a movie still.", + "A digital art depicting a chicken wearing a suit.", + "The image is a medium shot of a highly detailed pencil-drawn black and white manga featuring the face of Alita by Yukito Kishiro.", + "Dogs eating bones and pigs eating cabbages.", + "A portrait of a cheerful, young Latino magician holding a grand and pink coffee mug, surrounded by stars and magic.", + "An anime-style advertisement featuring a pizza and an explosion.", + "Clint Eastwood fighting with a white Michelin man costume with hippo and alien plant surroundings during a beach sunset.", + "Anime portrait of a beautiful woman from Ghost in the Shell 1985 by multiple artists.", + "Anime movie poster featuring Princess Azula.", + "A monkey wearing a jacket.", + "Goku is depicted as a Jojo's Bizarre Adventure character in a dynamic and cool pose on a manga page, drawn in the style of Hirohiko Araki.", + "An anime girl with an athletic build poses confidently while holding an assault rifle, wearing a green tank top, black shorts, and aviator sunglasses.", + "A close-up anime portrait of Sailor Moon in front of Russian panel houses with a grey backdrop.", + "Two skunks - one small and dark blue with yellow eyes and a larger albino one with red eyes.", + "Spongebob wearing Rick Owens clothing in an avant garde fashion look on r/Streetwear.", + "Arnold Schwarzenegger depicted as a magical anime character by Kyoto Animation.", + "The Moomins are shown in fluffy knight armor exploring an enchanted forest filled with magic trees, mushrooms, moss and glowing fairies.", + "A Dorohedoro Caiman Funko POP figurine is displayed with its accompanying box.", + "The image depicts Kazuo Koike in a catfish rogue costume surrounded by various anime artists' renderings.", + "A female Sonic the Hedgehog with black sclera and bright red pupils.", + "Bob Ross riding a brown bear in Alaska.", + "A monkey in a suit smoking a cigar looks sad in a close-up shot.", + "Three friends running a magical bakery, including a young man with glasses and black hair, a beautiful brunette woman, and a cute goth girl with red hair.", + "The image shows a friendly owl perched on a pile of books.", + "anthropomorphic Virginia opossum playing guitar.", + "A hyena fursona sits in a savannah sunset amidst the grass.", + "A sailboat emoji with a rainbow-colored sail.", + "Goth girl in a maid outfit in a cluttered bedroom.", + "Monkey giving thumbs up in a selfie.", + "The image is of an anime-style girl with a dog, featuring intricate and elegant art in the style of Arthur Rackham.", + "A robot cuddles kittens.", + "A goat explores a cat jungle on Mars.", + "Illustration from \"The Cat in the Hat\" by Dr. Seuss featuring the mischievous cat balancing various household items on his umbrella.", + "A pile of chicken eggs with a confused chicken nearby.", + "Winnie the Pooh cartoon featuring Eeyore's rejection by female squirrels in the forest.", + "A portrait of a young goth woman wearing Warhammer bikini armor.", + "A bowl of eyeballs in milk.", + "A kangaroo wearing an orange hoodie and sunglasses holds a sign in front of the Sydney Opera House.", + "A girl peers over the edge of a mountain at a giant city in the dark of night, depicted in a manga illustration by Kentaro Miura and Hiromu Arakawa.", + "Nigel Farage depicted in World of Warcraft.", + "A children's bag shaped and themed after Shrek, with a Shrek head design as a bottle.", + "A depiction of the character Gunther from the show Adventure Time.", + "Katara from Avatar smiling at camera.", + "An anthropomorphic cat riding a Harley Davidson in Arizona with sunglasses and a leather jacket.", + "A pizza with a cat on top.", + "The image depicts a muscular mouse wielding assault rifles, in a Disney art style.", + "A head and shoulder portrait of a scary black cartoon rabbit wearing a shirt with big eyes and a laughing expression in a Walt Disney retro art style.", + "Marine Le Pen wearing a 70s-style outfit and sporting an afro hairstyle.", + "A fire-type Pokemon is depicted in concept art found on ArtStation.", + "A kangaroo wearing an orange hoodie and blue sunglasses stands on the grass in front of the Sydney Opera House, holding a sign that says Welcome Friends.", + "A cute robot lifting dumbbells and drinking a shake.", + "A cozy tavern with a retro video game vibe and cinematic lighting, featuring a cartoon-style animation and detailed background art.", + "A knitted Capybara wearing sunglasses sips a Mojito at the beach during sunset.", + "A corgi wearing athletic clothes lifts weights in a gym.", + "Portrait photo of Ahsoka Tano.", + "Studio photo portrait of Lain Iwakura from Serial Experiments Lain wearing floral garlands over her traditional dress.", + "Milt Kahl's sketch of Cecil Turtle.", + "Jack Skellington dances with Sally in a dystopian desert Christmas scene.", + "A winter mountain landscape at deep night with snowy terrain and colorful flowers, under beautiful clouds and no people, portrayed as an anime background illustration with intricate detail and sharp focus.", + "Cargo bay interior designed in the style of Cowboy Bebop.", + "A cinematic shot of Mr. Bean in a Black Ops outfit.", + "Anime art of One Piece character with intricate details.", + "Digital art of a female marten animal cartoon character wearing jewelry with a blonde hairstyle.", + "Illustration of a sad octopus in a children's book.", + "A woman from The Simpsons is dancing in the rain in a detailed and intricate fantasy environment.", + "A mushroom playing an acoustic guitar, depicted in a manga style by artists Takehiko Inoue, Artgerm, and J.C. Leyendecker, showcasing intricate and elegant details.", + "An astronaut rides a monocycle on Saturn's ring in a vintage scifi comic-style propaganda poster inspired by the space race and cold war era.", + "Communistic poster with anime styled characters.", + "Doctor Who characters performing as a boy band.", + "A cat-tiger hybrid animal in a house.", + "An anime-style advertisement featuring a pizza with multiple explosions in the background.", + "A cat protecting her kittens from eagles in a wholesome illustration.", + "An infinitely long hotdog on a white background.", + "The image depicts characters from the animated series \"Kim Possible\" created by several artists.", + "A monkey is pictured acting as a DJ.", + "Yoko Ono flying on a broomstick in the sky.", + "Albus Dumbledore dressed up as Wonder Woman.", + "A cinematic shot of Avatar Azula.", + "A comic book cover featuring three cyberpunk hired guns posing for a portrait.", + "An anime girl shrugging on Artstation.", + "A crayon drawing of Saul Goodman created by children.", + "The image features a beet next to Dwight Schrute.", + "A yellow cartoon character with pointy ears, red cheeks, and a lightning bolt tail known as Pikachu.", + "Studio Ghibli artwork.", + "A gangster squirrel is counting his money in a low angle film still.", + "An image related to a Team Fortress 2 update.", + "Photo of Ty Lee from Avatar.", + "Spongebob wearing avant garde fashion by Rick Owens.", + "A landscape with a Walt Disney-styled building.", + "A syrup bottle crashes into twin towers made of pancakes.", + "Nendoroid figurine of a princess dressed as a serial killer.", + "A person standing on a street corner holding an umbrella as Skittles rain down around them.", + "The image is a promotional artwork for the Sailor Moon Eternal animated fantasy film from Japan.", + "A sideview of toilet bowls in a battle pose, each with tentacles and glowing red eyes, creating a complex action scene in a bathroom.", + "Photo of a chocolate-type Pokemon card.", + "A chimp wearing a suit is smoking.", + "A ps2 anime witch from madoka magicka is flying on a broom through New York causing people to run for their lives due to a terrorist attack.", + "Medium shot black and white manga pencil drawing with a highly detailed face of Alita by Yukito Kishiro.", + "Full-page scan of Twilight Sparkle concept art by Lauren Faust.", + "Monkey D Luffy in DragonBall Z by Akira Toriyama, in color.", + "A fluffy chick is nested in an antique coffee cup in a humorous illustration.", + "A Pokemon resembling a line graph.", + "A red cat is walking down outdoor stairs in warm colors.", + "Excited Sonic fan experiencing the franchise for the first time.", + "A pikachu in a forest illustration.", + "A mushroom house in a dark forest, with warm light emitting from its windows.", + "The image depicts a large cat snoozing on the central area of a Starcraft 2 game map.", + "A man is shown driving a homemade cardboard tank inspired by the artwork of Edmund Leighton.", + "The image depicts a crazy chef and includes the names of various artists who contributed to its creation.", + "The image is a portrait of Homer Simpson as Thanos from Infinity War, with a detailed and photorealistic face and dramatic cinematic lighting.", + "A blue bill pokemon wearing a red puffy coat.", + "Luke Skywalker with Muppets.", + "A head and shoulders portrait of a black cartoon rabbit wearing a shirt and laughing with big eyes in the style of Walt Disney animation.", + "Japanese magazine advertisements in anime style.", + "A girl in a school uniform playing an electric guitar.", + "Jawa from Star Wars (1977) drinking a beer in a pub.", + "Image description, Spromple Sploop, third brother of Sans Undertale found as a new character in Undertale.", + "A Fortnite poster featuring chibi kittens wearing cyberpunk headphones and shades, with anime-stylized art by Takeshi Murakami.", + "A man wearing a cat costume and tuxedo poses for a portrait.", + "A theater dressing room with a mirror, chair, and couch, drawn in Day of the Tentacle style by Peter Chan.", + "A comic book style portrait of a schizophrenic person living in a parallel world with intricate detail.", + "A cat dressed up as Thomas the train.", + "Young girl with red hair and freckles holding a large bong.", + "A fullbody photograph of a lifelike human anime girl in steel plate armor, surrounded by books, with pink hair and realistic features.", + "An anime alien checks her email while submerged in translucent goo in a still from a Kiyoshi Kurosawa movie.", + "An art piece depicting a Capybara wearing sunglasses.", + "A female anime character depicted as Catholic.", + "Photograph of an anime figure with detailed features.", + "A man in a tiger costume on a street fighting pose.", + "A Miyazaki-style scene featuring an expansive landscape with mountains in the background, wind-swept fields in the foreground, and a steam-powered train moving from left to right on tracks in the middle.", + "A manga-style illustration of a cyborg Doctor Strange drawn by Moebius and Stephan Martiniere.", + "Comic art featuring a capybara.", + "A black cat wearing a suit and smoking a cigar.", + "A man drinking cosmic energy depicted in an anime style digital art.", + "A West Highland white terrier holding a \"Hug me!\" sign.", + "A kangaroo wearing an orange hoodie and blue sunglasses stands in front of the Sydney Opera House holding a sign that says Welcome Friends.", + "A woman sleeping on a bed of doughnuts.", + "A demon anime girl wearing obsidian armor stands in front of a dark background with falling ash in this Masami Kurumada-inspired digital painting.", + "A cat with wings featured in a comic book story.", + "Anime-style illustration of a cloudy sky by Yoshitaka Amano.", + "Full body shot of an anthropomorphic cockroach wearing a suit with long thin antennae and vibrant colors.", + "A blue bear wearing cowboy boots.", + "An ultra-realistic manga illustration of a special forces soldier in a remote village, with cyberpunk, sci-fi, and fantasy elements.", + "Cartoon-style badger wearing a scarf against a green background.", + "A cartoon catfish with a large body.", + "A cute anime-style female cat girl with large eyes is pictured underwater with a simple background.", + "Kid reading a comic with a flashlight in a cozy attic filled with antiques and furniture, illustrated in the style of Studio Ghibli, Tekkon Kinkreet, Akira, Breath of the Wild, and Miyazaki.", + "Image of Tifa Lockhart in a JoJo's Bizarre Adventure style.", + "Yoda performing at Woodstock.", + "\"The Little Prince talks to the Fox in the style of The Nightmare Before Christmas.\"", + "A young man in a small Tokyo room with an open window sits at his computer surrounded by anime posters and appliances, while a small bed remains unmade in the background.", + "A crowd of pink elephants playing steampunk instruments during a grindcore show.", + "A child squeezing a party ball.", + "A cat and dog are fixing a website on a laptop.", + "A Funko Pop figure of Rika Furude.", + "A die cut sticker featuring Tony Chopper wearing a strawhat and splatter paint.", + "An anthropomorphic frog wizard wearing a cape and holding a wand.", + "Blond-haired girl depicted in anime style.", + "The image features various anime characters.", + "Magical girl manga cover with intricate detailing.", + "Anthropomorphic sun battles Darth Vader.", + "An egirl with pink hair and extensive makeup.", + "Batman is shown working as a Bagger at a grocery store.", + "The image depicts a character from the anime Demon Slayer.", + "A lemon character with sunglasses on the beach.", + "A bluejay is eating spaghetti.", + "A lizard wearing sunglasses.", + "A furry convention in a luxury hotel.", + "The image depicts Fubuki from the anime/manga series One Punch Man.", + "The image depicts a muscular man wearing a lightning vest, with sharp facial features and intense, focused eyes, set against a dungeon exterior and painted in a dark, anime-inspired style with rough brushstrokes, giving it the look of an oil painting.", + "Animation keyframes featuring a wolf's walking motion.", + "A hotdog playing golf in hell.", + "Danny DeVito and Rhea Perlman playing Link and Zelda in a cinematic still.", + "Portrait of girl from Big Trouble in Little China, illustrated by various artists.", + "Portrait of a dog astronaut wearing an astronaut helmet.", + "Comic art featuring Walter White and Saul Goodman at sunset.", + "Portrait of a man camouflaged as a waffle on a plate.", + "Marine Le Pen with afro hair and 70s fashion attire.", + "Walter White dressed as a medieval-style king.", + "A scared cat is being attacked by a giant carnivorous sandwich.", + "Wallace smoking crack from a crack pipe in a still from the short movie A Grand Day Out (1989), Wallace and Gromit, Aardman Animations, claymation.", + "The image depicts an anime woman in a futuristic cyberpunk setting, with intricate and highly detailed elements crafted in digital painting.", + "A black and white comic book panel featuring Dream (a character from the comic series \"The Sandman\") standing alone with a speech bubble above his head.", + "Portrait of a monkey wearing an astronaut helmet.", + "Mabel Pines eating a donut in a colourful digital drawing.", + "A scared cat running away from a giant carnivorous sandwich.", + "A garden gnome wearing goggles and a headscarf hanging off the back of a car in full speed in a wasteland.", + "A Pok\u00e9mon battle between Obi-Wan Kenobi and Anakin Skywalker.", + "A birthday greeting for Pungeroo.", + "A girl reading on a bench in an abandoned train station depicted in stylized anime art.", + "A drawing of a haunted house made by children.", + "A cute and fluffy baby caterpillar with a grin, dancing and fighting, and a screaming bioluminescent toy monster creature.", + "Bob Ross riding a brown bear in Alaska.", + "A cat wearing a war helmet.", + "Walter White as a Squishmallow on a clear background.", + "Young wizard practicing a spell while holding a spell book and a black ball in a large room, wearing intricate leather armor, in a comic cover art style with a plain background.", + "Asuka Langley Soryu from League of Legends, depicted with fluorescent skin and an art style inspired by various artists.", + "\"Tetsuo and Kaneda engage in a race through Neo Tokyo in a pencil drawing featuring a scribbled style.\"", + "A princess and a frog.", + "An illustrated man hides under an umbrella to avoid a falling rain of french fries.", + "Charlie Brown is depicted at a music festival in an illustration by Charles Schulz.", + "A hedgehog using a calculator.", + "A cute plush griffon with a lion body and seagull head.", + "King Kong body slams Kirby.", + "A green velociraptor is shown in a boxing position with a brown horse.", + "Anime girl with a medium shot, gazing away, featuring cyberpunk elements, inspired by wlop and currently trending on Artstation.", + "Image of He-Man, a muscular superhero from the Masters of the Universe franchise, holding a sword and standing confidently with a determined look on his face.", + "Bald Walter White meets a smiling anime girl with a hime cut under a tree.", + "A kitten is using chopsticks to eat sushi.", + "Image of a vodka bottle with a human face and features drawn onto it, giving it a personified appearance.", + "An anime-style Pieta depicting Michelangelo's sculpture.", + "Young Japanese boy in an action pose with an angry expression and pompadour hairstyle on a manga cover.", + "Spongebob walking in a Marvel movie.", + "Twilight Sparkle from My Little Pony.", + "A portrait of a silver and white brindle persian cat dressed as a renaissance queen, standing atop a skyscraper overlooking a city.", + "The image is a stylized medium shot portrait of a business monkey that would be suitable for box art or advertising.", + "A Pikachu superhero in a hyperrealistic style.", + "A Ken Sugimori character sheet for a gazelle Pokemon design.", + "An anthropomorphic Nintendo 64 controller devouring children.", + "A biomechanical portrait of a jester with wide eyes, created by Masamune Shirow, Wayne Barlowe, and Roger Dean in anime style.", + "A cute humanoid cat soldier wears a yellow raincoat, carries a rifle, and ventures through a dense forest, looking back over their shoulder.", + "Cthulhu hiding in a Where's Wally puzzle illustration by Martin Handford.", + "Punk rock band with banana microphone and background singers in banana costumes, photographed during a concert and later reimagined as a pencil drawing.", + "Manga page of \"The Lord of the Rings\" in the style of \"Fullmetal Alchemist.\"", + "An anime portrait of fire spirit twins with red eyes and fiery skin wearing clothes of flames, digital painting by several artists, trending on ArtStation.", + "A man drinking cosmic energy in an anime-style digital art by Park Sung-woo.", + "Tupac Shakur in an anime.", + "Retired Wile E. Coyote having fun.", + "Duke Nukem casually drinks his beer while his house is ablaze.", + "Super Mario selling mushrooms in an LA alley, dressed in his signature outfit.", + "Two contrasting superheroes in a cloud city with a starry background, illustrated in anime style.", + "A digital art depiction of a Pok\u00e9mon that resembles a chair.", + "Anime portrait of an Asian schoolgirl with her pet sugar glider.", + "An anthropomorphic mafia frog in a suit smoking a cigar.", + "A young boy with a pompadour hairdo is depicted in an action pose on a manga cover by artists Katsuhiro Otomo, Tetsuo Hara, Hirohiko Araki, Jotaro Kujo and Banchou.", + "Call of Duty-themed cats.", + "Sauron snowboarding.", + "Muppets dressed in military uniforms during WWII.", + "Kawaii-style alien monster game assets.", + "A lemon wearing sunglasses on the beach.", + "A cute anthropomorphic fox knight wearing a cape and crown in pale blue armor.", + "A cute Pokemon resembling a blue duck wearing a puffy red coat.", + "Anime-style vector illustration of a cloudy sky with a unique perspective.", + "A Nintendo 64 controller with anthropomorphic features consuming small children.", + "A medium shot black and white manga pencil drawing of Alita by Yukito Kishiro.", + "Three cats disguised as a human, ordering drinks at a bar.", + "Darth Vader working as a short order cook in a diner in the new Star Wars movie.", + "An overweight plumber with a toilet plunger and a Mandalorian with a lightsaber are duelling.", + "A closeup of a monkey in a suit smoking a cigar with a sad expression.", + "A close-up portrait of Sailor Moon standing in front of Russian panel houses.", + "A close-up medium shot of a frog depicted in a flat ink sketch by Jim Lee, resembling a superhero comic book character.", + "Selfie of a cosplayer at comic con.", + "A goat and goose hybrid animal.", + "A cat dressed as Thomas the Train.", + "A moai wearing headphones.", + "Darth Vader working as a short order cook in a diner in the new Star Wars movie.", + "A frontal portrait of a anime girl with chin length pink hair wearing sunglasses and a white tshirt smiling.", + "A woman is wearing a cupcake costume.", + "Momo challenge defeats Cobra Commander.", + "A colorful cartoon tent in a bazaar with a borderlands-inspired aesthetic.", + "A still image of Johnny Bravo in Twin Peaks (1990) television show.", + "A man dressed up as Space Dandy attends an anime convention.", + "A still image from Star Wars featuring C-3PO as a hamster.", + "A kangaroo wearing an orange hoodie and blue sunglasses stands in front of the Sydney Opera House holding a \"Welcome Friends\" sign.", + "A Robert Crumb comic depicting beta males being subservient to women.", + "A photo of Big Chungus from Looney Tunes.", + "A pink stuffed kangaroo with a blue shirt sitting on a couch, and geared towards babies.", + "A cute anime cyborg-girl holding a hyper-detailed baby dragon.", + "A chibi frog character surfing at the beach.", + "Moomins in space suits discover the mushroom planet while flying with jetpacks.", + "Moomins in space suits flying around with jetpacks discovering the mushroom planet.", + "A lemon character wearing sunglasses on the beach.", + "A cartoon image of bananas dressed up in punk rock style, currently popular on ArtStation.", + "A candy house on the ocean in a fantasy setting.", + "Selfie photo of Luigi and Waluigi at an outdoor market.", + "Vanellope von Schweetz in a racing game.", + "A little boy flying through space eating pizza and cheese among candy planets in a comic book style drawing.", + "Man looking at his phone in anime style.", + "Sora and Riku are playing a game of Disney-themed chess.", + "The image depicts wires or cables made of salami.", + "A hamster wearing a beanie and smoking a bong.", + "\"A fox in magic school uniform stands in the center of a luminous magic array.\"", + "A man drinking a cup of cosmic energy in a surreal anime style artwork by Masafumi Harada.", + "A medium shot black and white manga of Alita by Yukito Kishiro, created through pencil drawing.", + "A photo of an anthropomorphic duck wearing a suit.", + "Mickey Mouse performing at Woodstock.", + "A pencil-drawn black and white manga medium shot featuring Alita by Yukito Kishiro.", + "A cute frog baby sits in a searose cup in a humorous illustration.", + "A medium shot of Alita's highly-detailed face in a black and white pencil-drawn manga by Yukito Kishiro.", + "Jinx from League of Legends in a pin-up pose.", + "The image depicts a person playing Warhammer.", + "Neymar in a still from a 2012 anime.", + "A golden retriever representing god.", + "A moon rocket from Wallace and Grommet.", + "The Little Prince talking to the fox in an animation shot by Tim Burton's art.", + "The image is a black and white manga pencil drawing of Alita by Yukito Kishiro, with a highly detailed face.", + "A portrait of a cavewoman supermodel wearing mastodon fur and warpaint, walking down a fashion runway in an art nouveau embroidered outfit.", + "The image is a portrait featuring characters and artists from various sources.", + "Psytrance artwork by Jhonen Vasquez.", + "Photo of Marjorie Taylor Greene cosplaying as Jabba the Hutt from Return of the Jedi.", + "Goku with an undercut haircut in anime art from Bleach.", + "A Minecraft character named Herobrine standing on a grass block with arms crossed and looking directly at the viewer.", + "Photo of Ronald McDonald with a menacing expression.", + "A howler monkey holding a joint in a full body pose.", + "A winged man disguised as a mothra wearing a traffic cone hat with a ripped physique.", + "An illustration from the realistic comic book \"Tiger White\" featuring detailed artwork by a skilled illustrator.", + "Funko Pop of a capybara with sunglasses.", + "A puppy driving a car in a film still.", + "Sully from Monsters, Inc. eating a peanut butter and jelly sandwich.", + "A lynx dressed in a flight suit.", + "Asuna from Sword Art Online.", + "A portrait of a tanned anime girl with white hair wearing a red shirt and a black leather jacket, with a cyberpunk aesthetic.", + "A small egg with arms, legs, and goggles.", + "A male teen wearing a dark formal overcoat and anime style, depicted in a portrait photo with dark short hair and brown eyes.", + "A cactus is crossing the street.", + "Dwayne \"The Rock\" Johnson in Pixar's Ratatouille movie.", + "A tiger wearing a train conductor's hat holds a skateboard with a yin-yang symbol.", + "Two cartoon characters, one cat and one mouse, facing each other with Jerry holding a piece of cheese.", + "A new artwork depicting Pikachu as a superhero fighting villains with dramatic lightning.", + "Monkey giving thumbs up in a selfie.", + "Frog emerging from yogurt.", + "Mario and Luigi relaxing on a beach.", + "A cockroach priest rides in a retro science fiction cabriolet.", + "Little boxers in action in a detailed and intricate environment.", + "A line drawing of a dog at the park in a coloring book.", + "A clean and well-composed coloring book line art of a bubble bobble character with no background.", + "A man playing 8 musical instruments with his 8 arms.", + "An elephant is seen flying near a cliff.", + "A watercolor portrait illustration of an anime girl as a vampire wearing a vintage kimono as a skirt with a dragon print crop top, a corset belt, and platform heels.", + "A Leonardo ninja teenage mutant turtle drinks tea at a wooden desk in a sci-fi space station with a view of a planet through a nearby window.", + "A portrait of a smiling Dragonite in a sunflower field with a cloudy sky backdrop.", + "Big Chungus is dying in a shootout.", + "A father and son work together to create a small robot in a black and white anime-style illustration.", + "The image is a humorous illustration of a furry alien chick nesting in a floral cup.", + "A medium shot black and white manga pencil drawing of Alita's highly detailed face.", + "Doraemon is depicted as the Terminator using the Unreal Engine.", + "Marjorie Taylor Greene depicted as Jabba the Hutt in a film-like image with a sweaty appearance, referencing Return of the Jedi.", + "A manga style illustration of a cyborg Doctor Strange by Moebius and Stephan Martiniere.", + "Portrait of a male furry anthro mountain goat in a pinstripe suit and waistcoat, smoking a cigar.", + "\"A kangaroo wearing an orange hoodie and blue sunglasses stands in front of the Sydney Opera House holding a sign that says Welcome Friends.\"", + "A D&D party consisting of a lion humanoid with a giant axe, an elf knight, a young half-elf wizard, and a tiefling female warlock, all depicted in a children's art book and appearing happy.", + "A dragon sleeps on a couch.", + "Three friends working in a magical bakery.", + "Close up portrait of a mash-up character featuring the heads of Albert Einstein, The Hulk, and Pikachu, in a comic book style.", + "A frog wearing an anime-inspired onesie.", + "A white bichon frise puppy dog riding a black motorcycle in Hollywood at sundown with palm trees in the background.", + "Silvio Santos character in GTA V loading screen.", + "Portrait of a cat astronaut with Japanese samurai helmets.", + "Shinji Ikari has swag.", + "A miniature anthropomorphic cat knight wearing pale blue armor and a crown.", + "The image depicts a young anime giantess with long blonde hair, sky blue eyes, an evil grin, wearing a bikini and miniskirt, in a highly detailed and cinematic wallpaper by Stanley Artgerm Lau.", + "A portrait of a fierce cavewoman neanderthal supermodel wearing mastodon fur on a catwalk during a fashion show.", + "Bigfoot hunting deer with a shotgun.", + "The image is a digital painting of a beautiful anime teen in a cyberpunk Kowloon setting with intricate and detailed elements of sci-fi and fantasy.", + "A cute little anthropomorphic bear knight wearing a cape and crown in pale blue armor.", + "An anime painting of a schoolgirl with a sugar glider by Gaston Bussiere, J.C. Leyendecker, and Craig Mullins.", + "A person dressed in a tiger costume stands on top of a dumpster in a street fighter scene.", + "A catgirl wearing a navy admiral-inspired Japanese military uniform.", + "A Walter White funko pop figurine.", + "A blue-haired girl with soft features stares directly at the camera in an extreme close-up Instagram picture.", + "A skeleton pirate.", + "The image is of a sad muppet funeral in a graveyard, with a casket, umbrellas, headstones, and rain.", + "Danny DeVito dressed up as Wolverine.", + "A clean line art of a bubble bobble character, with no background and suitable for coloring books.", + "Anime poster of a woman wearing futuristic streetwear with spiky hair, featuring intricate eyes and a pretty face.", + "A cute digital art of a unicorn.", + "Ben 10 transforming into Fof\u00e3o.", + "VeggieTales characters depicted as the Last Supper.", + "A teenage mutant ninja turtle, Leonardo, enjoys a cup of tea at a wooden desk in a sci-fi space station orbiting a large planet visible through a window.", + "Two colorful parrots perched together eating an egg tart.", + "A Salem black cat girl in anime style with a simple background.", + "The image is of a slim anime girl in a studio with smoke and pink violet light.", + "An Ultraman preparing to take flight.", + "A digital illustration of a Pixar minion dressed as a character from the horror movie Hellraiser, created by multiple artists.", + "Lebanese Kevin O'Leary chatting with Cookie Monster in a cafe.", + "SpongeBob in Dragon Ball style.", + "A comic portrait of an Indian goddess with realistic shading and fine details in an anime style, set at night and created by various artists including Ilya Kuvshinov and Artgerm.", + "Waluigi in a plugsuit leaning on a vending machine in a Japanese city with a tense atmosphere, as an Evangelion anime-style character.", + "A comical magazine poster of an ancient golden palace.", + "A unicorn with a car on its head.", + "A werewolf man.", + "A slim anime girl in a studio, surrounded by smoke and illuminated by pink violet light.", + "An illustration of Marceline the Vampire Queen and Princess Bubblegum embracing, created by artist Artgerm Wlop.", + "Tom, a rage face, a gray yellow pony, a cat, a car and a wombat at the anime convention in the subway.", + "A French bulldog wearing a chef's hat.", + "Undertale character Spromple Sploop, third brother of Sans.", + "The hero of time sits near a fire in an anime style illustration.", + "A Delorean covered in Pokemon stickers.", + "Goku fights Joseph Joestar in a flashy anime-style fist fight.", + "A 3D rendering of Mickey Mouse as a Funko Pop with a white background and high levels of detail.", + "A baby red panda wearing cake as a hat.", + "Tom Cruise shooting lightning at Oprah.", + "Image of a viking battle in manga style with intricate details.", + "A goose goat hybrid depicted in the image.", + "A landscape with a Maya-style building and Winnie the Pooh on grass.", + "A dog resembling Hugh Laurie.", + "Geralt of Rivia in a Wallace and Gromit style animation.", + "Chuck Norris fighting the giant Cthulhu.", + "A happy pink dolphin flying through candy planets while eating pizza and cheese.", + "A toad baby sitting in a rose blossom, depicted in a humorous and detailed illustration.", + "FBI chasing a big orange pig through a swamp with Mar-a-Lago hotel in the background.", + "A koala bear dressed as a ninja smoking in a kayak.", + "A photo of Big Chungus from Looney Tunes, featuring a highly detailed and impressive depiction of the character.", + "A scared cat is running away from a carnivorous sandwich.", + "Hulk and Juggernaut battling in a high-energy action scene with dynamic poses and explosive effects.", + "Apes working in a Bed Bath & Beyond store wearing gold chains.", + "Cartoon image with blue sky, fluffy clouds and smoke-shaped clouds from the Lorax.", + "A young owl sits on antique books in a humorous illustration.", + "A vintage photo featuring Sonic the Hedgehog.", + "A black and white cat with a mullet is watching TV.", + "A muscular Squidward wearing blue goggles in an illustrated graphic novel featuring a colorful palette and various mediums.", + "A Pokemon duel between Kenobi and Anakin Skywalker.", + "The image features a cute, friendly magic turtle character design in light blue with ancient and welcoming elements.", + "Frontal portrait of an anime girl with chin length pink hair wearing a white t-shirt and smiling.", + "A giant cat sitting on a Starcraft map.", + "Bob Ross riding on a brown bear in Alaska.", + "The image features the Despicable Me character in the style of Kurzgesagt art.", + "A cute plush griffon with a seagull head and lion body.", + "An anthropomorphic West-Highland-White-Terrier dressed as Iron Man in a cyberpunk style.", + "A movie poster featuring chicken, cow, capybara, and pig in an epic cinematic style.", + "Cookie Monster looking unhappy as his cookie supply dwindles.", + "A grandma wearing a Sailor Moon costume in the mountains looking at clouds.", + "Portrait of a girl dressed up as Dora the Explorer.", + "Snake eating salad.", + "Digital comic art of children playing tennis created by Marvel.", + "The image features a witch with symmetrical eyes, wearing a black leather jacket and jeans, with long blonde hair.", + "A close-up, fine-detailed anime portrait of Sailor Moon, set against a post-Soviet city landscape with deep bokeh effects in the background, created in the style of Hayao Miyazaki by Studio Ghibli in 2000.", + "Nigel Thornberry character in Fortnite.", + "A medium shot black and white manga image of Alita by Yukito Kishiro.", + "A landscape featuring a Kyoto Animation-style building.", + "A woman screams and cries after an alien invasion in a manga-style illustration with green tones.", + "Frontal portrait of a pink-haired anime girl wearing a white tshirt and smiling.", + "The image shows a pirate holding up a beer in celebration.", + "An image of an anime boy with impressive artwork.", + "A kangaroo wearing an orange hoodie and blue sunglasses holding a sign in front of the Sydney Opera House.", + "Teal and purple claymation image of Splatoon action on Nintendo.", + "A manga-style illustration of a alluring soldier in a cyberpunk/fantasy setting, with intricate details and various artists collaborating.", + "Capybara Funko Pop.", + "In the image, there is a tall, yellow-skinned man wearing green and black striped cycling shorts and a long red and black striped ostrich feather boa with a manically smiling expression.", + "Darth Vader dressed in I Love Lucy attire.", + "A closeup portrait of a petite young cyberpunk girl with Japanese heritage, styled in an anime aesthetic.", + "Superheroes having fun in a moshpit with a metal band on stage and a sunset in the background.", + "A little orange kitten sits on a pink heart-shaped pillow.", + "Kermit is portrayed as James Bond in an anime scene with cinematic lighting.", + "A brown bear pushes a shopping cart in a grocery store.", + "An anthropomorphic lemon wearing sunglasses at the beach.", + "Tupac Shakur in an anime screenshot.", + "A portrait of Spike Spiegel, a fire mage, dressed in an obsidian suit and devil disguise on top of a volcano.", + "An image related to the children's author Dr. Seuss.", + "Anime Costa Blanca by Studio Ghibli.", + "Zelda depicted as a white rabbit with pink eyes and long ears, standing on a grassy field with a blue sky in the background.", + "A corgi dressed as a bee costume.", + "Anime girl with blood and a horn on her head.", + "Portrait of Robin Williams in a humorous and cartoonish style by Gediminas Pranckevicius H 640 and Tomasz Alen Kopera, featured on ArtStation.", + "Still of Superman farting in a subway.", + "A closeup profile portrait of a tin toy fairytale princess in a snow bikini.", + "Hulk-Yoda hybrid character in center frame.", + "A die cut sticker featuring a princess Mononoke mask with a splatter paint design.", + "An apocalyptic scene from Kenshin.", + "The image depicts Ninja Turtles surrounded by pizza.", + "Illustration of Ninja Turtles in highly detailed, sharp focus by artist Goro Fujita on Artstation.", + "A cute anthropomorphic bear knight wearing a cape and crown in pale blue armor.", + "Tifa Lockhart is depicted in a crossover artwork with JoJo's Bizarre Adventure.", + "The image is titled \"The End of the World\" by Raymond Briggs.", + "A portrait of Spike Spiegel wearing an obsidian vest and disguised as a devil atop a volcano.", + "A man wearing a Batman costume holds a green glowing orb.", + "Anime character holding a axolotl with a black mouth mask.", + "A man drinking cosmic energy in a surreal anime-style digital art piece by Masafumi Harada.", + "A portrait of an anime girl with black hair and glasses in front of a school building.", + "A creepy cartoon rabbit wearing pants and a shirt, with dramatic lightning and a cinematic atmosphere.", + "Spider-Man and Deadpool are standing next to each other facing forward in a movie scene.", + "A lemon character wearing sunglasses on the beach.", + "Illustration from \"Last Saturday I Took My Mommy and My Daddy to the Zoo\" by Dr. Seuss and Don Freeman featuring warm colors and an autumn theme.", + "Photo of Ronald McDonald with a menacing expression.", + "A giant cat sleeps in the middle of a Starcraft 2 map.", + "A portrait of Fox McCloud firing a blaster in anthropomorphic furry art style from the Star Fox series, illustrated by Jim Burns.", + "A white porcelain knight in oversized toilet bowl armor.", + "The image features a depiction of Kermit the Frog as James Bond in an anime-style scene with dramatic lighting.", + "A digital shaman from Attack on Titan.", + "A cute young owl sits on a pile of antique books in a humorous illustration.", + "An anime painting featuring a sugar glider princess sitting on her throne.", + "A comic book cover featuring a superhero named \"Eagle Man\" with an eagle mask and wing logo, resembling a traditional comic book cover.", + "Izuku Midoriya and Yoji Shinkawa in a chibi-style artwork.", + "Close-up manga film still of nuclear shelter interior with side window view and a hyperrealistic film projection of a wired cyborg brain map scene in the background.", + "A person with a fish head.", + "Carl Sagan appears as a character in the TV show The Simpsons.", + "Snoop Dogg with a long neck.", + "A tall, yellow-skinned man with a manic smile wearing green and black striped cycling shorts and a red feather boa in a stylized comic art design by Joshua Middleton.", + "A Japanese delinquent with a pompadour hairstyle and an angry expression is seen in an action pose on a comic book cover illustrated by Tetsuo Hara, Yusuke Murata, Jotaro Kujo, Akira Kongou featuring an action scene from a manga issue, with Metal Bat and Kuwabara hairstyle in the background.", + "A realistic anime painting of a cosmic woman wearing clothes made of universes with glowing red eyes.", + "A manga style illustration of a futuristic submachine firearm concept by Moebius and Stephan Martiniere.", + "An illustration of a character holding a sword and wearing an anime-style outfit.", + "A cute Japanese girl with small horns and sharp vampire teeth, dressed in a white coat, praying on the floor of a destroyed library.", + "Cyberpunk girl sitting in a box in advanced anime digital art.", + "A highly detailed portrait of a creepy cartoon rabbit standing up wearing clothing in the style of 1960's Walt Disney animation with dramatic lightning.", + "An anthropomorphic cat wearing sunglasses and a leather jacket rides a Harley Davidson in Arizona.", + "An action scene in Neo Tokyo featuring Tetsuo racing Kaneda on his motorcycle.", + "Ahri in a pinup pose.", + "A cute chibi werewolf Animal Crossing villager.", + "An anime jedi girl holding red lightsabers, wearing a bikini and miniskirt, and with long blond hair and blue eyes, stands in a front view, mid-shot pose in a highly-detailed, cinematic wallpaper by Stanley Artgerm Lau.", + "A portrait of a beautiful anime girl with pink hair wearing a white t-shirt and looking directly at the viewer.", + "A woman and a bird depicted in black and white outlines, featured in a comic book panel by Moebius.", + "A serene, anime-style landscape with vibrant flowers and trees, picturesque clouds, and no signs of human activity.", + "A bear in an astronaut suit sits on a rock on Mars surrounded by flowers under a starry sky.", + "An elephant carrying a house on its back.", + "A warrior in futuristic jade armor, illustrated in a detailed manga style by Moebius and Stephan Martiniere.", + "A person is playing a mobile app game where they eat apple pie.", + "A girl in a school uniform making a scissor hand gesture.", + "Sketch of Cecil Turtle by Milt Kahl.", + "A kawaii anime depiction of a chronobreaker in a brutalist setting.", + "A comic portrait of a modern vampire with fine details and realistic shading in an anime style, set in the night.", + "A horse is using utensils to eat food.", + "Patrick Star dressed as an Evangelion pilot.", + "A detailed product photo of an anime Nendoroid figurine of Duke Nukem.", + "A mash-up of Tom and Jerry cartoon characters and Pokemon trading cards.", + "The cabbage merchant from Avatar, The Last Airbender meeting Team Rocket in animated style.", + "\"Rio de Janeiro depicted in an anime screenshot from 2012.\"", + "A capybara wearing sunglasses and a blue cap.", + "A female orc warrior depicted in an anime style suitable for Dungeons and Dragons.", + "A wooper and magikarp swimming in a river.", + "The image is a campaign poster with the text \"Vote Steve for President\" and features the Minecraft character Steve.", + "features the popular character sitting on a cloud with a pink and blue sky in the background.", + "A 1930s-style cartoon mouse.", + "A minotaur is doing aerobic exercises.", + "A little boy flying through space eating pizza and cheese with candy planets.", + "An illustrated cat with a black spot on her trunk sits in an old house with a window overlooking a blue sky.", + "A VTuber concept art of a curvy anime girl wearing a green tank top, aviator sunglasses, and holding a cigarette, with a long hair and tanned body.", + "The image features Elmo holding a raccoon.", + "The Joker eating a chicken tender at a Popeyes restaurant.", + "A flat ink sketch of a hedgehog in the comic book style of Jim Lee.", + "A minimalist tattoo inspired by the Studio Ghibli films.", + "A pink teddy bear witch teaches a classroom of younger teddy bear witches.", + "A purple Sonic the Hedgehog drawing.", + "A bunny rabbit wearing glasses and a rose gold Rolex watch, dancing at a rave with glow sticks, surrounded by neon lights and dancing people.", + "Undertale character Spromple Sploop, third brother of Sans.", + "Ayanami Rei from Neon Genesis Evangelion holding a chainsaw.", + "Frontal portrait of an anime girl with pink hair wearing a white tshirt and smiling.", + "A 3D render of Kirby eating a banana.", + "A smiling face made of spaghetti and ketchup.", + "A young boy crying because he is in love, illustrated by multiple artists.", + "Image featuring the title of a manga, \"The End of the World\" by Eiichiro Oda.", + "A sad emoji.", + "A sloth wearing black Morpheus sunglasses holding blue and red pills in his hands.", + "Cartoon catfish depicted as overweight with muted color tones.", + "Ducks performing kung fu.", + "Anti-drug poster featuring Spyro the Dragon.", + "The image features characters from The Simpsons in a Skyrim setting.", + "Predator alien in a McDonald's restaurant with an angry expression over the wrong hamburger.", + "A realistic, anime-style portrait of Cammy set against a green cosmic background.", + "A scene featuring characters from Gravity Falls.", + "A cartoon anthropomorphic lamb depicted as a satanic priest in detailed 3D render.", + "A dragon angel pimp is being worshipped by 100 adoring fans.", + "The image features an anime woman in a cyberpunk Kowloon technoscape, with intricate and elegant details.", + "A cat taking a selfie.", + "A cat inside a rocket on a planet with cactuses.", + "Peter Griffin is in a garden.", + "A photo of a pile of yoshi smoking weed on a wooden table in a magical forest.", + "The image depicts a red and white circular cardboard cutout of a cartoon character with the word \"POG\" written on it.", + "The comic book cover features a superhero in a yellow costume with an eagle head mask and a left-facing eagle symbol, with a publisher logo in the upper right corner and issue number one in the lower right corner.", + "A person wearing a business suit with a banana for a head.", + "Sketch of manga girl in dramatic pose.", + "Spongebob in Dragon Ball style.", + "John Cena dressed up as The Joker.", + "Voldemort dressed as The Grinch, striking a pose for a photo.", + "A giant cat meowing at Earth planet in a cosmic exploration setting.", + "Up-shot of a scifi shanty town favela street scenery with colorful metal rooftops and wooden and concrete walls, with intricate details in an anime style.", + "A woman wearing a colorful mecha body suit with a gun, and a pretty face, has an intimidating expression and glowing eyes, in an anime style artwork by RossDraws, Artgerm and Greg Rutkowski.", + "Danny Devito battling red crabs.", + "San Goku and SpongeBob are pictured in a battle.", + "Luffy and Goku are in a battle.", + "Beautiful portrait commission of a male furry anthro Black Reindeer fursona with a tail, wings and a cute attractive face wearing stylish black and rainbow galaxy clothes in a city at night while it rains.", + "An animated man in pain holds his tight hamstrings while seeking help.", + ".\n\nA full-body shot of an anime maid with rich detail, featuring a pretty face and eyes.", + "A bread cat.", + "Portrait of a monkey wearing a spacesuit and an astronaut helmet.", + "Darth Vader playing electric guitar on top of mountain.", + "The image is an artistic rendition of Kermit the Frog dressed as James Bond, with cinematic lighting.", + "A screenshot of Rio de Janeiro from a 2012 anime.", + "A cute and fluffy Oka animal from official Disney Pixar art.", + "A chicken smoking from a bong.", + "A teenage mutant ninja turtle drinks tea at a desk in a sci-fi space station with autumn light coming through a window.", + "Ron Weasley wearing a Peruvian hat.", + "A yellow submarine with Lisa Frank-style designs.", + "Matt Groening's illustration depicts an imminent explosion in Neo Tokyo.", + "Manga art of a Japanese girl holding a knife with a loving expression.", + "A portrait of a man resembling Super Mario against a stylized background.", + "A cute turtle baby sitting in a coffee cup, captured in an humorous illustration with warm colors and a night scenery.", + "There is an anthropomorphic male wizard in the image wearing 3D cinema glasses.", + "A digital art image of an anthropomorphic chicken wearing a suit.", + "Celine Dion is riding a dinosaur at the beach.", + "A man with multiple squid stapled to his forehead, looking weary and sporting a thick rope mustache, but otherwise unidentified.", + "A kangaroo wearing an orange hoodie and blue sunglasses stands in front of the Sydney Opera House holding a sign that says Welcome Friends.", + "A chimp is pictured wearing a suit and smoking.", + "The image depicts a bee dressed as a chef, created by Greg Rutkowski under the pseudonym Artgerm.", + "Light Yagami character in Dead By Daylight.", + "A shy cartoon little komodo.", + "A grandma dressed as Sailor Moon hiking in the mountains, gazing at the clouds.", + "A yearbook photograph of a prom queen named Liz Truss who is infected with squid pox and has tentacles and weeping sores, voted most likely to initiate nuclear war.", + "Photo of a sad muppet funeral with a casket at the graveside outside a church in a 1980s style.", + "Crying angry cartoon man designed by Jim Davis.", + "A portrait of a cat wearing a samurai helmet.", + "Portrait of Lain Iwakura in traditional dress and floral garlands.", + "The Green Lantern featured in a TV show.", + "A diaper disposal robot overflowing in a colorful, celshaded girly bedroom.", + "A white rabbit watches the sunset on the beach.", + "An oversized white knight wearing a toilet bowl helmet.", + "Funko Pop figure of a Slime Rancher character.", + "A raccoon wearing a suit smoking a cigar.", + "A grey background studio-lit American Psycho Funko Pop figure with intricate detailing and smooth aesthetics.", + "A picture of Peppa Pig.", + "The image depicts a fat warrior man holding a large sword in a green field.", + "Spiderman fighting Venom on a comic book cover in an abandoned warehouse.", + "The image features a cartoon coconut with a ko ko nut, created using Unreal Engine 5 and Octane Render.", + "A close-up anime portrait of Sailor Moon in front of a Russian panel house landscape with a grey color scheme and deep bokeh.", + "A manga-style illustration of a submachine gun in 2050 by Moebius and Stephan Martiniere.", + "Among Us character featured in Super Smash Bros. game.", + "Shaggy wins against Superman.", + "Celine Dion appears angry at a kitten in a hot tub.", + "A cute little horse knight wearing a pale blue armor, cape, and crown.", + "Friedrich Nietzsche bodyslams Karl Marx in a vintage WWE championship match.", + "The image depicts a milkman dressed as a superhero in a comic book style with flat shading and hyper-detailed features.", + "Frontal portrait of an anime girl with pink hair and sunglasses wearing a white tshirt.", + "A photo of Homer Simpson.", + "The image is full color line art.", + "A movie poster for a 3D animated film about Matteo Salvini.", + "Anime style women in Japanese school uniforms, designed as D&D video game characters by various artists.", + "The image features a Brigitte Bardot Barbie doll with vibrant colors, vivid details, and a fantasy background.", + "Peach-colored cartoon robots in need of love.", + "Tsunade from Naruto is in the center of the frame in a medium shot.", + "Homer Simpson as an Avenger in Endgame.", + "A furry cat girl.", + "A close-up portrait of Sailor Moon standing in front of a panel house with a grey winter landscape and an Orthodox church in the background.", + "A girl in school uniform making a scissor hand gesture.", + "An illustrated robot serving coffee with coffee beans, steam, and friendly expression.", + "Surgeon operating on a baked potato.", + "A lemon wearing sunglasses on the beach.", + "A kangaroo wearing an orange hoodie and blue sunglasses stands on the grass in front of the Sydney Opera House holding a sign that says Welcome Friends.", + "Image, Ben Garrison comic.", + "Pop figure of mom playing the piano with glasses and slightly curly brown hair.", + "A lemon character wearing sunglasses on the beach.", + "A film still from the animation \"The Chinese Market\" directed by Walt Disney.", + "Anime-style character sheet of cyberpunk women in highly-detailed digital painting by artgerm, greg rutkowski, krenz cushart, and wlop.", + "A yellow striped monster panics while using a laptop.", + "Beatrice from Umineko in an official anime visual.", + "A man drinking a cup of cosmic energy in a surreal anime style by Hideaki Sorachi.", + "A joyful school girl sailor moon portrait in front of a Russian panel house with a grey winter background, closely captured with a deep bokeh effect.", + "Hank Hill mishandling propane in a comic strip.", + "A full-frame, adorable purple monster from Pixar.", + "A portrait of a teenage mutant ninja turtle with a cute face in a dark fantasy style." + ], + "concept-art": [ + "Product still of the new iPhone 2.0 in 2029.", + "Pippi is tethered to the international space station in her space suit amidst stars and galaxies.", + "A dragon attacking burning medieval hobbit homes in a picturesque landscape with a waterfall and bridge.", + "A samurai wields a glowing magical katana in a dynamic pose, with eastern influences and fantasy elements in a striking, detailed image.", + "Elves attacking New York City.", + "A detailed, realistic image of a biohazard lab evacuation with horror influences and multiple art styles incorporated.", + "A samurai cloaked in white with swords stands in a light beam of a dark cave, with a ruby red sorrow evident in the image.", + "A giant burning pineapple illuminates the forest and mountain backdrop in this cinematic concept art for a video game.", + "A close-up portrait of a beautiful girl with an autumn leaves headdress and melting wax.", + "A minimalistic fisherman in geometric design with isometric mountains and forest in the background and flying fish and a moon on top.", + "A vaporwave wallpaper filled with intricate patterns and shapes in gold, ruby, quartz, and sapphire.", + "Disney Concept Artists created a fugue with blunt borders following the rule of thirds.", + "Splashart of a champion composed of bubbles.", + "A digital rendering of an Asian valkyrie with snakes surrounding her, wearing an ankh necklace.", + "A black wolf standing on a fallen tree in a winter forest.", + "Pixar environment created using Renderman.", + "A hyper-realistic 3D render of the Doge meme featuring a shiba inu, portrayed in cinematic lighting.", + "\"A digital artwork titled 'Your Forger' created by Peter Mohrbacher.\"", + "A product still of metallic black and white Nike shoes with a red glowing swoosh, styled after Darth Vader.", + "A blueprint of a posthuman robot design.", + "A lion smoking a cigar in a cinematic style with dramatic lighting and high detail.", + "Image featuring the apocalyptic scene from Final Fantasy.", + "Portrait of Sailor Mars with an alien/machine face, intricate and elegant, highly detailed concept art by Artgerm, Greg Rutkowski, and Alphonse Mucha.", + "A closeup of a schoolgirl posing from behind in dramatic lighting and focus, created by artist Ayami Kojima.", + "A dark school corridor during after hours.", + "A monk in an orange robe looks out of a round window in a spaceship in dramatic lighting.", + "An alien species creates humans at the beginning of time.", + "A digital artwork with a techno theme created by Allie Brosh.", + "A spacecraft firing a railgun at a crumbling planet in a space background.", + "A modern exhibition showcasing the conceptual art of Pascal Campion, utilizing Unreal Engine 6 and ray tracing.", + "An underwater century city ruin portrayed in photorealistic detail using realistic paint and cinematic lighting, depicted as concept art on Artstation.", + "A depiction of Willa Holland as a vampire, with sharp teeth, symmetrical eyes, wearing a black leather jacket and jeans, and long black hair.", + "A digital illustration of a beautiful and alluring American SWAT team in dramatic poses, set in a post-apocalyptic cyberpunk Tokyo with overgrown vegetation and intricate sci-fi and fantasy details, created by multiple artists on ArtStation.", + "A desaturated cinematic portrait of a bigfoot.", + "A hyperrealistic concept art image with intricate detailing and sharp focus, showcasing a 3D perspective rendered with Octane, and features a lens flare effect.", + "Concept art of a cult of wizards in outer space.", + "A cyberpunk street filled with flying vehicles and towering corporate buildings dominates the skyline at dusk.", + "An anthropomorphic and surreal depiction of artificial intelligence's self-image.", + "The image depicts a dark and empty space.", + "A forest with blue flowers illustrated in a digital matte style by Dan Mumford and M.W Kaluta.", + "A portrait of a character in a scenic environment.", + "The image depicts a giant lovecraftian whale fighting the horrors of the unknown, with ornate scales and a round moon in the background.", + "A cyberpunk woman on a motorbike drives away down a street while wearing sunglasses.", + "A stylized portrait featuring sliced coconut, electronics, and AI in a cartoonish cute setting with a dramatic atmosphere.", + "A metal bat-shaped creature with a glowing red head and leathery golden body, wings spread wide as if about to take off.", + "A landscape with a building resembling the Iphone 4 front camera.", + "The image features a traditional girl with ornate and intricate details in rich colors, created by multiple artists and displayed on Artstation.", + "Design for a stool by Zahahadid.", + "Vintage astronaut taking photos.", + "A Japanese princess with a dragon in a dramatic combat pose in an arctic desert, depicted in intricate and highly detailed sci-fi/fantasy concept art reminiscent of trending works on ArtStation.", + "Photo of Alexandra Daddario as a cyberpunk warrior with weapons.", + "The image depicts an otherworldly landscape with a waterfall, trees, mountains, and lush greenery, under dramatic lighting.", + "Colorful scifi shanty town with metal rooftops and wooden and concrete walls in the style of Studio Ghibli and other anime influences.", + "A man with coral cyborg's parts wears Alexander McQueen style clothes in a highly detailed digital painting, set in a Soviet-style Disney Land environment.", + "A man on a boat is crossing a dark body of water in hell with creatures swimming around, depicted in \"Sea of Souls\" by Marc Simonetti.", + "\"Symmetrical portrait of a fantasy sorceress created by renowned artists including Yoshitaka Amano, Ruan Jia, Kentaro Miura, and Artgerm.\"", + "An ancient statue of a mushroom goddess wearing pagan clothes and leaves, located in a cedar forest.", + "Jason Momoa as a Viking.", + "A full samurai armor is worn by Spiderman, with fantastic details on the eyes and face, created by various artists trending on multiple art platforms.", + "The image is a front centered portrait of Elisha Cuthbert as a paladin wearing gold armor with blonde hair and dramatic lighting in the style of Eddie Mendoza, Raphael Lacoste, and Alex Ross.", + "Ben Bernanke portrayed as the villainous wizard Saruman in a piece of digital art.", + "The interior of a spaceship orbiting alpha centauri.", + "Portrait of a male furry anthro Blue wolf fursona wearing black cyberpunk clothes in a city at night while it rains.", + "Man on boat crossing a body of water with creatures in the water.", + "Front facing symmetrical portrait of Imogen Poots as a D&D Paladin character avatar, with Arcane League of Legends concept art style and global illumination lighting.", + "A biology diagram of a methane-breathing alien.", + "\"A photo of alien devices from alien spaceship.\"", + "A girl in a dress looks out from the edge of a mountain at a city that resembles a cat.", + "A portrayal of Dr. Manhattan as a hexagonal hypercube portal by Razorpunk, featuring art styles of Bastien Lecouffe-Deharme, Jonas De Ro, Jim Mahfood, and Luis Royo, in a sharpie fine tip embossing stamp rendering.", + "A 3D render of a cyberpunk-necromancer with intricate details and believable eyes, depicted in a front-facing, symmetrical view as epic fantasy art.", + "A digital painting of a full-bodied fat dragon dog, with highly detailed features and smooth textures, created by multiple artists, displayed on ArtStation as concept art.", + "The image is a stunning illustration of a knight warrior wearing Nordic armor and a Skyrim mask, with intricate details and dynamic lighting that make it perfect for RPG portraits and cosplay.", + "An SCP agency interior with rows of iridescent alien artifacts suspended in gold and quartz cylinders, surrounded by alien flora, with a dramatic camera angle emphasizing infinity.", + "A modular dome house, with a futuristic design and made of glass, is floating above a city in a cloudy atmosphere.", + "A mixed race D&D character.", + "A water squirrel spirit wearing a red hoodie sits under the stars, surrounded by artwork from various artists.", + "Concept art of a highly detailed landscape, centered and utilizing rule of thirds, with dynamic lighting for a cinematic effect.", + "A biblical Noah's Ark floats on turbulent waves with dark clouds and rain, depicted in a beautiful graphic propaganda poster art style from the 1970s.", + "A head-on portrait of a male orc druid in intricately detailed shaman leather armor, designed in an elegant Art Nouveau tarot card style.", + "A man is mid-air, leaping through the foreground with outstretched arms.", + "A 3D rendering of a robot screaming at a death metal concert.", + "An anime girl in a bikini merges with an orchid in an intricate and elegant portrait done in the style of Paolo Roversi, with soft lighting and high detail.", + "The image features various characters including Ghostbusters, by artists Jesper Ejsing, Rhads, Makoto Shinkai, Lois van Baarle, Ilya Kuvshinov, and Rossdraws, in a heroic world.", + "A cliff covered in blue blood.", + "Studio portraits of Innsmouth ocean-dwellers, mutant fishmen, in HP Lovecraft style with a dark monochrome atmosphere.", + "A cosmonaut in a spacesuit drinks tea at an old wooden desk in a richly decorated house.", + "A character named Edelgard from the video game Fire Emblem.", + "A portrait of bamboo living pods shaped like a sea shell embedded on the side of a cliff.", + "The image is a brandmark logo for an AI research lab in transportation, featuring vector art with no text and a trendy, hip corporate style.", + "Two girls holding hands while watching the world burn in the style of various artists.", + "A creepy long-legged monster stands on a bed of rotten flowers in a dark forest at midnight.", + "A cyberpunk Lamborghini is pictured in front of a dark and dirty cityscape, created by Nicholas Hiatt.", + "A photo of an attractive male android, half robot and half humanoid, posed like a statue on display at a museum.", + "An image of Akira, from the artist Simon Stalenhag.", + "An illustration of a fierce one armed Japanese woman in a power pose, wearing a long jacket and delinquent cap.", + "A low poly lemon in high quality rendering.", + "Metal dragon head badge with detailed relief, displaying a digital painting of concept art, illustrated by Giger, Rutkowski, Shimoda, Leighton, and Bowater.", + "\"Tifa Lockhart in a red cottagecore dress, illustrated portrait by Krenz Cushart, William Turner, and Wenjun Lin with rim and top lighting, set on an overcast background.\"", + "The image depicts Erebus's Titan, Revenant, and Reaver drone in high detail, created by artists Jenni Pasanen and Chris Rallis in a cinematic style.", + "A dinosaur in an exoskeleton, depicted in a highly detailed digital painting with a dark and eerie atmosphere.", + "A photorealistic Bob Odenkirk is sitting under a tree with a smiling anime girl with black hair and hime cut in a digital art anime key visual.", + "A neon-colored frog in a cyberpunk setting.", + "An AI art movie poster.", + "A cyberpunk Tom Waits character sheet created by various artists is trending on Artstation.", + "A VTuber model concept art of a beautiful girl in a black and yellow hoodie looking on a smartphone in her hand, with blue eyes, long hair, and a futuristic city background.", + "A shaggy creature resembling a mixture of a guinea pig and a battleship stands imposingly on a beach at sunset, depicted in tonalist shades of grey, blue, and red.", + "A god is meditating while floating down from heaven.", + "Black and white portrait of Thabo Mbeki with highly detailed ink lines and a cyberpunk flair, created for the Inktober challenge as part of the Cyberpunk 2020 manual coloring pages.", + "A portrait of an orc in a fantasy art style.", + "A photo of a male android, half robot and half humanoid, resembling actor Liam Hemsworth, posing stoically on display at a museum.", + "A cinematic Kingdom Hearts boss battle set during a stormy night.", + "An albino lion wearing a Mafia hat and smoking a cigar, digitally painted by multiple artists, trending on Artstation.", + "A girl looks out from the edge of a mountain onto a large city at night.", + "Robots from the 1950s playing Atari 2600 styled game.", + "An entrance to a dungeon at the base of an ancient mountain in the morning light with a Studio Ghibli inspired style, possibly done by Hayao Miyazaki.", + "The image features several female cyborg characters designed by popular artists on Artstation.", + "An elderly woman poses for a high fashion photoshoot in colorful, patterned clothes with a cyberpunk 2077 vibe.", + "A pirate is using a black dragon to light his cigar in a digital art piece that is trending on ArtStation.", + "The image features a big white cliff, a cargo favela, a wall fortress, a neon pub, and some plants, with vivid and colorful style depicted in hyperrealistic CGI.", + "The image depicts Nikola Tesla with thunderbolts and glowing white eyes.", + "A 3D rendering of Jakarta with yellowish light, resembling the real city.", + "A twisted horror head and shoulders portrait.", + "Psytrance artwork featuring a futuristic, intergalactic battle scene with intricate detail and vibrant colors, inspired by the video game Starcraft.", + "A photorealistic image of a giant floating glass sphere in a rocky landscape surrounded by a gentle mist.", + "Abstract retro-futuristic art depicting the passage of time and a clash between nostalgia and excitement for the future.", + "The image depicts Darth Jar Jar from the Star Wars franchise.", + "The image features an intricate abstract artwork of Ronaldo Nazario, created by Tooth Wu, Wlop, Beeple, and Dan Mumford, and trending on ArtStation.", + "The image depicts a celestial artificial intelligence mind in a dynamic pose, with intricate details, in a dark fantasy style.", + "A portrait of Princess Mononoke in armor.", + "A man screams for help as he is sucked into concrete in a creepy and realistic horror scene.", + "Solar punk vehicle in a bustling city.", + "Poster of Captain America losing the war featuring artwork from various artists.", + "A waist up portrait of an old man smiling with red balloons surrounded by gothic concept art by artists Artgerm and Wlop on ArtStation.", + "A collection of futuristic hard surface exploration shapes and form kitbash comprising props, small gadgets, and game assets of varying sizes, designed by Simon St\u00e5lenhag with insane attention to detail and modular arrangements.", + "A portrait of a skeleton possessed by a spirit with green smoke exiting its empty eyes.", + "A full-body shot of a beautiful female in an intricate dress, with a sharp focus on her perfect eyes, captured by artist Artgerm in a snowy winter setting.", + "A mage in a mask creates a burst of power in a fantasy setting against a majestic meteor.", + "Two girls holding hands while watching the world burn.", + "A game screenshot featuring Woolie Madden with dreadlocks in Mass Effect.", + "Snoop Dogg with exaggerated facial features in a realistic 3D rendering.", + "Painting of a Bladerunner spaceship in concept art style by Bougeureau.", + "The image depicts a woman drowning in heavy rain in a bedroom with cables and clothes made out of veins, inspired by Zdzislaw Beksinski's art style.", + "Portrait of goth girl in Warhammer armor.", + "The end of the world from Final Fantasy VII.", + "A skeleton sits on a throne in a mountain of bones amidst detailed concept art.", + "A winged man resembling Norm MacDonald dressed as Mothra with a traffic cone hat and a ripped physique, illustrated by Tradd Moore.", + "A neofuturistic island city depicted in a photo-realistic illustration by five artists.", + "A digital artwork featuring Cyber Ultra Instinct Goku against a chaotic fractal background, rendered in Maya with hyperdetailed features and a cinematic shot.", + "The image is a side profile painted portrait of Ryan Gosling as an arrogant, blonde elf ranger for a D&D or Gloomhaven campaign with an art nouveau style, vibrant color lines, backlit effect, and created by Kuvshinov, Krentz, and Gilleard.", + "A golden dust cat shape created through digital art with sand texture and rendered using Unreal Engine.", + "A realistic and detailed digital illustration of a serial killer's basement by various artists.", + "Digital art of Niko as a president in the game #OneshotGame.", + "Interior of Microsoft flagship store designed in Wes Anderson style.", + "A head and shoulders portrait of a demonic figure.", + "A human portrait formed out of neon rain on a galactic background.", + "A creepy eldritch monster in a Swedish forest photographed from a low angle with detailed, realistic features and soft colors, inspired by Lovecraftian horror and created by artist Simon St\u00e5lenhag.", + "Exploded view diagram of a xenomorph.", + "An angel falling towards Andromeda.", + "A digital painting of a fantasy character wearing Mandalorian armor and wielding a crossbow with steampunk and Lovecraftian elements, created by artgerm, greg rutkowski, and magali villeneuve.", + "A digital concept art of a squatting school girl in uniform, with glowing lights and intricate details by artgerm, greg rutkowski and alphonse mucha.", + "Portrait of a male furry Black Reindeer anthro wearing black and rainbow galaxy clothes, with wings and tail, in an outerspace city at night while it rains.", + "A full-body portrait of a female cybered shadowrunner with a dark and cyberpunk atmosphere created by Echo Chernik in the style of Shadowrun Returns PC game.", + "The image is an inventory item in League of Legends, consisting of a key icon with an outer glow on a solid background.", + "A horror-themed old advertising poster featuring a man with a soccer ball for a head in a comic drawn by Junji Ito using pastels and gradients.", + "The image depicts \"The Creeper,\" a tall, yellow-skinned man wearing green and black striped cycling shorts and a red feather boa, with stylized comic art design influences from Joshua Middleton, Mucha, and Kandinsky.", + "An alien drinking cosmic energy from a cup.", + "The image is a digital art poster sized in the style of Utamaro Kitagawa featuring Lil Wayne.", + "Adventurers standing in front of a cave entrance in a fantasy art style.", + "Tombs of the universe depicted by Shinji Kimura on Artstation.", + "A graphic poster depicting the fiery end of the world with detailed botanical illustrations and artistic influences.", + "A painting by Greg Manchess depicting an anime woman.", + "The image is of a massive cave interior filled with glowing stalactites and stalagmites.", + "Image of a teenage girl wearing black clothing with dark makeup and piercing, staring intensely at the camera with a wild and crazed expression on her face.", + "A digital illustration of Harry Potter watering a cannabis field in Hogwarts.", + "A dragon sitting on a couch in a digital illustration.", + "Spooky haunted arcade machine with alluring overhead lighting.", + "A TV set mounted on a post in a colorful and highly detailed landscape.", + "A galaxy-colored DnD dice is shown against a sunset over a sea, in artwork by Greg Rutkowski and Thomas Kinkade that is trending on Artstation.", + "A hyperrealistic mixed media image of a hand with particle teleportation around the fingertips, featuring perfect symmetry, dim volumetric lighting, and stunning 3D render inspired art by Greg Rutkowski and Unreal Engine.", + "A digital artwork showcasing a futuristic and technological theme created by James Jean.", + "A crazy looking fish swimming in Alcatraz in the style of artgerm.", + "A dark-haired man, wearing a shirt, is shown from behind in a highly detailed digital painting by three artists.", + "An enormous bear pulls a canon on wheels on the eastern front during WWII.", + "Detailed image of a creepy family in deep space, created by Richard Corben and Katsuhiro Otomo, with intricate and extremely detailed artwork.", + "Man crossing a body of water in hell with creatures and sea of souls around him.", + "The interior of a spaceship orbiting alpha centauri.", + "Portrait of a digital shaman.", + "A silver surfer in a Japanese forest at dawn.", + "The image depicts a beautiful, thick female with long white hair, wearing a black dress with a belt around her waist, silver earrings, and a black choker, with a face resembling William Dafoe, in a highly detailed digital painting by multiple artists including Artgerm and Ilya Kuvshinov.", + "A man and a woman bounce in a living room with dark energy surrounding them, as a plant sits to the side.", + "The image features a collection of detailed sci-fi spaceship parts with pastel colors and greeble patterns.", + "Minimalistic surreal interior with arches, glass 3D objects, and abstract pools around.", + "The image is a simple futuristic logo for a company called Novita, with purple and maroon colors.", + "Surrealistic digital art with a futuristic and optimistic theme.", + "Portrait of a Totoro woman in digital painting style with a beautiful face, created by artist Donato Giancola and inspired by Alphonse Mucha and Joseph Christian Leyendecker's art, displayed on ArtStation and featuring other similar artists like WLOP and Boris Vallejo.", + "Close-up portrait of a goblin emperor with a bone headdress, depicted in vibrant colors with intricate details.", + "A witch is casting a water spell.", + "The Kremlin ruins are engulfed in flames in a digital art illustration with a fantastical style and Morandi color scheme.", + "Concept art for the video game Mystic Unity featuring visionary characters, dark magicians with elongated arms, and barbarian buddhas in epic landscapes.", + "A detailed profile portrait of a powerful Japanese samurai with beast-like features, created by Moebius and Laurie Greasley.", + "There is a secret museum of magical items inside a crystal greenhouse palace filled with intricate bookshelves, plants, and Victorian style decor.", + "\"Front centered symmetrical portrait of Elisha Cuthbert as a D&D paladin with cinematic lighting.\"", + "A cat from the video game Stray.", + "The image depicts a fierce war wolf mount in a fantasy video game world, with bioluminescent lighting and concept sketch art by Feng Zhu and Alena Aenami.", + "Digital art of suspended animation chambers with floating people inside.", + "A futuristic android prototype with clamps for hands and boots, wearing a nanotech swimsuit, is portrayed in a portrait-style image.", + "The image shows the grim reaper in a full-body pose, wearing a purple cloak.", + "A digital painting of Martin Luther King Jr in a highly detailed and elegant style, by artgerm, Donato Giancola, and Alphonse Mucha, posted on Artstation as concept art.", + "Low poly John Travolta in GoldenEye 64.", + "A young woman smiling in the etheric hypothalamus of her mind.", + "A photograph of a man turning into a donkey, teeth growing from his eyes, spinning inside a tumble dryer, and melting onto the floor.", + "An illustration of Garfield the cat created by artist Moebius.", + "A photo of a horse with a human face and skulls underneath, in gothic style.", + "A male android modeled after soccer player Antoine Griezmann stands motionless, appearing as half robot and half humanoid, on display at a museum.", + "The image depicts a concept art of Schrodinger's cat in a box with an abstract background of waves and particles in a dynamic composition.", + "Nikola Tesla and Aphex Twin playing the moog synthesizer for a Rolling Stones cover.", + "Digital art of Prince of Roses.", + "A portrait of a native Kerala warrior wearing a sci-fi inspired armor made of wood and cloth, with intricate details and elegant sci-fi tech wear.", + "A digital portrait of an attractive man in military uniform with glowing lights and a dynamic pose.", + "A mecha jet fighter engages in an air battle with an explosion as a backdrop, set against a dark, starry sky in a highly-detailed art piece by Stephan Martiniere.", + "A portrait of a flower Fairy with intricate details and dynamic lighting, presented in cinematic style and following the rule of thirds.", + "Two complementary forces depicted in a 3D rendering.", + "A person staring into a lucid dream world with an adventure waiting.", + "A high-tech laboratory hovers above a purple ocean in a sci-fi style reminiscent of artist Greg Rutkowski, inspired by Stanislav Lem's book Solaris.", + "A film still of Darth Vader working as a short order cook in a diner in the new Star Wars movie.", + "A dragon.", + "A cyberpunk illustration by Stina Persson featuring a robot and two lovers creating a string figure.", + "A young Spanish man drinks coffee from a magical green cup adorned with stars in a 3D graphic by Filip Hodas.", + "A tribal elder meditates in a futuristic temple.", + "The image depicts a colorful smurf-like ghost creature with a big eye, surrounded by sushi and roots in a micro-world with fluo fishscale accents.", + "Spiderman as Wolverine with detailed muscular features and a full face, trending on multiple art platforms, created with hyperdetailed Unreal Engine, and optimized for high resolution viewing.", + "A portrait of a Chinese cyberpunk machine decorated with Chinese opera motifs.", + "A girl in a luxurious wedding dress holds a ceremonial sword with a intimidating expression and red eyes.", + "A portrait of a woman with a paper bag over her head.", + "Miranda Cosgrove as Lilo in Disney's Lilo and Stitch live-action film, in character costume.", + "An image titled \"the end of the world\" by Greg Rutkowski.", + "The image depicts a fiery boss from the video game Dark Souls, illustrated in color by artists Paul Gustave Dore and Ivan Aivazovsky.", + "Viktor Reznov holding a wooden AWP, sporting a beige coat, fedora and sunglasses.", + "A low polygon ice monster depicted in concept art.", + "Architecture render with pleasing aesthetics.", + "\"Spawn by Todd McFarlane with intricate details.\"", + "A massive frog robot wreaking havoc on a city.", + "Head-on centered portrait of Maya Ali as a black-haired RPG mage, depicted in stylized concept art for a Blizzard game, by Lois Van Baarle, Ilya Kuvshinov, and RossDraws.", + "The image depicts a yautja character that is well-detailed and has proportional features, and it is currently popular on ArtStation.", + "A young black woman stands in front of a ringed planet in space.", + "The image features a cyberpunk goddess with no mouth, intricate details, neon lighting, and sweat drops.", + "An airport lounge designed by HR Giger.", + "A photograph of a porcelain statue of Holly Herndon in a glass jar, posing in a pointe position, with a futuristic background.", + "Dr. Dre wearing a cap posing with Clash Royal style characters in a cinematic and highly detailed artwork.", + "A detailed portrait of an organic sci-fi gadget modeled after the asura ghost from Chinese mythology.", + "A steampunk egg-shaped mech is in a winter village during a blizzard with lightning in the background.", + "Portrait of Leonardo da Vinci as a steampunk cyborg.", + "A close-up portrait of a woman with glowing neon wires in the background, featuring work by several artists including Beksi\u0144ski, Giger, and Whelan.", + "An image of Anakin Skywalker dressed similarly to the character Patrick Bateman from American Psycho (1999).", + "A horse and astronaut in one image.", + "Cat wearing backpack from Stray video game.", + "The image is a professional cel-shaded illustration by artist Seb McKinnon featuring a flaming quarry with a fantasy, magical vibe.", + "Image of the Sandman, an ancient wizard with hour glass, casting beautiful dreams.", + "A India-style wall panel with symbols is featured in a loading screen background concept art for a Russian MMORPG by Katauri, with a visual style reminiscent of Hearthstone.", + "Symmetrical Libra zodiac art by Brian Froud in a mystic style.", + "An astronaut in white futuristic cybernetic armor running on the surface of the moon, featured in an artwork illustration on Artstation.", + "A human sitting in a white chamber, with golden glowing surfaces and an otherworldly colorful life form representing their consciousness.", + "Pippi in a Wes Anderson film.", + "A 3D render of an engine room covered in translucent brown paper bags, styled after H.R. Giger.", + "A warrior wearing triceratops-inspired metal armor.", + "Image of Earth reflected in a human eye, rendered with Octane, in high resolution.", + "The image depicts a God smashing mirrors, while a detailed unicorn-dragon is present in the scene.", + "The image is a vibrant and intricate illustration of a man, with a focus on his shoulder and head, created using inkpen and Unreal Engine technology.", + "Up-shot of a colorful, intricate scifi shanty town with metal rooftops, wooden and concrete walls in the style of Studio Ghibli, Tekkon Kinkreet, Akira, and Breath of the Wild.", + "A half-robot, half-humanoid male android, actor Liam Hemsworth, in a statue-like pose with shiny skin and a blank stare displayed at a museum.", + "The artwork features the gods of the deep in a cinematic style, created by Nekro and Tomer Hanuka with intricate details.", + "English sash window with steampunk elements.", + "The image is a digital painting portrait of Belarusian President Lukashenko depicted as a character from Warhammer 40k, surrounded by traditional motifs and symbols, and holding a potato.", + "A woman depicted in digital art with intricate details.", + "Soldier with plasma rifle walking through a portal to another dimension, art by Emmanuel Shiu.", + "Digital art of a cherry tree overlooking a valley with a waterfall at sunset.", + "The image depicts a scene from Max Payne in Tokyo, with a cyberpunk style and realistic composition.", + "Techno artwork by Wes Anderson.", + "The image is a headshot of a happy girl with white hair in a school uniform, illustrated by Ilya Kuvshinov.", + "A concept SUV designed by Dolorean driving through the African savanna.", + "A photo from a popular movie created by AI with intricate details.", + "Harley Quinn, in a tattered orange jumpsuit and garter, portrayed as an escaped prisoner, in a highly detailed, symmetrical digital painting.", + "A poster for a film animation titled \"The Boy Who Drew Triangles,\" featuring artwork by Dustin Nguyen, Akihiko Yoshida, Greg Tocchini, Greg Rutkowski, and Cliff Chiang.", + "A hyper realistic portrait of an intricately detailed African masked cyborg in a broad cyberpunk background with electrical cables.", + "A woman wears a white skull mask resembling a Bastien Lecouffe-Deharme marble sculpture in a gothic style.", + "An aircraft carrier perfectly encased within a glass bottle.", + "Captain Picard wearing a circuitry sombrero and digital sunglasses in a cyberpunk setting.", + "A Ukrainian survivor takes a final selfie as they flee a nuclear blast, with their damaged body bleeding and running in fear.", + "A cyberpunk street scene in Saint-Petersburg.", + "A video game character stands on a platform amidst a giant cityscape, battling dragons amidst dark magic, night fog and clouds while shooting energy beams.", + "A white dragon skeleton with moss, flowers and intricate details, resembling the style of HR Giger.", + "A humanoid metal robot with an Anubis head captured in a full body shot.", + "A film still of Luke Skywalker as a Sith Lord.", + "A Lamborghini in a cyberpunk city, created in hyperrealistic style with attention to detail and stunning artwork by Chris Labrooy.", + "A taco truck crushing teletubby zombies in a 3D render.", + "A surreal image of George Harrison as a Jedi, surrounded by a psychedelic Star Wars landscape and holding a blotter paper of LSD.", + "An alluring goddess floating through a robotic tunnel surrounded by flowing tendrils of energy and spiral mandalas.", + "The image is a 3D art piece.", + "An android girl in warframe armor with blue cyborg eyes stands on a mothership in a scifi, futuristic galaxy, depicted in a highly detailed, cinematic-style art piece.", + "A landscape made of glass mirrors.", + "A humanoid robot with a head resembling singer Iggy Pop, with 80% of its body being robotic and 20% human-like.", + "A close-up portrait of Harry Kane in a cyberpunk style, with soft studio lighting and full frontal view.", + "Concept art of a post-apocalyptic heroine by multiple artists.", + "Portrait of a hybrid Korean woman standing in front of a firetruck, created by artists HR Giger, Greg Rutkowski, Luis Royo, and Wayne Barlowe.", + "One sentence description, The image features the character Malphite from the popular online game League of Legends.", + "The image is a horror movie poster featuring the kuntilanak antapani, created by Hanung Bramantyo, Joko Anwar, and Stephen Spielberg using Unreal Engine, Blender, and Photoshop software.", + "Yoko Ono flying on a broomstick with lightning in the skies.", + "A broken 15th century warship Manowar in a large cave with stalagmites and stalactites.", + "An ultra-realistic cypherpunk stands in a computer lab scene, captured in a film still for a retrofuturistic fashion magazine.", + "A minimalistic heart drawing created using Adobe Illustrator.", + "The image depicts a realistic, detailed technological demon god with rich deep colors and influences from various artists and art styles.", + "A horse and an astronaut appear in the same image.", + "Realistic image of Monstrosity Mega God covered in spiders and centipedes by multiple artists with deep, gothic colors.", + "The image depicts a fiery world and is inspired by artists such as Andy Warhol, Matisse, and David Hockney, and can be found on Artstation.", + "Fashion portrait of a shapeshifting alien with a tentacle-like transformation, featuring soft lighting and a focus on the subject's eyes.", + "The image is a portrait of Imogen Poots in the role of a blonde D&D Paladin against a stylized background.", + "A goddess on a green dragon in a fantasy artwork.", + "The image is of a steampunk robot with a ripped physique, an orange mohawk, and multiple gears and bolts traveling in a vehicle.", + "The image is a pencil and ink sketch depicting a mutant creature with a Lovecraftian atmosphere, drawn in a noir, monochrome fine art style.", + "Full body image of a male angel with white hair, detailed white wings, and medieval knight's armor, surrounded by black smoke, in a terrifying and symmetrical pose.", + "Professional digital art of Godzilla with stunning detail.", + "The image is of a futuristic flying spaghetti monster with eyes on its antennae, portrayed in a hyperrealistic, hyper-detailed style with wide eyes and a psychedelic vibe, created by Roger Dean, Masamune Shirow, and Wayne Barlowe.", + "A portrait of Eva Green in Grand Theft Auto V, featuring fantasy art elements by various artists.", + "A full body portrait of a citizen in the year 2500.", + "An imperial star destroyer is being attacked by X-Wing Starfighters above a city.", + "A surreal image featuring a rainbow and neon glow with a biohazard scientist in a laboratory evacuation scene, showcasing a mix of gothic and neo-gothic styles with rich colors.", + "A hyper-realistic matte landscape of robots going to war against humanity in a grunge aesthetic with dynamic lighting.", + "The image is a cute, symmetrical logo of a prompt randomizer app, created in vector art.", + "A Penrose triangle is depicted with a numerical value indicating the number of sides.", + "A futuristic computer interface displaying holographic, transparent and neon Japanese manga elements from Ghost in the Shell and Akira.", + "A god is seen in a dream-like state at the end of time in a colorful, realistic image.", + "A futuristic cyberpunk Paris street.", + "A man playing 8 musical instruments with his multiple arms.", + "A highly detailed VFX portrait of a pretty boy wearing white glasses and short goatee, created by multiple artists.", + "A concept design of a heavily armored vehicle resembling a cat, with rocket boosters and a rollcage, in a post-apocalyptic style.", + "A horror movie inspired by Junji Ito's artwork, featuring intricate and dark ink drawings.", + "8-bit pixel art depicting data paintings from a global database.", + "Portrait of Hisoka Morow as a medieval jester with a porcelain doll-like appearance and a whimsical, happy expression.", + "A Pokemon that resembles a phone booth is gaining popularity on Artstation and Unreal Engine.", + "An orange cat wearing magical ornate armor with a backdrop of Art Nouveau-inspired design.", + "A symmetrical portrait of Pyramid Head from Silent Hill, featuring the artistic works of several individuals.", + "Portrait of a digital shaman by Peter Holme III.", + "A poster for the animated film \"Tokyo Flood\" featuring artwork by Dustin Nguyen, Akihiko Yoshida, Greg Tocchini, Greg Rutkowski, and Cliff Chiang.", + "The image is a hyper-detailed depiction of a biomechanical evangelion created by artist Greg Hildebrandt.", + "An animation key shot of a traditional city with tiled roofs, intricate architecture, and a blue sky with clouds.", + "A necklace made of many fingers around the neck of a character design.", + "Darth Vader goes ice skating in the new Star Wars movie.", + "A woman portrait wearing a paper bag over her head and holding a sword.", + "A concept art style photograph of a cute little creature with big eyes, in atmospheric low lighting, by Greg Rutkowski.", + "Exterior shot of a dystopian city's red light district depicting scifi futuristic vehicles, robots, neon lights, people walking around, and a police state.", + "The image is a trippy cheeseburger with warm colors, depicted in highly detailed illustration and rendered in octane, created by the award winning studio 4.", + "The image features a creature with a large eye, roots and cactus elements, and a ghostly appearance, surrounded by sushi and micro world details.", + "The image is of an alien, created by artist M\u0153bius, with colorful detailing.", + "A full body portrait of a sorceress with a long glowing hooded cloak, by Maciej Kuciara and Jason Chan.", + "Asian lightning goddess wearing a hoodie and modern clothing, looking at the viewer, in a highly detailed digital painting.", + "The image features a 3D sculpture of Tony Tony Chopper illuminated by soft studio lighting.", + "A cyberpunk woman close-up dancing in the rain in Gunma prefecture at midnight, with a style by Tomino-sama.", + "A portrait of a young, sophisticated female dark knight.", + "Clarence Thomas depicted as the devil.", + "The image is a futuristic portrait of an android wearing boots and a nanotech swimsuit.", + "An ancient Japanese temple located in a forest near a river, with dramatic lighting and a singular building centered in the image.", + "A Lego Shrek figure created with Blender and Octane render, popular on Artstation.", + "A stylized digital art image of a cherry tree overlooking a valley with a waterfall during sunset.", + "An anime girl is shown being consumed by gears and industrial lights, with a beautiful upper body shot and pretty face.", + "A geometric art poster featuring Eva Green as Space Commander Alpha from the Year 4000, in carbon black and antique gold with no text.", + "A stylized image of a 25-year-old blonde actress flying above Los Angeles at night with the LACMA lights visible.", + "An organic aztec cyborg depicted in vivid colors and dark shadows with a highly detailed, hyper-realistic style and dramatic lighting, by Brom and Bastien Lecouffe-Deharme.", + "The image features Kang the Conqueror from Marvel in a realistic style.", + "A stylized, tall yellow-skinned man with a maniacal smile, striped cycling shorts, and a long ostrich feather boa, depicted in a poster-like art style by Joshua Middleton.", + "A beautiful young goddess of nature with plant hair and antlers, set in a mystical forest with mushrooms and opal crystals.", + "A modern version of Odin in casual clothes with a long grey beard, accompanied by his two ravens Huginn and Muninn, in a professional photo session by several artists with a detailed and intricate face.", + "\"A surreal house made entirely of dogs in Pixar's highly detailed concept art.\"", + "A male android baseball player named Mike Trout posing like a statue on display at a museum with shiny skin and a blank stare.", + "The image is a wooden sculpture of a cute robot with cat ears, displayed in a contemporary art gallery.", + "A portrait of a sea woman with fish wings, confident pose, and pixie-like features.", + "A head shot of a pretty girl dressed in a cyberpunk version of Marie Antoinette's rococo style, depicted through detailed digital art and trending on Art Station.", + "A woman from the old west holding an ornately decorated revolver in each hand.", + "A symmetrical portrait of Imogen Poots as a paladin on a stylized background with comic book-style lighting, created by artists Lois Van Baarle, Ilya Kuvshinov, and Rossdraws.", + "A tiny planet image of Rio de Janeiro.", + "A dark surrealism digital art character rendered by Beeple, inspired by Zdzis\u0142aw Beksi\u0144ski and H.R Giger.", + "A stylized 3D CGI fortnite pirate ghost ship with the black Jolly Roger flag by RHADS.", + "The image features a magnified electron microscope view of a sprawling mega city, portrayed in an orange and blue wireframe render.", + "A solar eclipse can be seen over a field with grass and purple flowers, with a single tree amidst a windy landscape.", + "The image features grungy urban vibes with diesel smoke effects, referencing various artists and art platforms.", + "A robotic surgical arm creates organic ceramic forms in a laboratory.", + "A girl with silver hair in a post apocalyptic setting portrayed in a cinematic illustration by Yoji Shinkawa and Krenz Cushart.", + "The image is titled \"Chaos Theory Black Orange\" and features elements of synth-wave and hyperrealism.", + "A dark cavern with a labyrinthine design, inspired by the trending art style on ArtStation.", + "A portrait of a melting skull with intricate abstract details, created by Tooth Wu, Wlop Beeple, and Dan Mumford, using Octane Render.", + "A highly detailed digital portrait of the mythological figure Hephaestus, featuring a fantasy interpretation inspired by D&D and well-known French rugby player Sebastien Chabal, as illustrated by artists Artgerm, Greg Rutkowski, and Magali Villeneuve and gaining popularity on ArtStation.", + "A boy playing with a toy robot in a futuristic city.", + "A green field with flowers and pink and yellow clouds under a bright sun at sunset, illustrated by Peter Chan in a colorful Day of the Tentacle style on Artstation.", + "The image features a realistic, detailed charcoal sketch of a Kurdish samurai in a cinematic concept art style.", + "The image depicts a badly injured Ukrainian taking a selfie, trying to escape from the background of a massive nuclear explosion.", + "Grandfather clock on the moon.", + "A screenshot of the game Yume 2kki.", + "A young adult, clean-shaven plump cleric wearing a silver breastplate with religious engravings and a stressed expression.", + "A detailed zombie in medieval armor, portrayed in a symmetric portrait.", + "A tree with planets or galaxies hanging from it, on top of a calm sea, with an eye in the background that lines up with the tree's iris.", + "A brain activity diagram depicted as an intricate and highly detailed engineering drawing and blueprint, created as digital concept art by artists including Artgerm, Greg Rutkowski, Alphonse Mucha, and Wlop.", + "A digital art piece featuring a 3D representation of abyssal plant and bioluminiscent animals with a hyperbeast, displaying a comics style cover for GTA.", + "A cute astronaut stands in front of a spaceship on Mars.", + "Colorful artistic album cover design by M\u0153bius.", + "A portrait of a masked adolescent girl named Aurora, against a neon background of Santiago, Chile, created using concept art oil on canvas by Yoji Shinkawa, Ryuichi Sakamoto, Esao Andrews, and Yoshitaka Amano.", + "A fox wearing a Mafia Hat, red Tie and white shirt in fantasy concept art.", + "A werewolf in mid transformation in the style of Rick Baker.", + "Spiderman character in the game Sea of Thieves.", + "The image depicts Hwasa as a gothic female satyr, with intricate details and high elegance, in a digital painting with a fantasy concept art style.", + "A cyberpunk city with renaissance architecture in the style of Beksinski.", + "Concept art of the biggest ice cream in the world by multiple artists, featuring intricate details and a realistic design.", + "The Little Prince and the fox in a Tim Burton style artwork.", + "A portrait of a character in a scenic environment.", + ". \n\nPeter Parker as Spiderman in heavy rain, with a wet face - highly detailed concept art.", + "Cyberpunk-style guns creatively pieced together incorporating various design elements.", + "A depiction of Indian gods engaged in a cosmic battle with the assistance of hanuman in a technologically-inspired artwork by Simon Stalenhag.", + "A front-facing, symmetrical painted portrait of Imogen Poots in the role of a D&D Paladin character, with Arcane League of Legends influence and global illumination lighting.", + "A tarrasque is charging towards a galleon amidst a cataclysmic scene in this fantasy artstation piece.", + "A robot scorpion in a field surrounded by nanobots.", + "A digital art portrait of a cat wearing a spacesuit with a surreal background, by Krenz Cushart.", + "A man and woman ride a bicycle in a living room with a dark energy hovering in the center, accompanied by a plant and a background with multiple artistic influences and styles.", + "A red dragon flies over an erupting volcano in a highly detailed fantasy concept art.", + "A masked warrior in diamond armor holds a diamond spear in a digital art full body portrait.", + "A full body shot of an elegant, Scottish woman wearing a dress with a sharp focus on her striking eyes in a realistic and beautifully retouched art piece by Artgerm and Jason Chan.", + "Tracer game character wearing a yellow bikini with blonde hair and black eyes, standing at full height.", + "Adventurers walking along a wall beneath cliffs in a fantasy setting.", + "A person in a suit holding a sword.", + "A portrait of a female hybrid atlantean anubis alien warrior with refined and detailed digital art style.", + "Wooper Pokemon swimming in a galaxy river in a TCG-style digital art.", + "Image featuring a crystal palace with a dream-like guide line composition and a soft Monet-inspired tintal effect, created by multiple artists and trending on Artstation.", + "An anthropomorphic Loxodon wizard teaches his apprentice a new magical spell in front of a magical gateway to another universe.", + "A detailed, nightmarish image of a wrathful scientist-god surrounded by vibrant, Gothic colors and influenced by various artists.", + "Photo of mushrooms growing on an exotic planet in a galaxy far away.", + "A realistic painting of an astronaut suit holding a cyber sniper plasma rifle and glowing laptop detector, with multiple chelate appendages and diamond bubble materials.", + "Portrait of a Water Crustacean Crab Mage in a blue scuba suit under the sea.", + "Concept art of a sci-fi battle mech designed for George Washington by James Clyne.", + "Psytrance artwork featuring octane design.", + "A snow globe containing a universe as a piece of award-winning art.", + "A Japanese girl with small horns, long hair, and an elegant smile is praying on the floor of a destroyed church, portrayed in detailed artwork by Yoji Shinkawa.", + "A young woman with curly red hair, freckles, and blue eyes smiles in a portrait styled in the Artgerm style, likely representing a character from the RPGs Dungeons and Dragons.", + "A colored character study of a female geek rocker, wearing glasses.", + "Jay-Z depicted as a Jedi with a green light saber in a highly detailed, symmetrical, and realistic portrait, portrayed in a digital painting style as concept art with cinematic lighting by artgerm, Greg Rutkowski and Alphonse Mucha on ArtStation.", + "Portrait of a young goth girl in warhammer armor, art by Kuvshinov Ilya, Wayne Barlowe, Gustav Klimt, Artgerm, and Wlop.", + "A radical leftist political poster advocating for land rights in the Amazon, inspired by OSPAAL posters.", + "A half body portrait of an Asian cyberpunk mechanoid fashion idol wearing a neon jellyfish headdress and xenomorphic body suit.", + "A male android, half robot and half humanoid, portraying actor Liam Hemsworth, stands motionless as if on display at a museum.", + "Amphitheater filled with crowd looking at a dumpster on fire in patriotic colors.", + "The image is a digital art headshot of an owlfolk character with high detail and dramatic lighting.", + "The image depicts a person with their hands up, seemingly surrendering, in a highly detailed and immersive video game environment.", + "A digital art piece featuring a picturesque farm from Stardew Valley.", + "A digital painting of a Pok\u00e9mon named Faerow in a concept art style.", + "A highly detailed metal cover art featuring a digital painting by Alex Ross, Greg Rutkowski, and Alphonse Mucha on Artstation.", + "A Landrover drives through a rain-soaked forest in a highly-detailed digital artwork by Greg Rutkowski and Artgerm.", + "Portrait of a digital shaman by a League of Legends concept artist.", + "of mystical creatures and symbols, and various written incantations on yellowed pages, lies open on a wooden table in an old library. \n\n\"A book of dark magic with illustrations of mystical creatures and symbols lies open on a wooden table in an old library.\"", + "A human-like robot is being built in a production room, as seen in a film still from the series West World.", + "A portrait of a character in a scenic environment by James Cameron.", + "The image is an intricate artwork in a dark art style depicting the moment of a person's transition to a borderline state, influenced by the styles of various artists such as Hieronymus Bosch, Beeple, Tooth Wu, Dan Mumford, Wlop, Rossdraws, James Jean, and Yoshitaka Amano.", + "A cyberpunk underworld metropolis with dark lighting designed by Carlo Scarpa in the style of Thomas Cole.", + "Liz Truss made of nuclear warheads.", + "The image depicts a dart field from the Legend of Dragoon game.", + "A realistic painting of a pentaradial astronaut eva suit in a jumping float pose inside a futuristic space station, covered in diamond fractal lace iridescent bubble skin and camera appendage stalks with a clear helmet.", + "A 3D terminator is shopping for an avocado at a Whole Foods store.", + "A horror movie poster with a funhouse featured.", + "Imogen Poots portrayed as a D&D Paladin in a fantasy concept art by Tomer Hanuka.", + "The image depicts the superhero Ultraman Brahma, rendered in a realistic art style by Jacqueline E with coloring by Tafy Laplanche and a background by Bo Feng Lin.", + "The image depicts Pyke from League of Legends in a concept art style, created by Grafit Studio and currently trending on ArtStation.", + "A detailed image of an electric arachnid demiurge god with rich colors and elements of neo-gothic and Gothic art styles.", + "A mechanical cosmic god entity designed by HR Giger, Beksinski, and Stephan Martiniere.", + "The image depicts a female scientist holding a small spinning black hole in a laboratory, illustrated in detailed digital art style.", + "A deer painting a book.", + "Cameras record a couple embracing in a sci-fi scene reimagined by Industrial Light and Magic.", + "A music cover for a dark metal album featuring eerie and sinister artwork with no words or letters.", + "A mushroom growing out of a metal sphere in a rainforest with sunset lighting and intricate detail.", + "A portrait painting of a 1940s pinup transformed into an Overwatch character with bold, organic shapes and hard edges.", + "A surrealistic digital artwork of a misty forest filled with glowing monsters.", + "The image depicts a portrait of a man's head and face wearing an opalescent vest with eyes looking towards a snack while surrounded by a dungeon interior.", + "The image depicts a mount called the High Priest's Lightsworn Seeker from the online game World of Warcraft, with artwork credits from several popular artists.", + "Trondheim city rendered in Skyrim style.", + "A knight princess on a horse strikes a combat pose in an intricate sci-fi fantasy setting.", + "A Soviet cosmonaut riding a bike in a cave surrounded by art posters, suggesting a fatal singularity in space travel.", + "A colorful suburban neighborhood on a post-apocalyptic planet with detailed illustrations of creatures from Jim Henson's Creature Shop.", + "A half-masked laboratory technician man with cybernetic enhancements in a dystopian scifi outfit and mechanical parts.", + "A decrepit robot girl in Chernobyl.", + "Portrait of a digital shaman.", + "A digital art image featuring 3D abyssal bioluminiscent animals and a hyperbeast by Brock Hofer on ArtStation with a GTA cover comics style by James Gurney and beeple, utilizing global illumination Volume lighting and pastel tone mapping.", + "Pokemon characters appearing in Animal Crossing New Horizons.", + "A realistic digital art depicting a dwarven automobile.", + "The image features a goddess in a dark color scheme with high detail and smooth rendering.", + "An orc wearing orc attire and jewelry by various artists.", + "A monster coming out of a cellphone screen.", + "\"A mouse with dinosaur spines and spikes.\"", + "A path winding through a forest depicted in digital art.", + "A close-up image of a woman wearing a samurai mask, fire dancing in a dirty cyberpunk alley with smoke and mist.", + "Google Pixel 6 render.", + "The image depicts a forest with realistic gnomes and mushrooms on the ground, with warm lighting shining through the trees.", + "Black and gold egg-shaped mech suit in a winter village during a blizzard, with volumetric lighting and lightning in the background at night time.", + "A cyberpunk-style photo of Belgrade at night featuring neon lights and brutalist architecture in vivid colors.", + "Metallic brain in 3D render.", + "This is a highly detailed digital painting of Zeus, portrayed by Stephen Lang, in a fantasy portrait with lightning, trending on ArtStation and created as concept art by Artgerm, Greg Rutkowski, and Magali Villeneuve for D&D.", + "The image depicts something related to GME in black and white.", + "Super Mario 64 level in Unreal Engine.", + "The image is a highly detailed portrait of an oak in GTA V, created using Unreal Engine and featuring fantasy artwork by various artists.", + "League of Legends art of a purple nether portal in a library with wooden interior.", + "A Louis Vuitton designed costume for furry catgirls featured on a high fashion magazine cover with a symmetrical and detailed portrait.", + "A rainbow-colored monster made of gems and crystals.", + "A goblin is killed by a sigil on the ground inside a haunted house with an inverted cross on the wall.", + "The image features Brock Lesnar depicting Iron Man in a dynamic action pose, inspired by several artists.", + "The image features artwork by Lucas Hikaru.", + "A man on a boat crossing a hellish body of water with soul-like creatures swimming around.", + "A male android, singer Grant Knoche, poses on stage with a blank stare, half-robot and half-humanoid, with shiny skin.", + "An armored trooper carries a plasma rifle while standing in front of a walking battle tank in a fantasy art depiction.", + "A cyberpunk-style Batman in a dark city, depicted in an extremely detailed piece of artwork by Chris Labrooy.", + "The image depicts a steampunk-style Mandalorian character battling an alien mutant megafauna creature in richly-colored, highly-detailed artwork.", + "The image is a digital painting of Itachi with a highly detailed moon in the background, created as concept art and illustrated by Greg Rutkowski and Alphonse Mucha.", + "Image of a girl in a Soviet-style room with a monster hiding under the bed, featuring gothic and deep color elements by various artists.", + "The interior of the SCP agency contains numerous rows of large alien artifacts in cylindrical containers made of gold and quartz, surrounded by overgrown alien flora, and inspired by the game Control.", + "A black marker pen drawing of a man inside a squid.", + "A sand monster amidst a tornado in the desert.", + "The image portrays a surreal scene of a hybrid creature consisting of a great leviathan, cybernetic turtle, and cephalopod terrapin in a magical universe surrounded by a cozy hot springs, cave, forest, and lush plants amidst a luminous stellar sky.", + "A militarized police vehicle with mounted weapons rides through an Egyptian town as troops search the area, with futuristic pyramids in the background.", + "Assassin's Creed logo for upcoming game.", + "A digital painting of a cute baby demon in hell with intricate details and smooth focus, created for Unreal Engine 5, by artgerm, greg rutkowski, and alphonse mucha.", + "A person wearing a Spider-Man suit in the game Half-Life Alyx.", + "A red panda skydiving with a parachute in a digital illustration.", + "Personification of death in style of Junji Ito.", + "A detailed image of a prismatic rainbow in a lab with a biohazard symbol, featuring multiple artist styles and rich colors.", + "Portrait of Leonardo da Vinci as a steampunk cyborg, clockwork automaton with medieval technology.", + "A creepy figure knocking on a door.", + "A psychedelic shaman with celtic tattoos in an ancient temple.", + "A metal bat bird with a red heart head, golden body, joints, and wings as if it is taking off.", + "A penguin in an authoritarian pose is depicted in a Shepard Fairey poster.", + "A Land Rover driving through a rainy swamp in a digital painting by artists Greg Rutkowski and Artgerm on ArtStation.", + "A masked laboratory technician man with cybernetic enhancements, a dystopian scifi outfit, and mechanical features.", + "Image of female water creature wearing a rock mask, designed by various artists and trending on Artstation.", + "A realistic concept art of an acid shotgun.", + "A digital art scene depicting an action-packed attack on a Mahindra Thar by tribe members in a sunny Kerala village.", + "A 3D render of a smiling black curvaceous model with cream on her face, set against a blue background with top down POV.", + "A horse is flying over an astronaut on the ground.", + "A samurai in space.", + "Abstract ying-yang representation from Evangelion.", + "A digital portrait of a young, handsome Captain Kirk with intricate details and dramatic lighting.", + "Cameras film a frightened couple embracing in a sci-fi scene reimagined by Industrial Light and Magic.", + "A film still of Master Splinter during his Jedi training.", + "The image is a surreal and biomechanical CD cover artwork featuring a grid heightmap pattern and a mix of abstract mechanisms influenced by various artists' styles.", + "A praying mantis nun in a grassy field during sunset.", + "A male android portraying football player Christian McCaffrey, with robotic features and shiny skin, posing motionlessly at the gym.", + "Digital art featuring small white butterflies amidst a starry darkness.", + "Flying over Death Star surface, abstract pattern.", + "A blonde-haired, black-eyed Tracer game character stands tall wearing a yellow bikini.", + "A giant guardian wearing road sign armor, a popular character design on Artstation.", + "A firebird reading a book in a library.", + "he surface of an alien planet with twisted trees.", + "A Japanese soldier swims on an Indonesian lava mountain.", + "A portrait depicting Forest Gog, a muscular and masculine female with a clear face, in an elegant and intricate fantasy style.", + "A surreal image featuring gold, ruby, glass, water, and vaporwave elements.", + "A demonic Asian warrior with a flaming katana, depicted in concept art.", + "The Little Mermaid wearing a crown swimming in a vast ocean in a highly detailed and colorful digital art.", + "Image of apocalyptic landscape from the video game Hearthstone.", + "The image is a fight promotional poster featuring two boxers facing off.", + "A space man sat on a beach chair on the moon, pixel art.", + "A warrior standing on a psychedelic landscape.", + "An anthropomorphic knight toilet bowl of large size made with advanced technology.", + "The image depicts the two complementary forces of life.", + "Psytrance artwork by Charlie Bowater.", + "A woman in a grey hoodie confronts a red dragon in a cave in a digital art fantasy painting.", + "Nutiliti in space.", + "A drawing of a female warrior in full body pose.", + "This is a 3D isometric illustration with studio lighting.", + "Man standing in front of a giant mirror in a propaganda poster-style illustration by Dean Ellis and John Watkiss.", + "A conceptual art of a man in meditative pose with a large head, made up of colorful, twisting geometric patterns.", + "A sorceress wearing metallic robes conjuring the universe with galaxies and nebulas in the background.", + "Portrait of a robotic monster.", + "Call of Duty Advanced Warfare battle with epic fight and vast sense of scale.", + "An image depicting a proof for the meaning of life.", + "A beached sea dragon on a shore by Jaime Jones.", + "A 3D digital painting of surreal luminous awareness in the eyes of others by multiple artists.", + "Portrait of a cyberpunk gang.", + "A colorful monster character design inspired by various artists' styles.", + "Illustration of a hand holding a colorful shaurma at night with a psychedelic background.", + "A future city surrounded by water and mountains.", + "A futuristic calculator with muted colors by Jean-Baptiste Monge.", + "Poster for fantasy Japanese film \"Genshin Impact\", featuring eye-catching artwork.", + "The image depicts a young Gillian Anderson as a retro SCI-FI heroine from 1985.", + "A league of legends concept art of a young female scientist holding a small black hole in a laboratory.", + "A black Chinese lion dance with intricate scrollwork and a heraldic design by Peter Mohrbacher and Kentaro Miura.", + "The image depicts Max Payne standing in a blood-covered Tokyo city street with realistic buildings, cars, and people.", + "Tactical team in a fictional depiction of hell.", + "A 1940s city street at night with yellow lit windows and a man standing under a street light in front of a steam punk mystical red fog.", + "The image depicts a concept art of the biggest taco shop in the world with incredible detail and realism.", + "A hyperrealistic digital portrait of a zombie bigfoot with intricate detailing and cinematic lighting by Derek Riggs.", + "A 3D painting of Medusa, rendered with octane, featuring dramatic lighting and created by an award-winning artist.", + "A realistic, detailed and colorful image of a technological demon god with gothic and neo-gothic influences, created by multiple artists including Lisa Frank, Ayami Kojima, Amano, Karol Bak, Greg Hildebrandt and Mark Brooks, featuring elements from Beksinski paintings, and part of Adrian Ghenie and Gerhard Richter's work, as well as art by Takato Yamamoto.", + "A photo of Gigachad, created with CG technology and recognized with an award.", + "The image shows a android girl with a beautiful face wearing warframe armor in a futuristic scifi setting.", + "The image features an overweight cyberpunk corporate woman on a comics page with realistic shading and fine details, created by a group of artists including Greg Rutkowski, Diego Gisbert Llorens, Magali Villeneuve, Artgerm, Jeremy Lipkin, and Rob Rey.", + "A loot crate from the video game Apex Legends.", + "The image is titled \"Burning Memory\" and features dark, dramatic, and highly detailed artwork by multiple artists, depicting a scene from the video game Bloodborne.", + "Mr. Bean featured on a WWII propaganda poster holding a gun.", + "Portrait of a digital shaman.", + "US Air Force battling against the Rebellion in a radioactive environment with detailed digital artwork.", + "A depiction of Hermione Granger from the Harry Potter series as a zombie.", + "The image is titled \"holy cow that's outta this world\" by Jean-Baptiste Monge and features muted colors.", + "Image depicting a person's face composed entirely of fruits and vegetables.", + "The image depicts a blueprint of an opanim.", + "Close up of an eye with the Earth inside the pupil, inspired by Wes Anderson's art.", + "A concept art digital CG painting of a place in Bali, trending on ArtStation and created using Unreal Engine.", + "Digital art portraying the concept of humans being nature's clay, with detailed and surreal visuals featuring work by multiple artists.", + "A futuristic city with a lake, a reflection of utopia, and jungle scenery, featuring drones and androids.", + "Character Pawn from Rimworld videogame depicted in Cedric Peyravernay's sprite art with dramatic lighting.", + "A bozo clown fights dinosaurs on a tree in a cyberpunk world.", + "A tall female warrior with massive plate armor, no helmet, big white glowing wings, and long red hair, illustrated by multiple artists in insanely detailed and intricate concept art.", + "A League of Legends concept art of a girl holding a gun, with an athletic feminine body, tanned skin, and long hair, wearing an aviator sunglasses and a green tank top, in a fantasy style illustration with an empty background.", + "The image depicts a schoolgirl from the NCsoft game lineage 2, with a cinematic atmosphere and sharp focus.", + "The image depicts a snow queen in sci-fi armor, standing in a bas-relief sculpture scene and painted in great detail.", + "A vivid and intricate depiction of a terrifying god-like creature with rich, bold colors and influences from various artists.", + "An alternate dimension with psychedelic terrain, cotton candy plant life, Japanese-style temples and glowing lanterns in the Taiwanese-style architecture.", + "A portrait of a cat in a spacesuit, with a surreal backdrop, by Krenz Cushart, popular on art and design platforms.", + "A stylized image of fish resembling mythical fantasy creatures in the style of Moebius.", + "A studio photo portrait of Asuka Langley posing as a Slavic person, taken by Ross Tran and WLOP.", + "A transparent demon in clothing chasing someone around a dinner table in a dark and terrifying setting.", + "A girl wearing a Tokyo Ghoul mask gazes out at the city of Santiago at dawn.", + "A portrait of a crazy pirate, created by multiple artists.", + "A fantasy portrait of a female cat person with a blurred background, portrayed in a cinematic dystopian brutalist atmosphere.", + "A cyberpunk city street with flying vehicles and towering corporate buildings at dusk.", + "A fierce barbarian woman stands confidently in mountainous terrain, clad in iron chainmail with copper hair and tanned skin.", + "A 3d close portrait of a symmetrical rusty razor wire artwork by Stanley Donwood.", + "A hyper-realistic landscape from a Neil Blomkamp film featuring a crashed spaceship, detailed grass, and a photorealistic sky.", + "The image features a portrait of a Soviet cosmonaut with influences from artists Beksinski and Stephan Martiniere.", + "A digital concept art of a machinist in uniform with a sci-fi repair tool and glowing lights.", + "A Halloween-themed TV show room with a big screen on the wall, designed by Disney Concept Artists with blunt borders and following the rule of thirds.", + "A realistic depiction of a death angel by artists Tafy Laplanche and Masashi Kishimoto.", + "The image is a digital painting of an orc with intricate and highly detailed eyes, in sharp focus, by Wayne Reynolds.", + "A wizard with the face of death is conjuring the universe, surrounded by galaxies and nebulas.", + "A man on a boat crossing a hellish sea surrounded by monstrous creatures in \"Sea of Souls\" by Gainax Co.", + "A green spaceship flies over a city depicted in a futuristic artwork.", + "human body render.", + "The image is a photo of a xenomorph, incorporating elements from various artists.", + "A young actress resembling Sofia Vassilieva appears to be flying over Los Angeles at night like a superhero.", + "A symmetrical portrait of a wild-looking man by multiple artists including Yoichi Hatakenaka, Masamune Shirow, Josan Gonzales, and others.", + "A werewolf howling on a cliff at night.", + "Portrait of a digital shaman.", + "A beachfront house with a synthwave aesthetic.", + "A yellow-armored warforged character brandishing a paladin sword and shield, portrayed in dynamic lighting.", + "The image depicts Ciri from League of Legends, with fluorescent skin, and features a hyper-detailed, smooth render created using a cinematic lighting and rendered in Unreal Engine 5 and Octane.", + "A kinetic sculpture of a colorful bird with a long tail surrounded by swirling lines and shapes.", + "Unsettling AI-generated art.", + "A moth tied up with the ability to change night into day.", + "Two iridescent squid shaped buildings in a mysterious alien desert landscape.", + "A closeup portrait of a medieval goblin wearing cat helmets.", + "An image depicting the event of 9/11 as portrayed in the film \"American Psycho\" (1999).", + "A samurai wearing a red and white kimono holds a big katana while standing among wisteria trees, with elegant and symmetrical facial features.", + "Portrait of goth girl created by Kuvshinov Ilya.", + "Barbarian woman riding a red dragon, holding a broadsword, in gold armour.", + "A landscape with a building in the style of Simon Stalenhag.", + "The image is a techno artwork created by Steve Argyle.", + "Digital art of futuristic cryogenic chambers with sleeping patients inside.", + "An image depicting the popular video game character Crash Bandicoot in a movie adaptation.", + "A man with a ripped physique wears goggles and a collar with a leash in a Steampunk-inspired portrait as a detective, with a Chihuahua in a costume standing next to him.", + "The image is a digital illustration of a crumbled paper bag that has been transformed into a cathedral, inspired by artists like Edward Hopper, James Gilleard, and Zdzislaw Beksinski.", + "A giant colloidal nonila by artist Greg Grusowitzy stands in the distance against the backdrop of a Japanese city, with its content blacked out.", + "A digital painting of a realistic cyborg with precise human anatomy, striking pose, and intricate details inside a futuristic setting.", + "The image depicts a 3D cyberplant surrounded by bioluminescent animals, with a GTA cover comics style and global illumination.", + "The image depicts a symmetrical mecha version of the asura from Chinese mythology, with a ghostly appearance and intricate details, created as a digital painting by artists including Artgerm, Greg Rutkowski, and Alphonse Mucha.", + "A landscape featuring a lone magic the gathering-style building.", + "Digital artwork of a female character wearing a Soviet era uniform and cat ears, with a Karl Marx t-shirt, in Krenz Cushart's style.", + "Image of ghosts circling an old forgotten church by Noah Bradley, Darek Zabrocki, James Paick, and Natasha Tan.", + "A space battle between a galactic fleet and aliens illustrated in the style of Syd Mead and Chris Foss.", + "Image of the end of the world scene from Warcraft.", + "Half-length head portrait of the goddess of autumn with wheat ears on her head, depicted as dreamy and beautiful, by wlop.", + "Ruins of Moscow post-apocalypse.", + "An image of a fantastical city floating in the clouds.", + "A skeleton dressed as a samurai.", + "A symmetrical portrait of an elven woman with split-dye hair by Charlie Bowater for Dungeons and Dragons art.", + "Undead rat halloween prop with a skull face, glowing red eyes, black tattered robes, and two blue flames resembling a grim reaper.", + "Portrait of Guts from Berserk submerged in red water.", + "The image is a concept art of a highly-detailed zombie head and shoulders from Resident Evil, designed to be horror-themed and terrifying.", + "The image depicts a beautiful goddess of spring wearing a wreath and flowy green skirt, created by artist wlop.", + "A busy fantasy street depicting a single street within an old city lined with quirky shops, old buildings, cobblestones, and street life.", + "A half-robot, half-humanoid male android modeled after soccer player Antoine Griezmann, with shiny skin, posing motionlessly at a museum exhibit.", + "An old pirate ship floating in space with volumetric light and intricate details in a digital painting by Ruan Jia, Randy Vargas, and Greg Rutkowski.", + "A full-body shot of a beautiful woman wearing a dress, with sharp focus on her eyes, depicted in an elegant, intricate style by artgerm, jason chan, and mark hill.", + "A macro close-up portrait of a goddess phoenix with a mask made of ram skull, accompanied by intricate artwork of bioluminescent betta fish, jellyfish, and super intricate ornaments.", + "A woman in a sci-fi suit is battling undead creatures in a spaceship using Unreal Engine 5.", + "Maya Ali portrayed as a D&D sorcerer in an Art Nouveau style portrait.", + "Jason Statham depicted as the Hulk in a still from an action film, with dramatic lighting and visual effects.", + "Concept art of a Russian female netrunner with unique hair designs for a D&D video game by Marc Brunet and Artgerm.", + "Exterior image of a small magic items and curios shop in a busy fantasy city.", + "A stylized portrait of a cute cartoonish coconut with electronic components, created by artist Noah Bradley and trending on Artstation and Deviantart.", + "Nathan Jones portrayed as a menacing Batman in full armor and cape, illuminated from below.", + "\"A cyberpunk model in futuristic clothing walks on a catwalk.\"", + "A screenshot of the game Slay the Spire, depicting its deckbuilding gameplay.", + "Batman wearing metal gear armor holding gun with a cinematic, dramatic background.", + "The image is a concept art of sci-fi props such as panels, with hard surface shapes that are detailed and explore form, featuring pastel colors, created by artists Simon Stalenhag and Syd Mead, and displayed on ArtStation.", + "A metallic brain in 3D render.", + "An old sorcerer observing the results of his study in nuclear magic.", + "Kratos playing a flaming guitar with red stripe eye and spartan rage.", + "An angel watches over a child in a detailed digital art piece found on DeviantArt and ArtStation.", + "The image depicts the architectural section of a museum of emotions designed by Zaha Hadid and Toyo Ito with soft, happy shapes.", + "Portrait of a middle aged elf in a blue cloak with clock iconography, brown hair, and a raised hand.", + "A large organic crystal with vibrant blue azurite and green malachite in a swordlike shape against a grey background.", + "A vividly realistic depiction of a snowy Swedish lake at night with hyper-detailed, cinematic-level artistry showcased on ArtStation.", + "Cassandra Cain as a Tekken character portrayed in a realistic style with a character select portrait in cg animation.", + "A concept art of a steam goth viking priestess with steampunk machine parts and overgrown vegetation in the background.", + "An artwork with a technological theme by Ian McQue.", + "Nine human faces from Neanderthal to Modern Human and beyond depict the future of human appearance.", + "Goro Fujita's illustration depicts a big city on the left and a forest on the right, separated by a highway filled with cars leaving the city.", + "A city made of books in concept art for Disco Elysium by Aleksander Rostov.", + "3D human face made of holographic chrome material symbolizing cyborg and AI.", + "A vampire sits at a banquet table in a dungeon setting surrounded by plates of rats and spiders and red candles.", + "A symmetrical portrait of Imogen Poots as a warrior paladin wearing full metal armor with blonde hair and a distant, elegant expression.", + "The image depicts a fat orc pirate chef in a detailed digital painting, created as concept art and illustrated by multiple artists featured on Artstation.", + "League of legends champion splashart of hammer hitting an apple.", + "A highly detailed portrait of an orc character in GTA V, featuring fantasy art by multiple artists and rendered in Unreal Engine with global illumination and radiant light.", + "A portrait of a pirate in a cyberpunk cafe setting.", + "Concept art of scientists by Jama Jurabaev.", + "Medieval tavern and castle viewed in isometric perspective on a grey background.", + "Portrait of Martyn Ford as a dark evil Batman, wearing cape and armor.", + "A cyborg cat crashes into a gothic world planet with a fantastic landscape, depicted in bright colors and ultra-detailed hyperrealism.", + "The image displays a close-up texture of a circuit processor in a seamless pattern using Substance material.", + "A female marten wearing jewelry and a cute hairstyle, depicted in a digital art piece by Stanley Artgerm Lau, WLOP, and Rossdraws.", + "A highly detailed digital painting of a portrait of Mandalorian armor-wearing character holding a crossbow with a steampunk and Lovecraftian vibe, created by artgerm, Greg Rutkowski, and Magali Villeneuve.", + "A vivid and detailed image of a nightmare scientist god with rich and deep colors, reminiscent of gothic art and featuring elements from various artists.", + "A full body concept art of Yoruichi Shihouin with an intricate and epic composition by Artgerm, Greg Rutkowski and Alphonse Mucha.", + "An Instagram photo of a schoolgirl from Deus Ex Human Revolution with a cinematic and dramatic atmosphere, featuring sharp focus and volumetric lighting.", + "A full body shot of a gorgeous Scottish female in an elegant dress, with sharp focus on her perfect eyes, in an intricately detailed art by Artgerm and Jason Chan.", + "A 3D illustrated chubby room with studio lighting.", + "A side profile portrait of Maya Ali as a mage with intricate details, neon and sweat drops in a highly detailed digital painting.", + "An object is visible through mist and appears ominous.", + "A black bronze sculpture in the center of an ancient Egyptian temple, worshipped by red-robed acolytes.", + "A colorful and epic digital painting of a space marine portrait in a futuristic battlefield, created by multiple artists and centered in the frame.", + "The image features a scene where gravity is depicted in a similar style to that of Hiroaki Tsutsumi's artwork.", + "Image of Hecate, the Greek goddess, with intricate symbolic details, illustrated by Takehiko Inoue and Ilya Kuvshinov.", + "An elven queen wearing transparent silk in a fantasy character portrait.", + "Illustration of futuristic Mayan warrior robots with complex armor and surgical arms, featuring an oracle witch and cybernetic symbiosis, in a serpent-themed laboratory.", + "Humanoid descendants taking a selfie on a sci-fi planet.", + "An ethereal magical elven city.", + "The image is of an astronaut in a diamond fractal lace suit with camera appendages, jumps in a blobby holographic bubble with insectoid compound eye camera lenses.", + "Concept art of a highly detailed Resident Evil zombie head and shoulders.", + "Concept art painting of a Fire Nation colony on the coast of the Earth Kingdom.", + "A three-dimensional object created by combining different parts from various sources, known as a kitbash.", + "A sci-fi concept art piece with detailed exploration of hard surface shapes and forms, using pastel colors and featuring props and panels.", + "A male android, half robot and half humanoid, resembling soccer player Antoine Griezmann, stands still inside a museum exhibit with shiny skin and a blank stare.", + "Schrodinger's cat in a box, depicted as both dead and alive, with a black cat and glowing eyes, against an abstract background of waves and particles.", + "Poster of Star Wars Return of the Jedi featuring artwork by Dice Tsutsumi, Makoto Shinkai, Studio Ghibli.", + "Pixar concept artists created a techno artwork.", + "Saul Goodman character depicted in multiple video games.", + "A detailed, colorful image of a wrathful god with a nightmare-inducing diamond-shaped head, created in a variety of artistic styles.", + "Sci-fi illustration of a pink woman by Wayne Barlowe.", + "A cheerful girly monster collecting mushrooms in the forest illustrated by Goro Fujita.", + "A minimalist vector art brandmark for a research lab that studies attention and a wandering mind.", + "A Louis Vuitton bag designed for catgirls with a symmetrical and detailed design showcased through professional lighting.", + "A half-robot, half-humanoid male android of soccer player Antoine Griezmann stands in a museum on display, posing like a statue with a blank stare.", + "A futuristic modern house on a floating rock island surrounded by waterfalls, moons, and stars on an alien planet.", + "The image is a simple close-up portrait of a punk goddess with a mohawk and tiger skull, wearing a classical Japanese kimono and an intricately detailed crow kitsune mask, surrounded by various creatures like a betta fish, jellyfish, and phoenix, depicted through a stunning mix of bio luminescent, plasma, ice, water, and wind effects in a piece of artwork by Tooth Wu, Wlop, Beeple, and Greg Rutkowski.", + "Male video game character head designs featuring unique silhouettes and casual streetwear by Marc Brunet and Artgerm.", + "A hyperrealistic mixed media image of a proportionally sized human hand undergoing particle teleportation, with perfect symmetry and dim volumetric lighting.", + "Winter-themed vector art panel for CNC cutting machines with a unique winter design.", + "A bright cube hovers over a portrait of a beautiful dark woman with ice blue eyes, created by multiple artists including Artgerm and Greg Rutkowski.", + "A young engineer man with cybernetic enhancements wearing a suit and bowtie, a detailed mask, and a gloomy expression, with half of his face mechanical.", + "The image features Breton monks resembling Rasputin from The Lorax, with cinematic lighting and a shallow depth of field.", + "A hyper-realistic full body portrait of an ornate, high-tech android astronaut with laced white plastic, colorful LED lights, and cables twirling around them, rendered in a Secessionist style.", + "A male android football player, half-robot and half-humanoid, posed like a statue with a blank stare on display at a museum.", + "A digital painting portrait of a young black man cyborg with holographic computer displays.", + "Detailed texture with Greeble elements.", + "A 3D-rendered, anatomical flaming heart is held by skeletal metal hands in this expressive image.", + "A dreamlike scene with a vaporwave aesthetic.", + "A crowd of adults and children enjoy a holiday in a conceptual wooden architecture utopia in Russia on a clear day.", + "A digital painting of a fantasy character wearing Mandalorian armor and holding a crossbow, with steampunk and Lovecraftian elements.", + "A cyberpunk cyborg girl stands in a futuristic city street at sunset, styled by Tomino-Sama.", + "The image depicts stormtroopers in a hyper realistic style, with intricate and hyper detailed design, characterized by ambient and volumetric lighting, reminiscent of Star Wars concept art by George Lucas and Ralph McQuarrie, with a style similar to GTA V.", + "A giant magical gyroscope in an ethereal cave.", + "A full-body digital render of a synthwave cowgirl wearing metallic armor, featuring characters from Persona 3 and The Witcher 3, with a lunar cyberpunk city as a backdrop.", + "A humanoid alien wearing glossy black armor and a respirator, with yellow eyes, appears in a dramatic scene with harsh lighting, in a still from the movie Real Steel (2011).", + "A walking battle tank armed with cannons is parked in front of a command station in a fantasy-themed image.", + "Hyper illustration of album artwork with dark and desaturated colours, featuring a futuristic year 4000 design and a small contrasting feature.", + "Screenshot of Gigachad in Dead By Daylight.", + "Realistic painting of an astronaut in a lobster suit with diamond 3d fractal lace, camera appendage stalks, and a clear brain case helmet, floating in a holographic bubble inside a futuristic space station.", + "A digital painting of a black leather lab coat with a tarot card, created as concept art for Blizzard Entertainment by ILM and posted on Artstation.", + "A wide shot of Ben Affleck in a 15th century knight suit with a medieval background and cinematic feel.", + "The image is iridescent, psychedelic, and holographic.", + "\"Steve Buscemi portrays the Joker.\"", + "A digital portrait of Colonel Sanders wearing a military uniform and an eyepatch, created by Moebius, Tyler Edlin, and HR Giger.", + "Samuel Hyde portrayed as a Medieval King by artists Gerald Brom, Mark Arian, Stanley Artgerm Lau, and WLOP, with intricate details and a realistic style.", + "A pleiadian woman wielding a plasma gun in a dark bodysuit stands in a barren field, her long silver hair flowing behind her.", + "Yellow bikini-clad Tracer game character with blonde hair and black eyes standing at full height.", + "XQC featured as a GTA character on a loading screen.", + "A young engineer man with cybernetic enhancements wearing a suit and bowtie, portrayed in a cinematic, dystopian, steampunk style.", + "Cutaway diagram of a hi-tech RNA bioweapon monster with intricate details and a black background.", + "Illustration of military troops with a retro feel, inspired by watercolor paintings and trending on ArtStation.", + "A female cyborg in a rubber and gas mask is being attacked by alien brainsuckers in ancient Egypt in a big budget sci-fi movie.", + "The image depicts a young, plump, clean-shaven cleric wearing a silver breastplate with religious engravings and a stressed expression.", + "Simon Pegg wearing nanotech cuirass as Xochipilli, standing inside Machu Picchu Citadel, surrounded by shamanic god figures depicted in regal and immense portrait style.", + "A stylized 3D model of Jon Taffer in a fighting game.", + "The image is a concept art character design sheet featuring anime-style women in tek gear, French maid, pinup, cyberpunk, sci-fi, and fantasy styles by various artists.", + "Portrait of an African person wearing a futuristic gadget with decorative organic designs.", + "A young adult, chubby cleric wearing a silver and emerald breastplate with religious engravings is portrayed in a dramatic and finely detailed portrait geared towards Dungeons and Dragons enthusiasts.", + "A snowy lake in Sweden captured in a vibrant, cinematic style with intense detail and raytracing technology showcased on Artstation.", + "Illustration of a car driving on a highway with mountains in the distance, created by Goro Fujita.", + "The image depicts alluring cyborgs in a cyberpunk science fiction Tokyo red-light district, rendered in ultra-realistic style with highly intricate and detailed digital painting by Artgerm, Greg Rutkowski, and Alphonse Mucha.", + "Plato wearing VR-glasses in Plato's cave." + ], + "paintings": [ + "Portrait of Archduke Franz Ferdinand by Charlotte Grimm, depicting his detailed face.", + "A head-on centered symmetrical portrait of Elisha Cuthbert as a holy paladin, wearing steel armour and with blonde hair, depicted in a highly detailed digital painting with dramatic lighting, in the style of Artgerm and Anna Podedworna.", + "Museum painting of a mouse stealing cheese artwork.", + "Male character illustration by Gaston Bussiere.", + "A painting of a Persian cat dressed as a Renaissance king, standing on a skyscraper overlooking a city.", + "A symmetrical oil painting of two waterfalls in a dense forest.", + "A Japanese poster depicts a dragon soaring over a stormy sea amidst thunder with cliffs and clouds in the backdrop.", + "A painting of the tarot card \"The Sun\" by Michelangelo Merisi da Caravaggio.", + "An installation art of Venus emerging from the sea on a giant clam shell with flowing hair and a bright future ahead.", + "Color illustration of Kate Bush with 3D shadowing.", + "A detailed ink illustration of a hedgehog.", + "A painting depicting the 6th mass extinction with elements of surrealism, created with a collaboration of Caravaggio, Matisse, and Rothko, using Unreal Engine.", + "An oil painting close-up portrait of a young black woman wearing a crown of wildflowers, surrounded by hazy golden light.", + "Image depicting the opposing forces of life in artistic style by Greg Rutkowski.", + "A close-up hyperrealistic oil painting of a nurse fashion model with red lipstick, ginger hair, freckles, in a style mixing classicism and 80s sci-fi, set in complete darkness.", + "A lion roaring on a meadow, depicted in an oil painting by William Adolphe Bouguereau.", + "Salvador Dali's \"A Dream Within a Dream.\"", + "The image depicts an angel protecting a child, created by Boris Groh, Makoto Shinkai, Thomas Kinkade, or James Gilleard, and is available on DeviantArt and ArtStation.", + "A pre-Raphaelite mixed media portrait painting with ornamental art nouveau fashion embroidery, featuring a piercing gaze and dayglo colors.", + "A surreal portrait of a woman with a giant carnation face in a flower field at sunset with colorful clouds and a large sky, created by artist Simon St\u00e5lenhag.", + "A painting depicting a wuxia character standing on a roof under a moonlit night.", + "The image is \"Akira\" by Sandro Botticelli.", + "A ceramic glass mosaic depicts Mona Lisa's smile.", + "A close-up portrait of a woman surrounded by autumnal elements, painted with intricate details and ornamental features.", + "A painted portrait of Zeus, a handsome and muscular Greek god, depicted with intricate detail and a fantastic flair.", + "Art print featuring a famous boxing match knock out by Neil Leifer.", + "A monkey in a blue top hat painted in oil by Vincent van Gogh in the 1800s.", + "A painting of the Kool-Aid Man.", + "\"An intricate, biomechanical blue planet by Bruce Pennington.\"", + "Man on boat crossing a body of water with creatures in the water, \"Sea of Souls\" artwork by Dan Witz.", + "Three small dinosaurs entering a grocery store painted by Thomas Kinkade.", + "A portrait of Enrique Tabara painted with digital art.", + "Mickey Mouse painting by Frank Frazetta.", + "Portrait of a character in a scenic environment by Etel Adnan.", + "Scary African voodoo paintings by Jean-Michel Basquiat.", + "A portrait of Yojimbo the Desert Samurai.", + "Maya Ali as a D&D Mage wearing wizard robes in the style of various artists, depicted in a head-on symmetrical painted portrait.", + "An abstract artwork with retro-future motifs.", + "Victorian genre painting portrait of Royal Dano, an old west character in fantasy costume, against a red background.", + "Abstract watercolor painting of midsummer in Scandinavia by artist Anders Zorn.", + "Psytrance artwork by Lee Madgwick.", + "A painting of a Chinese temple with a lone monk walking on winding steps and three red crowned cranes in the water.", + "A skull-shaped island with rocks and vegetation, painted by Ghibli with strong light and shadow.", + "Colorful illustration of a forest tunnel illuminated by sunlight and filled with wildflowers.", + "An abstract collage featuring grey and lilac colors with a touch of sparkle.", + "Oil painting of a man under a tree in the rain, by Greg Rutkowski.", + "A painting of foxes and wolves running through the forest by Jan Brueghel the Elder.", + "Maya Ali as a D&D wizard, with black hair and wearing medieval robes, in a head-on symmetrical centered painted portrait.", + "A spring landscape painting featuring a treeless mountain village with melting lake ice, winding stone steps, and fog.", + "Image of Albert Einstein created by Park Jun Seong.", + "The image features a surreal fox and skulls in highly detailed, liquid oilpaint style.", + "Ryu from Street Fighter in a Van Gogh-inspired oil painting style.", + "An oil painting of a treasure lost in a rainforest.", + "A female goth cosplayer with black hair, fishnet clothing, tattoos, red lips and light gray eyes, with a beautifully detailed face, featured in a rich, colorful painting by multiple artists.", + "The image depicts intricate gold and blue swirls, spirals, scrolls, and cloud-like formations created with fluid ink, set against a smooth and blooming background.", + "The image is a box art from the 1989 video game \"Gauntlet Legends,\" featuring a samurai in power armor designed by Keith Parkinson, with artwork by Artgerm, depicted in oil on canvas.", + "FBI raiding Mar-a-Lago, chasing a big pig through a swamp depicted in a highly-detailed oil painting.", + "\"The End of the World\" artwork by John Howe.", + "Beige canvas tents set up in an arctic landscape with no vegetation, surrounded by rolling hills - reminiscent of a romanticist painting.", + "The image is a digital painting of a 1920s style flapper girl in a speakeasy, featuring intricate details and a focus on elegance.", + "A dragon with the head of Macho Man Randy Savage by Jeff Easley.", + "A portrait of the San Antonio Spurs' coach, Greg Popovich, dressed in a military-style jacket, depicted in oil on canvas by artist William Sidney Mount.", + "Illustration of watermelons, passion fruit, yellow lemons, mint leaves, and ice cubes in a colorful and happy retro style.", + "A gouache illustration by Krenz Cushart of a girl in a school uniform standing on the edge of a tall building.", + "A seamless pattern with photorealistic Roy Lichtenstein patterns, woodgrain, and gold accents.", + "A portrait of a character with black hair and blue eyes by artist Miho Hirano.", + "A digital painting of a blue-skinned wizard with intricate and elegant details, created by multiple artists and posted on Artstation.", + "Portrait of Princess of Eternal Fire and Death, with dynamic lighting and centered composition.", + "Psytrance artwork by Naoto Hattori.", + "A painting of a Bladerunner interior room in Africa with detailed artwork.", + "An abstract representation of the yin Yang concept by Albert Bierstadt.", + "A wren bird navigates a collage of cybernetic and urban motifs designed by Dave McKean, Ivan Shishkin, and Yoshitaka Amano.", + "A man is depicted screaming with expressions of hate, sadness, fear, and anxiety in a painting by Agnolo Bronzino.", + "An otherworldly world depicted with vivid colors by Fuco Ueda.", + "A portrait of Mario and Luigi from Mario Bros with a detailed face and a city background, painted by Bouguereau.", + "Portrait of a sci-fi outlaw by Gerald Brom, Kim Kyoung Hwan, and Norman Rockwell.", + "Wrecked ships on beach with palm trees, white stone ruins, sharp rocks and bushes, cloudy weather, isometric view, highly detailed digital painting.", + "A gouache illustration of a school girl on the edge of a tall building, with a delicate face, in a Morandi color scheme, by Krenz Cushart.", + "Al Pacino portrayed as Gandalf in a highly detailed symmetrical pencil sketch illustration by Jim Burns.", + "An oil painting of a duck on the prowl by Pavlo Makov.", + "An angel guards a man praying in a gothic church.", + "A fantasy illustration of a cat standing on a rooftop.", + "Salvador Dali's artwork depicts the future of humanity.", + "Techno artwork by Ivan Bilibin.", + "The image is a hyperrealistic painting of animals with human-like facial expressions, featuring uncanny valley elements, deepdream maximalism, and a psychedelic, neutral color palette.", + "A detailed and symmetrical taurus artwork in mystic style by Brian Froud.", + "A matte painting of spaceship earth at Epcot at sunset, surrounded by torches designed by Frank Lloyd Wright and Zaha Hadid.", + "Symmetrical Gemini artwork with a mystic twist by Brian Froud.", + "Psytrance artwork by Zack Snyder.", + "Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski.", + "Harry Potter book cover with Van Gogh-inspired design and visible book title.", + "An Albert Bierstadt landscape painting featuring mountains, lakes, and a McDonald's restaurant.", + "Portrait of a creature with bat ears, a wolf snout, eagle features, wearing a poncho and helmet, created in 1923.", + "Realism tattoo design sketch of a pirate ship.", + "The image is a digital painting of a deus ex machine with intricate detailing and elegant design, featuring dramatic lighting and rendered with Octane Render.", + "A tonalist painting of crataegus fruit goblins with expressive big eyes and visible brushstrokes.", + "A baroque pattern that seamlessly tiles.", + "A digital painting of a fantasy kitchen environment with elements of cartoons, comics, and manga.", + "Red magician holding a dead rabbit in a surrealistic-themed artwork.", + "A pineapple bean bag designed by Vladimir Kush and Ilya Kuvshinov.", + "A portrait painting of Priscilla from Claymore with intricate details and an eerie, realistic style, created by Artgerm, Greg Rutkowski, and Alphonse Mucha.", + "An atom bomb explosion in Heaven, depicted in the oil on canvas masterpiece by Thomas Cole, currently trending on ArtStation.", + "Katie McGrath is depicted as the princess of mars in a symmetrical art nouveau portrait by Gil Elvgren.", + "A pirate with a beer is illustrated in detailed digital painting.", + "A dilapidated shack hidden in a misty, overgrown Witchwood forest inhabited by evil fairies, depicted in a detailed, ink illustration by Greg Rutkowski.", + "A painting of a crystal mine featuring mechanical mining equipment.", + "A painting of an enhanced person by several artists, trending on ArtStation.", + "Artistic depiction of the human world, highly detailed and hypnotic.", + "Fractal totem artwork by Zdzislaw Beksinski, James Gilleard, and Edward Hopper is highly detailed and trending on ArtStation.", + "An extreme close-up art installation by Kentaro Miura depicting a fiery soul with the potential for destruction.", + "A photorealistic portrait of a colorful fantasy landscape with a hyper-realistic river, mountains, trees, and bright blue sky.", + "A building in a landscape by Ivan Aivazovsky.", + "An art piece created by Joe Fenton.", + "A digital painting by Karol Bak depicting a goddess inspired by tarot cards and Dark Souls with attention to detail and a smooth style.", + "A beached sea dragon on a shore by Jaime Jones, with dramatic lighting.", + "A full body portrait of a Kurdish bride in a beautiful dress with detailed features and ornate detail, created by Monet and Mucha and currently trending on Artstation.", + "Victorian era people floating towards each other with outstretched arms in a valley connected by human spirituality.", + "A portrait of Boromir with a sword.", + "A painting of a man with an elephant face holding a harp and wearing bard costume.", + "A wizard operates an elaborate pneumatic machine of unknown function in a fantasy painting.", + "The image is a hyperrealist painting of a man being consumed by squid tentacles.", + "A Landrover crosses a forest path in the rain in a highly-detailed digital painting by artists Greg Rutkowski and Artgerm.", + "The image depicts Saint Megawati with the PDI-P logo in the background, colored by Bo Feng Lin.", + "A glowing dry tree stands alone under a starry sky in a detailed fantasy artwork by Greg Rutsowski.", + "Two contrasting forces depicted in a painting by Qian Xuan.", + "An oil painting of flora addict illustrated through detailed smoke.", + "A man in deep meditation.", + "A business logo for an AI startup designed by Paul Rand, Saul Bass, and Rob Janoff featured on ArtStation.", + "Man wearing a black suit standing next to a Leyendecker.", + "A teddy bear inspired by Vincent van Gogh.", + "\"Abstract oil painting portraying momentum by John Berkey and Gabriel Dawe.\"", + "Oil painting portrait of a young woman dancing through a field of flowers at sunset with mountains in the background.", + "Portrait of Herzl as a florist, painted by Van Gogh.", + "Painting of a smiling young woman with red hair in front of a fabric background by Craig Mullins.", + "Friends splashing in the ocean in a fantasy digital painting by Makoto Shinkai and James Gurney.", + "The image is an artwork created by Paulo Moreira.", + "Persephone with pomegranates.", + "Portrait of young Asian woman with tanned skin and blonde curly hair, dressed in a light dress, featured in a fantasy artwork for Dungeons and Dragons or Pathfinder with a highly detailed and realistic face painted by Pino Daeni at Guildhouse.", + "Psytrance artwork by Paul Lehr.", + "A painting of a lobster wearing an astronaut suit, covered in diamond 3D fractal lace and equipped with camera appendage stalks, floating inside a futuristic space station bubble, caught in a jumping pose.", + "A painting titled \"Positivism\" by Tara McPherson featuring a smiling figure in the darkness.", + "An artwork titled \"The End of the World\" by Luis Royo.", + "An oil painting of an android woman covered by plants and crystals in a mystical forest with a symmetrical face and steampunk elements.", + "A cave with blue mushrooms, a stream, dark rocks, shadows, butterflies, and mist, painted by Jessiah Thomason.", + "A painting of a horse by Xu Beihong in a style influenced by Warhol and pop art.", + "A chalice sits on an altar with chains and stars in the background, surrounded by a stoic and modern atmosphere, with a light dust giving a magnificent and theatrical effect, painted by Jean Honore Fragonard and Greg Rutkowski.", + "A Lamborghini Countach in the Arizona desert, depicted in an oil painting.", + "Abstract representation of ying Yang concept from Berserk.", + "A painting of Bastet, the Egyptian cat goddess, by William Adolphe Bouguereau.", + "Abstract art of Peter Falk in the style of Salvador Dali.", + "An art print by Barry Moser.", + "A portrait painting of a muscular Indian woman with a lower-back tattoo, wearing a sari, and covered in blood.", + "A colorful abstract wallpaper with a portrait motif.", + "Portrait of a conquistador with a pet tiger in a jungle.", + "The image shows a pair of hands arranged in a pose reminiscent of the Mona Lisa.", + "A pen and ink drawing of a steam punk dragon with clean lines and crisp detail by Olivia Kemp and Julia Hill.", + "Totoro depicted in the cubist style by Pablo Picasso.", + "A portrait of an elven queen character in a fantasy setting with intricate details by Frank Frazetta, who won an ArtStation contest.", + "A majestic phoenix flies in the sky in a fantasy matte painting.", + "Hyper-realistic painting of a girl in Chris Mars style.", + "Portrait of a Chinese woman enjoying wine and mooncake, surrounded by moonfestival decorations.", + "The image depicts a cotton candy milkshake island by Tara McPherson, featuring vivid colors and intricate details.", + "A techno artwork by Gottfried Helnwein.", + "An ocelot sneaking through a bog in an oil painting.", + "Swedish lake at night with heavy snowfall depicted in hyper-realistic and detailed art.", + "A painting by Rembrandt depicts a video game tournament.", + "A romantic portrait with juicy brush strokes by two artists, featuring dark shades of black and red, with an expressive touch.", + "A hyperrealistic still life portrait of a mind contemplating itself with imaginary thought patterns and nature of mind.", + "portrait. \n\nThe description without modifiers, Telly Savalas depicted with devil's horns.", + "Oil painting of Independence Day Celebration in Indonesia with Indonesian Flag in the background.", + "An artwork featuring a person standing in a field with a large tree and black hexagonal clouds in the blue sky, created by artist Simon Stalenhag.", + "Mona Lisa sees Marcel Duchamp's readymade.", + "A monochrome baroque art deco woodcut of a flying crow.", + "A portrait of a red and blue-haired elven woman, painted in detailed matte oil on canvas, against an empty background by artists Charlie Bowater, Lise Deharme, and Wlop, popular in the Dungeons and Dragons art community and associated with the Critical Role franchise.", + "An oil painting of an android woman covered in plants and crystals in a mystical forest, with a symmetrical face and steampunk inspired details.", + "A digital painting of a beautiful woman with book pages for a mouth in the style of Alex Broeckel.", + "A depiction of a lunar goddess in the art styles of Michael Whelan and Gustave Dore.", + "Illustration of a cottage designed by Salvador Dali in a blooming forest during spring with a nearby stream, created by Goro Fujita.", + "A still life portrait of a mind exploding inside a temple, incorporating sacred geometry and refracting light, by Sandro Botticelli.", + "An oil painting of a child king ruling a kingdom made entirely of cheese in a surreal and comic book style.", + "Painting of Skaven by Greg Rutowski with high level of detail.", + "A dancer with swirling hair in the wind.", + "\"Regency era painting of Ringo Starr\"", + "The painting is a high detail full body fantasy portrait of Emily Blunt as a stoic barbarian woman by Justin Sweet, with a scenic background and a sombre mood.", + "Moomins enjoying a magical and fluffy wonderland depicted in a warm and cozy illustration with volumetric light.", + "A close-up oil painting of a littlest pet shop fuzzy skunk in a field.", + "An oil painting of the UNSC Pillar of Autumn.", + "A landscape with a building in matte painting style.", + "The painting depicts the fifth circle of hell with a vivid, angry red sky and demons flying overhead in a dramatic landscape.", + "Renaissance noblewoman with blue eyes and pale skin in a classical portrait pose in the art style of Ib Iwerks.", + "A highly-detailed painting of an astronaut in a jumping float pose inside a futuristic space station with an iridescent bubble skin and clear brain case.", + "Psytrance artwork by Ernst Haeckel.", + "A matte oil on canvas portrait of an elven woman with red and blue split-dyed hair by Charlie Bowater, Lise Deharme, or Wlop.", + "Passionate and delightful summer day on another planet depicted in a Salvador Dali painting.", + "A hyperrealistic acrylic portrait of a cyberpunk-necromancer with intricate details, believable eyes, front-facing and symmetrical.", + "A painting by Tara McPherson featuring a smiling figure in darkness.", + "A digital painting of a beautiful creature in intricate detail, centered in a low angle shot.", + "A grand scale painting of Batman and Joker by Howard Chaykin and Alex Grey.", + "A drawing of lips drinking beer with red hearts and a dark ambiance, expressing sadness, in a sots art style by Godfrey Blow featured on Deviantart.", + "Portrait artwork of Eren Jaeger from Attack on Titan by artist Cushart Krenz.", + "A portrait art of a necromancer, referencing DND and War craft.", + "A creepy man dressed as a chicken is frightening children in a painting by Bussiere, Mullins, and Leyendecker.", + "A digital painting of the legendary water city of Atlantis, featuring a Greek temple, statues, and a red flag flying from a tower, in a stylized and hyperrealistic style.", + "A painting of a female with a unique mix of features by Egon Schiele.", + "Intricately detailed image with an air of improbability.", + "A landscape featuring a Donato Giancola-style building.", + "An art piece by Wojciech Siudmak depicting an individual gazing at the vast cosmos.", + "A vintage horror illustration of a mad scientist's poisonous laboratory with spooky lighting and detailed imagery.", + "Portrait of Lucifer by various artists, trending on ArtStation.", + "English woman playing the lute, depicted in an art piece by William-Adolphe Bouguereau.", + "A painting titled \"The End of the World\" by Emilia Wilk.", + "Realistic detailed painting of a horror machine consuming a city, featuring rich deep colors and created by Byun Shi Ji and Jiang Feng.", + "A deserted garden painting featuring roses, trees, and a waterfall, created by Simon St\u00e5lenhag on ArtStation.", + "A painting of a sand statue melting onto a beach by Kinkade and Grimshaw.", + "A man holding a skull on stage in a theatrical performance depicted in an oil painting.", + "A hyperrealist painting of a horse with a leg tentacle transplant and a mane of tentacles, reminiscent of Stubbs' Whistlejacket.", + "A hyperrealist portrait of a girl emperor in long starry robes.", + "A landscape painting of a China mountain village with a turbulent blood lake.", + "Portrait of a girl in white ancient clothing with flowers by Alphonse Mucha.", + "Pencil sketch of Danny DeVito by Milt Kahl.", + "Abstract portrait created through drip painting by Dan Hillier.", + "A hyper-realistic lamb by Alex Grey.", + "A close up oil painting of a pensive young black woman with long hair in a white dress, standing amidst colorful nebula stardust galaxies and white roses.", + "An oil painting of an ancient antler deity holding a yellow rat pig and laughing in a bright room.", + "A close-up oil painting of a fashion model dressed in a black robe with a surreal waterfall instead of a head, blending elements of classicism and 80s sci-fi hyperrealism.", + "A night scene of a lavender field with a town and church in the background, reminiscent of Vincent van Gogh's style.", + "A heron silhouetted against a beautiful sunrise, created by Greg Rutkowski.", + "A watercolor painting of a psychedelic angel with intricate details, surrealism, vivid colors, and ornate patterns by Abhishek Singh.", + "A painting of a firefall cascading over a high cliff.", + "An androgynous God of the Stars smiles down upon the viewer from a dramatic angle in a professional illustration.", + "An under-painting of a spontaneous, unfinished romantic portrait with beautiful brush strokes in black and red, by Richard Schmid and Sargent, with an expressionist style.", + "The image depicts a post-apocalyptic landscape dubbed \"The End of the World\" by artist John Howe.", + "A painting by Picasso displayed next to a Logitech MX Master 3.", + "A head-on portrait of a Korean woman painted as a D&D wizard, wearing medieval robes, with intricate details and an elegant style resembling Artgerm, Anna Podedworna and Alex Ross.", + "A profile portrait in Peruvian realist style, featuring cerulean blue, cadmium red, and zinc white colors with intense key lighting and expressive shadows.", + "The image depicts a man looking at his reflection in a limited neutral palette with a painterly design.", + "A blond person wearing a suit, medical gloves and a skull face mask is shown in a frontal portrait by Kim Kyoung Hwan.", + "A closeup image of a tin toy retro rocket spaceship, captured in symmetrical fashion by several renowned artists.", + "A serene landscape depicting a garden of Eden with lake reflections, fruit trees, and animals, captured in vivid and psychedelic style.", + "An English woman plays the lute, depicted in great detail by William-Adolphe Bouguereau during golden hour.", + "A portrait of an elemental entity with strong rim lighting and intricate details, painted digitally by Alvaro Castagnet, Peter Mohrbacher, and Dan Mumford.", + "An image portraying Barry Lyndon.", + "A Psytrance artwork featuring Kazuya from Tekken, created by Sam Spratt.", + "A frog baby sits in a searose cup in a humorous illustration by Esao Andrews and M. W. Kaluta.", + "An oil painting by Frank Frazetta depicting a humanoid spawn of dragons with scales, big eyes, and a menacing appearance.", + "A beautiful mixed media portrait painting with piercing gaze, dayglo pink and blue, and intricate ornamentation by Kimura, Kerstens, and Rockwell.", + "Medieval painting of a rat king.", + "A digital painting of Teemo from League of Legends, wearing cyborg parts and a new skin, in a fantasy MMORPG style.", + "A vintage-style vector illustration depicting space travel.", + "A painting of a Japanese repair shop with celestial ornaments and HR Giger architecture.", + "Bust portrait of a gothic goddess crying at dawn by various artists.", + "The image depicts a painting titled \"Buy Life Before Rain\" by Bouguerau with a dynamic and realistic style, featuring lightning and available on Artstation.", + "Portrait of a woman without eyebrows in dramatic lighting.", + "A cinematic fashion portrait of a Hindu goddess standing in a beautiful garden.", + "A landscape featuring mountains, a valley, sunset light, wildlife and a gorilla, reminiscent of Bob Ross's artwork.", + "An oil painting of a tall waterfall on a rocky cliff within a forest of trees with glowing blue leaves.", + "A painting of a cityscape with a red cloud shining light over a sea and bridge, by Greg Rutkowski and Thomas Kinkade.", + "An intricate artwork featuring a psychic physicist depicted as a mathematical genius with gothic, rich colors reminiscent of beksinski and Takato Yamamoto's style.", + "A tonalist painting featuring crataegus fruit goblins with expressive eyes and visible brushstrokes.", + "A realistic painting of a bifurcated astronaut suit with a clear brain case and camera appendage stalks, covered in diamond and iridescent fractal bubble materials, in a jumping float pose.", + "The image features an ancient Chinese landscape with a mountain, waterfalls, willow trees, and arch bridges set against a blue background.", + "Realistic image of the retro Zombies Ate My Neighbors SNES game with rich, deep colors reminiscent of a Beksinski painting.", + "An intricate and captivating artwork of spirals created by Tom Haugomat, Serena Malyon, Maxim Shirkov, Alex Pogrebniak, and Robin Gundersen.", + "The image is an artwork created by Ida Rentoul Outhwaite.", + "A mechanical wren robot bird is perched on the shoulder of a Latin monk woman in an oil on canvas painting.", + "An oil painting of Audrey Hepburn portraying Cersei Lannister from Game of Thrones.", + "A portrait painting of a black South Indian woman wearing a sari with intricate details and an eerie sense of horror, created in ultra-realistic style by artgerm, Greg Rutkowski, and Alphonse Mucha.", + "Soft airbrush illustration of a water drop on a white background by Pater Sato.", + "The image features a collection of artwork by various artists including Les Edwards, Zdzislaw Beksinski, and John Harris.", + "An ominous oil painting of a pale alien cultist with large fish eyes, high forehead, and smooth, waxy skin.", + "Oil painting of fashion model with face tattoo in classicism style mixed with 80s Japanese sci-fi art.", + "Soft airbrush illustration of a female eye with eyeliner and long lashes on a white background, inspired by 80's airbrush art by Pater Sato.", + "a blonde girl with a red dress and red shoes, holding a small blue umbrella and standing in the rain.", + "A head-on painted portrait of Maya Ali as a D&D wizard in medieval robes, with intricate detailing and an elegant, fantasy style.", + "A full-length portrait of a woman in a navy blue gown with gold embroidery, standing in a park setting.", + "The image depicts a blond female space explorer with tribal tattoos, slicked-back hair, narrow eyes, wearing an orange safety vest and atompunk jumpsuit in a realistic, detailed oil painting.", + "A painting depicting a wuxia scene in the winter, lit by neon lights.", + "Psytrance artwork by Gottfried Helnwein.", + "An oil painting of a vintage rally car, including a yellow Porsche with smoke and dirt from drifting.", + "A detailed front view portrait of a woman with ornate growing around, including flowers and a skull.", + "The image is of a stylized Overwatch building in watercolor gouache, featuring interesting shapes and forms, located in a desolate landscape with a food stall in an Asian-style alleyway.", + "The image features artistic styles of Jendral Sudirman, Utamaro Kitagawa, and Raden Saleh.", + "Brian Froud's symmetrical and detailed artwork features artistic depictions of Pisces zodiac signs in a mystic style.", + "A piece of art displaying sorrow in San Francisco, created by Cinta Arribas with color by Sonia Alins and line by Benjamin Flouw.", + "A pin-up girl drawing with psychedelic elements inspired by Picasso's style.", + "A woman in black robes stands in a shopping mall, depicted in art by Edward Hopper, Zdislav Beksinski, and Wayne Barlowe on ArtStation.", + "The image depicts a biopunk-style underwater city inhabited by octopus and squid-like creatures, with an anglerfish and U-boat mechanism, reminiscent of HP Lovecraft's Great Old Ones, in a dark and detailed Dutch-style oil painting by Rembrandt van Rijn.", + "A simple black and white ink drawing of the word \"sunyata\" written in a flowing script, surrounded by small dots and swirls, on a white background.", + "A closeup portrait of a gray owl with spreaded wings attacking in cinematic lighting, digital painting by Greg Rutkowski used as album cover art on Artstation.", + "An oil painting of an anthropomorphic fox overlooking a village in the moor.", + "A raccoon in formal attire, carrying a bag and cane, depicted in a Rembrandt-style oil painting.", + "A serene meadow with a tree, river, bridge, and mountains in the background under a slightly overcast sunrise sky.", + "Psychedelic art featuring abstract shapes and bright colors, created by John Berkey.", + "A digital painting of a hairless, inside-out cat with intricate details and a horror theme.", + "Close-up hyperrealistic oil painting portrait of a nun fashion model looking up against a black background, with classicism and 80s sci-fi Japanese book art influences.", + "A zentangle pizza illustration with colorful ink.", + "A digital painting of a favela city shrouded in mystical colors with radiant god rays and vibrant hues in the style of multiple artists.", + "Portrait of a Victorian politician in a suit, sitting down, painted by Thomas Lawrence with a highly detailed face.", + "Abstract yin yang representation by Ivan Bilibin.", + "A portrait of a grey alien aristocrat in the style of a classic Dutch painting, depicted in oil on canvas by Rembrandt van Rijn.", + "Realistic portrait painting of an astronaut suit with a 3D fractal lace design and iridescent bubble texture.", + "A digital painting of a knight sitting by a campfire in a dark forest.", + "A painting of Iron Hans by the Brothers Grimm.", + "Caravaggio's painting depicts a large fish.", + "Portrait of Buddha on stylized background.", + "A painting of a vampire woman wearing a red silk dress and crown jewels by Michelangelo Merisi da Caravaggio.", + "A portrait painting of Dwayne \"The Rock\" Johnson as a velociraptor dinosaur in Miami.", + "A painting by Gediminas Pranckevicius of a person standing in the rain amidst trees.", + "A colorful, mysterious painting of a fungus-filled castle room viewed from the floor with warm sunset light.", + "A sunflower standing alone in a field on a rainy day in a symbolic style, by rj monet.", + "A painting of a beautiful princess holding a large cup of coffee, with dark hair, blue eyes, wearing a black dress, dark eye shadow, and red lips.", + "A portrait painting of Yondu Udonta in an asymmetrical profile shot, incorporating bold shapes and hard edges with a stylized street art aesthetic.", + "\"The Joy\" by Shaun Tan features a colorful and detailed artistic record jacket design.", + "A Yohitaka Amano painting of a young lion beastman with white mane, wearing complex fantasy clothing and huge paws, depicted in a medieval market on a windy day.", + "Sunset over a mountain with misty waterfall cascades and god rays, creating an otherworldly, mattepainting-like atmosphere.", + "Digital painting of a furry deer character on FurAffinity.", + "A Filipino woman gazes out at a stunning sunset in a highly-detailed digital painting.", + "Medium shot of a character in Boris Vallejo's style.", + "A head-on centered symmetrical painted portrait of an Indian woman as a D&D Mage with black hair and intricate fantasy details.", + "Close-up of a face in agony, pulled by hands, within a frame on a tiled wall.", + "A digital painting of a young goddess with flower and fruit adornments evoking symbolic metaphors.", + "Digital painting of a sinister middle-aged witch with red hair, fair skin, and a symmetrical face.", + "The image is a skull created by artist Karl Gerstner.", + "Heroine portrait with a mysterious and old-fashioned style.", + "A portrait of Rosalia painted by various artists.", + "A vintage portrait of a stunningly beautiful Asian tribal female.", + "A tonalist painting of a bipedal pony creature soldier.", + "A gouache illustration of a girl in school uniform standing on a tall building roof.", + "A detailed painting of a futuristic spaceship with ornamental features.", + "An organic cyborg with two alien skeletons and an alien centipede on a planet face, in the style of Zdzislaw Beksinski surrealism, is depicted in a poster art album cover with contrasting two skulls and deep-looking eyeballs against a simple background.", + "An astronaut floats amidst planets against a cosmic backdrop in a highly detailed, refreshing digital painting by James Jean.", + "Spontaneous and unfinished romantic portrait painting featuring beautiful brush strokes in a realist style by Richard Schmid and Sargent, trending on CGsociety in red.", + "Lindsey Pelas depicted as a spy in a detailed digital painting by Artgerm, Greg Rutkowski, and Alphonse Mucha.", + "A painting depicting a scenic view of Guangzhou, China as a tourist destination by David Inshaw.", + "A surrealist painting styled like Rene Magritte depicting a young couple in Art Deco fashion escaping their past.", + "Japanese Samurai character portrait art by Donato Giancola and Craig Mullins.", + "A painted portrait of Zeus, god of thunder, with white hair and a muscular, hairy upper body, wearing a flowy robe, created by Gaston Bussiere and Alphonse Mucha.", + "A painting featuring a dog by artist Koyamori.", + "Artwork depicting a futuristic car, created by Ed Roth.", + "A painting depicting a Buddhist deity with a comet in the night sky and a deer, made using various materials including lapis lazuli, malachite, cinnabar, and gold, by Piero della Francesca, Balthus, and Agnes Pelton.", + "Medieval landscape of a Kievan Rus army with Tzars and flags.", + "Draw a shape with two vertical and one horizontal line of equal thickness in the middle.", + "Ivan Bilibin's artwork depicts the two opposing forces that constitute the essence of all life.", + "A masterpiece of art depicting the weather cycle by Gerald Brom and Zdzis\u0142aw Beksin\u0301ski.", + "\"Goddess of Flowers\" by Alphonse Mucha, depicting a woman surrounded by floral patterns.", + "A fantasy matte painting of Link from Zelda in a forest during autumn with sun rays and dust.", + "Middle aged man in a suit and waistcoat holding a cane, depicted in highly detailed digital art painting by Greg Rutkowski.", + "A landscape featuring a unique digital painting-style building.", + "The Mona Lisa depicted as Marilyn Monroe.", + "The image depicts Zeus, god of thunder, with long white hair wielding lightning in a fantasy-themed digital painting.", + "A digital painting of a warrior with a crocodile face in a heroic pose, viewed from the side, by Ross Tran.", + "A surreal portrait of a young Spanish man wearing sock and titled \"Super Spy Captain\" with deep purple hair and green eyes on an orange background.", + "A high detail portrait of a royal mansion by Michelangelo Merisi da Caravaggio.", + "A sunset panorama showing a graveyard of souls, with backlight and painted by Frazetta.", + "A dilapidated cardboard fort hidden in a misty and overgrown witchwood swamp, depicted in a detailed and intricate ink illustration.", + "An abstract painting with various shapes and colors.", + "A cosmonaut otter poses for a portrait painted in intricate detail by Rembrandt.", + "A surreal cat with a smile and intricate details.", + "A child's colorful drawing of a smiling cat outside a house.", + "Louis Wain's depiction of the two complementary forces of life.", + "A painting of a girl standing on a mountain looking out at an approaching storm over the ocean, with wind blowing and ocean mist, surrounded by lightning.", + "Portrait of Herzl as a florist.", + "A detailed painting of Atlantis by multiple artists, featuring intricate detailing and vibrant colors.", + "English woman playing the lute, depicted by William-Adolphe Bouguereau.", + "The image depicts two lovers parting ways inside a broken wooden house.", + "A man on a boat crosses a body of water in Hell with creatures in the water, depicted in the painting \"Sea of Souls\" by James Gurney.", + "Animals fashioned from gems, colorful and shapely, depicted in natural lighting, with a slight effervescence, artist credited to Alex Ross.", + "A painting by Jacques Louis David featuring a whiskey bottle, cigar, video camera, and glass orb.", + "\"A racoon wearing a suit smoking a cigar in the style of James Gurney.\"", + "Portrait of Neo from Matrix, featuring several artists.", + "African man painting by Jean-Michel Basquiat.", + "Man in boat crossing a body of water with creatures in a surreal underworld, artwork by Alex Grey.", + "Head-on centered symmetrical painted portrait of Kareena Kapoor Khan as a D&D Mage wearing intricate fantasy robes.", + "A graphic poster featuring an avocado and raspberry observing a burning world, inspired by old botanical illustrations, Matisse, Caravaggio, Basquiat, and Japanese art.", + "A colorful painting of a female cyberpunk sorceress in the clouds.", + "A magazine collage portrait of a depressed girl made by an art student.", + "A painting of Kermit the Frog as a Catholic pope by Michelangelo Merisi da Caravaggio.", + "Portrait of Michael Jordan in intricate digital painting with smooth details by Artgerm, Greg Rutkowski, Alphonse Mucha, and William-Adolphe Bouguereau.", + "A man is depicted in full body in an artwork by Leyendecker.", + "A cobblestone street with a tree over the sea at sunset, illuminated by sun rays, in a colorful illustration by Peter Chan on Artstation.", + "A painting depicting a black woman taking a selfie in Wal-Mart while being followed by a man.", + "An oil painting by Frank Frazetta depicts a humanoid creature with scales, big menacing eyes, and dragon-like features bred with humans in a fantasy setting.", + "A bird god swings a gold metal weapon while in combat ghostly fighting pose in a traditional Chinese myth.", + "A hyperrealistic image of diamond and gems created by Alex Grey.", + "An artwork depicting a Dark Souls boss by Paul Gustave Dore and Ivan Aivazovsky.", + "The image depicts Moses debating with God, done in a realistic style by Tafy Laplanche and colored by Hiroshi Nagai.", + "A detailed male angel with white hair and wings is flying in black smoke.", + "A Victorian woman quietly sings by a lake at night surrounded by fireflies, moon, and stars, painted by Vincent van Gogh and Jacques-Louis David.", + "A painting by Rembrandt depicts a video game tournament.", + "A highly ornate, detailed ink illustration of a naraka Buddhist demon Korean female with symmetrical long head and intricate details.", + "The image is an artwork created by Frederic Edwin Church.", + "A portrait of Licorice Vampire by Alessandro Allori.", + "A painting of enlightenment by Salvador Dali.", + "A lantern floats in a dark river at night in an artwork by Thomas Kinkade.", + "A detailed watercolor illustration of rabbits.", + "A 3D rendering of angels and demons fighting at the entrance to a fractal palace in Bouguereau's painting.", + "A cute young demon princess in a forest, depicted in digital painting.", + "A pen illustration of a man wrestling his phone by Gustave Dor\u00e9 with crosshatching and pops of colorful Ben Day dots.", + "A head-on centered symmetrical painted portrait of Katrina Kaif as a D&D Mage wearing a hood and intricate fantasy robes.", + "A digital painting by James Jean depicting a goddess in a strong pose surrounded by planets in a hyper-realistic style.", + "A birch forest in autumn with falling leaves that resemble flying butterflies and dancing elves.", + "A large wave is about to crash down on three small boats filled with terrified people.", + "A painting of illumination by Salvador Dali.", + "Black metal cover art with eerie and sinister feel, featuring no text or letters.", + "A Maori art inspired magic turtle.", + "Portrait of a dark mystical woman, art by Artgerm, Greg Rutkowski, Alphonse Mucha, and William-Adolphe Bouguereau.", + "Colorful psychedelic paint twists.", + "A colorful vintage horror illustration of a mad scientist's poisonous laboratory, featuring machines, dials, knobs, levers, and scientists in spooky lighting.", + "A Taiwanese princess wearing a sundress and jewelry.", + "A fine art drawing of a machine that offers a painful trip into a shattered dimension and the psyche of a squid.", + "The image is a digital painting of a woman in a white and gold gown with wing motifs, following the Art Nouveau style.", + "A retro-style vector art illustration of space travel, resembling a vintage poster.", + "A painting of a landscape by Thomas Kinkade.", + "A painting of day lilies with photorealistic details.", + "Oil painting of a werewolf mid-transformation in Sanjulian's style.", + "Shohreh Aghdashloo portrayed as an Iranian woman in a refined and elaborate digital painting.", + "The image depicts Marvel's Thor, the God of Lightning, in artwork created by Nicholas Roerich.", + "Autumn birch forest with falling leaves resembling flying butterflies and dancing elves.", + "An Overwatch building with interesting shapes and a food stall, depicted in watercolor gouache paintings by Simon Stalenhag.", + "A half gold half marble statue of a beautiful woman and a skull in a renaissance style.", + "The image features a purple watercolor painting of a gown with floral accents and reflections in the water below.", + "Tsunade from Naruto in a white shirt depicted in highly detailed digital painting by artgerm, greg rutkowski, and alphonse mucha.", + "A character portrait of a random character in a smooth and detailed digital painting style, inspired by Metal Gear and various artists including Ruan Jia, Mandy Jurgens, William-Adolphe Bouguereau, and Artgerm.", + "A head-on painted portrait of a Korean woman dressed as a D&D wizard in medieval robes, with intricate and elegant details in a fantasy style.", + "A forest scene in the morning light created by Chiho Aoshima.", + "A fantasy female warrior portrait featuring a beautiful face with shining eyes, crystals, and plants, in a realistic oil painting style with dramatic and cinematic lighting.", + "A beautiful, ultra realistic cyborg figure strikes dramatic poses in a post-apocalyptic, cyberpunk Tokyo, in an intricate and highly detailed sci-fi digital painting.", + "A portrait of a woman with cat-eye glasses and a slight smile, resembling a mix of Asa Butterfield and Pam Beesly, inspired by Gustav Klimt.", + "A woman and her dog sit on a tree and watch the sunset in a digital painting by artgerm, greg rutkowski, and Alphonse Mucha.", + "A head-on centered symmetrical painted portrait of Mahira Khan as a D&D wizard wearing intricate fantasy robes.", + "A digital painting of a young pirate with sharp features and a piercing gaze.", + "A theatre access corridor with 3 doors and vibrant, impressionistic colors in a fisheye lens view.", + "A unique painting.", + "A digital painting of a biology sea monster encountered by a sailing ship in the deep and dark sea.", + "A painting by Rembrandt depicting a dark, stormy night at sea with large waves and a small pneumatic boat carrying illegal immigrants.", + "An assemblage by Boris Vallejo and Squeak Carnwath depicts a blooming orchard in a stormy Indonesian landscape.", + "A portrait painting of a red-haired, smiling woman in a green dress against a golden background with intricate patterns.", + "Portrait of a shaman with intense emotions.", + "A painting featuring beefy men as professors in a classroom by Gaston Bussiere, Craig Mullins, Greg Rutkowski, and Alphonse Mucha.", + "The image depicts colorful liquid ink and oil bubbles creating a harmonious marble abstraction.", + "A digital painting of Venus goddess in sci-fi armor with a style inspired by Sandro Botticelli.", + "The image is of Roman ruins featuring silver and gold artifacts, depicted in hyper-detailed art style by artists Greg Rutkowski and Gustave Dore, and has been shared on various online platforms including Artstation, Worth1000.com, CGSociety, and DeviantArt.", + "Portrait painting of Jack Elam, a rugged cowboy gunfighter, in a fantasy costume against a red background.", + "An oil painting of Aquaman doing his taxes in a dark room with a lonely candle illuminating the room.", + "A landscape with an art nouveau building.", + "A neoclassic painting of a box of radiation featured on ArtStation.", + "A Native American woman is wearing an elaborate and intricate headdress in a digital painting.", + "A yellow noir wired neon robot Kerberos without memory or feelings, portrayed as a god in an oil on canvas painting by Dave McKean and Esao Andrews.", + "Psytrance artwork by Ryohei Hase.", + "A painting depicting a person feeling sad and depressed during rainy weather.", + "An abstract wallpaper with a portrait design.", + "A painted portrait of Persephone in ancient Greece with intricate detail, iridescent coloring, and golden hour lighting.", + "Chrome spheres on a chromatic cube by Ayami Kojima and John Jude Palencar.", + "Portrait of Jason Isaacs as a florist with a long shot perspective, inspired by Van Gogh.", + "A flock of red balloons flying up against a light blue palette.", + "The image is a drawing of a skeletal, frail figure driving a chariot pulled by two skeletal animals.", + "A master gouache painting of ships docked at the harbor by Claude Monet.", + "A digital painting of a cat sitting on a bench watching a black hole in the sky.", + "A surreal image that looks like a dream.", + "A digital painting of an elegant elemental entity with vivid colors and strong rim lighting.", + "A painting of a school building done in the style of Vincent van Gogh, with a prison van parked outside.", + "A colorful and stylized portrayal of Paddington Bear inspired by the art of Keith Haring.", + "A symmetrical painted portrait of Elisha Cuthbert wearing steel armour as a templar with blonde hair and intricate detailing, in a dramatic lighting style reminiscent of Artgerm and Anna Podedworna.", + "image of a building in downtown New York designed by M.C. Escher.", + "The image is a painting of a building, with a red cloud and a tumultuous sea in the background, created by Greg Rutkowski and Thomas Kinkade.", + "A digital painting of a mockingbird on a branch by artists Jacqueline E, Tafy, and Bo Feng.", + "The image depicts the god dreaming at the end of time.", + "A dark and ominous house with looming lightning in the background.", + "Classical romantic painting of Hatsune Miku with blue hair.", + "Matte painting of a wizard's study room with dynamic centralized perspective.", + "A cubist painting by Gaston Bussiere, Craig Mullins, and J.C. Leyendecker.", + "Studio portraits of Innsmouth ocean-dwellers with Lovecraftian, mutant, fishmen features in a fine art, black and white style with a dark and eerie atmosphere.", + "A portrait of Clint Eastwood done by Bill Sienkiewicz.", + "A sandstorm with vibrant colors that create a psychedelic effect.", + "Galactus devouring planet earth, depicted in an artwork by Francisco Goya.", + "A fractal cyborg with biomechanical parts emitting smoke from its face, depicted by Alan Bean.", + "Airbrush drawing of the Joker in hyper-detailed style reminiscent of Greg Rutowski.", + "A man with a gray and green mohawk wearing a brown tank top and headset in a portrait by Martin Ansin.", + "The image is a painting of a wood bridge with an Atlantis Zeus statue and a Grec temple in the background, adorned with ivy plants and multicolor roses.", + "A digital art painting by Greg Rutkowski of a detailed and cinematic portrait of a male elven man in a black cloak with a fantasy vibe.", + "A painting of a vampire woman wearing a red royal dress and jewelry, with sharp fangs.", + "A stunning painting of a red-clouded morning sky over a city and bridge, reflecting on a sea below.", + "Portrait of Hulk Hogan with a city background, painted in detail and with epic lighting, by Bouguereau.", + "A disheveled owl is perched on a pine tree.", + "A woman in transparent robes standing in a shopping mall, as depicted in oil paintings by Hopper, Beksinski, and Barlowe.", + "The image is a painting of Michael Scott sitting on his desk by artists JC Leyendecker and Phil Hale, with angular brush strokes and a vintage, painterly feel.", + "The image depicts a laughing Tiefling Pirate with purple skin and intricate, highly detailed jewelry, with beautiful eyes and an elegant appearance, in a fantasy setting, as a digital painting by Douglas Shuler.", + "A Latin woman in a red riding hood costume has a mechanical wren-bird-robot on her shoulder in an oil on canvas portrait by Yoji Shinkawa and Dave McKean.", + "A painting of a Native American warrior woman with blue eyes and silver armor by Jon Foster.", + "A portrait art of a woman with red hair, without a face.", + "An artistic depiction of ghosts and paranormal.", + "A psychedelic painting of a fantasy space whale.", + "A person is hugging a large white animal in a detailed fantasy painting by Krenz Cushart, a Pixiv contest winner.", + "The image shows two circular figures, one black and one white, overlapping to represent the concept of complementary forces in life.", + "There is a huge cat beside a small house, amidst wheat field harvesting, with a large tree, under a blue sky, in a Simon Stalenhag matte painting on Art Station.", + "Oil painting of a nebula creating planets in the style of Caravaggio.", + "Portrait of an alien aristocrat in XVI century Dutch clothing, painted in oil on canvas by Rembrandt van Rijn.", + "\"Fine art exhibit in a white cube by Marcel Broodthaers.\"", + "Mona Lisa with added mustaches, works of art.", + "Psytrance artwork by Lisa Frank.", + "The image showcases a collection of stylized candy designs and RPG assets created by Takeshi Murakami.", + "Oil painting of a Victorian waif under a street lamp in a dark alley.", + "A death metal album cover by Eliran Kantor.", + "Sherlock Holmes and Watson replace the typical figures in Grant Wood's American Gothic painting.", + "Hand-drawn portrait sketch featuring Wyatt Earp, Doc Holliday, Frank Sinatra, and Brad Pitt with pencil shadows and high detail, currently trending on ArtStation.", + "A peaceful, nature-filled landscape with vibrant flowers and trees and a serene cloud-filled sky.", + "An Overwatch building with interesting shapes and forms, including a food stall, depicted in a detailed watercolor gouache painting by artist Simon St\u00e5lenhag.", + "An astronaut with a Chinese dragon head wears armor and a helmet, in a piece by Bouguereau.", + "A broken videogame console with a colorful and compelling painting.", + "A human skeleton in a suit is standing in a colorful meadow of flowers, depicted in a detailed painting by Ren\u00e9 Magritte.", + "A digital painting of an evil geisha in a bar.", + "A Sasquatch stands near a Native American totem pole against a scenic mountain backdrop in a digital painting.", + "A Chinese wuxia walks on stone steps towards a stone gate leading to a dark cave in a beautiful landscape painting featuring a temple, a turbulent lake, waterfall, fog, and a single rainbow.", + "Two cosplay girls with black hair, fully tattooed bodies, wearing fishnet corsets and holding whips, with symmetrical, detailed faces and painted by Tom Bagshaw.", + "The image depicts two complementary forces that make up all aspects and phenomena of life, by James Gurney.", + "A oil portrait of a young female scientist holding a small glowing black hole in a laboratory.", + "A painting of a castle room with colorful fungi, stone walls and floor visible, and a sorcerer in blue robes floating in an orb of blue light, by Greg Hildebrandt and Tim Hildebrandt with a mysterious and mystical mood.", + "A solar eclipse is depicted over a field of grass and flowers with a small forest in the distance, as a matte painting on Art Station by Simon Stalenhag.", + "The image is a piece of art created by Livia Prima and is described as wonderful and beautiful.", + "A digital painting of a cyberpunk woman in 80s fashion and high heels, with intricate details, created by artgerm, greg rutkowski, and alphonse mucha.", + "An English woman plays the lute with a slender neck and long dark hair in a painting by William-Adolphe Bouguereau.", + "An eerie, colorful film noir scene depicts a mysterious Tel Aviv street at dusk.", + "A beautiful girl posing dramatically, with stunning eyes and features, by Davinci on Pixiv.", + "A fantasy painting of a male warrior with two swords defending a castle wall against an approaching army.", + "The image depicts bicycles parked in a park, with an intricate and elegant neo-rococo expressionist style and a touch of orientalism, created through a digital photorealistic painting technique.", + "Man crosses hellish water with creatures and souls, by Zeng Fanzhi.", + "An oil painting of a pet rat as an English professor lecturing in a university classroom.", + "The image depicts a family portrait of a revolutionary leader in Chinese contemporary art by Zhang Xiaogang.", + "A painting by Jacques Louis David featuring a whiskey bottle, a cigar, a Panasonic video camera, and a glass orb.", + "A detailed digital painting of an ancient overgrown statue in a clearing, with vibrant colors and mystical lighting.", + "A doll's house featuring the symbolist theme of icosahedrons and painted in the styles of Charles Gleyre and Johannes Vermeer.", + "Male playing piano with audience surrounded by paintings by Gaston Bussiere, Craig Mullins, and J.C. Leyendecker.", + "Description, Techno artwork by Rob Hefferan.", + "A hand with lavender nail polish is depicted against a green wallpaper background by a famous realist painter from the 19th century.", + "A kirigami building surrounded by a jungle, featuring dichromatism and volumetric light, with intricate details, created by Remedios Varo Uranga.", + "A deep retro scifi cave with dramatic cinematic composition and beautiful lighting, portrayed in a desaturated and psychedelic oil on canvas masterpiece by artists Tim Hildebrandt, Wayne Barlowe, Bruce Pennington, Donato Giancola, and Larry Elmore, trending on Artstation and featured on Pixiv.", + "Two girls holding hands watching the world burn with fire in an old botanical illustration style.", + "A hyperrealist portrait of Jaina Proudmoore on a colorful planet.", + "Dragon with the body of a dragon and the head of Macho Man Randy Savage by Jeff Easley.", + "A digital painting of The Last Supper with cats as the characters.", + "A woman lies down surrounded by piles of paintings by Greg Rutkowski.", + "A kirigami building in a jungle, with volumetric light, designed with intricate and ornate details by Remedios Varo Uranga.", + "Person-robot hybrid painting, featuring action lines and visible brushstrokes, by artist Phil Hale.", + "a digital painting of an elegant elemental entity with vivid colors and intricate details.", + "A knight in black armour holds a detailed silver sword in front of a castle.", + "Bob Ross painting Mario on an easel in his office.", + "Aoshima's masterpiece depicts a forest illuminated by morning light.", + "A green landscape by Makoto Shinkai.", + "A painting of an epic cinematic scene from Ramayan by Beeple.", + "A dramatic sea shell artwork by Alex Grey.", + "A gouache of a giantess in school uniform standing in a city, with an anime style and created by various artists including Ilya Kuvshinov and Magali Villeneuve.", + "Leo zodiac artwork in mystic style by Brian Froud.", + "A modern painting depicting the merging of technology and nature in hyper-realistic detail.", + "A portrait of a demonic cult leader priestess wearing a crown of blood and snickerdoodles on her calves, painted with vibrant shades of gray and black brushstrokes.", + "Portrait of a person with Cthulhu features, painted by Bouguereau.", + "Description, An artistic rendering of a cosmic portal with a beach at dusk on the other side.", + "The image depicts the art style of Masao Saito.", + "A painted portrait of a blonde elf ranger in profile, with a beautifully backlit and swirly vibrant color scheme.", + "A painting of a fungus-filled castle room with colorful fungi overgrowth and glowing spores, lit by warm sunset light, by Greg and Tim Hildebrandt with a mysterious, mystical mood and highly detailed imagery.", + "Portrait of Jason Isaacs as a florist, styled similarly to Van Gogh's paintings, shown from a distance.", + "Techno artwork by Ed Roth.", + "Someone attempts to paint an artificial artwork using a new tool.", + "A portrait painting of a muscular bloodied Indian woman, wearing a sari and jewellery, with a lower back tattoo, seen in side profile and high detail.", + "A portrait painting of a muscular Indian woman with a bloodied lower back, wearing a sari and jewelry and featuring intricate details.", + "A cat sits next to a small house with a red roof and a big tree, with a blue sky in the background, in a painting by Simon Stalenhag.", + "A garden filled with colorful flowers and plants, as painted by Robert Venosa.", + "An illustration by Jim Woodring depicting a Frank mass grave.", + "The image is a watercolor painting of a futuristic Overwatch building with a food stall, interesting shapes and forms, and megastructures.", + "\"Pam Beesly is depicted as a surprised time traveller in Edmund Blair Leighton's painting.\"", + "A gothic black panther is depicted in a highly detailed, anatomically correct manner, with dramatic lighting, on an oil canvas.", + "Oil painting of a beautiful female cyborg with wire hair, golden details, and opal crystals, standing amidst plants in a mystical forest.", + "A bamboo artwork in the style of Hiroaki Tsutsumi.", + "A brownstone building located in a forest setting, painted by Eytan Zana.", + "A philosopher king sits on his throne lost in contemplation, depicted in a highly detailed DnD portrait by Raffaello Ossola and Ross Tran.", + "Painting of a satellite station with floral ornaments.", + "Odin fights Fenrir with a spear in a detailed and realistic painting by Andreas Rocha.", + "A coffee cup with Hundertwasser's design.", + "A head-on centered symmetrical painted portrait of Katrina Kaif as a D&D Purple Mage wearing intricate fantasy robes.", + "A masterpiece.", + "A queen with red hair and a green and black dress stands veiled in a highly detailed and elegant digital painting.", + "An anime portrait painting of an attractive Asian schoolgirl with her sugar glider by Gaston Bussiere.", + "Realistic image of cute fairies by multiple artists, featuring rich and deep colors.", + "A close-up oil painting of a fashion model looking at a melting cyborg face, dressed in black robe, with classicism and 80s sci-fi hyperrealism style.", + "The image depicts two opposing yin-yang symbols.", + "Oil painting of Tracer from the game Overwatch standing in a grassy field with a peaceful atmosphere, smiling and surrounded by light rays.", + "Life and death are depicted in a single image.", + "Portrait of Seiko Matsuda in the 80's by Sergey Kolesov on Art Station.", + "Painting of Durdle Door in Starry Night style.", + "Cross section of an apple in a limited neutral palette with a beautiful graphic design and a painterly style.", + "A silver surfer is depicted floating in space in an artwork by Edward Hopper.", + "A girl wearing a Pikachu hoodie holds a Nintendo Switch in a digital painting by artgerm, Greg Rutkowski, and Alphonse Mucha on Artstation.", + "Portrait of Jokowi by Basuki Abdullah and Raden Saleh with a Banksy and Kentaro Miura influence, currently popular on ArtStation.", + "Medieval traditional pattern.", + "Greg Manchess painted a medium shot portrait of Harley Quinn in armor as an Overwatch character with bold shapes and hard edges, resembling street art, and is trending on ArtStation.", + "An attractive male stands in winter against a neon light backdrop with elements of wuxia in a painting by Bussiere, Mullins, and Leyendecker.", + "A hyperrealist portrait of a fairy girl emperor wearing a crown and long starry robes.", + "Psytrance artwork by Andre Francois.", + "A painting by Frazetta depicting a lion as a barbarian hunter, with textured details, a cyan graveyard and a dramatic moonlit sky.", + "A digital painting of a cartoon shop environment surrounded by five fantasy environments, with a fat brush concept sketch by artist BD Enki Bilal.", + "Edelgard from Fire Emblem depicted in Artgerm's style.", + "Monalisa painting a portrait of Leonardo Da Vinci.", + "A beautiful girl stands in front of a dark wallpaper in a painting by various artists.", + "Rick and Morty depicted in a vintage tintype style.", + "An image depicting the concept of yin and yang.", + "A female archer elf leads a group of adventurers through a forest of crystal trees in a fantasy matte painting.", + "A soldier in orange armor with a mask holding a sniper rifle, illustrated by Frazetta and Mohrbacher.", + "A psychedelic explorer inside an ancient temple surrounded by hauntingly surreal paintings.", + "A detailed painting of an octopod astronaut in a jumping float pose holding a plasma spear with iridescent and diamond materials.", + "A digital painting of a ninja gaiden girl in an armored dieselpunk wardrobe at snowy fuji mountain moonlight.", + "A group of winged fairies playing cards on a table in a moonlit forest by a pond filled with water lilies, artwork by Ida Rentoul Outhwaite.", + "A forest painted in gouache with a morning light.", + "An illustration of an angry cat wearing a chef hat baking cookies in the style of ukiyo-e.", + "A portrait painting of Priscilla from Claymore with intricate details and an eerie feel, created by Artgerm, Greg Rutkowski, and Alphonse Mucha.", + "Portrait of a beautiful redhead archer in high fantasy style.", + "An oil painting by Rembrandt depicting a muscular cat wielding a weapon with dramatic clouds in the background.", + "The image features a purple flower with a reflective surface, surrounded by mystic and glacial elements.", + "An oil painting of a gothic horse.", + "A bald general with an angry expression in an intricately detailed and elegant digital painting.", + "The image depicts a broken heart split in half, with one side tilting upwards resembling a waving tail-coat.", + "A fluffy owl sits atop a stack of antique books in a detailed and moody illustration.", + "A digital painting of a magical ritual location with volumetric lighting and elements from various artworks and games.", + "A depressed gingerbread man painted in a lowbrow pop surrealism style by Fernando Botero, Mark Ryden, and Hikari Shimoda in Candyland.", + "An oil painting of a duck by Vasiliy Rabchenko.", + "A digital painting of a furious woman with intricate details in a cyberpunk-inspired setting featuring neon lights and sweat drops.", + "A Genhis Khan death metal image.", + "A black canvas by Karl Gerstner.", + "Artistic depiction of ghosts and paranormal.", + "A painting of Henry Cavill by Yoji Shinkawa.", + "A painting by Jean-Michel Basquiat of the head and shoulders of a strong black African man.", + "A train crosses a trestle bridge in the mountains in an optimistic and vibrant illustration.", + "Golden cows in flight.", + "A painting featuring a woman wearing virtual reality glasses and a bird, created by Dave McKean and Ivan Shishkin.", + "Portrait of Beatrice Dalle by Jeremy Phil Hale and Casey Baugh, trending on Pinterest.", + "A portrait painting of a Red Borzoi Dog wearing a red beret as an Overwatch character.", + "The image depicts two opposite forces that interconnect and govern all aspects of life, created by Wes Anderson.", + "A portrait of a humanoid frog dressed as a wizard, holding magic trinkets painted by Craig Mullins.", + "Folk horror painting of dead pines with eerie and creepy atmosphere.", + "The image features a Cubist style and is showcased on various art platforms including Instagram, with influences from artists such as Frank Stella, Beeple, Giger, Kopera, and Zawadzki.", + "The image represents the feelings of isolation, struggle and search for connection of someone who has grown up in a dysfunctional and toxic environment, living in a society that doesn't seem to understand or care about their struggles.", + "A gouache illustration of a girl in school uniform standing on the edge of a tall building.", + "An abstract bubble.", + "Zeus, god of thunder, poses powerfully in the ocean wearing a white robe.", + "There is an image that represents the balance between yin and yang.", + "Redhead punk girl playing electric guitar in an oil painting masterpiece.", + "A cyberpunk-inspired digital painting featuring intricate and highly detailed city architecture.", + "A blind monk wearing an orange robe stares out the window of a spaceship in a dramatic lighting as depicted in a matte painting.", + "A painting of a cyberpunk skyscraper with floral ornaments by Andreas Achenbach.", + "A cottage designed by Salvador Dali is surrounded by blooming forest, with a nearby stream in spring.", + "Yoshitaka Amano's painting of a young lion beastman with a white mane, wearing complex fantasy clothing and huge paws, at a medieval market on a windy day.", + "An unfinished romantic portrait with beautiful brush strokes in black and red by Richard Schmid and Sargent.", + "Lord Shiva creating a hybrid elephant-human figure with intricate and vibrant detailing.", + "A watercolor painting of a frog on a lily pad.", + "The portrait depicts a female figure with a slight smile and cat-eye glasses, resembling a mix between Asa Butterfield and Pam Beesly, and is created in the style of Gustav Klimt.", + "A woman in a bathing suit captured in an ink drawing by Sam Bosma with outlined and stippled details.", + "A woman wearing a raincoat stands on a city street during nighttime rain in a digital painting with a film noir and Lynchian atmosphere.", + "A Drew Struzan painting depicting Saul Goodman, Mike Ehrmantraut, Kim Wexler, and Gustavo Fring in an 80s movie poster style with rim light.", + "Norman Rockwell's western painting displays a Native American on horseback atop a hill.", + "A cyberpunk giant robot depicted in oil on canvas by Simon St\u00e5lenhag and Umberto Boccioni.", + "A head-on symmetrical painted portrait of Elisha Cuthbert as a paladin in ornate iron armor, with a tarot card and stained glass in the background, in the art nouveau style.", + "The image is a painting by William-Adolphe Bouguereau of Alfric Overguard, a calm and strong black man with alert eyes and a wide nose.", + "Artwork in the style of the archangel.", + "A smiling sorceress holds a tabby cat while wearing a winged helmet in an artwork by William-Adolphe Bouguereau.", + "A portrait of a female alien Xenomorph Queen by HR Giger, Greg Rutkowski, Luis Royo, and Wayne Barlowe.", + "A portrait of an old woman holding a stack of green paper bags, with a white paper bag over her head and dressed in red paper bags, in highly detailed artwork by Edward Hopper, Zdislav Beksinski, and Wayne Barlowe.", + "Portrait of woman painted with colorful gouache impasto.", + "Illustration of a brunette girl crossing a small river in a forest, painted by Goro Fujita.", + "A woman in overalls with a German wow helmet, glasses, long red hair, and a forest background in a highly detailed portrait digital painting by artgerm, Greg Rutkowski, and Magali Villeneuve.", + "A group of fairies playing cards on a table in a moonlit forest next to a pond filled with water lilies, artwork by Ida Rentoul Outhwaite.", + "Castle Grayskull depicted in a sunset scene by Frazetta.", + "The image is a painting by Thomas Kinkade that is atmospheric and breathtaking.", + "Mixed media art of a flooded grave in Italy, with dirty water and trash floating inside, surrounded by a fence.", + "A depiction of human-like creatures with scales and dragon-like features in a menacing oil painting by Frank Frazetta.", + "Jerma 985 depicted in an art piece with volumetric lighting.", + "A painting of two people standing on a checkered floor, serving as an album cover on Tumblr.", + "An abstract representation of the yin yang concept by Paul Lehr.", + "An angel is depicted falling into Andromeda in a Renaissance-style painting alongside shots of other classical painters.", + "A solar eclipse is depicted in a field in Iceland with a lone tree swaying in the wind, created as a matte painting by artist Simon St\u00e5lenhag.", + "Pippi in Keith Haring style.", + "The image features a castle surrounded by a dreamy garden with roses and a cloudy sky in the background.", + "A dystopian future city with a nuclear fallout, created through a matte painting with Octane Render.", + "Oil painting portrait of a fashion model with a black scarf covering her face and rectangle shapes over her eyes, set against a dark background with elements of classicism and 80s Japanese sci-fi art.", + "Oil painting of coffee beans by Frida Kahlo, Van Gogh, Monet, Picasso, and Dali.", + "A highly detailed goddess portrait with a focus on the eyes.", + "A painting of the sun tarot card by Michelangelo Merisi da Caravaggio.", + "An astronaut in a highly detailed digital painting depicting the end of the universe, illustrated by James Jean.", + "A woman is depicted crying amidst yellow and blue clouds in a piece by Kim Keever.", + "A digital painting of a female warrior adorned in intricate armor costumes with light and shadow effects, created by artist Wlop and shared on Art Station.", + "Abstract zen design.", + "A landscape featuring a building in the style of Peter Mohrbacher.", + "Psytrance artwork by Jhonen Vasquez.", + "A painting of Kermit the Frog depicted as a Catholic pope by Michelangelo Merisi da Caravaggio.", + "Portrait of Marilyn Monroe as a queen.", + "A digital painting of six fantasy environments in an indoor setting, depicted in a cartoonish comic book style.", + "Over the shoulder view of Leonardo da Vinci painting Mona Lisa.", + "A Nordic queen wearing an ornate cloak and crown.", + "The image is a centered portrait of Mahira Khan as a D&D wizard wearing intricate and elegant fantasy robes, created as a highly detailed digital painting in the styles of Artgerm, Anna Podedworna, and Alex Ross.", + "A black cat sits under a crescent moon at night, with multiple artists credited for its creation.", + "Portrait of a princess wearing black tar by various artists.", + "Jean-Michel Basquiat's painting depicts a strong black African man's face in head and shoulders view.", + "A portrait of a crying gothic demon goddess by multiple well-known artists, depicted in oil on canvas.", + "A painting by Dimitra Milan featuring a woman posing with a tiger against a dreamy cloud-filled backdrop.", + "Oil painting of a Virginia opossum playing guitar in the style of Michael Whelan.", + "A painting of the Macy's Thanksgiving Day Parade.", + "A picturesque medieval hobbit home surrounded by lush forest features a bridge over a creek, a chimney emitting smoke, a waterfall in the background and flowers.", + "Image of \"The Demonology of Modern Politics\" by Jean Giraud depicting a haunting and dark interpretation of the political world.", + "A beautiful Arabian angel wearing a niqab and adorned with jewelry by various artists.", + "An extraterrestrial celebration of new life on an ancient post-apocalyptic planet featuring vivid and colorful creatures from the Jim Henson creature shop, depicted in a cinematic oil painting with highly detailed illustrations.", + "A gouache painting by Claude Monet of ships docked at the harbor.", + "The image depicts an artwork created in the style of Masao Saito.", + "A duck oil painting.", + "A painting of a toy tiger by Murakami.", + "An attractive male in a wuxia setting illuminated by neon lights, depicted in a painting by Gaston Bussiere, Craig Mullins, and J.C. Leyendecker.", + "A satyr sits on a tree trunk in a forest in a beautiful painting by Caravaggio.", + "Techno artwork by Albert Bierstadt.", + "A cyberpunk character from the 1989 NES Commodore 64 box art, painted by Keith Parkinson in oil on canvas, appears cozy.", + "A head-on symmetrical centered painted portrait of Elisha Cuthbert as a paladin, wearing ornate iron armor and medieval robes in a fantasy tarot card style with intricate details and elegant design.", + "A gouache painting by Claude Monet showing a group of ships docked at a harbor with detailed composition.", + "Greg Manchess painted a portrait of Baby Yoda as an Overwatch character in an asymmetrical, organic style with bold shapes and hard edges, resembling street art - trending on ArtStation, created with the help of Huang Guangjian, Gil Elvgren, and Sachin Teng.", + "The image depicts the Marvel superhero Thor, who is the god of lightning, in artwork created by Nicholas Roerich.", + "An illustration of a futuristic interior hall with various furniture, sacred geometry, and a plant in watercolor gouache style.", + "A city painted in an afrofuturistic style by Gaston Bussiere, Craig Mullins, and J.C. Leyendecker.", + "A Renaissance painting depicting a man encountering the devil.", + "A portrait of Rafael Nadal in Van Gogh's style.", + "David Choe created Transylvanian folk art-inspired graffiti.", + "A beautiful artwork featuring Hotarubi in the forest of fireflies.", + "Psytrance artwork by Greg Rutkowski.", + "Portrait of a young man with scars on his brown skin wearing a black turtleneck, by Martin Ansin.", + "Close-up portrait of a girl from the 80s, featuring artwork by multiple artists.", + "Persephone with pomegranates.", + "A drawing of a seraphim.", + "Palette knife painting of a vibrant and expressive woman's face by Francoise Nielly.", + "A girl covered in blue and pink leaves and petals, backed up against illuminated light, in the style of Stefan Kostic.", + "Image of rock band playing in hell with clean face on fire, featuring neo-gothic, Beksinski-inspired painting, and art by Takato Yamamoto.", + "The image depicts the two complementary forces of life.", + "A painting of Tifa Lockhart, a character portrait in a sunflower garden.", + "A muscular woman carrying Earth on her back against a desert background with complementary colors.", + "A portrait of Sri Sultan Hamengkubuwono IX wearing a traditional Javanese blangkon and batik, done by four different artists.", + "A 3D painting of a serious female sorceress in a stormy weather, with an anaglyphy effect.", + "A red-haired girl stands among the rubble and ruins of the Chernobyl power plant, surrounded by flowers and vines, in a hyperrealistic oil painting by Greg Rutkowski.", + "A highly detailed digital painting of Edward I of England in a Dungeons & Dragons inspired fantasy portrait by Artgerm, Greg Rutkowski, and Magali Villeneuve.", + "A painting of a teen witch with red hair.", + "Portrait of an Egyptian queen in a gold dress with a revealing neckline.", + "A digital painting depicts a man playing the electric bass with curly hair, square glasses, and a striped t-shirt, in a cinematic still style with black and red colors.", + "Japanese hot spring interior with lanterns, koi fish and bonsai trees in a painting by Greg Rutkowski and Craig Mullins.", + "A painting of Kermit the Frog as a Catholic Pope by Michelangelo Merisi da Caravaggio.", + "A soldier looks up at a horned giant woman in a neoclassical scene.", + "The image depicts a traditional Japanese geisha in a kimono, standing by a lake with snowy mountains in the background, and a beautiful sunset reflected on the water's surface.", + "Portrait of female physicist Chienshiung Wu in uniform and equipment with strong eyes, in an artistic and hyperrealistic style.", + "A gouache illustration of a girl in a school uniform standing on a tall building's edge, in a Morandi color scheme by Krenz Cushart on Art Station.", + "A Renaissance painting depicts a group of indigenous people burning the WhatsApp logo in a tribal style, resembling the work of Veronese.", + "Drawing of the Chad stable diffusion.", + "Close-up portrait of a teen girl wearing a leather jacket, depicted in an oil painting style with dramatic lighting.", + "A close-up of \"Vampire Kiss\" artwork by multiple artists.", + "A centered waist-up portrait of an angel with vibrant colors and a bokeh background.", + "An oil painting of a man in a black robe with no face in a style that mixes classicism with 80s sci-fi.", + "A mixed media portrait painting with piercing gaze in dayglo pink and blue, featuring ornamental and elegant art nouveau fashion details.", + "A framed frontal picture of one face liking another, with artwork by multiple artists including Yoichi Hatakenaka and Masamune Shirow.", + "Close-up hyperrealistic oil portrait of a nurse fashion model with red lipstick, ginger hair, freckles, and a mix of classicism and 80s sci-fi inspired style set in complete darkness.", + "Oil portrait of Super Mario as a shaman tripping on mushrooms in a dark and detailed scene.", + "Patrick Nagel's 1980 digital fashion illustration.", + "A gouache painting of a giantess in a school uniform standing in a miniature city, with an anime style and fine details.", + "A painting depicting Alzheimer's disease by Diego Gisbert Llorens.", + "An ultra-realistic illustration of a bird god swinging a gold metal stick weapon, with a blue man face and yellow bird mouth, and intricate traditional Chinese elements.", + "Woman wearing samurai helmets, portrait by Bouguereau.", + "A high detail full body oil painting illustration of a stoic barbarian woman in a scenic background, with realistic proportions, and a sombre mood.", + "Morning light illuminates a forest in Chiho Aoshima's artwork.", + "An oil painting titled \"Broken Lord\" by Greg Rutkowski showing the character resting in a library.", + "A landscape with Kandinsky paintings inside a curvy smooth room by James Turrell.", + "A cinematic movie scene depicts Jackie Chan mutating into a botfly larva, created in a beautiful, detailed matte painting style with a horror theme.", + "A surrealistic painting showing a living room with abundant furniture by Jacek Yerka.", + "A pencil sketch of Danny Devito by Milt Kahl.", + "An oil painting of a mechanical circuit board astral bird mask with abstract surrealist forms, featuring powerful glowing eyes and mystical magic symbols by Yvonne McGillivray, Mandy Jurgens, and Michael Divine.", + "A woman in golden robes walks on the ocean amidst god rays and vibrant colors, in a style influenced by several renowned artists.", + "The image is a fine art portrait of a room filled with supermodel robot parts, with an art nouveau fashion embroidered style and a soft color scheme, focusing on the head in sharp detail with a fantasy element and soft blurred background light.", + "The image portrays Ophelia with a detailed and elegant face, featuring wonderful eyes, wearing an intricate dress, and created with hyperrealistic painting techniques.", + "The monk is speaking.", + "The image depicts a goddess amidst planets colliding, painted digitally with intricate detail by artist James Jean.", + "The image is a cinematic portrait of Walt Whitman depicted as a bodhisattva in the style of several famous artists, painted in oil on canvas or gouache with intricate details and desaturated colors.", + "A colorful illustration by Peter Chan featuring a green field with flowers, pink and yellow clouds at sunset.", + "A red-haired queen wearing a green and black dress and veil is depicted in an intricate and elegant digital painting.", + "A portrait of a couple in love in a living room surrounded by dark energy and a plant, with an artistic cover artwork style and various artist influences.", + "The Colosseum depicted as a garbage bin overflowing with trash in Rome, inspired by the art style of Matisse.", + "Imogen Poots depicted as a D&D Paladin RPG character, portrayed in a front-facing symmetrical painted portrait with global illumination lighting.", + "A digital painting of Homeworld, Pride of Hiiagara by Rob Cunningham.", + "A young girl with a red hat at night.", + "A landscape painting of Ship Rock by Carlos de Haes.", + "Maya Ali as a D&D sorcerer in elegant wizard robes, with a tarot card background, depicted in a head-on symmetrical portrait, painted in a cell-shaded style.", + "The image depicts a dark, haunted chamber in a royal library with various styles of artwork displayed, including anime-inspired brushstrokes and oil paintings with impasto textures.", + "A serene nighttime cityscape of Tokyo with lake reflections, fruit trees, and animals in the foreground.", + "An oil painting representing nothingness.", + "The image is an oil on canvas depiction of an atom bomb explosion in Mumbai from a grounded perspective.", + "A painting titled \"Cosmic\" by Teun van der Zalm.", + "A mixed media painting of Danny DeVito as Gollum in a dark cave.", + "Psytrance artwork by HR Giger.", + "A column of tired men in 1800s navy uniforms marches across a barren arctic landscape with no vegetation under an overcast sky.", + "A man kneels at the base of the Christian cross in a 1970s illustrated advertising art portrait with a limited, earth tone palette.", + "An art piece depicting Albert Wesker and Chris Redfield, painted by Gaston Bussiere and Jean Giraud.", + "The painting is called \"The Day's Watch\" by Rembrandt.", + "A watercolor painting of a galaxy in a jar.", + "Flag design for communist European Union featuring a hammer and sickle." + ], + "photo": [ + "A man taking a drink from a water fountain.", + "Fruit in a jar filled with liquid sitting on a wooden table.", + "A bathroom sink cluttered with multiple personal care items.", + "A smiling man is cooking in his kitchen.", + "A beautiful blue and pink sky overlooking the beach.", + "A man smiles as he stirs his food in the pot.", + "Several bikers are lined up in a parking lot.", + "There is no picture or image sorry sorry", + "A small car parked in by a vespa.", + "Several people around some motorcycles on the side of a road.", + "A black and white cat looking out a window over another cat.", + "A woman in a purple top pulling food out of a oven", + "Fighter jets on display in front of a museum.", + "An empty road with buildings on each side.", + "Two vespas parked next to a light post.", + "A peak into a long bathroom with a toilet, but no shower.", + "A face car driving past a parked motorcycle.", + "A computer monitor glows on a wooden desk that has a black computer chair near it.", + "a medium sized plane on an air port run way", + "A bicycle chained to a pole on a snowy day", + "A half eaten dessert and half empty cup.", + "A blue airplane in a blue, cloudless sky", + "A corner view of a kitchen with white appliances and dark wood cabinets.", + "a cat laying on the floor of a kitchen", + "A man and his dog riding on a bike. ", + "A bathroom with a toilet and sink inside.", + "A bathroom stall containing an empty toilet in it.", + "A brown and black dog sticking its head out a window.", + "A white busted up toilet sitting on it's side.", + "a hairy man lying on a bench besides a bush", + "A bunch of people waiting in line by a rail.", + "A counter in a coffee house with choices of coffee and syrup flavors.", + "Motorcycles parked on the sidewalk next to a road.", + "A dresser in a room that is painted bright yellow.", + "A person with his head out of a window while on a train. ", + "a tiled bathroom with a toilet and sink inside of it ", + "A man sitting in a chair, in a black and white photo.", + "A bathroom with clear glass shower door and tile floor.", + "A dog sitting in a bathroom with a urinal and a torn wall.", + "A white expensive car parked on top of a cement slab.", + "An airplane flying past the Moon in the sky.", + "A woman sitting under an umbrella in the middle of a restaurant.", + "A woman getting ready to cook some food in a small kitchen.", + "A car sitting in the middle of the grass in the rain.", + "A man and woman riding on the back of a motorcycle.", + "Foods are being put in to the mason jars", + "A man sitting on a bench in a lobby.", + "A motorcycle is parked next to the fire hydrant", + "A bike parked on top of a boat.", + "A for of four urinals mounted to a wall.", + "A couple of old fashioned oak wood dining tables.", + "A magazine with a couple of cat around a toilet on it's cover.", + "this is a very dark picture of a room with a shelf", + "Meat left out on the kitchen counter could spoil.", + "People leaning out the windows of a train as it goes through the countryside.", + "a small plane with a propellor sitting on a runway", + "An out house with the door opened sitting in a field.", + "A man sitting on a modern bench talking on a phone.", + "A woman wearing a hair net cutting a large sheet cake.", + "a toilet sits next to a shower and sink ", + "Chopped meat laid out on towels in a home kitchen, in preparation for cooking.", + "A cat standing on a toilet seat in a bathroom.", + "A small engine plane sitting on a runway.", + "A bicycle parked and leaning against a brick building.", + "there is a small kitten inside of a sink", + "A kitchen with a wooden table with a cat sleeping on top of it.", + "A small kitchen does have plenty of cabinets.", + "A white bathroom with a white toilet and sink.", + "a small propeller plane sits on a run way ", + "A white toilet sitting next to a large window.", + "A large fire place sitting next to a doorway.", + "A man wearing a black neck tie and glasses.", + "Large shower sectional of a bathroom in a brown and white photograph.", + "A bike sitting next to a brick wall in the open.", + "A group of motor bikes on a street.", + "there are many people trying to avoid the rain", + "The small, single engine airplane is parked on the tarmac. ", + "A woman eating vegetables in front of a stove.", + "The blue shower curtains are inside of the bathtub next to the toilet. ", + "A group of people with umbrellas standing around a white car.", + "A man hanging his head out of the side of a train.", + "A couple of small rooms in a house.", + "Personal computer desk room with large glass double doors.", + "A modern style bathroom with a large tub and shower and tile floor.", + "A golden bicycle with a basket next to a brick wall.", + "A man getting food ready while people watch.", + "A wire fence containing various hair clips with a building in the background.", + "a vintage photo of some people sitting on a bench ", + "An elderly lady pours some cups of tea on a tray.", + "A white jet airliner parked on a runway at night.", + "there is a chocolate cake and ice cream on a plate", + "An outhouse sitting in the middle of a field.", + "A bunch of birds that are sitting on steps.", + "A city filled with lots of tall white buildings.", + "A bathroom sink that is under a mirror.", + "there is a mirror and a picture on the wall ", + "A man that is sitting on a couch.", + "there us a woman and a young child sitting on a bench", + "A woman that is sitting under an umbrella.", + "A woman that is standing near an open oven.", + "there is a white toilet and a sink in this bathroom ", + "A group of people posing with festive items.", + "A woman in an orange vest and blue helmet riding a horse up a flight of stairs.", + "A group of people that are sitting on bikes in the street.", + "A group of motorcycles are parked next together.", + "A black motorcycle parked on a brick sidewalk next to a road.", + "Cars, people, buildings and street lamps on a city street.", + "there is a chef making food as people watch", + "An empty kitchen with lots of tile blue counter top space.", + "there is a man sticking his head out of a train window", + "a tiled bathroom with a toilet and scale in it ", + "A bunch of airplanes lined up in a row at an airport.", + "A desk sitting next to a showroom of cars in it.", + "An elderly man is sitting on a couch.", + "a bunch of glasses with some food inside of it ", + "A city at night filled with lots of traffic.", + "Careful bicycle riders add florescents to their clothes for safety in the dark.", + "this is a dark picture of a large kitchen", + "A white toilet sitting next to a shower in a bathroom.", + "A crowd of people watching an airplane on a runway.", + "A man sitting on a black and yellow bench on the phone.", + "A woman taking a photo over the shoulder of a man on a bike.", + "this kitchen has a white and black stove in it", + "The dirt bike has seen many hill climbs in its history.", + "A plane flies in the sky in front of a silhouette of a moon.", + "a cluttered room with a table and shelf on the wall.", + "there are many men playing soccer in a field", + "A woman forks vegetables out of a bowl into her mouth. ", + "A woman taking a picture of herself in a mirror.", + "A couple of men riding a motorcycle down a street.", + "Portable toilet in a wooden box area of a field.", + "A motorcycle bike leaning against a white trailer.", + "The view of a bathroom tub, shower, and toilet.", + "A bathroom with a toilet and a scale.", + "A jet flies in the distance with the moon in the background. ", + "there is a man wearing a suit sitting on a bench", + "A group of Navy cooks standing around a giant cake.", + "A group of people in suits standing in a kitchen.", + "A white toilet sitting under a window next to a chair.", + "A dog is staring at a picture on a flat screen tv.", + "A man is sitting on a public bench on a busy city street.", + "Man talking on personal cell phone on a yellow and black bench.", + "a counter top with food sitting on some towels", + "A bathroom with a sink, vanity and shower stall.", + "A view of a very dark lit kitchen from the other side of the room.", + "A motorcycle parked on a sidewalk near a street light.", + "There is traffic on a busy city street. ", + "a blue bicycle a blender sand and a person", + "Two people on a motorcycle with tone taking a photo", + "The cat is sitting on the old butcher block.", + "a dirt bike laying against a trailer in a grassy field", + "A tabby cat sleeping on a wooden island in an old looking kitchen.", + "view of tall city buildings with cars and people walking by", + "a man standing in front of a big display case of donuts ", + "Woman walking down the side walk of a busy night city.", + "A single propellor aircraft that is parked on an airport apron with vehicles and another plane in the background.", + "Woman eating an assortment of mixed vegetables in a bowl.", + "There are orange slices in canning jars without lids.", + "a bathroom with a stand up shower and tub.", + "some people driving down the road with their bikes ", + "A brown cat crouches and arches its back in a white sink.", + "A group of waiters standing in a line. ", + "A beach area with a bicycle that has a blender attached to the front, parked on the sand.", + "Meats being prepared for cooking on kitchen counter.", + "A woman sits under the sheet on a mattress on the floor.", + "Large dog looking at television show in living room.", + "A man driving a motorcycle with a woman holding a cell phone.", + "A young woman standing in a kitchen eats a plate of vegetables.", + "The motorcyclist in a helmet is looking over the side of a bridge. ", + "At night on a street with a group of a bicycle riders riding down the road together.", + "The woman sitting at the table looks bored.", + "The woman in the kitchen is tending to her food.", + "some people holding umbrellas and standing by a car in the rain", + "A bike parked in front of a doorway.", + "A person is riding his motorcycle on the dirt road.", + "A bathroom that has a door just for the toilet area.", + "Eight jars are being filled with orange slices. ", + "Woman with a motorcycle staring over a bridge at a wetlands. ", + "A bathroom has pink tiles and a black toilet.", + "A group of people holding umbrellas stand near a car.", + "there are two woman that are riding motorcycles ", + "Some men and women in white shirts and bow ties standing in a row.", + "A container of antibacterial wipes in a bathroom.", + "A monitor screen, printer, couch and chair in the room", + "A very dimly lite kitchen in someone's house at night.", + "this kitchen is very big and has wood cainets", + "A bunch of uncooked food on a counter.", + "this is a wood table in a cluttered kitchen", + "A bathtub that is in a bathroom under a wooden object.", + "A bunch of people standing around and posing for a picture.", + "A shelf of various cups and glasses mounted to the wall.", + "A man standing by his motorcycle is looking out to take in the view. ", + "A light that is on above a mirror.", + "a bathroom with a tub next to a fancy shower stall ", + "A automobile with multiple bicycles on a roof rack. ", + "this small bathroom has white sink and a toilet", + "A man standing behind the counter at a doughnut shop.", + "a toilet a tub some pipes and a window", + "A view of a table with a bunch of cakes and tea on it.", + "Large sized kitchen with a dining room section.", + "There is a cyclist riding above all the pigeons.", + "there is a woman that is cutting a white cake", + "Some people are enjoying time on a beach. ", + "Pile of strings and books next to a laptop computer.", + "A man is standing in front of a case filled with pastries.\n", + "A woman marking a cake with the back of a chef's knife. ", + "A bicycle that is stored in someone's closet in the apartment. ", + "A woman eating fresh vegetables from a bowl.", + "A large kitchen with a lot of cabinets and counter space.", + " a bathroom with a picture of a bookshelf above the urinals", + "Line of men and three woman standing in front of a kitchen.", + "kitchen with a wooden kitchen island and checkered floor", + "there is a woman staring in the kitchen pouring tea", + "this man is riding a board near a field", + "A man on a motorcycle riding in the desert.", + "A dining room with hard wood floors that is very fancy. ", + "A group of young bicyclists on a city street at night.", + "The bath tub and toilet in this bathroom are black.", + "Pots and pans that are on the side of a sink.", + "there is a all black motorcycle that is parked on the street", + "a black toilet in a wood floored bathroom", + "The jars on the table are full of oranges.", + "A cat sits on an open toilet in a bathroom.", + "a female standing in the bathroom and taking a photo with her phone", + "two men on a scooter riding down the roadway", + "A bathroom, showing the shower, toilet and sink.", + "A wooden table sitting in the center of a kitchen.", + "a jet airplane sitting on a runway next to a building", + "There is an airplane on the runway in the distance.", + "a group of people sitting on the sand with a lake in the background", + "A small powder room with a sink and vanity, toilet, mirror, and an empty towel bar.", + "Various kitchen dishes are arranged on many different shelves. ", + "A man on a bicycle above spectator stands, where pigeons graze.", + "There is a cat standing on the toilet seat.", + "A man in a helmet and jacket riding a motorcycle in the desert.", + "Many objects are sitting on a counter in a kitchen.", + "a man sitting on a motorcycle in the desert", + "A line of urinals against a wall with bookshelves above.", + "A woman holding a colorful kite on top of a green field.", + "a bathroom with tiled floor and a circular window ", + "this is a bench out near a field", + "a bathroom view of a tiolet and sink ", + "A white sink sitting under a mirror next to a toilet.", + "A woman sitting at a table next to an umbrella.", + "A woman standing in a kitchen baking bread.", + "A series of shelves holding colorful glassware and dishes.", + "a guy in the desert sitting on his motorcycle", + "A kitchen in a camp with gear and coats laid out.", + "Three people sit on a bench looking out over the water. ", + "A person on a bike is next to a train on the tracks. ", + "a group of boys playing in a field next to a forrest", + " A blender sitting on top of a table.", + "A public restroom with toilet, sink and a grab bar.", + "A very tall clock tower sitting above a building.", + "five restaurant wait staff and two mangers ", + "A plane traveling down a run way, near the highway.", + "Two people standing in a small kitchen with an arched passage.", + "A man rides a motorcycle down a dirt road. ", + "i table filled with cups and a plate of food.", + "A motorcycle parked on a stone cobble road, in the sun.", + "A man standing in front of a bunch of doughnuts.", + "Open shelves hold an assortment of glasses, cups, and bowls. ", + "Old photo of man sitting on his motorcycle", + "Two people ride motorcycles down a city street.", + "a mirror a sink a toilet and a blue basket", + "A family riding their bikes next to the streetlight. ", + "A top down view of a bathroom with a scale and toilet.", + "Dessert for two is placed on a table.", + "An intersection with cars is pictured in this image.", + "A table topped with lots of food and drinks.", + "An old motorcycle with a side car attached.", + "the kitchen has a stove and sink with pots and pans", + "The urinals are sitting below the shelves full of books,", + "A cat sitting inside a sink in a bathroom ", + "A kitchen with a lot of kitchen furniture and accessories", + "A bicycle sits parked in front of a bookstore.", + "A colorful kite is ready for launch on a blue sky day", + "A cyclist pedals past a flock of birds perched on a grating.", + "A mirror shows another light in a background of a wonderful bathroom", + "a white table with sandwiches and cups of tea and people and sivlerware", + "A bike rider traveling down a road, in the desert.", + "Two kittens are cuddling and enjoying a soft pillow", + "A man on a skateboard rides down a narrow road.", + "A kitchen with a stove, table, cabinets, and other items ", + "A plane flies in the sky passing over the moon.", + "A kitchen with many of the appliances removed with blue and white tile.", + "A bathroom with sink, toilet, and bathtub and black and white floor tiles.", + "this is a red bike on a dirt path", + "A man waits to cross the railroad tracks as two trains cross.", + "A person rides an electric bike on a desert trail.", + "A bathroom with a sink and other items. ", + "this is a toilet and trash can and a sink", + "A stop sign sits in front of a billboard in a quiet area.", + "A bike parked in front of a book shelf.", + "A view of a kitchen with a burner top stove.", + "a black toilet some toilet paper and brown tiles", + "A boy wearing a suit riding a skateboard down the road", + "a bathroom with a big mirror above the sink", + "A TV sitting on top of a counter inside of a store.", + "a bathroom with a glass sink base with a bowl on top", + "A motorcycle with a flat rear tire sits in a workshop, while a person stands behind it, facing away from the camera.", + "A bathroom with a toilet, sink and shower stall.", + "A busy street with cars and buses on it.", + "some people an airport a runway and a jet", + "A man and a woman looking at cell phones.", + "A man with a fro riding a skateboard down a road.", + "A kitchen with and island and several counters in it.", + "A couple of sinks with brown tile and a decorative mirror.", + "A view of a messy room, with shelves on the wall.", + "The motorcycle is parked on the side of the paved road. ", + "a bathroom view of a stand up shower and toilet with a sink near by", + "A bathroom with a toilet and a scale on the floor.", + "a couple of people standing inside a kitchen.", + "There are a lot of cupboards and refrigerator in the room. ", + "Night is falling on an empty city street.", + "A vintage antique motorcycle sitting in a shop being worked on.", + "a black gray and white cat a toilet sink and mirror", + "The plane is taking off into the yellow sky.", + "A curly haired boy rides a skateboard down a road.", + "A woman is seen in the rear view mirror of a motorcycle.", + "A woman pouring coffee into cups on a counter.", + "a bike resting in the sand with a blender built on top", + "A PICTURE OF A KITCHEN WITH TILE COUNTER TOP", + "A walk in shower sitting next to a bath tub.", + "A city street filled with lots of traffic.", + "Person riding a four wheeler on a beach towards a bridge.", + "A PICTURE OF A BATHROOM WITH SLIDING SHOWER ", + "Black motorcycle with a side car in the middle of the street. ", + "Group of people standing around each other in the middle of a city street. ", + "a kitchen with a stove sitting on a hard wood floor and cabinets", + "A man walking around with his dog and sheep.", + "there are two cats that are laying inside of a tub", + "there is a small dog that is looking threw the glass", + "A man taking a picture of himself in front of three huge beer bottles", + "A PICTURE OF A MAN WITH BEER BEHIND HIM ", + "there is a small out house that is made of wood", + "Several people smile for the camera at night.", + "this is a bathroom that has a sink and toilet", + "A black dog sitting in front of a TV.", + "A PICTURE OF ALL WHITE IN A BATHROOM ", + "A person wearing a safety vest rides a horse up the staircase.", + "a man sitting in a chair on a tiled floor next to a heater", + "this is a clock on top of a tower", + "Bike leaned against a wall of books inside and establishment.", + "this is a group of people standing near a river\n", + "a bathroom with a toilet and sink and a bath tub sitting on a hardfloor", + "A person rides a vehicle on the beach.", + "this kitchen has a white stove and all white cabinets", + "Three people sit on a bench together facing away.", + "Two cats sitting together in an empty bathtub.", + "A puppy staring through a red sectioned window.", + "a toilet sitting on a tiled floor in enclosed bathroom stall", + "People are walking and cars are driving in a city.", + "A kitchen with a stove, microwave and cabinets.", + "Two teams compete at a sport in a park.", + "A bathroom with a toilet, counter, and mirror.", + "A lidless toilet is shown caked in dirt or other filth.", + "A tiled bathroom is shown with a compact style toilet.", + "there is a police man riding a tav on the beach", + "A small white car with a small white dog riding in it.", + "a toilet a sink a towel a light and a mirror", + "A man riding an ATV next to the ocean.", + "Two people riding a motorcycle near a group of people.", + "A group of people walking down a walkway.", + "a woman a white mat and pillow and white wall", + "A man riding a motorcycle down a road near a forest.", + "A little girl holding a brown stuffed animal.", + "there is a very beautiful view out of this bathroom window", + "A dog stands close to a television looking at it.", + "this bathroom is very big and has lots of room", + "A pair of cats sit in an empty bathtub.", + "A dog looks through ribbed glass in a red door.", + "there is a old black motorcycle inside of a garage", + "A motorcyclist parked near a railing looks out over the water.", + "Two people are looking at a truck while a dog is being walked.", + "Several people are seen sitting around and smoking.", + "this is a man sitting on a green couch", + "A juicer attached to the top of a bike.", + "A cramped bathroom with a sink in the corner.", + "A crowded street filled with British traffic and buses.", + "A small clock is seen on the side of a church.", + "this is an airplane sitting on the runway", + "A girl is holding a large kite on a grassy field.", + "Two people sitting on a motorcycle that parked on the road.", + "A composite image of an office desk, cars and buildings.", + "A man wearing a hat in front of large bottles.", + "a church with a tall tower with a clock built into it", + "A group of young men jump in the air playing a game.", + "some shelves filled with bowls and cups ", + "A dog sits in front of and watches the television.", + "An old rusting toilet with the lid up. ", + "some peeled oranges sitting in a clear blender", + "Two people standing in a kitchen near a stove.", + "A young girl walking barefoot carries a stuffed animal.", + "a man reflected in a rear view mirror of a motorcycle", + "A native American couple on a bike pose for a photo.", + "A person laying on a bathtub with their feet sticking out.", + "A wooden outhouse sitting in the grass near trees.", + "A toilet sitting in a stall on tile.", + "A kite flying in the sky on a cloudy day.", + "Two cats occupy a bathtub, one sitting and one lying down. ", + "A dog looks out through a lined window. ", + "A group of people are standing in the snow on skis.", + "A person is sitting on a motorcycle looking in the mirror.", + "a bathroom wall missing some pink wall tiles ", + "A man walking across a field holding a wand near a dog.", + "a bathroom with towels under a sink and a big mirror above it", + "Adults and children gather near a dock on the beach.", + "A green, red, yellow and blue kite fly's through the sky.", + "a man standing next to a laptop and bottles of beer", + "two cats resting side by side on a bed", + "A large jetliner sitting on top of an airport runway.", + "A sad woman laying on a mattress on a hardwood floor.", + "A red bus parked next to a crowd of people.", + "Looking through the window of showroom at car dealership.", + "Two people standing next to each other in a kitchen.", + "a bathroom view of a sink toilet on a tiled floor", + "a group of people standing in the snow with gear on", + "a church with a clock built into the side of it", + "A daytime view of a messy kitchen corner.", + "a colorful kite flying high on a cloudy day", + "A man riding a red scooter down the street.", + "An old toilet outside against an old painted wall.", + "A kitchen counter top with a white bowl sitting next to another white bowl.", + "A toilet filled with nasty grime sitting up against a bathroom wall.", + "A woman standing between a motor bike and a striped wall over a river.", + "A group of people sitting on top of a bench.", + "A group of men standing around a luggage cart.", + "a person in a bathroom having a reflection in the mirror", + "a kitchen with a microwave, a stove, and cabinets.", + "a street with cars lined with poles and wires.", + " two men and one woman standing in a kitchen", + "A dark and cluttered storage area with wood walls.", + "Three people sitting on a bench looking at the ocean.", + "people sitting on a bench facing the water.", + "a cat sitting in a sink with its eyes open", + "two cars parked side by side on a show room floor", + "two cats chill in the bathtub one is laying down", + "a dog who looks sad stares outside of the window of a red door ", + "a lady holding a kite and walking in a grassy area", + "An old toilet with a rotten lid next to a rusted pipe.", + "some piled oranges in a glass blender ready to be blended", + "A room with a chair and pictures mounted on the wall. ", + "a sink in a bathroom with a shaver and personal hygeine items on the counter top", + "a street with people and vehicles in the middle of it", + "a room showing a wooden table and a capboard", + "Several cars parked near a desk holding a computer.", + "A parked white car with and open door and a dog inside.", + "A black motorcycle with a sidecar parked on cobblestone.", + "a group of vehicles parked next to a firehydrant", + "A woman standing on grass holding a colorful kite.", + "A white kitchen with a gas stove and microwave.", + "a motor bike carrying very many people on the street", + "A row of urinals with a well-stocked bookshelf in front. ", + "A woman lying on a thin mattress on the floor with her knees up.", + "A plane riding down a runway of an airport.", + "a cake and two spoons on a plate", + "Toilet in a bathroom in an international location with a basket.", + "Group of horses in a field with a pinto in the foreground.", + "A dim lit room consisting of many objects put together. ", + "a man standing in a bathroom looking into a mirror", + "a street view of people walking down the sidewalk ", + "A airplane sitting on a runway at a small airport.", + "A room with a sink and a skeleton foot. ", + "Sun shining through the blinds into a white bathroom.", + "A man in a blue hat on a bike behind a train. ", + "An airplane on the runway of an airport.", + "Two motorcycles going down a city street with woman drivers", + "A black and white still life of a branch with flowers in a vase", + "A clean odd little bathroom with a white porcelain toilet.", + "A paint horse and other breeds in the background grazing in a green field.", + "A kitchen area that has items on the counter tops. ", + "a tall red bus is by the curb in a city", + "A restroom with a toilet and a mirror. ", + "A tiled floor bathroom with a red and black shower curtain.", + "An older woman pouring tea in the kitchen.", + "A bicycle is placed behind an open door.", + "A man and two dogs are riding a scooter.", + "A busy street with traffic moving in both directions and several two level buses on the street with people around.", + "A bathroom has a toilet and a scale.", + "A bathroom with outdated fixtures and a clothes hamper in the middle of the floor.", + "The tiles are falling off the wall in this old bathroom", + "A dog sits in a white car with the door open.", + "a white toilet is in the corner of a bathroom", + "Top view of a few skinned oranges inside of a blender", + "a sink well cleaned and some drawers and hand wash", + "a bunch of people are standing on a snowy hill", + "View of toilet with a dirty lid and a missing cover to it's tank", + "A street with a few people walking and cars in the road. ", + "a coupe of people are sitting outside on a bench", + "A little girl is carrying a stuffed animal.", + "A man and a woman are riding a motorcycle.", + "a couple of bathroom items sitting on a sink", + "a couple of motorcyclists are driving down the road", + "three people sitting on a motorcycle in a street", + "a couple of vehicles are parked in a lot", + "A messy kitchen with dirty dishes and white cabinets", + "A minimalist room features white appliances and beige walls.", + "A kichen with dirty dishes in the sink.", + "A table with a plate holding several sandwiches, tea cups and condiments. ", + "A bathroom sink that is surrounded by various toiletries.", + "Two motorcycles are parked on the shoulder of a mountainous freeway.", + "An intersection is shown on a cloudy day.", + "Pink bike sits on a guard rail by the river.", + "people on the street with their cars moving", + "This is a state of the art bathroom where the appliances don't look like they should", + "The man who uses this bathroom shaved this morning", + "A bathroom with a white toilet, tub, and tile floor.", + "A man riding a scooter with a dog on it. ", + "a t.v. that is sitting on a shelf with some lights near by", + "a bath room with its door open and light on", + "a bike that is leaning up against a book rack", + "Motorcycles parked in a row in the street. ", + "A group of people are standing together at night.", + "a little white car that has a dog in it", + "A man is standing in a field with a dog and goat.", + "a bunch of different electronics all on one big pile. ", + "Two motorcycles sit on the side of a secluded road.", + "a airplane that is on a runway by some grass", + "A pink bicycle leaning against a green railing next to a canal.", + "Three people on a motor bile that is riding in a street, with one of them wearing a helmet.", + "A man looks at himself in the mirror of a motorcycle.", + "a room filled with white furniture and books on the ground. ", + "A bicycle leaned against the hallway wall in a house", + "two different kinds of lights in a bath room. ", + "a room with wood and ivory furniture inside. ", + "THERE IS A PLATE WITH SWEET DESSERTS ON THE PLATE ", + "a road sign showing stop and a vehicle moving", + "Two lights shine above a messy bathroom toilet.", + "A chair sits against a wall in a wood floored room.", + "a couple of men that are next to some boxes", + "Several people are standing around watching a band perform on stage.", + "A small bathroom has a port hole window.", + "A young girl with a stuffed toy in a park.", + "A black and whit cat sitting in a sink.", + "A person is taking a picture of a bathroom with a toilet in it.", + "Some cakes are on a white plate with spoons.", + "A man looks into the mirror as he styles his hair.", + "a couple of sinks in a bright colored bathroom", + "a man on a motorcycle that is in some grass", + "Some people are next to a pier on the sand.", + "A car is illegally parked near a fire hydrant.", + "A couple of dead, stuffed giraffe on display.", + "A quaint toilet in a room with no door, a chair sitting outside of the area.", + "A green and blue motorcycle parked on the side of a road.", + "A very simple bathroom with beige and cream colored decor.", + "a man in a room with a camera with a toilet", + "A modern bathroom with a toilet and sink area.", + "A man on his bicycle waits for two trains to pass by.", + "a bath room with a trash can next to the tolit. ", + "A purple bicycle is parked on a fence next to a river.", + "A kitchen with a lot of counter space, a sink, stove and refrigerator in it. ", + "A man on a motorcycle is looking in his mirror.", + "A little red headed girl walking with a stuffed puppy.", + "A white towel is at the edge of a white bathtub.", + "A bathroom is shown with a glass counter and cone-shaped sink.", + "A man walking down the street with a cane while others sit on a bench.", + "a motorcycle that is parked in side a buliding", + "A view from a bus shows people on bicycles and another bus in traffic.", + "a bathroom that has a tub and a shower", + "a vase with a flower growing very well", + "a small little toilet that is in a corner", + "a couple of horse that are eating some grass", + "a man that is riding a motorcycle on a road", + "a couple of motorcycles are off the side of the street", + "A tea kettle sits on the burner of stove.", + "a black cat that is sitting in a sink", + "a room tha has a toilet and a sink in it", + "A blender filled with three peeled oranges sitting on a counter.", + "a couple of motorcycles that are next to a road", + "A man is in a yard on a motorcycle.", + "A truck traveling down the street near a fire hydrant.", + "Two small cats are sleeping on white sheets.", + "a group of people with bikes posing for a photo ", + "Toilet with raised lid with tub and chair in old bathroom. ", + "A person riding a four wheel on the beach.", + "a bright light sitting in front of a tv ", + "Clocks are brightly lit on a huge tower.", + "a group of people that are smoking on a bench", + "A bathroom with shower stall, toilet, and bathtub.", + "A man is training a sheepdog for a sheepdog trial.", + "Looking down on a stony surface shows a bowl with an orange in it and what looks like a large piece of red plastic.", + "Two motorcycles ride down a street in a city.", + "A little girl is making a huge mess with a birthday cake. ", + "A very large kitchen area in a building.", + "A yellow bike sits on a wall in the hallway.", + "This is a photo of someones bathroom in their home and there are feet hanging out the side of the tub.", + "a small little bathroom with a toilet in it", + "Asian man and woman sitting and looking at cell phones", + "Someone is juicing an orange on a juicer.", + "A bicycle leaned against an outdoor magazine stand.", + "A black and white photo of a steam of flowers inside a vase.", + "A bathroom with white toliet and sink visible", + "A kitchen with tile back splash and stainless steel appliances.", + "Looking through a door and seeing a toilet and sink.", + "Some guys are standing over an old antique truck and someone is walking a dog nearby. ", + "A large jetliner sitting on top of a tarmac.", + "A baby with a bib eats a cake.", + "A stop sign out in the middle of nowhere ", + "A group of police officer standing in front of a red bus.", + "A woman holding two rainbow slices of cake.", + "A group of Frisbee players are running around a field. ", + "A toilet that has been covered in filth.", + "The clock on the side of the metal building is gold and black. ", + "The motorcyclist has his hands at his side while riding swiftly down the road. ", + "A modern restroom with a weird looking sink, toilet, and shower.", + "Small groups of people, including a person walking a dog, are scattered about an outdoor area, encompassing some streets, that is filled with classic cars. ", + "a bunch of crates on a air plane run way", + "A sky view looking up at a jumbo jet plane.", + "A bathroom showing toilet, sink, and shower ", + " room with a book and a white carpet", + "A scooter with a helmet hanging off it's handlebars.", + "A truck driving on a crowded street past several parked cars.", + "A bunch of people walking around in a street", + "people riding bikes near a beach and others swimming", + "Two kittens curled up in a white sheet that looks soft.", + "A cat laying on the seat of a motorcycle ", + "a piece of orange in a bowl next to a concrete edge ", + "A road lined with rock-face shows a man and a woman, both wearing hats, astride a red, white and blue decorated bike. ", + "A kitchen has white cabinets and stainless steel appliances.", + "A crowd of people walking and riding their bikes.", + "A crowd of people are gathered outdoors on the street.", + "Sheepherders move their sheep across a highway as vehicular traffic passes between their flock.", + "A bathroom with sink, toilet, and tub ", + "A crowd of people at an outdoor concert.", + "A woman sitting on a bench with cars behind her.", + "A cat is alseep on a motorcycle seat.", + "A kite flying in a partly cloudy sky ", + "white toilet and sink with mirror on white wall", + "A small kitchen is shown with a stove, dishwasher and sink.", + "A toilet that is has been colored black.", + "A small baby bird on a piece of metal.", + "a jet airliner wing that has two jet engines", + "two sinks under a mirror and a light on a wall", + "human hands juicing an orange on a counter top", + "young man looking a different image of himself in the mirror", + "A man riding a bike down a dirt road.", + "a black and white photo with a vase and flower coming out of it", + "Three people are riding down the street on one motorcycle. ", + "Seven people on a biking trip in front of a large city.", + "A wooden table sitting in the middle of a room.", + "Three bikers by a red bus on the street.", + "Three people are standing in the same kitchen area.", + "A view of wing with two jet engines are on a runway while people watch.", + "Man and dog on scooter in city street on sunny day.", + "Two Asian people inside a train looking at their mobile phones.", + "A white toilet tin a bathroom sitting next to a sink.", + "The view of a restroom toilet, and sink area.", + "The motorcycle is tilting as he turns through a cave. ", + "A view of an airplane traveling across the bright sky.", + "The kitchen counter and sink have dishes on them.", + "A tower with a clock is displayed in the evening.", + "A giraffe and fence design are painted onto the wall.", + "A man wearing a helmet posing on top of a motorcycle.", + "A very large black and gold clock mounted to the side of a building.", + "A man riding on the back of a motorcycle down a highway.", + "A huge commercial airplane goes down the landing strip.", + "A bike is chained to the post on the sidewalk", + "The black and white cat is sitting in a bathroom sink.", + "The show girl is posing on a blue motorcycle on display. ", + "Two fake looking giraffes are on display at an exhibit.", + "A young baby is eating and playing with some cake.", + "Two adults and a child ride a motorcycle together.", + "A small eating area with a table and cabinets next to a window.", + "A small wooden toy car has an elephant sitting inside.", + "Cement ledge with orange in bowl and red plastic bag below. ", + "An old classic church is in front a big blue sky.", + "Kitchen area with modern appliances and plenty of cabinets.", + "A man with a baseball cap and glasses seated in front of three large beer bottles.", + "A bathroom with a small sink and toilet. ", + "A bathroom with mirror, toilet, and sink ", + "This is a photo of someones bathroom in their home.", + "A child in a booster chair eating a cake ", + "A woman sitting on top of a purple motorcycle.", + "A bathroom scene with a toilet and a sink.", + "The top of a steeped church building with clocks and small windows. ", + "Two messy toilet stalls with toilets where one lid is raised. ", + "a man wearing a helmet while riding a motorcycle ", + "Man with golf club and a dog and a goat", + "Two turbines on the wing of an airplane", + "An empty bench along a sidewalk in neighborhood.", + "there is a man riding a bike up the road", + "A brick ally way with an old wooden bench with people sitting and smoking on it. ", + "there is a very tall giraffe inside of a building", + "A couple of airplanes sitting on top of a runway.", + "a bathroom with a littlt tub and a clothes hamper by the toilet", + "A large jet flying through a cloudy blue sky.", + "A close up of the face of a clock on a building.", + "A man riding on a motorcycle on the road.", + "there is a very large black and gold clock on a building", + "there is a man riding a motorcycle and not holding the handles", + "there is a person making freshly squeezed orange juice", + "A man riding on the back of a motorcycle on top of a grass field.", + "A christmas wreath is hanging from the door", + "group of bikers posing for a picture ", + "A black bench that is by a sidewalk on a street.", + "A bottle of wine sitting on top of a table next to a glass of wine.", + "A dog sitting in front of an open door looking outside.", + "A toy elephant sits in a toy wooden car.", + "A group of bikers parked in the middle of a street.", + "A wreath with a red bow on it hanging on a white door.", + "A white toilet sitting in the corner of a room.", + "A lush green field with horses standing on top of it.", + "Several cars drive down the road on a cloudy day.", + "A crowd of people riding bikes down a street.", + "there is a woman sitting on a bench in front of cars", + "There is an orange in the cup and a bag in the water.", + "an empty bench sitting on the side of a sidewalk", + "A person sits on a motorcycle while wearing riding gear.", + "A plane is on display near the water.", + "A mans reflection in a side view mirror.", + "Two people wearing hats riding a motorcycle together.", + "there is a dog that is sitting in a car", + "A couple of white bathroom sinks mounted to a wall.", + "A pink bicycle leaning against a fence near a river.", + "there is a man crossing the tracks on a bike", + "tan colored bathroom with white toilet and mirror", + "A closed toilet seat in a bathroom next to a checkered curtain.", + "A bathroom vanity with a large mirror hanging on the wall", + "A colorful kite flying in a cloudy blue sky.", + "A road with two vehicles out in the middle of nowhere with animals climbing up a hill on the left.", + "The numbers and hands on the clock are gold.", + "The man on the motorcycle does not have his hands on the handlebars.", + "A white stove top oven inside of a kitchen.", + "A line of motorcycles parked on the side of a street.", + "A small elephant toy sitting inside of a wooden car.", + "some one in the bath room laying in the bath", + "Men are unloading the trolley of luggage on the runway.", + "A small bird sitting in a metal wheel ", + "Someone is riding a motorcycle through a grassy field. ", + "a bunch of people in a kitchen getting food ready", + "A billboard posed by the side of a street in a rural town.", + "A picture of a man sitting on a motorcycle on a dirt road.", + "A woman juicing oranges on top of a manual juicer.", + "A small cute cat sitting in the bathroom sink.", + "Several people standing next to each other that are snow skiing.", + "some cut up fruit is sitting in a blender", + "A small and plain white bathroom with a toilet and a tub.", + "A man in riding gear, riding a red motorcycle down a road.", + "A bunch of airplanes are parked on the runway. ", + "this plane has two large fans on its wings", + "This is a photo of a bathroom in someones home.", + "This is a large statue in someones living room.", + "An old propeller airplane is displayed near the water.", + "A man and a woman using their cellphones simultaneously.", + "there are many people walking along this street", + "A kitchen showing marble tile and wood cabinets.", + "A line of motorcycles are all parked next to each other.", + "A man, woman, and child preparing food in a kitchen.", + "A black and white photo of a flowing growing out of a vase.", + "A passenger jet being serviced on a runway in an airport.", + "Three people are preparing a meal in a small kitchen.", + "A pair of planes parked in a small rural airfield.", + "A bathroom with a stand alone shower and a peep window.", + "Several vehicles with pieces of luggage on them with planes off to the side.", + "a black motorcycle is parked by the side of the road", + "A small bathroom with a tub, toilet, sink, and a laundry basket are shown.", + "A bus stopped on the side of the road while people board it.", + "A bunch of people posing with some bikes.", + "a jet engine on the wing of a plane", + "A bunch of bicycles parked on the street with items sitting around them ", + "A dog standing in front of a doorway.", + "Two small planes sitting near each other on a run way.", + "there is a bus that has a bike attached to the front", + "A bird that is sitting in the rim of a tire.", + "The black motorcycle is parked on the sidewalk.", + "A corner of a rest room with a big shower.", + "a dog with a plate of food on the ground", + "there is a very large plane that is stopped at the airport ", + "Bicycles with back packs parked in a public place.", + "A white walled bathroom features beige appliances and furniture.", + "Several bicycles sit parked nest to each other.", + "Some big commercial planes all parked by each other.", + "a woman holding a plate of cake in her hand", + "yellow and red motorcycle with a man riding on it next to grass", + "A motorcycle stands in front of three people on a sidewalk.", + "classic cars on a city street with people and a dog", + "People getting on a bus in the city", + "A large commercial airliner silhoetted in the sun.", + "Residential bathroom with modern design and tile floor.", + "a bus with a view of a lot of traffic and the back of another bus with a billboard on the back end", + "A young man riding through the air on top of a skateboard.", + "A toy elephant is sitting inside a wooden car toy.", + "A motorized bicycle covered with greens and beans.", + "A man sitting at a table in front of bowls of spices.", + "there is a bathroom that has a lot of things on the floor", + "A passenger jet aircraft flying in the sky.", + "An eye level counter-view shows blue tile, a faucet, dish scrubbers, bowls, a squirt bottle and similar kitchen items. ", + "A TV sitting on top of a wooden stand.", + "A person sitting on a motorcycle in the grass.", + "A white toilet in a generic public bathroom stall.", + "a couple of people in uniforms are sitting together", + "A group of giraffe standing around each other.", + "Street merchant with bowls of grains and other products. ", + "A man driving a luggage cart sitting on top of a runway.", + "Residential bathroom with commode and shower and plain white walls.", + "Ornate archway inset with matching fireplace in room.", + "there is a red bus that has a mans face on it", + "a wooden skate with a toy elephant inside of it ", + "a bunch of people on skiing on a hill" + ] +} \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/hpsv2_score.py b/MindIE/MultiModal/StableDiffusion-XL/hpsv2_score.py new file mode 100644 index 0000000000000000000000000000000000000000..04e9bd8d8f82ece84c642520b001b62901286eda --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/hpsv2_score.py @@ -0,0 +1,123 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from typing import Union +import json + +from clint.textui import progress +import hpsv2 +from hpsv2.utils import root_path, hps_version_map +from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer +import huggingface_hub +from PIL import Image +import requests +import torch + + +def initialize_model(pretrained_path, device): + model, _, preprocess_val = create_model_and_transforms( + "ViT-H-14", pretrained=pretrained_path, precision='amp', + device=device, + jit=False, + force_quick_gelu=False, + force_custom_text=False, + force_patch_dropout=False, + force_image_size=None, + pretrained_image=False, + image_mean=None, + image_std=None, + light_augmentation=True, + aug_cfg={}, + output_dict=True, + with_score_predictor=False, + with_region_predictor=False + ) + return model, preprocess_val + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--image_info", + type=str, + default="./image_info.json", + help="Image_info.json file.", + ) + parser.add_argument( + "--HPSv2_checkpoint", + type=str, + default="./HPS_v2_compressed.pt", + help="HPS_v2 model weights", + ) + parser.add_argument( + "--clip_checkpoint", + type=str, + default="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", + help="open clip model weights", + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model, preprocess_val = initialize_model(args.clip_checkpoint, device) + + checkpoint = torch.load(args.HPSv2_checkpoint, map_location=device) + model.load_state_dict(checkpoint['state_dict']) + tokenizer = get_tokenizer('ViT-H-14') + model = model.to(device) + model.eval() + + with os.fdopen(os.open(args.image_info, os.O_RDONLY), "r") as f: + image_info = json.load(f) + + result = [] + for i, info in enumerate(image_info): + image_file = info['images'][0] + prompt = info['prompt'] + + # Load your image and prompt + with torch.no_grad(): + # Process the image + if isinstance(image_file, str): + image = preprocess_val(Image.open(image_file)) + elif isinstance(image_file, Image.Image): + image = preprocess_val(image_file) + else: + raise TypeError('The type of parameter img_path is illegal.') + image = image.unsqueeze(0).to(device=device, non_blocking=True) + # Process the prompt + text = tokenizer([prompt]).to(device=device, non_blocking=True) + # Calculate the HPS + with torch.cuda.amp.autocast(): + outputs = model(image, text) + image_features = outputs["image_features"] + text_features = outputs["text_features"] + logits_per_image = image_features @ text_features.T + + hps_score = torch.diagonal(logits_per_image).cpu().numpy() + print(f"image {i} hps_score: ", hps_score[0]) + + result.append(hps_score[0]) + + print('avg HPSv2 score:', sum(result) / len(result)) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/infer_pipe.py b/MindIE/MultiModal/StableDiffusion-XL/infer_pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c03e342551dafa29a1affb3652d30a8d09432b --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/infer_pipe.py @@ -0,0 +1,84 @@ +import os +import argparse +import time +from typing import Union, List +import torch +import torch_npu +import numpy as np +import cv2 + +from stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline +from stable_diffusion_xl.unet.unet_model import UNet2DConditionModel + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--path", + type=str, + default='/stable-diffusion-xl-base-1.0', + help="The path of all model weights, suach as vae, unet, text_encoder, tokenizer, scheduler", + ) + parser.add_argument( + "--device_id", + type=int, + default=0, + help="NPU device id", + ) + parser.add_argument( + "--dtype", + type=torch.dtype, + default=torch.float16 + ) + parser.add_argument( + "--prompts", + type=List[str], + default=["A dog, site on beach."] + ) + parser.add_argument( + "--num_image_per_prompt", + type=int, + default=1 + ) + parser.add_argument( + "--height", + type=int, + default=1024 + ) + parser.add_argument( + "--width", + type=int, + default=1024 + ) + return parser.parse_args() + + +def init_env(device_id: int): + torch.npu.set_device(device_id) + + +def init_pipe(model_path: str, dtype=torch.float16): + unet = UNet2DConditionModel.from_pretrained(os.path.join(model_path, 'unet'), cache_method="agb_cahce") + pipe = StableDiffusionXLPipeline.from_pretrained(model_path, unet=unet) + pipe.to(dtype).to("npu") + return pipe + + +def infer_prompts(pipe, prompts, height=1024, width=1024, num_image=1): + images = pipe( + prompt=prompts, + height=height, + width=width, + num_images_per_prompt=num_image + ).images + for i, img in enumerate(images): + img_bgr = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) + cv2.imwrite(f"{i}.png", img_bgr) + + +if __name__ == "__main__": + args = parse_arguments() + init_env(args.device_id) + pipe = init_pipe(args.path, args.dtype) + prompts = args.prompts + infer_prompts(pipe, prompts, args.height, args.width, args.num_image_per_prompt) diff --git a/MindIE/MultiModal/StableDiffusion-XL/lora.patch b/MindIE/MultiModal/StableDiffusion-XL/lora.patch new file mode 100644 index 0000000000000000000000000000000000000000..fc6bfd26ab89d61473224d4b1406c5425433b114 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/lora.patch @@ -0,0 +1,1016 @@ +--- lora.py 2024-09-03 20:58:26.279828700 +0800 ++++ lora.py 2024-10-07 16:19:52.446325600 +0800 +@@ -36,6 +36,948 @@ + + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name ++KeyOrderList =[ ++ "down_blocks.0.resnets.0.conv1", ++ "down_blocks.0.resnets.0.time_emb_proj", ++ "down_blocks.0.resnets.0.conv2", ++ "down_blocks.0.resnets.1.conv1", ++ "down_blocks.0.resnets.1.time_emb_proj", ++ "down_blocks.0.resnets.1.conv2", ++ "down_blocks.0.downsamplers.0.conv", ++ "down_blocks.1.resnets.0.conv1", ++ "down_blocks.1.resnets.0.time_emb_proj", ++ "down_blocks.1.resnets.0.conv2", ++ "down_blocks.1.resnets.0.conv_shortcut", ++ "down_blocks.1.attentions.0.proj_in", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0", ++ "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj", ++ "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0", ++ "down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj", ++ "down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2", ++ "down_blocks.1.attentions.0.proj_out", ++ "down_blocks.1.resnets.1.conv1", ++ "down_blocks.1.resnets.1.time_emb_proj", ++ "down_blocks.1.resnets.1.conv2", ++ "down_blocks.1.attentions.1.proj_in", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0", ++ "down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj", ++ "down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0", ++ "down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj", ++ "down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2", ++ "down_blocks.1.attentions.1.proj_out", ++ "down_blocks.1.downsamplers.0.conv", ++ "down_blocks.2.resnets.0.conv1", ++ "down_blocks.2.resnets.0.time_emb_proj", ++ "down_blocks.2.resnets.0.conv2", ++ "down_blocks.2.resnets.0.conv_shortcut", ++ "down_blocks.2.attentions.0.proj_in", ++ "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.0.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.1.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.1.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.1.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.2.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.2.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.2.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.3.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.3.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.3.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.3.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.4.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.4.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.4.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.4.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.5.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.5.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.5.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.5.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.6.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.6.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.6.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.6.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.7.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.7.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.7.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.7.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.8.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.8.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.8.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.8.ff.net.2", ++ "down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.9.attn1.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_q", ++ "down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_k", ++ "down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_v", ++ "down_blocks.2.attentions.0.transformer_blocks.9.attn2.to_out.0", ++ "down_blocks.2.attentions.0.transformer_blocks.9.ff.net.0.proj", ++ "down_blocks.2.attentions.0.transformer_blocks.9.ff.net.2", ++ "down_blocks.2.attentions.0.proj_out", ++ "down_blocks.2.resnets.1.conv1", ++ "down_blocks.2.resnets.1.time_emb_proj", ++ "down_blocks.2.resnets.1.conv2", ++ "down_blocks.2.attentions.1.proj_in", ++ "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.0.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.1.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.1.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.1.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.1.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.2.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.2.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.2.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.3.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.3.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.3.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.3.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.4.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.4.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.4.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.4.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.5.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.5.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.5.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.5.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.6.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.6.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.6.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.6.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.7.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.7.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.7.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.7.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.8.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.8.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.8.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.8.ff.net.2", ++ "down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.9.attn1.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_q", ++ "down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_k", ++ "down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_v", ++ "down_blocks.2.attentions.1.transformer_blocks.9.attn2.to_out.0", ++ "down_blocks.2.attentions.1.transformer_blocks.9.ff.net.0.proj", ++ "down_blocks.2.attentions.1.transformer_blocks.9.ff.net.2", ++ "down_blocks.2.attentions.1.proj_out", ++ "mid_block.resnets.0.conv1", ++ "mid_block.resnets.0.time_emb_proj", ++ "mid_block.resnets.0.conv2", ++ "mid_block.attentions.0.proj_in", ++ "mid_block.attentions.0.transformer_blocks.0.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.0.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.0.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.0.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.0.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.0.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.0.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.1.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.1.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.1.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.1.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.1.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.1.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.1.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.1.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.1.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.1.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.2.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.2.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.2.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.2.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.2.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.2.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.2.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.2.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.2.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.2.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.3.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.3.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.3.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.3.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.3.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.3.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.3.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.3.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.3.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.3.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.4.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.4.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.4.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.4.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.4.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.4.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.4.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.4.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.4.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.4.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.5.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.5.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.5.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.5.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.5.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.5.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.5.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.5.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.5.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.5.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.6.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.6.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.6.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.6.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.6.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.6.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.6.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.6.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.6.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.6.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.7.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.7.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.7.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.7.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.7.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.7.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.7.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.7.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.7.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.7.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.8.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.8.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.8.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.8.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.8.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.8.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.8.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.8.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.8.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.8.ff.net.2", ++ "mid_block.attentions.0.transformer_blocks.9.attn1.to_q", ++ "mid_block.attentions.0.transformer_blocks.9.attn1.to_k", ++ "mid_block.attentions.0.transformer_blocks.9.attn1.to_v", ++ "mid_block.attentions.0.transformer_blocks.9.attn1.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.9.attn2.to_q", ++ "mid_block.attentions.0.transformer_blocks.9.attn2.to_k", ++ "mid_block.attentions.0.transformer_blocks.9.attn2.to_v", ++ "mid_block.attentions.0.transformer_blocks.9.attn2.to_out.0", ++ "mid_block.attentions.0.transformer_blocks.9.ff.net.0.proj", ++ "mid_block.attentions.0.transformer_blocks.9.ff.net.2", ++ "mid_block.attentions.0.proj_out", ++ "mid_block.resnets.1.conv1", ++ "mid_block.resnets.1.time_emb_proj", ++ "mid_block.resnets.1.conv2", ++ "up_blocks.0.resnets.0.conv1", ++ "up_blocks.0.resnets.0.time_emb_proj", ++ "up_blocks.0.resnets.0.conv2", ++ "up_blocks.0.resnets.0.conv_shortcut", ++ "up_blocks.0.attentions.0.proj_in", ++ "up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.0.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.1.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.2.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.2.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.2.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.2.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.3.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.3.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.3.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.3.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.4.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.4.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.4.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.4.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.5.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.5.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.5.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.5.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.6.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.6.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.6.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.6.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.7.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.7.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.7.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.7.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.8.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.8.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.8.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.8.ff.net.2", ++ "up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.9.attn1.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_q", ++ "up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_k", ++ "up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_v", ++ "up_blocks.0.attentions.0.transformer_blocks.9.attn2.to_out.0", ++ "up_blocks.0.attentions.0.transformer_blocks.9.ff.net.0.proj", ++ "up_blocks.0.attentions.0.transformer_blocks.9.ff.net.2", ++ "up_blocks.0.attentions.0.proj_out", ++ "up_blocks.0.resnets.1.conv1", ++ "up_blocks.0.resnets.1.time_emb_proj", ++ "up_blocks.0.resnets.1.conv2", ++ "up_blocks.0.resnets.1.conv_shortcut", ++ "up_blocks.0.attentions.1.proj_in", ++ "up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.0.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.1.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.2.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.2.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.2.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.2.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.3.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.3.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.3.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.3.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.4.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.4.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.4.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.4.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.5.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.5.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.5.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.5.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.6.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.6.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.6.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.6.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.7.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.7.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.7.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.7.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.8.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.8.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.8.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.8.ff.net.2", ++ "up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.9.attn1.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_q", ++ "up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_k", ++ "up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_v", ++ "up_blocks.0.attentions.1.transformer_blocks.9.attn2.to_out.0", ++ "up_blocks.0.attentions.1.transformer_blocks.9.ff.net.0.proj", ++ "up_blocks.0.attentions.1.transformer_blocks.9.ff.net.2", ++ "up_blocks.0.attentions.1.proj_out", ++ "up_blocks.0.resnets.2.conv1", ++ "up_blocks.0.resnets.2.time_emb_proj", ++ "up_blocks.0.resnets.2.conv2", ++ "up_blocks.0.resnets.2.conv_shortcut", ++ "up_blocks.0.attentions.2.proj_in", ++ "up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.0.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.1.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.2.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.2.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.2.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.3.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.3.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.3.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.3.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.4.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.4.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.4.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.4.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.5.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.5.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.5.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.5.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.6.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.6.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.6.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.6.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.7.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.7.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.7.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.7.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.8.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.8.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.8.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.8.ff.net.2", ++ "up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.9.attn1.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_q", ++ "up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_k", ++ "up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_v", ++ "up_blocks.0.attentions.2.transformer_blocks.9.attn2.to_out.0", ++ "up_blocks.0.attentions.2.transformer_blocks.9.ff.net.0.proj", ++ "up_blocks.0.attentions.2.transformer_blocks.9.ff.net.2", ++ "up_blocks.0.attentions.2.proj_out", ++ "up_blocks.0.upsamplers.0.conv", ++ "up_blocks.1.resnets.0.conv1", ++ "up_blocks.1.resnets.0.time_emb_proj", ++ "up_blocks.1.resnets.0.conv2", ++ "up_blocks.1.resnets.0.conv_shortcut", ++ "up_blocks.1.attentions.0.proj_in", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2", ++ "up_blocks.1.attentions.0.proj_out", ++ "up_blocks.1.resnets.1.conv1", ++ "up_blocks.1.resnets.1.time_emb_proj", ++ "up_blocks.1.resnets.1.conv2", ++ "up_blocks.1.resnets.1.conv_shortcut", ++ "up_blocks.1.attentions.1.proj_in", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2", ++ "up_blocks.1.attentions.1.proj_out", ++ "up_blocks.1.resnets.2.conv1", ++ "up_blocks.1.resnets.2.time_emb_proj", ++ "up_blocks.1.resnets.2.conv2", ++ "up_blocks.1.resnets.2.conv_shortcut", ++ "up_blocks.1.attentions.2.proj_in", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2", ++ "up_blocks.1.attentions.2.proj_out", ++ "up_blocks.1.upsamplers.0.conv", ++ "up_blocks.2.resnets.0.conv1", ++ "up_blocks.2.resnets.0.time_emb_proj", ++ "up_blocks.2.resnets.0.conv2", ++ "up_blocks.2.resnets.0.conv_shortcut", ++ "up_blocks.2.resnets.1.conv1", ++ "up_blocks.2.resnets.1.time_emb_proj", ++ "up_blocks.2.resnets.1.conv2", ++ "up_blocks.2.resnets.1.conv_shortcut", ++ "up_blocks.2.resnets.2.conv1", ++ "up_blocks.2.resnets.2.time_emb_proj", ++ "up_blocks.2.resnets.2.conv2", ++ "up_blocks.2.resnets.2.conv_shortcut" ++] ++ ++UnetSkip_key = [ ++ "down_blocks.0.resnets.0.conv1", ++ "down_blocks.0.resnets.0.time_emb_proj", ++ "down_blocks.0.resnets.0.conv2", ++ "down_blocks.0.resnets.1.conv1", ++ "down_blocks.0.resnets.1.time_emb_proj", ++ "down_blocks.0.resnets.1.conv2", ++ "down_blocks.0.downsamplers.0.conv", ++ "down_blocks.1.resnets.0.conv1", ++ "down_blocks.1.resnets.0.time_emb_proj", ++ "down_blocks.1.resnets.0.conv2", ++ "down_blocks.1.resnets.0.conv_shortcut", ++ "down_blocks.1.attentions.0.proj_in", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v", ++ "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0", ++ "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj", ++ "down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v", ++ "down_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0", ++ "down_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj", ++ "down_blocks.1.attentions.0.transformer_blocks.1.ff.net.2", ++ "down_blocks.1.attentions.0.proj_out", ++ "down_blocks.1.resnets.1.conv1", ++ "down_blocks.1.resnets.1.time_emb_proj", ++ "down_blocks.1.resnets.1.conv2", ++ "down_blocks.1.attentions.1.proj_in", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v", ++ "down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0", ++ "down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj", ++ "down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v", ++ "down_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0", ++ "down_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj", ++ "down_blocks.1.attentions.1.transformer_blocks.1.ff.net.2", ++ "down_blocks.1.attentions.1.proj_out", ++ "up_blocks.1.resnets.0.conv1", ++ "up_blocks.1.resnets.0.time_emb_proj", ++ "up_blocks.1.resnets.0.conv2", ++ "up_blocks.1.resnets.0.conv_shortcut", ++ "up_blocks.1.attentions.0.proj_in", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v", ++ "up_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.1.attentions.0.transformer_blocks.0.ff.net.2", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_q", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_k", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_v", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_q", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_k", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_v", ++ "up_blocks.1.attentions.0.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.1.attentions.0.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.1.attentions.0.transformer_blocks.1.ff.net.2", ++ "up_blocks.1.attentions.0.proj_out", ++ "up_blocks.1.resnets.1.conv1", ++ "up_blocks.1.resnets.1.time_emb_proj", ++ "up_blocks.1.resnets.1.conv2", ++ "up_blocks.1.resnets.1.conv_shortcut", ++ "up_blocks.1.attentions.1.proj_in", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v", ++ "up_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.1.attentions.1.transformer_blocks.0.ff.net.2", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_q", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_k", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_v", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_q", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_k", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_v", ++ "up_blocks.1.attentions.1.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.1.attentions.1.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.1.attentions.1.transformer_blocks.1.ff.net.2", ++ "up_blocks.1.attentions.1.proj_out", ++ "up_blocks.1.resnets.2.conv1", ++ "up_blocks.1.resnets.2.time_emb_proj", ++ "up_blocks.1.resnets.2.conv2", ++ "up_blocks.1.resnets.2.conv_shortcut", ++ "up_blocks.1.attentions.2.proj_in", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_q", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_k", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_v", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn1.to_out.0", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_q", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_k", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_v", ++ "up_blocks.1.attentions.2.transformer_blocks.0.attn2.to_out.0", ++ "up_blocks.1.attentions.2.transformer_blocks.0.ff.net.0.proj", ++ "up_blocks.1.attentions.2.transformer_blocks.0.ff.net.2", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_q", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_k", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_v", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn1.to_out.0", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_q", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_k", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_v", ++ "up_blocks.1.attentions.2.transformer_blocks.1.attn2.to_out.0", ++ "up_blocks.1.attentions.2.transformer_blocks.1.ff.net.0.proj", ++ "up_blocks.1.attentions.2.transformer_blocks.1.ff.net.2", ++ "up_blocks.1.attentions.2.proj_out", ++ "up_blocks.1.upsamplers.0.conv", ++ "up_blocks.2.resnets.0.conv1", ++ "up_blocks.2.resnets.0.time_emb_proj", ++ "up_blocks.2.resnets.0.conv2", ++ "up_blocks.2.resnets.0.conv_shortcut", ++ "up_blocks.2.resnets.1.conv1", ++ "up_blocks.2.resnets.1.time_emb_proj", ++ "up_blocks.2.resnets.1.conv2", ++ "up_blocks.2.resnets.1.conv_shortcut", ++ "up_blocks.2.resnets.2.conv1", ++ "up_blocks.2.resnets.2.time_emb_proj", ++ "up_blocks.2.resnets.2.conv2", ++ "up_blocks.2.resnets.2.conv_shortcut" ++] + + + def text_encoder_attn_modules(text_encoder): +@@ -295,6 +1237,7 @@ + def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer ++ self.status = False + + def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): + self.lora_layer = lora_layer +@@ -352,17 +1295,22 @@ + self.w_down = None + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: +- if self.lora_layer is None: +- # make sure to the functional Conv2D function as otherwise torch.compile's graph will break +- # see: https://github.com/huggingface/diffusers/pull/4315 ++ if self.status: + return F.conv2d( +- hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ++ hidden_states, self.mindie_buffer, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + else: +- original_outputs = F.conv2d( +- hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups +- ) +- return original_outputs + (scale * self.lora_layer(hidden_states)) ++ if self.lora_layer is None: ++ # make sure to the functional Conv2D function as otherwise torch.compile's graph will break ++ # see: https://github.com/huggingface/diffusers/pull/4315 ++ return F.conv2d( ++ hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ++ ) ++ else: ++ original_outputs = F.conv2d( ++ hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ++ ) ++ return original_outputs + (scale * self.lora_layer(hidden_states)) + + + class LoRACompatibleLinear(nn.Linear): +@@ -373,6 +1321,7 @@ + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer ++ self.status = False + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer +@@ -426,9 +1375,13 @@ + self.w_down = None + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: +- if self.lora_layer is None: +- out = super().forward(hidden_states) ++ if self.status: ++ out = torch.nn.functional.linear(hidden_states, self.mindie_buffer, self.bias) + return out + else: +- out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) +- return out ++ if self.lora_layer is None: ++ out = super().forward(hidden_states) ++ return out ++ else: ++ out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) ++ return out diff --git a/MindIE/MultiModal/StableDiffusion-XL/lorahot_score.py b/MindIE/MultiModal/StableDiffusion-XL/lorahot_score.py new file mode 100644 index 0000000000000000000000000000000000000000..86cfb80c24e6e1760073e8e66b354c4c191999e3 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/lorahot_score.py @@ -0,0 +1,130 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import time +import argparse +import logging + +import open_clip +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F + +printlog = logging.getLogger() +printlog.addHandler(logging.StreamHandler()) +printlog.setLevel(logging.INFO) + + +# single image +def cos_similarity(model_clip, preprocess, image_file1, image_file2, device): + img1 = preprocess(Image.open(image_file1[0])).unsqueeze(0).to(device) + img2 = preprocess(Image.open(image_file2[0])).unsqueeze(0).to(device) + + img_ft1 = model_clip.encode_image(img1).float() + img_ft2 = model_clip.encode_image(img2).float() + + score = F.cosine_similarity(img_ft1, img_ft2).squeeze() + + return score.cpu() + + +def main(): + args = parse_arguments() + + if args.device is None: + device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + else: + device = torch.device(args.device) + + t_b = time.time() + printlog.info("Load clip model...") + model_clip, _, preprocess = open_clip.create_model_and_transforms( + args.model_name, pretrained=args.model_weights_path, device=device) + model_clip.eval() + printlog.info(f">done. elapsed time: {(time.time() - t_b):.3f} s") + + with os.fdopen(os.open(args.image_info_wo_lorahot, os.O_RDONLY), "r") as f: + image_info_wo_lorahot = json.load(f) + + with os.fdopen(os.open(args.image_info_lorahot, os.O_RDONLY), "r") as f: + image_info_lorahot = json.load(f) + + t_b = time.time() + printlog.info("Calc cos similarity score...") + all_scores = [] + info_length = len(image_info_wo_lorahot) + for i in range(info_length): + + image_file1 = image_info_wo_lorahot[i]['images'] + image_file2 = image_info_lorahot[i]['images'] + prompt = image_info_wo_lorahot[i]['prompt'] + printlog.info(f"[{i + 1}/{len(image_info_wo_lorahot)}] {prompt}") + + image_scores = cos_similarity(model_clip, + preprocess, + image_file1, + image_file2, + device) + + printlog.info(f"cos similarity scores: {image_scores}") + + all_scores.append(image_scores.item()) + printlog.info(f">done. elapsed time: {(time.time() - t_b):.3f} s") + + average_score = np.average(all_scores) + printlog.info("====================================") + printlog.info(f"average score: {average_score:.3f}") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="device for torch.", + ) + parser.add_argument( + "--image_info_wo_lorahot", + type=str, + default="./image_info_wo_lorahot.json", + help="Image_info_wo_lorahot.json file.", + ) + parser.add_argument( + "--image_info_lorahot", + type=str, + default="./image_info_lorahot.json", + help="Image_info_lorahot.json file.", + ) + parser.add_argument( + "--model_name", + type=str, + default="ViT-H-14", + help="open clip model name", + ) + parser.add_argument( + "--model_weights_path", + type=str, + default="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", + help="open clip model weights", + ) + return parser.parse_args() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/prompts.txt b/MindIE/MultiModal/StableDiffusion-XL/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..a375a0bb63931d0d5da6c6d91df1e14f870f47d0 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/prompts.txt @@ -0,0 +1,16 @@ +Beautiful illustration of The ocean. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Islands in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Seaports in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The waves. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Grassland. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Wheat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Hut Tong. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The boat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Pine trees. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Bamboo. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The temple. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Cloud in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Sun in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Spring. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Lotus. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Snow piles. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/quant/CMakeLists.txt b/MindIE/MultiModal/StableDiffusion-XL/quant/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a79cb7131ad18cab1dec02453de917944ee54ca6 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/quant/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +project(MIE_LLM_quant_ops) + +find_package(Torch REQUIRED) + +set(TORCH_ROOT "/usr/local/libtorch2.0.0" CACHE STRING "") +MESSAGE("Torch root : ${TORCH_ROOT}") + +add_library(quant_ops SHARED quant_ops.cpp) + +target_compile_features(quant_ops PRIVATE cxx_std_17) + +target_link_libraries(quant_ops PUBLIC + c10 + torch + torch_cpu) \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/quant/build.sh b/MindIE/MultiModal/StableDiffusion-XL/quant/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..d743e97ccad0ba59cc49df6067ffaf7367bfa012 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/quant/build.sh @@ -0,0 +1,8 @@ +rm -r build +mkdir build +cd build + +TorchPath="torch/path/you/should/set" + +cmake .. -DTORCH_ROOT=${TorchPath} -DCMAKE_PREFIX_PATH=${TorchPath} +make -j \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/quant/quant_ops.cpp b/MindIE/MultiModal/StableDiffusion-XL/quant/quant_ops.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b223d32fc6a394dfb7cf714e565c701b39986920 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/quant/quant_ops.cpp @@ -0,0 +1,125 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + + +at::Tensor QuantizeTensorPlaceholder(at::Tensor x, at::Tensor scale, at::Tensor offset) +{ + auto quant_x = x * scale; + quant_x = quant_x + torch::broadcast_tensors({offset, quant_x})[0]; + quant_x = quant_x.round(); + quant_x = quant_x.clamp(-128, 127); + return quant_x; +} + +at::Tensor QuantizeFloatPlaceholder(at::Tensor x, double scale, double offset) +{ + auto quant_x = x * scale + offset; + quant_x = quant_x.round(); + quant_x = quant_x.clamp(-128, 127); + return quant_x; +} + +at::Tensor DequantizeTensorPlaceholder(at::Tensor x, at::Tensor scale) +{ + auto fp_x = x; + auto round_x = fp_x.round(); + auto dequant_x = round_x.clamp(-128, 127); + return dequant_x; +} + +at::Tensor DequantizeFloatPlaceholder(at::Tensor x, double scale) +{ + auto fp_x = x; + auto round_x = fp_x.round(); + auto dequant_x = round_x.clamp(-128, 127); + return dequant_x; +} + +at::Tensor QuantConvolutionPlaceholder(at::Tensor input, at::Tensor weight, c10::optional bias, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, + int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) +{ + auto fp_weight = weight.to(torch::kFloat); + if (bias.has_value()) { + auto fp_bias = bias.value().to(torch::kFloat); + auto output = torch::_convolution(input, fp_weight, fp_bias, stride, padding, dilation, transposed, output_padding, + groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + return output; + }else{ + auto output = torch::_convolution(input, fp_weight, bias, stride, padding, dilation, transposed, output_padding, groups, + benchmark, deterministic, cudnn_enabled, allow_tf32); + return output; + } + +} + +at::Tensor QuantLinearPlaceholder(at::Tensor input, at::Tensor weight, c10::optional bias) +{ + auto fp_weight = weight.to(torch::kFloat); + if (bias.has_value()) { + auto fp_bias = bias.value().to(torch::kFloat); + auto output = torch::linear(input, fp_weight, fp_bias); + return output; + }else{ + auto output = torch::linear(input, fp_weight,bias); + return output; + } +} + +// register torchscript quant ops schema to Pytorch +TORCH_LIBRARY_FRAGMENT(MindIE, m) { +m.def(TORCH_SELECTIVE_SCHEMA("MindIE::quantize.tensor(Tensor x, Tensor scale, Tensor offset) -> Tensor")); +m.def(TORCH_SELECTIVE_SCHEMA("MindIE::quantize.float(Tensor x,float scale, float offset) -> Tensor")); +m.def(TORCH_SELECTIVE_SCHEMA("MindIE::dequantize.tensor(Tensor x, Tensor scale) -> Tensor")); +m.def(TORCH_SELECTIVE_SCHEMA("MindIE::dequantize.float(Tensor x, float scale) -> Tensor")); +m.def(TORCH_SELECTIVE_SCHEMA("MindIE::quant_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding,\n" + " int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic,\n" + " bool cudnn_enabled, bool allow_tf32) -> Tensor")); +m.def(TORCH_SELECTIVE_SCHEMA("MindIE::quant_linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)")); +} + +// register CPU kernel function for all_reduce +TORCH_LIBRARY_IMPL(MindIE, CPU, m) { +m.impl( + TORCH_SELECTIVE_NAME("MindIE::quantize.tensor"), + TORCH_FN(QuantizeTensorPlaceholder) +); +m.impl( + TORCH_SELECTIVE_NAME("MindIE::quantize.float"), + TORCH_FN(QuantizeFloatPlaceholder) +); +m.impl( + TORCH_SELECTIVE_NAME("MindIE::dequantize.tensor"), + TORCH_FN(DequantizeTensorPlaceholder) +); +m.impl( + TORCH_SELECTIVE_NAME("MindIE::dequantize.float"), + TORCH_FN(DequantizeFloatPlaceholder) +); +m.impl( + TORCH_SELECTIVE_NAME("MindIE::quant_convolution"), + TORCH_FN(QuantConvolutionPlaceholder) +); +m.impl( + TORCH_SELECTIVE_NAME("MindIE::quant_linear"), + TORCH_FN(QuantLinearPlaceholder) +); +} diff --git a/MindIE/MultiModal/StableDiffusion-XL/quant_utils.py b/MindIE/MultiModal/StableDiffusion-XL/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..840b48fa6c8de5747440037ebed8c6bb8203a25c --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/quant_utils.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +import numpy as np + + +class Quantize(nn.Module): + def __init__(self, scale, offset): + super(Quantize, self).__init__() + self.offset = offset + self.scale = 1 / scale + + def forward(self, x): + x = torch.ops.MindIE.quantize(x, self.scale.item(), self.offset.item()) + return x + + +class DeQuantize(nn.Module): + def __init__(self, scale): + super(DeQuantize, self).__init__() + self.scale = scale + + def forward(self, x): + x = torch.ops.MindIE.dequantize(x, self.scale) + return x + + +class QuantConvModule(nn.Module): + def __init__(self, layer, input_scale, input_offset, quant_weight, weight_scale, deq_scale): + super(QuantConvModule, self).__init__() + self.input_scale = input_scale + self.input_offset = input_offset + self.quant = Quantize(scale=input_scale, offset=input_offset) + self.weight = torch.nn.Parameter(quant_weight, requires_grad=False) + self.set_bias = False + self.layer = layer + if self.layer.bias is not None: + self.bias = torch.nn.Parameter(torch.round(self.layer.bias / torch.squeeze(input_scale) / torch.squeeze( + weight_scale)).to(torch.int32), requires_grad=False) + else: + self.bias = None + self.de_quant = DeQuantize(deq_scale) + + def forward(self, x, scale: float = 1.0): + x = self.quant(x) + x = torch.ops.MindIE.quant_convolution(x, self.weight, self.bias, self.layer.stride, self.layer.padding, + self.layer.dilation, self.layer.transposed, self.layer.output_padding, + self.layer.groups, False, False, + False, False) + x = self.de_quant(x) + return x + + +class QuantLinearModule(nn.Module): + def __init__(self, layer, input_scale, input_offset, quant_weight, weight_scale, deq_scale): + super(QuantLinearModule, self).__init__() + self.input_scale = input_scale + self.input_offset = input_offset + self.quant = Quantize(scale=input_scale, offset=input_offset) + self.layer = layer + self.weight = torch.nn.Parameter(quant_weight, requires_grad=False) + self.set_bias = False + self.bias = layer.bias + self.de_quant = DeQuantize(deq_scale) + + def forward(self, x, scale: float = 1.0): + x = self.quant(x) + x = torch.ops.MindIE.quant_linear(x, self.weight, self.bias) + x = self.de_quant(x) + return x + + +def modify_model(model, input_scale_dict, input_offset_dict, weight_scale_dict, weight_offset_dict, quant_weight_dict): + for name, layer in model.named_modules(): + if name in input_scale_dict: + if quant_weight_dict[name] is None: + continue + input_scale = input_scale_dict[name] if input_scale_dict[name] is not None else torch.Tensor([1.]) + input_offset = input_offset_dict[name] if input_offset_dict[name] is not None else torch.Tensor([0.]) + quant_weight = quant_weight_dict[name].to(torch.int8) + weight_scale = weight_scale_dict[name] + + x_scale = np.array(input_scale) * np.array(weight_scale) + packed_weight_np_data = x_scale.squeeze() + float32_scale_deq = np.array(packed_weight_np_data, np.float32) + uint32_scale_deq = np.frombuffer(float32_scale_deq, np.uint32) + uint64_result = np.zeros(float32_scale_deq.shape, np.int64) + # per-tensor + if len(uint64_result.shape) == 0: + uint64_result = np.expand_dims(uint64_result, axis=0) + uint64_result |= np.int64(uint32_scale_deq) + + deq_scale = torch.Tensor(uint64_result).to(torch.int64) + if isinstance(layer, nn.Conv2d): + quant_module = QuantConvModule(layer, input_scale, input_offset, quant_weight, weight_scale, deq_scale) + elif isinstance(layer, nn.Linear): + + correction = quant_weight.to(torch.float32).sum(dim=1)*input_offset.to(torch.float32) + ori_bias = layer.bias if layer.bias is not None else 0 + int_bias = torch.nn.Parameter(torch.round(ori_bias/torch.Tensor(x_scale)-correction).to(torch.int32), + requires_grad=False) + layer.bias = int_bias + quant_module = QuantLinearModule(layer, input_scale, input_offset, quant_weight, weight_scale, + deq_scale) + else: + continue + + submodules, layer_name = name.split('.')[:-1], name.split('.')[-1] + setattr(model.get_submodule('.'.join(submodules)), layer_name, quant_module) + print(f'converter layer {name} from {layer.__class__.__name__} to {quant_module.__class__.__name__} succ') + return model diff --git a/MindIE/MultiModal/StableDiffusion-XL/requirements.txt b/MindIE/MultiModal/StableDiffusion-XL/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ac639d196ee01ac48fa05b758cf5e3a49939df4d --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/requirements.txt @@ -0,0 +1,6 @@ +setuptools==57.5.0 +torch==2.1.0 +diffusers==0.26.3 +transformers==4.46.0 +open_clip_torch==2.20.0 +onnx==1.15.0 \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_attention_patch.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..59bfa9e822472e8740c8ba5d1666afa4b8abd2ec --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_attention_patch.py @@ -0,0 +1,28 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.26.3', f"Expected diffusers version 0.26.3, but got {diffusers_version}" + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/layers/__init__.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/layers/attention.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d468563964e2ced42056e5b3cbc409130bc957b5 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/layers/attention.py @@ -0,0 +1,420 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional +import inspect + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu + + +def get_npu_device(): + # 默认获取当前设备信息 + device_name = torch_npu.npu.get_device_name() + if "3" in device_name: + return "DUO" + elif "9" in device_name: + return "A2" + else: + return "" + + +soc = get_npu_device() + + +class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + """ + + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, eps=eps) + self.norm_k = nn.LayerNorm(dim_head, eps=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + if self.context_pre_only is not None: + self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) + + # set attention processor + if processor is None: + processor = ( + AttnProcessor2_0() + ) + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = attn.inner_dim // attn.heads + + if attention_mask: + attention_mask = ~attention_mask.to(torch.bool) + q_seqlen = query.shape[1] + kv_seqlen = key.shape[1] + + if q_seqlen == kv_seqlen: + # self attention + query = query.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.reshape(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + hidden_states = torch_npu.npu_prompt_flash_attention( + query, key, value, atten_mask=attention_mask, + input_layout='BNSD', scale_value=attn.scale, + pre_tokens=65535, next_tokens=65535, num_heads=attn.heads) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.inner_dim) + else: + # cross attention + query = query.reshape(batch_size, -1, attn.heads, head_dim) + key = key.reshape(batch_size, -1, attn.heads, head_dim) + value = value.reshape(batch_size, -1, attn.heads, head_dim) + if soc == "A2": + hidden_states = torch_npu.npu_fusion_attention( + query, key, value, atten_mask=attention_mask, + input_layout='BSND', scale=attn.scale, + pre_tockens=65535, next_tockens=65535, head_num=attn.heads)[0] + else: + hidden_states = torch_npu.npu_prompt_flash_attention( + query, key, value, atten_mask=attention_mask, + input_layout='BSND', scale_value=attn.scale, + pre_tokens=65535, next_tokens=65535, num_heads=attn.heads) + hidden_states = hidden_states.reshape(batch_size, -1, attn.inner_dim) + + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/pipeline_stable_diffusion_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..b186901a59abce00ba28ce5898be4c6ad777d056 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -0,0 +1,1299 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + step=i, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/__init__.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/transformer_2d_model.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/transformer_2d_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccaa5ce322a5fb51ec984304217d7c73ef65bdf --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/transformer_2d_model.py @@ -0,0 +1,530 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import inspect +from typing import Optional, Dict, Any + +import torch +import torch.nn as nn + +from diffusers.models.modeling_utils import LegacyModelMixin +from diffusers.configuration_utils import LegacyConfigMixin, register_to_config +from .transformer_blocks import BasicTransformerBlock + + +class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + interpolation_scale: float = None, + use_additional_conditions: Optional[bool] = None, + ): + super().__init__() + # Validate inputs. + if patch_size is not None: + if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # Set some common variables used across the board. + self.use_linear_projection = use_linear_projection + self.interpolation_scale = interpolation_scale + self.caption_channels = caption_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = self.num_attention_heads * self.attention_head_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + + if use_additional_conditions is None: + if norm_type == "ada_norm_single" and sample_size == 128: + use_additional_conditions = True + else: + use_additional_conditions = False + self.use_additional_conditions = use_additional_conditions + + # 2. Initialize the right blocks. + # These functions follow a common structure: + # a. Initialize the input blocks. b. Initialize the transformer blocks. + # c. Initialize the output blocks and other projection blocks when necessary. + if self.is_input_continuous: + self._init_continuous_input(norm_type=norm_type) + elif self.is_input_vectorized: + self._init_vectorized_inputs(norm_type=norm_type) + elif self.is_input_patches: + self._init_patched_inputs(norm_type=norm_type) + + def _init_continuous_input(self, norm_type): + self.norm = torch.nn.GroupNorm( + num_groups=self.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True + ) + if self.use_linear_projection: + self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim) + else: + self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.num_attention_heads, + self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + ) + for _ in range(self.num_layers) + ] + ) + + if self.use_linear_projection: + self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) + else: + self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) + + def _init_vectorized_inputs(self, norm_type): + self.height = self.sample_size + self.width = self.sample_size + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=self.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.num_attention_heads, + self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + ) + for _ in range(self.num_layers) + ] + ) + + self.norm_out = nn.LayerNorm(self.inner_dim) + self.out = nn.Linear(self.inner_dim, self.num_vector_embeds - 1) + + def _init_patched_inputs(self, norm_type): + self.height = self.sample_size + self.width = self.sample_size + + self.patch_size = self.patch_size + interpolation_scale = ( + self.interpolation_scale + if self.interpolation_scale is not None + else max(self.sample_size // 64, 1) + ) + self.pos_embed = PatchEmbed( + height=self.sample_size, + width=self.sample_size, + patch_size=self.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.num_attention_heads, + self.attention_head_dim, + dropout=self.dropout, + cross_attention_dim=self.cross_attention_dim, + activation_fn=self.activation_fn, + num_embeds_ada_norm=self.num_embeds_ada_norm, + attention_bias=self.attention_bias, + only_cross_attention=self.only_cross_attention, + double_self_attention=self.double_self_attention, + upcast_attention=self.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.norm_elementwise_affine, + norm_eps=self.norm_eps, + attention_type=self.attention_type, + ) + for _ in range(self.num_layers) + ] + ) + + if self.norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, self.patch_size * self.patch_size * self.out_channels + ) + elif self.norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear( + self.inner_dim, self.patch_size * self.patch_size * self.out_channels + ) + + # PixArt-Alpha blocks. + self.adaln_single = None + if self.norm_type == "ada_norm_single": + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + + self.caption_projection = None + if self.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, hidden_size=self.inner_dim + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. + """ + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: + batch_size, _, height, width = hidden_states.shape + residual = hidden_states + hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs + ) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + output = self._get_output_for_continuous_inputs( + hidden_states=hidden_states, + residual=residual, + batch_size=batch_size, + height=height, + width=width, + inner_dim=inner_dim, + ) + elif self.is_input_vectorized: + output = self._get_output_for_vectorized_inputs(hidden_states) + elif self.is_input_patches: + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep, + class_labels=class_labels, + embedded_timestep=embedded_timestep, + height=height, + width=width, + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def _operate_on_continuous_inputs(self, hidden_states): + batch, _, height, width = hidden_states.shape + hidden_states = self.norm(hidden_states) + + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + return hidden_states, inner_dim + + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs): + batch_size = hidden_states.shape[0] + hidden_states = self.pos_embed(hidden_states) + embedded_timestep = None + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + return hidden_states, encoder_hidden_states, timestep, embedded_timestep + + def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + return output + + def _get_output_for_vectorized_inputs(self, hidden_states): + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + logits = logits.permute(0, 2, 1) + output = F.log_softmax(logits.double(), dim=1).float() + return output + + def _get_output_for_patched_inputs( + self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None + ): + if self.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + return output \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/transformer_blocks.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/transformer_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..f43e316b5486733a71edac76debfa60c2cda605b --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/transformer_blocks.py @@ -0,0 +1,467 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Dict, Any + +import torch +import torch.nn as nn + +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero, AdaLayerNormContinuous +from diffusers.models.attention import FeedForward +from ..layers.attention import Attention, AttnProcessor2_0 + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + raise ValueError(f"attention_type:{attention_type} is not supported!") + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # agb + self.enable_agb = False + + self.attn_count = 0 + self.last_attn_x = None + self.last_attn_x_before = None + self.attn_alpha_min = float('inf') + self.attn_alpha = 0 + self.attn_alpha_max = -float('inf') + self.last_attn = None + + self.cross_count = 0 + self.last_cross_x = None + self.last_cross_x_before = None + self.cross_alpha_min = float('inf') + self.cross_alpha = 0 + self.cross_alpha_max = -float('inf') + self.last_cross = None + + self.bound = [10, 2] + self.forcefresh = 6 + self.totalstep = 50 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + batch_size = hidden_states.shape[0] + # 0. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + # 1. Self-Attention + if self.enable_agb: + if (self.bound[0] - 0.5 < self.attn_count < self.totalstep - self.bound[1] - 0.5) and \ + (self.attn_count % self.forcefresh != 0): + broadcast_attn = True + else: + broadcast_attn = False + self.attn_count = (self.attn_count + 1) % self.totalstep + self.last_attn_x = hidden_states + if broadcast_attn: + hidden_states = self.last_attn + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + else: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + raise ValueError("'gligen' is not supported now!") + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output2 = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output2 + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + self.last_attn = attn_output + attn_output2 + ff_output + else: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + raise ValueError("'gligen' is not supported now!") + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + # i2vgen doesn't have this norm 🤷‍♂️ + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/unet_blocks.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..de030f72306432e4ab2ae41c851b7bdaa1e4986e --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/unet_blocks.py @@ -0,0 +1,579 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union, Tuple, Dict, Any +import torch +import torch.nn.functional as F +import torch.nn as nn + +from diffusers.models.resnet import ResnetBlock2D +from diffusers.models.unets.unet_2d_blocks import UpBlock2D, DownBlock2D +from diffusers.models.downsampling import Downsample2D +from diffusers.models.upsampling import Upsample2D +from .transformer_2d_model import Transformer2DModel + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + temb_channels: int, + in_channels: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + output_scale_factor: float = 1.0, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + mid_block_only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = 1, + dropout: float = 0.0, +): + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, # now is not used. + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + raise ValueError("DualTransformer2DModel is not supported now.") + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + additional_residuals: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + raise ValueError("DualTransformer2DModel is not supported now.") + resnets.append( + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, # will not used in inference mode + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + raise ValueError("DualTransformer2DModel is not supported now.") + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/unet_model.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/unet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..947c3f0319754bee6d136666ba19c3c813d443e6 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusion_xl/unet/unet_model.py @@ -0,0 +1,1242 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union, Dict, Any +import torch +import torch.nn as nn +import torch_npu + +from diffusers.models.embeddings import ( + Timesteps, + TimestepEmbedding, + TextImageProjection, + ImageProjection, + TextTimeEmbedding, + TextImageTimeEmbedding, + ImageTimeEmbedding, + ImageHintTimeEmbedding +) + +from diffusers.models.activations import get_activation +from diffusers.models.modeling_utils import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, scale_lora_layers, unscale_lora_layers + +from .unet_blocks import ( + get_down_block, + get_mid_block, + get_up_block +) + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor = None + + +class UNet2DConditionModel( + ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin +): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads: int = 64, + cache_method: str = None, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + + # class embedding + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) + + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + def enable_agb_cache(module): + for child in module.children(): + if hasattr(child, "enable_agb"): + child.enable_agb = True + if len(list(child.children())) > 0: + enable_agb_cache(child) + + if cache_method == "agb_cache": + enable_agb_cache(self.down_blocks) + enable_agb_cache(self.mid_block) + enable_agb_cache(self.up_blocks) + self.enable_unet_cache = cache_method == "static_cache" + self.cache = None + self.cache_step = [1, 2, 4, 6, 7, 9, 10, 12, 13, 14, 16, 18, 19, 21, 23, 24, 26, 27, 29, \ + 30, 31, 33, 34, 36, 37, 39, 40, 42, 43, 45, 47, 48, 49] + + def _check_config( + self, + down_block_types: Tuple[str], + up_block_types: Tuple[str], + only_cross_attention: Union[bool, Tuple[bool]], + block_out_channels: Tuple[int], + layers_per_block: Union[int, Tuple[int]], + cross_attention_dim: Union[int, Tuple[int]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: Optional[Union[int, Tuple[int]]], + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: Optional[str], + cross_attention_dim: Union[int, Tuple[int]], + encoder_hid_dim: Optional[int], + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: Optional[str], + act_fn: str, + num_class_embeds: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: Optional[int], + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: Optional[int], + encoder_hid_dim: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + def get_time_embed( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] + ) -> Optional[torch.Tensor]: + timesteps = timestep + if not torch.is_tensor(timesteps): + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb + + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb + + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> Optional[torch.Tensor]: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + **kwargs + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + if self.enable_unet_cache: + step = kwargs.get("step", 0) + if len(self.cache_step) > 0 and (step + 1) not in self.cache_step: + for block_id, downsample_block in enumerate(self.down_blocks): + if block_id >= 2: # skip last block + break + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + sample = self.cache.detach() + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + if i == 1: + res_samples = down_block_res_samples[-4:-1] + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + if i == 2: + res_samples = down_block_res_samples[:3] + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + else: + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + if i == 0: + self.cache = sample + else: + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_lora_patch.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_lora_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..93dbbb83ab7ca3465f343cd7af2673de47c6a8c6 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_lora_patch.py @@ -0,0 +1,31 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers +import logging + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + if diffusers_version != '0.26.3': + logging.error("patch error! diffusers_version does not equal to 0.26.3") + os.system(f'patch -p0 {diffusers_path[0]}/models/lora.py lora.patch') + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_lora.patch') + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_pipeline.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d170b50f1645573e9c21f1c3c2cc174e4249dc --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_pipeline.py @@ -0,0 +1,1196 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +from typing import Callable, List, Optional, Union +import numpy as np +import logging + +import torch +import mindietorch +from diffusers import StableDiffusionXLPipeline +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.schedulers import * +from quant_utils import modify_model +from safetensors.torch import load_file +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear + +clip_time = 0 +unet_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + batch_size: int, + num_images_per_prompt: int = 1, + max_num_prompts: int = 0 + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file, max_num_prompts) + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file, max_num_prompts) + elif prompt_file_type == 'hpsv2': + self.load_prompts_hpsv2(max_num_prompts) + else: + print("This operation is not supported!") + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + + def load_prompts_hpsv2(self, max_num_prompts: int): + with open('hpsv2_benchmark_prompts.json', 'r') as file: + all_prompts = json.load(file) + count = 0 + for style, prompts in all_prompts.items(): + for prompt in prompts: + count += 1 + if max_num_prompts and count >= max_num_prompts: + break + + if style not in self.catagories: + self.catagories.append(style) + + catagory_id = self.catagories.index(style) + self.prompts.append((prompt, catagory_id)) + + +class AIEStableDiffusionXLPipeline(StableDiffusionXLPipeline): + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0 = self.args.device[0] + else: + self.device_0 = args.device + self.data = None + if self.args.save_unet_input: + self.data = { 'use_cache':self.args.use_cache, 'parallel':isinstance(self.args.device, list)} + + def compile_aie_model(self): + if self.is_init: + return + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + + if self.args.flag in [0, 1, 3]: + if self.args.flag == 0: + tail = f"_static_{self.args.height}x{self.args.width}" + elif self.args.flag == 1: + tail = "" + else: + tail = f"_quant_{self.args.height}x{self.args.width}" + + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_bs{size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, f"ddim/ddim_bs{batch_size}_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + if not self.args.use_cache: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_compile{tail}.ts") + self.compiled_unet_model = torch.jit.load(unet_compiled_path).eval() + if self.args.use_cache: + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_compile_1{tail}.ts") + self.compiled_unet_model_skip = torch.jit.load(unet_skip_compiled_path).eval() + + unet_cache_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_compile_0{tail}.ts") + self.compiled_unet_model_cache = torch.jit.load(unet_cache_compiled_path).eval() + elif self.args.flag == 2: + tail = "_dynamic" + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, f"ddim/ddim_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + if not self.args.use_cache: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_compile{tail}.ts") + self.compiled_unet_model = torch.jit.load(unet_compiled_path).eval() + if self.args.use_cache: + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_compile_1{tail}.ts") + self.compiled_unet_model_skip = torch.jit.load(unet_skip_compiled_path).eval() + + unet_cache_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_compile_0{tail}.ts") + self.compiled_unet_model_cache = torch.jit.load(unet_cache_compiled_path).eval() + self.is_init = True + + def encode_prompt( + self, + prompt, + prompt_2, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + lora_scale, + clip_skip + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.compiled_clip_model, self.compiled_clip_model_2] if self.compiled_clip_model is not None + else [self.compiled_clip_model_2] + ) + + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + # flag = 0############################# + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1]) + + # We are only ALWAYS interested in the pooled output of the final text encoder + global clip_time + start = time.time() + prompt_embeds_npu = text_encoder(text_input_ids.to(f'npu:{self.device_0}')) + + pooled_prompt_embeds = prompt_embeds_npu[0].to('cpu') + clip_time += time.time() - start + + if clip_skip is None: + prompt_embeds = prompt_embeds_npu[2][-2].to('cpu') + + else: + # "2" because SDXL always indexes from the penultimate layer.????待定 + prompt_embeds = prompt_embeds_npu.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(f'npu:{self.device_0}'))[0].to('cpu') + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_prompt_embeds = [torch.from_numpy(text) for text in negative_prompt_embeds] + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def ascendie_infer( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[dict[str, any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[tuple[int, int]] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: Optional[tuple[int, int]] = None, + negative_original_size: Optional[tuple[int, int]] = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: Optional[tuple[int, int]] = None, + clip_skip: Optional[int] = None, + skip_steps=None, + flag_ddim: int = None, + flag_cache: int = None, + use_lora_hotswitch=False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + """ + global p1_time, p2_time, p3_time + start = time.time() + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + # 3. Encode input prompt + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + p1_time += time.time() - start + start1 = time.time() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + generator=torch.Generator("cpu").manual_seed(1) + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=self.text_encoder_2.config.projection_dim + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + # 8.1 Apply denoising_end + if ( + denoising_end is not None + and isinstance(denoising_end, float) + and denoising_end > 0 + and denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + cache = None + global unet_time + global vae_time + + skip_flag = torch.ones([1], dtype=torch.long) + cache_flag = torch.zeros([1], dtype=torch.long) + + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + if latents.is_cpu: + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t).to(f'npu:{self.device_0}') + + start = time.time() + if flag_cache: + if skip_steps[i]: + if self.data is not None and 'skip' not in self.data: + self.data['skip'] = (latent_model_input.to('cpu'), + t.to(torch.int64)[None].to('cpu'), + prompt_embeds.to('cpu'), + add_text_embeds.to('cpu'), + add_time_ids.to('cpu'), + skip_flag.to('cpu'), + cache.to('cpu')) + unet_input_skip = [ + latent_model_input, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + add_text_embeds.to(f'npu:{self.device_0}'), + add_time_ids.to(f'npu:{self.device_0}'), + skip_flag.to(f'npu:{self.device_0}'), + cache, + ] + if use_lora_hotswitch: + unet_input_skip = unet_input_skip + [torch.tensor([])] * 149 + noise_pred = self.compiled_unet_model_skip(*unet_input_skip) + else: + if self.data is not None and 'cache' not in self.data: + self.data['cache'] = (latent_model_input.to('cpu'), + t.to(torch.int64)[None].to('cpu'), + prompt_embeds.to('cpu'), + add_text_embeds.to('cpu'), + add_time_ids.to('cpu'), + cache_flag.to('cpu')) + unet_input_cache = [ + latent_model_input, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + add_text_embeds.to(f'npu:{self.device_0}'), + add_time_ids.to(f'npu:{self.device_0}'), + cache_flag.to(f'npu:{self.device_0}'), + ] + if use_lora_hotswitch: + unet_input_cache = unet_input_cache + [torch.tensor([])] * 788 + outputs = self.compiled_unet_model_cache(*unet_input_cache) + noise_pred = outputs[0] + cache = outputs[1] + else: + if self.data is not None and 'no_cache' not in self.data: + self.data['no_cache'] = (latent_model_input.to('cpu'), + t.to(torch.int64)[None].to('cpu'), + prompt_embeds.to('cpu'), + add_text_embeds.to('cpu'), + add_time_ids.to('cpu')) + unet_input = [ + latent_model_input, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + prompt_embeds.to(f'npu:{self.device_0}'), + add_text_embeds.to(f'npu:{self.device_0}'), + add_time_ids.to(f'npu:{self.device_0}'), + ] + if use_lora_hotswitch: + unet_input = unet_input + [torch.tensor([])] * 788 + noise_pred = self.compiled_unet_model(*unet_input) + unet_time += time.time() - start + + # perform guidance + if do_classifier_free_guidance: + if flag_ddim: + noise_pred = noise_pred.to('cpu') + x = np.array(i, dtype=np.int64) + y = torch.from_numpy(x).long() + + latents = self.compiled_scheduler( + noise_pred.to(f'npu:{self.device_0}'), + t[None].to(torch.int64).to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + y[None].to(f'npu:{self.device_0}')).to('cpu') + else: + noise_pred = noise_pred.to('cpu') + noise_pred, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred + guidance_scale * (noise_pred_text - + noise_pred) + + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + p2_time = time.time() - start1 + start2 = time.time() + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + latents = latents / self.vae.config.scaling_factor + + start = time.time() + latents = self.vae.post_quant_conv(latents) + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to('cpu') + vae_time += time.time() - start + # image = image.unsqueeze(0) + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if output_type == "pil": + image = self.numpy_to_pil(image) + + p3_time += time.time() - start2 + return (image, None) + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti", "hpsv2"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--scheduler", + choices=["DDIM", "Euler", "DPM", "SA-Solver", "EulerAncestral", "DPM++SDEKarras"], + default="DDIM", + help="Type of Sampling methods. Can choose from DDIM, Euler, DPM, SA-Solver", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_steps", + type=str, + default="1,2,4,6,7,9,10,12,13,14,16,18,19,21,23,24,26,27,29,\ + 30,31,33,34,36,37,39,40,42,43,45,47,48,49", # 17+33 + help="Steps to use cache data." + ) + parser.add_argument( + "--flag", + choices=[0, 1, 2, 3], + default=0, + type=int, + help="0 is static; 1 is dynami dims; 2 is dynamic range; 3 is quant", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width" + ) + parser.add_argument( + "--save_unet_input", + action="store_true", + help="save unet input for quant." + ) + parser.add_argument( + "--quant", + action="store_true", + help="use quantize unet." + ) + parser.add_argument( + "--use_loraHotswitch", + action="store_true", + help="use lora hot switch function" + ) + parser.add_argument( + "--lorabase_weight", + type=str, + default="./baseLoraPath/saveTensor.pt", + help="base lora weight path.", + ) + parser.add_argument( + "--loranew_weight", + type=str, + default="./newLoraPath/lora.pt", + help="new lora weight path.", + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + if args.quant: + torch.ops.load_library("./quant/build/libquant_ops.so") + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusionXLPipeline.from_pretrained(args.model).to("cpu") + + flag_ddim = 0 + if args.scheduler == "DDIM": + flag_ddim = 1 + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "Euler": + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "DPM": + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "SA-Solver": + pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "EulerAncestral": + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "DPM++SDEKarras": + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.scheduler.config.algorithm_type = 'sde-dpmsolver++' + pipe.scheduler.config.use_karras_sigmas = True + + mindietorch.set_device(args.device) + pipe.parser_args(args) + pipe.compile_aie_model() + skip_steps = [0] * args.steps + flag_cache = 0 + if args.use_cache: + flag_cache = 1 + for i in args.cache_steps.split(','): + if not i.isdigit() or int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + + if args.use_loraHotswitch: + from diffusers.models.lora import KeyOrderList, UnetSkip_key + # first combine base model and lora model into new + base_model = torch.load(args.lorabase_weight) + new_model = load_file(args.loranew_weight) + fusionweight = dict() + visited = [] + for name in new_model.keys(): + # one circle handle a pair key and skip .alpha key + if ".alpha" in name or name in visited: + continue + # for sdxl,lora hot switch is supported for unet + curr_layer = pipe.unet + layer_infos = name.split(".")[0].split("lora_unet_")[-1].replace('_', '.').split(".") + temp_name = layer_infos.pop(0) + desstr = "" + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + desstr = desstr + temp_name + "." + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + desstr = desstr + temp_name + break + except Exception: + temp_name = temp_name + "_" + layer_infos.pop(0) + pair_keys = [] + if "lora_down" in name: + pair_keys.append(name.replace("lora_down", "lora_up")) + pair_keys.append(name) + pair_keys.append(name.replace("lora_down.weight", "alpha")) + else: + pair_keys.append(name) + pair_keys.append(name.replace("lora_up", "lora_down")) + pair_keys.append(name.replace("lora_up.weight", "alpha")) + + # for different type layer,weight fussion is different + # prepare base weight + base_weight = base_model[desstr].to(torch.float16) + # prepare lora weight + lora_up_weight = new_model[pair_keys[0]].to(torch.float16) + lora_down_weight = new_model[pair_keys[1]].to(torch.float16) + # determin the ratio + if new_model[pair_keys[2]] == None: + ratio = 1.0 + else: + alpha = new_model[pair_keys[2]].item() + ratio = alpha / min(new_model[pair_keys[0]].shape[0], new_model[pair_keys[1]].shape[1]) + if isinstance(curr_layer, LoRACompatibleConv): + # fusion down and up + fusionupdown = torch.mm(lora_up_weight.flatten(start_dim = 1), lora_down_weight.flatten(start_dim = 1)) + fusionupdown = fusionupdown.reshape(base_weight.shape) + # main road + bypass + fusionweight[desstr] = base_weight + ratio * fusionupdown + elif isinstance(curr_layer, LoRACompatibleLinear): + fusion = ratio * torch.bmm(lora_up_weight[None, :], lora_down_weight[None, :])[0] + fusionweight[desstr] = base_weight + fusion + for item in pair_keys: + visited.append(item) + # specify key order + if args.use_cache: + #skip model + input_skip = [ + torch.ones(2, 4, 128, 128).to("npu"), + torch.ones(1).to(torch.long).to("npu"), + torch.ones(2, 77, 2048).to("npu"), + torch.ones(2, 1280).to("npu"), + torch.ones(2, 6).to("npu"), + torch.ones([1], dtype=torch.long).to("npu"), + torch.ones(2, 1280, 64, 64).to("npu") + ] + for name in UnetSkip_key: + try: + input_skip.append(fusionweight[name].to(torch.float16).to("npu")) + except KeyError: + logging.error('can not find UnetSkip_key key name:%s in fusionweight',name) + return + + outskip = pipe.compiled_unet_model_skip(*input_update) + outskip.to("cpu") + # cache model + input_cache = [ + torch.ones(2, 4, 128, 128).to("npu"), + torch.ones(1).to(torch.long).to("npu"), + torch.ones(2, 77, 2048).to("npu"), + torch.ones(2, 1280).to("npu"), + torch.ones(2, 6).to("npu"), + torch.ones([1], dtype=torch.long).to("npu") + ] + for name in KeyOrderList: + try: + input_cache.append(fusionweight[name].to(torch.float16).to("npu")) + except KeyError: + logging.error('can not find keyorderlist key name:%s in fusionweight',name) + return + + outcache = pipe.compiled_unet_model_cache(*input_cache) + outcache.to("cpu") + + else: + input_update = [ + torch.ones(2, 4, 128, 128).to("npu"), + torch.ones(1).to(torch.long).to("npu"), + torch.ones(2, 77, 2048).to("npu"), + torch.ones(2, 1280).to("npu"), + torch.ones(2, 6).to("npu") + ] + for name in KeyOrderList: + try: + input_update.append(fusionweight[name].to(torch.float16).to("npu")) + except KeyError: + logging.error('can not find keyorderlist key name:%s in fusionweight',name) + return + + output = pipe.compiled_unet_model(*input_update) + output.to("cpu") + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + prompts_2 = "" + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + + stream = mindietorch.npu.Stream("npu:" + str(args.device)) + with mindietorch.npu.stream(stream): + images = pipe.ascendie_infer( + prompts, + prompts_2, + width=args.width, + height=args.height, + num_inference_steps=args.steps, + guidance_scale=5.0, # 7.5, + skip_steps=skip_steps, + flag_ddim=flag_ddim, + flag_cache=flag_cache, + use_lora_hotswitch=args.use_loraHotswitch, + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n" + f"clip time: {clip_time / infer_num:.3f}s\n" + f"unet time: {unet_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n") + + if args.save_unet_input: + np.save('unet_data.npy', pipe.data) + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_pipeline_cache_parallel.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_pipeline_cache_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..cd024fb2503c8d51cb05ca31e97a9b8bc6261671 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_pipeline_cache_parallel.py @@ -0,0 +1,1045 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +from typing import Callable, List, Optional, Union +import numpy as np +import hpsv2 + +import torch +import mindietorch +from diffusers import StableDiffusionXLPipeline +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.schedulers import * +from background_runtime_cache import BackgroundRuntime, RuntimeIOInfo + +clip_time = 0 +unet_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 + + +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + batch_size: int, + num_images_per_prompt: int = 1, + max_num_prompts: int = 0 + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file, max_num_prompts) + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file, max_num_prompts) + elif prompt_file_type == 'hpsv2': + self.load_prompts_hpsv2(max_num_prompts) + else: + print("This operation is not supported!") + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + if max_num_prompts and i == max_num_prompts: + break + + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str, max_num_prompts: int): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + if max_num_prompts and i == max_num_prompts: + break + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + def load_prompts_hpsv2(self, max_num_prompts: int): + all_prompts = hpsv2.benchmark_prompts('all') + count = 0 + for style, prompts in all_prompts.items(): + for prompt in prompts: + count += 1 + if max_num_prompts and count >= max_num_prompts: + break + + if style not in self.catagories: + self.catagories.append(style) + + catagory_id = self.catagories.index(style) + self.prompts.append((prompt, catagory_id)) + + +class AIEStableDiffusionXLPipeline(StableDiffusionXLPipeline): + + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0, self.device_1 = args.device + else: + self.device_0 = args.device + + def compile_aie_model(self): + if self.is_init: + return + + in_channels = self.unet.config.out_channels + sample_size = self.unet.config.sample_size + encoder_hidden_size_2 = self.text_encoder_2.config.hidden_size + encoder_hidden_size = self.text_encoder.config.hidden_size + encoder_hidden_size_2 + max_position_embeddings = self.text_encoder.config.max_position_embeddings + + batch_size = self.args.batch_size + if self.args.flag == 0 or self.args.flag == 1: + if self.args.flag == 0: + tail = f"_static_{self.args.height}x{self.args.width}" + elif self.args.flag == 1: + tail = "" + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{batch_size}_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{batch_size}_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_bs{batch_size}_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, f"ddim/ddim_bs{batch_size}_parallel_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + if not self.args.use_cache: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_compile{tail}.ts") + self.compiled_unet_model = torch.jit.load(unet_compiled_path).eval() + + if self.args.use_cache: + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_parallel_compile_1{tail}.ts") + self.compiled_unet_model_skip = torch.jit.load(unet_skip_compiled_path).eval() + + unet_cache_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_parallel_compile_0{tail}.ts") + self.compiled_unet_model_cache = torch.jit.load(unet_cache_compiled_path).eval() + elif self.args.flag == 2: + tail = "_dynamic" + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_compile{tail}.ts") + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + + clip1_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_compile{tail}.ts") + self.compiled_clip_model = torch.jit.load(clip1_compiled_path).eval() + + clip2_compiled_path = os.path.join(self.args.output_dir, f"clip/clip2_compile{tail}.ts") + self.compiled_clip_model_2 = torch.jit.load(clip2_compiled_path).eval() + + scheduler_compiled_path = os.path.join(self.args.output_dir, f"ddim/ddim_parallel_compile{tail}.ts") + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + + if not self.args.use_cache: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_compile{tail}.ts") + self.compiled_unet_model = torch.jit.load(unet_compiled_path).eval() + + if self.args.use_cache: + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_parallel_compile_1{tail}.ts") + self.compiled_unet_model_skip = torch.jit.load(unet_skip_compiled_path).eval() + + unet_cache_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_parallel_compile_0{tail}.ts") + self.compiled_unet_model_cache = torch.jit.load(unet_cache_compiled_path).eval() + runtime_info_cache = RuntimeIOInfo( + input_shapes=[ + (batch_size, in_channels, sample_size, sample_size), + (1,), + (batch_size, max_position_embeddings, encoder_hidden_size), + (batch_size, encoder_hidden_size_2), + (batch_size, 6), + (1,) + ], + input_dtypes=[np.float32, np.int64, np.float32, np.float32, np.float32, np.int64], ######################## + output_shapes=[(batch_size, in_channels, sample_size, sample_size), + (batch_size, 1280, sample_size, sample_size)], + output_dtypes=[np.float32, np.float32] + ) + + if hasattr(self, 'device_1'): + self.unet_bg = BackgroundRuntime.clone(self.device_1, [unet_cache_compiled_path, unet_skip_compiled_path], + runtime_info_cache) + self.use_parallel_inferencing = True + + self.is_init = True + + def encode_prompt( + self, + prompt, + prompt_2, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + negative_prompt_2, + lora_scale, + clip_skip + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.compiled_clip_model, self.compiled_clip_model_2] if self.compiled_clip_model is not None + else [self.compiled_clip_model_2] + ) + + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + # flag = 0 + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1]) + + # We are only ALWAYS interested in the pooled output of the final text encoder + + global clip_time + start = time.time() + prompt_embeds_npu = text_encoder(text_input_ids.to(f'npu:{self.device_0}')) + + pooled_prompt_embeds = prompt_embeds_npu[0].to('cpu') + clip_time += time.time() - start + + if clip_skip is None: + prompt_embeds = prompt_embeds_npu[2][-2].to('cpu') + + else: + # "2" because SDXL always indexes from the penultimate layer.????待定 + prompt_embeds = prompt_embeds_npu.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(f'npu:{self.device_0}'))[0].to('cpu') + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_prompt_embeds = [torch.from_numpy(text) for text in negative_prompt_embeds] + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device="cpu") + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def ascendie_infer( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[dict[str, any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[tuple[int, int]] = None, + crops_coords_top_left: tuple[int, int] = (0, 0), + target_size: Optional[tuple[int, int]] = None, + negative_original_size: Optional[tuple[int, int]] = None, + negative_crops_coords_top_left: tuple[int, int] = (0, 0), + negative_target_size: Optional[tuple[int, int]] = None, + clip_skip: Optional[int] = None, + skip_steps=None, + flag_ddim: int = None, + flag_cache: int = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + """ + global p1_time, p2_time, p3_time + start = time.time() + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + # 3. Encode input prompt + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + lora_scale=lora_scale, + clip_skip=clip_skip, + ) + + p1_time += time.time() - start + start1 = time.time() + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=self.text_encoder_2.config.projection_dim + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + ) + else: + negative_add_time_ids = add_time_ids + + prompt_embeds = prompt_embeds.to(device) + negative_prompt_embeds = negative_prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + negative_add_time_ids = negative_add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + # 8.1 Apply denoising_end + if ( + denoising_end is not None + and isinstance(denoising_end, float) + and denoising_end > 0 + and denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + global unet_time + global vae_time + + cache = None + skip_flag = torch.ones([1], dtype=torch.long) + cache_flag = torch.zeros([1], dtype=torch.long) + + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + if not self.use_parallel_inferencing and do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t).to(f'npu:{self.device_0}') + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if self.use_parallel_inferencing and do_classifier_free_guidance: + self.unet_bg.infer_asyn([ + latent_model_input.numpy(), # [1, 4, 128, 128] + t[None].numpy().astype(np.int64), + prompt_embeds.to('cpu').numpy(), # [1, 77, 2048] + add_text_embeds.to('cpu').numpy(), # [1, 1280] + add_time_ids.numpy(), # .astype(torch.long) + skip_flag.numpy(), + ], + skip_steps[i]) + + latent_model_input_npu = latent_model_input.to(f'npu:{self.device_0}') # [1, 4, 128, 128] + + start = time.time() + if skip_steps[i]: + noise_pred_npu = self.compiled_unet_model_skip(latent_model_input_npu, # [1, 4, 128, 128] + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + negative_prompt_embeds.to(f'npu:{self.device_0}'), + # [1, 77, 2048] + negative_pooled_prompt_embeds.to(f'npu:{self.device_0}'), + # [1, 1280] + negative_add_time_ids.to(f'npu:{self.device_0}'), + skip_flag.to(f'npu:{self.device_0}'), + cache, ) # if_skip, cache + noise_pred = noise_pred_npu # .to('cpu') + else: + start = time.time() + outputs = self.compiled_unet_model_cache(latent_model_input_npu, + t.to(torch.int64)[None].to(f'npu:{self.device_0}'), + negative_prompt_embeds.to(f'npu:{self.device_0}'), + negative_pooled_prompt_embeds.to(f'npu:{self.device_0}'), + negative_add_time_ids.to(f'npu:{self.device_0}'), + cache_flag.to(f'npu:{self.device_0}'), + ) + cache = outputs[1] + noise_pred = outputs[0] + + unet_time += time.time() - start + + # perform guidance + if do_classifier_free_guidance: + if self.use_parallel_inferencing: + if (skip_steps[i]): + noise_pred_text = torch.from_numpy(self.unet_bg.wait_and_get_outputs()[0]) + + else: + out = self.unet_bg.wait_and_get_outputs() ########################################## + noise_pred_text = torch.from_numpy(out[0]) + + else: + noise_pred, noise_pred_text = noise_pred.chunk(2) + + x = np.array(i, dtype=np.int64) + y = torch.from_numpy(x).long() + + latents = self.compiled_scheduler( # 2、分别输入两类噪声预测,就可以不用额外增加concat + noise_pred.to(f'npu:{self.device_0}'), # 无条件 + noise_pred_text.to(f'npu:{self.device_0}'), # 有条件 + t[None].to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + y[None].to(f'npu:{self.device_0}')).to('cpu') + + p2_time = time.time() - start1 + start3 = time.time() + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + latents = latents / self.vae.config.scaling_factor + latents = self.vae.post_quant_conv(latents) + start = time.time() + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to('cpu') + # image = image.unsqueeze(0) + vae_time += time.time() - start + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if output_type == "pil": + image = self.numpy_to_pil(image) + + p3_time += time.time() - start3 + return (image, None) + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-xl-base-1.0", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti", "hpsv2"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=[0, 1], + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "--max_num_prompts", + default=0, + type=int, + help="Limit the number of prompts (0: no limit).", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--soc", + choices=["Duo", "A2"], + default="A2", + help="soc_version.", + ) + parser.add_argument( + "--scheduler", + choices=["DDIM", "Euler", "DPM", "SA-Solver"], + default="DDIM", + help="Type of Sampling methods. Can choose from DDIM, Euler, DPM, SA-Solver", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--cache_steps", + type=str, + default="1,2,4,6,7,9,10,12,13,14,16,18,19,21,23,24,26,27,29,\ + 30,31,33,34,36,37,39,40,42,43,45,47,48,49", # 17+33 + help="Steps to use cache data." + ) + parser.add_argument( + "--flag", + choices=[0, 1, 2], + default=0, + type=int, + help="0 is static; 1 is dynamic dims; 2 is dynamic range.", + ) + parser.add_argument( + "--height", + default=1024, + type=int, + help="image height", + ) + parser.add_argument( + "--width", + default=1024, + type=int, + help="image width", + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusionXLPipeline.from_pretrained(args.model).to("cpu") + + flag_ddim = 0 + if args.scheduler == "DDIM": + flag_ddim = 1 + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "Euler": + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "DPM": + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "SA-Solver": + pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "EulerAncestral": + pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "DPM++SDEKarras": + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.scheduler.config.algorithm_type = 'sde-dpmsolver++' + pipe.scheduler.config.use_karras_sigmas = Truev + + pipe.parser_args(args) + pipe.compile_aie_model() + mindietorch.set_device(pipe.device_0) + skip_steps = [0] * args.steps + flag_cache = 0 + if args.use_cache: + flag_cache = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt, + args.max_num_prompts) + + prompts_2 = "" + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + images = pipe.ascendie_infer( + prompts, + prompts_2, + width=args.width, + height=args.height, + num_inference_steps=args.steps, + guidance_scale=5.0, # 7.5, + skip_steps=skip_steps, + flag_ddim=flag_ddim, + flag_cache=flag_cache, + ) + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n" + f"clip time: {clip_time / infer_num:.3f}s\n" + f"unet time: {unet_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n") + + if hasattr(pipe, 'device_1'): + if (pipe.unet_bg): + pipe.unet_bg.stop() + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + mindietorch.finalize() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_unet_patch.py b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_unet_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..1b64b8a8f1152a615bfc78ad24e63dc70432563e --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/stable_diffusionxl_unet_patch.py @@ -0,0 +1,28 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.26.3', f"Expected diffusers version 0.26.3, but got {diffusers_version}" + os.system(f'patch -p0 {diffusers_path[0]}/models/unets/unet_2d_condition.py unet_2d_condition.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion-XL/unet_2d_condition.patch b/MindIE/MultiModal/StableDiffusion-XL/unet_2d_condition.patch new file mode 100644 index 0000000000000000000000000000000000000000..6935d6b2ca4d2d7ba457aac777b8b54d8803063a --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion-XL/unet_2d_condition.patch @@ -0,0 +1,247 @@ +--- unet_2d_condition.py 2024-06-04 17:04:44.033309200 +0800 ++++ unet_2d_condition_new.py 2024-06-04 17:27:41.224501400 +0800 +@@ -855,6 +855,8 @@ + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, ++ if_skip: int = 0, ++ inputCache: torch.FloatTensor = None, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. +@@ -1110,29 +1112,56 @@ + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + +- down_block_res_samples = (sample,) +- for downsample_block in self.down_blocks: +- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: +- # For t2i-adapter CrossAttnDownBlock2D +- additional_residuals = {} +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) +- +- sample, res_samples = downsample_block( +- hidden_states=sample, +- temb=emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- **additional_residuals, +- ) +- else: +- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- sample += down_intrablock_additional_residuals.pop(0) ++ if not if_skip: ++ down_block_res_samples = (sample,) ++ for downsample_block in self.down_blocks: ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) ++ ++ down_block_res_samples += res_samples ++ else: ++ down_block_res_samples = (sample,) ++ for tmp, downsample_block in enumerate(self.down_blocks): ++ if tmp >= 2: ++ break ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) + +- down_block_res_samples += res_samples ++ down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () +@@ -1146,61 +1175,97 @@ + down_block_res_samples = new_down_block_res_samples + + # 4. mid +- if self.mid_block is not None: +- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: +- sample = self.mid_block( +- sample, +- emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- ) +- else: +- sample = self.mid_block(sample, emb) +- +- # To support T2I-Adapter-XL +- if ( +- is_adapter +- and len(down_intrablock_additional_residuals) > 0 +- and sample.shape == down_intrablock_additional_residuals[0].shape +- ): +- sample += down_intrablock_additional_residuals.pop(0) ++ if not if_skip: ++ if self.mid_block is not None: ++ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: ++ sample = self.mid_block( ++ sample, ++ emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ else: ++ sample = self.mid_block(sample, emb) ++ ++ # To support T2I-Adapter-XL ++ if ( ++ is_adapter ++ and len(down_intrablock_additional_residuals) > 0 ++ and sample.shape == down_intrablock_additional_residuals[0].shape ++ ): ++ sample += down_intrablock_additional_residuals.pop(0) ++ else: ++ sample = inputCache + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up +- for i, upsample_block in enumerate(self.up_blocks): +- is_final_block = i == len(self.up_blocks) - 1 +- +- res_samples = down_block_res_samples[-len(upsample_block.resnets) :] +- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] +- +- # if we have not reached the final block and need to forward the +- # upsample size, we do it here +- if not is_final_block and forward_upsample_size: +- upsample_size = down_block_res_samples[-1].shape[2:] +- +- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: +- sample = upsample_block( +- hidden_states=sample, +- temb=emb, +- res_hidden_states_tuple=res_samples, +- encoder_hidden_states=encoder_hidden_states, +- cross_attention_kwargs=cross_attention_kwargs, +- upsample_size=upsample_size, +- attention_mask=attention_mask, +- encoder_attention_mask=encoder_attention_mask, +- ) +- else: +- sample = upsample_block( +- hidden_states=sample, +- temb=emb, +- res_hidden_states_tuple=res_samples, +- upsample_size=upsample_size, +- scale=lora_scale, +- ) ++ if not if_skip: ++ for i, upsample_block in enumerate(self.up_blocks): ++ is_final_block = i == len(self.up_blocks) - 1 ++ ++ res_samples = down_block_res_samples[-len(upsample_block.resnets) :] ++ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] ++ ++ # if we have not reached the final block and need to forward the ++ # upsample size, we do it here ++ if not is_final_block: ++ upsample_size = down_block_res_samples[-1].shape[2:] ++ ++ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ encoder_hidden_states=encoder_hidden_states, ++ cross_attention_kwargs=cross_attention_kwargs, ++ upsample_size=upsample_size, ++ attention_mask=attention_mask, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ else: ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ upsample_size=upsample_size, ++ scale=lora_scale, ++ ) ++ ++ if (not if_skip) and (i == 0): ++ inputCache = sample ++ ++ else: ++ ++ for i, upsample_block in enumerate(self.up_blocks): ++ is_final_block = i == len(self.up_blocks) - 1 ++ if not is_final_block: ++ upsample_size = down_block_res_samples[0].shape[2:] ++ ++ if i == 1: ++ res_samples = down_block_res_samples[-4:-1] ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ encoder_hidden_states=encoder_hidden_states, ++ cross_attention_kwargs=cross_attention_kwargs, ++ upsample_size=upsample_size, ++ attention_mask=attention_mask, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ if i == 2: ++ res_samples = down_block_res_samples[:3] ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ upsample_size=upsample_size, ++ scale=lora_scale, ++ ) + + # 6. post-process + if self.conv_norm_out: +@@ -1215,4 +1280,7 @@ + if not return_dict: + return (sample,) + +- return UNet2DConditionOutput(sample=sample) ++ if not if_skip: ++ return (sample, inputCache) ++ else: ++ return UNet2DConditionOutput(sample=sample) diff --git a/MindIE/MultiModal/StableDiffusion/README.md b/MindIE/MultiModal/StableDiffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3963e47827d7a38b32d5ad39ae9786126c91d7b2 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/README.md @@ -0,0 +1,468 @@ +# stable-diffusion模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + + stable-diffusion是一种文本到图像的扩散模型,能够在给定任何文本输入的情况下生成照片逼真的图像。有关稳定扩散函数的更多信息,请查看[Stable Diffusion blog](https://huggingface.co/blog/stable_diffusion)。 + +- 参考实现: + ```bash + # StableDiffusion v1.5 + https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5 + # StableDiffusion v2.1 + https://huggingface.co/stabilityai/stable-diffusion-2-1-base + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1或2 +Atlas 300I Duo推理卡:支持的卡数为1,可双芯并行 + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | input | 1 x 77 | FLOAT32| ND| + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | output1 | 1 x 512 x 512 x 3 | FLOAT32 | NHWD | + +**注意**:该模型当前仅支持batch size为1的情况。 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 +- + | 配套 | 版本 | 环境准备指导 | + | ------------------------------------------------------------ |--------| ------------------------------------------------------------ | + | Python | 3.10.13 | - | + | torch| 2.1.0 | - | + +**注意**:本README中的StableDiffusion v1.5和v2.1模型推理方式与torch-npu冲突,需卸载torch-npu包。 + +# 快速上手 + +## 获取源码 + +1. 按照requirements.txt要求的版本安装相关依赖,避免导出模型失败。 + ```bash + pip3 install -r requirements.txt + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + + 执行命令: + + ```bash + python3 stable_diffusion_attention_patch.py + ``` + + ```bash + python3 stable_diffusion_unet_patch.py + ``` + +## 准备数据集 + +1. 获取原始数据集。 + + 本模型输入文本信息生成图片,无需数据集。 + +## 模型推理 + +1. 模型转换。【可选】 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 0. 获取权重(可选) + + 可提前下载权重,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # v1.5 + git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5 + + # v2.1 + git clone https://huggingface.co/stabilityai/stable-diffusion-2-1-base + ``` + + 1. 导出pt模型并进行编译。(可选) + + 设置模型名称或路径 + ```bash + # v1.5 (执行时下载权重) + model_base="stable-diffusion-v1-5/stable-diffusion-v1-5" + + # v1.5 (使用上一步下载的权重) + model_base="./stable-diffusion-v1-5" + + # v2.1 (执行时下载权重) + model_base="stabilityai/stable-diffusion-2-1-base" + + # v2.1 (使用上一步下载的权重) + model_base="./stable-diffusion-2-1-base" + ``` + + 执行命令: + + ```bash + # 导出pt模型 + python3 export_ts.py --model ${model_base} --output_dir ./models \ + --parallel \ + --use_cache + ``` + + 参数说明: + - --model:模型名称或本地模型目录的路径 + - --output_dir: pt模型输出目录 + - --parallel:【可选】模型使用双芯/双卡并行推理 + - --use_cache: 【可选】模型使用UnetCache优化 + - --use_cache_faster: 【可选】模型使用deepcache+faster融合方案 + + 若不选择【--parallel】,即单卡/单芯,执行成功后会生成pt模型: + - ./models/clip/clip_bs1.pt + - ./models/vae/vae_bs1.pt + - ./models/ddim/ddim2.pt + - ./models/cat/cat.pt + - ./models/unet/unet_bs2.pt【不选择--use_cache】 + - ./models/unet/unet_bs2_0.pt【选择--use_cache】 + - ./models/unet/unet_bs2_1.pt【选择--use_cache】 + + 若选择【--parallel】,即双卡/双芯,执行成功后会生成pt模型: + - ./models/clip/clip_bs1.pt + - ./models/vae/vae_bs1.pt + - ./models/ddim/ddim1.pt + - ./models/unet/unet_bs1.pt【不选择--use_cache】 + - ./models/unet/unet_bs1_0.pt【选择--use_cache】 + - ./models/unet/unet_bs1_1.pt【选择--use_cache】 + + **注意**:若条件允许,该模型可以双芯片并行的方式进行推理,从而获得更短的端到端耗时。具体指令的差异之处会在后面的步骤中单独说明,请留意。 + + 使用Lora权重【可选】 + + 在[civitai](https://civitai.com)下载base model为SD1.5和SD2.1的的lora权重,一般选择safetensor格式的权重。执行转换脚本,将lora权重和model_base权重结合在一起。 + + 执行命令: + + ```bash + model_lora=lora权重路径 + model_new=适配lora之后的SD权重路径 + python3 convert_lora_safetensors_to_diffusers.py --base_model_path ${model_base} --checkpoint_path ${model_lora} --dump_path ${model_new} + ``` + + ```bash + # 若使用lora权重,导出pt模型 + python3 export_ts.py --model ${model_new} --output_dir ./models_lora \ + --parallel \ + --use_cache + ``` + + 执行成功后会在./models_lora路径下生成pt模型: + + **注意**:更换lora权重时,请手动删除models_lora路径的生成的pt模型,重新执行转换权重脚本和导出模型命令导出带lora权重的pt模型。 + + +2. 开始推理验证。 + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + # 1.若不使用并行推理: + # 1.1不使用lora权重 + numactl -C 0-23 python3 stable_diffusion_pipeline.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc A2 \ + --output_dir ./models \ + --use_cache + # 1.2使用带lora权重的新权重 + numactl -C 0-23 python3 stable_diffusion_pipeline.py \ + --model ${model_new} \ + --prompt_file ./prompts.txt \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc A2 \ + --output_dir ./models_lora \ + --use_cache + + # 2.若使用并行推理【Atlas 300I Duo】 + # 2.1不使用lora权重 + numactl -C 0-23 python3 stable_diffusion_pipeline_parallel.py \ + --model ${model_base} \ + --prompt_file ./prompts.txt \ + --device 0,1 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc Duo \ + --output_dir ./models \ + --use_cache + # 2.2使用带lora权重的新权重 + numactl -C 0-23 python3 stable_diffusion_pipeline_parallel.py \ + --model ${model_new} \ + --prompt_file ./prompts.txt \ + --device 0,1 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc Duo \ + --output_dir ./models_lora \ + --use_cache + ``` + + 参数说明: + - --model:模型名称或本地模型目录的路径。 + - --prompt_file:输入文本文件,按行分割。 + - --save_dir:生成图片的存放目录。 + - --steps:生成图片迭代次数。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + - --scheduler: 【可选】推荐使用DDIM采样器。 + - --soc: 硬件配置,根据硬件配置选择Duo或者A2。 + - --output_dir: 编译好的模型路径。 + - --use_cache: 【可选】推荐使用UnetCache策略。 + - --use_cache_faster: 【可选】模型使用deepcache+faster融合方案。 + + 执行完成后在`./results`目录下生成推理图片。并在终端显示推理时间。 + + **注意**:更换lora权重时,请手动删除models_lora路径的生成的编译好的pt模型,(xxx_compile.pt)重新执行推理脚本。 + + +## 精度验证 + + 由于生成的图片存在随机性,所以精度验证将使用CLIP-score来评估图片和输入文本的相关性,分数的取值范围为[-1, 1],越高越好。 + + 注意,由于要生成的图片数量较多,进行完整的精度验证需要耗费很长的时间。 + + 1. 下载Parti数据集 + + ```bash + wget https://raw.githubusercontent.com/google-research/parti/main/PartiPrompts.tsv --no-check-certificate + ``` + + 2. 下载Clip模型权重 + + ```bash + # 安装git-lfs + apt install git-lfs + git lfs install + git clone https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K + + # 或者访问https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/open_clip_pytorch_model.bin,将权重下载并放到这个目录下 + ``` + + 2. 使用推理脚本读取Parti数据集,生成图片 + ```bash + # 1.若不使用并行推理: + # 1.1不使用lora权重 + python3 stable_diffusion_pipeline.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc A2 \ + --output_dir ./models \ + --use_cache + # 1.2使用带lora权重的新权重 + python3 stable_diffusion_pipeline.py \ + --model ${model_new} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --device 0 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc A2 \ + --output_dir ./models_lora \ + --use_cache + + # 2.若使用并行推理【Atlas 300I Duo】 + # 2.1不使用lora权重 + python3 stable_diffusion_pipeline_parallel.py \ + --model ${model_base} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --device 0,1 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc Duo \ + --output_dir ./models \ + --use_cache + # 2.2使用带lora权重的新权重 + python3 stable_diffusion_pipeline_parallel.py \ + --model ${model_new} \ + --prompt_file ./PartiPrompts.tsv \ + --prompt_file_type parti \ + --num_images_per_prompt 4 \ + --device 0,1 \ + --save_dir ./results \ + --steps 50 \ + --scheduler DDIM \ + --soc Duo \ + --output_dir ./models_lora \ + --use_cache + ``` + + 增加的参数说明: + - --prompt_file:输入文本文件,按行分割。 + - --prompt_file_type: prompt文件类型,用于指定读取方式。 + - --num_images_per_prompt: 每个prompt生成的图片数量。 + + 执行完成后会在`./results`目录下生成推理图片,并且会在当前目录生成一个`image_info.json`文件,记录着图片和prompt的对应关系。 + + 4. 计算CLIP-score + + ```bash + python clip_score.py \ + --device=cpu \ + --image_info="image_info.json" \ + --model_name="ViT-H-14" \ + --model_weights_path="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin" + ``` + + 参数说明: + - --device: 推理设备。 + - --image_info: 上一步生成的`image_info.json`文件。 + - --model_name: Clip模型名称。 + - --model_weights_path: Clip模型权重文件路径。 + + 执行完成后会在屏幕打印出精度计算结果。 + + +# 模型推理性能&精度 + +性能参考下列数据。 + +### StableDiffusion v1.5 + +| 硬件形态 | 迭代次数 | 平均耗时(w/o UnetCache) | 平均耗时(with UnetCache) | +| :------: |:----:|:----:|:----:| +| Atlas 300I Duo双芯 | 50 | 2.5s | 1.54s | +| Atlas 800I A2(8*32G) | 50 | 1.6s | 0.95s | + +### StableDiffusion v2.1 + +| 硬件形态 | 迭代次数 | 平均耗时(w/o UnetCache) | 平均耗时(with UnetCache) | +| :------: |:----:|:----:|:----:| +| Atlas 300I Duo双芯 | 50 | 2.3s | 1.39s | +| Atlas 800I A2(8*32G) | 50 | 1.4s | 0.85s | + +### StableDiffusion v1.5 + +迭代50次的参考精度结果如下: + + ``` + average score: 0.363 + category average scores: + [Abstract], average score: 0.280 + [Vehicles], average score: 0.363 + [Illustrations], average score: 0.359 + [Arts], average score: 0.404 + [World Knowledge], average score: 0.372 + [People], average score: 0.364 + [Animals], average score: 0.373 + [Artifacts], average score: 0.359 + [Food & Beverage], average score: 0.355 + [Produce & Plants], average score: 0.358 + [Outdoor Scenes], average score: 0.355 + [Indoor Scenes], average score: 0.368 + ``` + +### StableDiffusion v2.1 + +迭代50次的参考精度结果如下: + + ``` + average score: 0.376 + category average scores: + [Abstract], average score: 0.285 + [Vehicles], average score: 0.377 + [Illustrations], average score: 0.376 + [Arts], average score: 0.414 + [World Knowledge], average score: 0.383 + [People], average score: 0.381 + [Animals], average score: 0.389 + [Artifacts], average score: 0.369 + [Food & Beverage], average score: 0.369 + [Produce & Plants], average score: 0.364 + [Outdoor Scenes], average score: 0.366 + [Indoor Scenes], average score: 0.381 + ``` + +**注意**:当前推理pipline中未固定随机种子,固定随机种子会对clip_score分数有影响 + +```bash +# 推理pipline main函数中加入 +generator = torch.Generator().manual_seed(xxx) +# 在ascendie_infer函数中加入参数 +generator=generator +``` \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion/attention_processor.patch b/MindIE/MultiModal/StableDiffusion/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..bd15281c5a3acf9752eec8a239323f66f1beadb7 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/attention_processor.patch @@ -0,0 +1,18 @@ +--- attention_processor.py 2024-07-02 07:42:32.312000000 +0000 ++++ attention_processor.py 2024-07-02 07:44:55.100000000 +0000 +@@ -205,10 +205,11 @@ + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- if processor is None: +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # if processor is None: ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor = AttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( diff --git a/MindIE/MultiModal/StableDiffusion/background_runtime.py b/MindIE/MultiModal/StableDiffusion/background_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..35b07e53b2976be0bf665433989bec8063d9f226 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/background_runtime.py @@ -0,0 +1,184 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +@dataclass +class RuntimeIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntime: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfo + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray]) -> None: + for i, _ in enumerate(self.input_arrays): + print(f'bg input shape: {self.input_arrays[i].shape}') + print(f'feeds shape: {feeds[i].shape}') + self.input_arrays[i][:] = feeds[i][:] + + self.sync_pipe.send('') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfo, + device_id: int, + model_path: str, + ) -> None: + # The sub process function + # Create a runtime + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model = torch.jit.load(model_path).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + mindietorch.set_device(device_id) + + # Tell the main function that we are ready + sync_pipe.send('') + + infer_num = 0 + preprocess_time = 0 + infer_time = 0 + forward_time = 0 + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + # Keep looping until recived a 'STOP' + while sync_pipe.recv() != 'STOP': + start = time.time() + sample, timestep, encoder_hidden_states = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + sample_npu = sample.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to(f"npu:{device_id}") + + preprocess_time += time.time() - start + + start2 = time.time() + with mindietorch.npu.stream(stream): + inf_start = time.time() + output_npu = model(sample_npu, timestep_npu, encoder_hidden_states_npu) + stream.synchronize() + inf_end = time.time() + + output_cpu = output_npu.to('cpu') + forward_time += inf_end - inf_start + infer_time += time.time() - start2 + + for i, _ in enumerate(output_arrays): + output = output_cpu.numpy() + output_arrays[i][:] = output[i][:] + + infer_num += 1 + sync_pipe.send('') + + infer_num /= 50 + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfo) -> 'BackgroundRuntime': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntime + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MultiModal/StableDiffusion/background_runtime_cache.py b/MindIE/MultiModal/StableDiffusion/background_runtime_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..3e89c255c1ea8b77edf0cd9c878d23a596e9db76 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/background_runtime_cache.py @@ -0,0 +1,192 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +@dataclass +class RuntimeIOInfoCache: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntimeCache: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfoCache + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray], skip) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + + if skip: + self.sync_pipe.send('skip') + else: + self.sync_pipe.send('cache') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfoCache, + device_id: int, + model_path: list, + ) -> None: + # The sub process function + # Create a runtime + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model_cache = torch.jit.load(model_path[0]).eval() + model_skip = torch.jit.load(model_path[1]).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + mindietorch.set_device(device_id) + + # Tell the main function that we are ready + sync_pipe.send('') + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + return_cache = None + + # Keep looping until recived a 'STOP' + while True: + flag = sync_pipe.recv() + if flag == 'STOP': + break + + if flag == 'cache': + sample, timestep, encoder_hidden_states, return_flag = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + else: + sample, timestep, encoder_hidden_states, return_flag = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + + sample_npu = sample.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to(f"npu:{device_id}") + flag_npu = return_flag.to(torch.int64).to(f"npu:{device_id}") + + if flag == 'cache': + with mindietorch.npu.stream(stream): + output_npu = model_cache(sample_npu, timestep_npu, encoder_hidden_states_npu, flag_npu) + stream.synchronize() + + output_cpu0 = output_npu[0].to('cpu') + output0 = output_cpu0.numpy() + output_arrays[0][:] = output0 + + return_cache = output_npu[1] + else: + with mindietorch.npu.stream(stream): + output_npu = model_skip(sample_npu, timestep_npu, encoder_hidden_states_npu, flag_npu, return_cache) + stream.synchronize() + + output_cpu0 = output_npu.to('cpu') + output0 = output_cpu0.numpy() + output_arrays[0][:] = output0 + + sync_pipe.send('') + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfoCache) -> 'BackgroundRuntimeCache': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntimeCache + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MultiModal/StableDiffusion/background_runtime_cache_faster.py b/MindIE/MultiModal/StableDiffusion/background_runtime_cache_faster.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f4fbf4057a6d2bbe5ce82a7707924530dfecea --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/background_runtime_cache_faster.py @@ -0,0 +1,194 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +@dataclass +class RuntimeIOInfoCacheFaster: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntimeCacheFaster: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfoCacheFaster + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray], skip) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + + if skip: + self.sync_pipe.send('skip') + else: + self.sync_pipe.send('cache') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfoCacheFaster, + device_id: int, + model_path: list, + ) -> None: + # The sub process function + # Create a runtime + mindietorch.set_device(device_id) + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model_cache = torch.jit.load(model_path[0]).eval() + model_skip = torch.jit.load(model_path[1]).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + # Tell the main function that we are ready + sync_pipe.send('') + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + return_cache = None + + # Keep looping until recived a 'STOP' + while True: + flag = sync_pipe.recv() + if flag == 'STOP': + break + + if flag == 'cache': + sample, timestep, encoder_hidden_states, return_flag, return_faster_flag = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + else: + sample, timestep, encoder_hidden_states, return_flag, return_faster_flag = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + + sample_npu = sample.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.int64).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to(f"npu:{device_id}") + flag_npu = return_flag.to(torch.int64).to(f"npu:{device_id}") + faster_flag_npu = return_faster_flag.to(torch.int64).to(f"npu:{device_id}") + + if flag == 'cache': + with mindietorch.npu.stream(stream): + output_npu = model_cache(sample_npu, timestep_npu, encoder_hidden_states_npu, flag_npu, faster_flag_npu) + stream.synchronize() + + output_cpu0 = output_npu[0].to('cpu') + output0 = output_cpu0.numpy() + output_arrays[0][:] = output0 + + return_cache = output_npu[1] + return_cache_faster = output_npu[2] + else: + with mindietorch.npu.stream(stream): + output_npu = model_skip(sample_npu, timestep_npu, encoder_hidden_states_npu, flag_npu, faster_flag_npu, return_cache, return_cache_faster) + stream.synchronize() + + output_cpu0 = output_npu.to('cpu') + output0 = output_cpu0.numpy() + output_arrays[0][:] = output0 + + sync_pipe.send('') + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfoCacheFaster) -> 'BackgroundRuntimeCacheFaster': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntimeCache + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MultiModal/StableDiffusion/clip_score.py b/MindIE/MultiModal/StableDiffusion/clip_score.py new file mode 100644 index 0000000000000000000000000000000000000000..069f5d6e9a9baaa61b9a3537bcab6f637605858e --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/clip_score.py @@ -0,0 +1,140 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import time +import argparse + +import open_clip +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F + + +def clip_score(model_clip, tokenizer, preprocess, prompt, image_files, device): + imgs = [] + texts = [] + for image_file in image_files: + img = preprocess(Image.open(image_file)).unsqueeze(0).to(device) + imgs.append(img) + text = tokenizer([prompt]).to(device) + texts.append(text) + + img = torch.cat(imgs) # [bs, 3, 224, 224] + text = torch.cat(texts) # [bs, 77] + + with torch.no_grad(): + text_ft = model_clip.encode_text(text).float() + img_ft = model_clip.encode_image(img).float() + score = F.cosine_similarity(img_ft, text_ft).squeeze() + + return score.cpu() + + +def main(): + args = parse_arguments() + + if args.device is None: + device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') + else: + device = torch.device(args.device) + + t_b = time.time() + print(f"Load clip model...") + model_clip, _, preprocess = open_clip.create_model_and_transforms( + args.model_name, pretrained=args.model_weights_path, device=device) + model_clip.eval() + print(f">done. elapsed time: {(time.time() - t_b):.3f} s") + + tokenizer = open_clip.get_tokenizer(args.model_name) + + with os.fdopen(os.open(args.image_info, os.O_RDONLY), "r") as f: + image_info = json.load(f) + + t_b = time.time() + print(f"Calc clip score...") + all_scores = [] + cat_scores = {} + + for i, info in enumerate(image_info): + image_files = info['images'] + category = info['category'] + prompt = info['prompt'] + + print(f"[{i + 1}/{len(image_info)}] {prompt}") + + image_scores = clip_score(model_clip, + tokenizer, + preprocess, + prompt, + image_files, + device) + if len(image_files) > 1: + best_score = max(image_scores) + else: + best_score = image_scores + + print(f"image scores: {image_scores}") + print(f"best score: {best_score}") + + all_scores.append(best_score) + if category not in cat_scores: + cat_scores[category] = [] + cat_scores[category].append(best_score) + print(f">done. elapsed time: {(time.time() - t_b):.3f} s") + + average_score = np.average(all_scores) + print(f"====================================") + print(f"average score: {average_score:.3f}") + print(f"category average scores:") + cat_average_scores = {} + for category, scores in cat_scores.items(): + cat_average_scores[category] = np.average(scores) + print(f"[{category}], average score: {cat_average_scores[category]:.3f}") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="device for torch.", + ) + parser.add_argument( + "--image_info", + type=str, + default="./image_info.json", + help="Image_info.json file.", + ) + parser.add_argument( + "--model_name", + type=str, + default="ViT-H-14", + help="open clip model name", + ) + parser.add_argument( + "--model_weights_path", + type=str, + default="./CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", + help="open clip model weights", + ) + return parser.parse_args() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion/convert_lora_safetensors_to_diffusers.py b/MindIE/MultiModal/StableDiffusion/convert_lora_safetensors_to_diffusers.py new file mode 100644 index 0000000000000000000000000000000000000000..c070d642a0da423f92d02537a95ab8dc31a90594 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/convert_lora_safetensors_to_diffusers.py @@ -0,0 +1,143 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Conversion script for the LoRA's safetensors checkpoints. """ + +import argparse + +import torch +from safetensors.torch import load_file + +from diffusers import StableDiffusionPipeline + + +def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT_ENCODER, alpha): + # load base model + pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) + + # load LoRA weight from .safetensors + state_dict = load_file(checkpoint_path) + + visited = [] + shape4failed = 0 + shapeno4failed = 0 + + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + not_found = False + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + not_found = True + break + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + if not_found: + continue + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + try: + if len(curr_layer.weight.shape) == 2: + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) # for SD2.1 + else: + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + except Exception: + shape4failed += 1 + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + try: + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) + except Exception: + shapeno4failed += 1 + + # update visited list + for item in pair_keys: + visited.append(item) + + return pipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" + ) + parser.add_argument( + "--lora_prefix_text_encoder", + default="lora_te", + type=str, + help="The prefix of text encoder weight in safetensors", + ) + parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") + parser.add_argument( + "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." + ) + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + args = parser.parse_args() + + base_model_path = args.base_model_path + checkpoint_path = args.checkpoint_path + dump_path = args.dump_path + lora_prefix_unet = args.lora_prefix_unet + lora_prefix_text_encoder = args.lora_prefix_text_encoder + alpha = args.alpha + + pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) + + pipe = pipe.to(args.device) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/MindIE/MultiModal/StableDiffusion/export_ts.py b/MindIE/MultiModal/StableDiffusion/export_ts.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba0cd0dce03936e3d247a1f9fd44b872b5612cc --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/export_ts.py @@ -0,0 +1,513 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from argparse import Namespace + +import torch +import torch.nn as nn +from diffusers import DDIMScheduler +from diffusers import StableDiffusionPipeline + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-2-1-base", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-steps", + "--steps", + type=int, + default=50, + help="steps." + ) + parser.add_argument( + "-guid", + "--guidance_scale", + type=float, + default=7.5, + help="guidance_scale" + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--use_cache_faster", + action="store_true", + help="Use cache with faster during inference." + ) + parser.add_argument( + "-p", + "--parallel", + action="store_true", + help="Export the unet of bs=1 for parallel inferencing.", + ) + + return parser.parse_args() + + +class ClipExport(torch.nn.Module): + def __init__(self, clip_model): + super().__init__() + self.clip_model = clip_model + + def forward(self, x): + return self.clip_model(x)[0] + + +def export_clip(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int) -> None: + print("Exporting the text encoder...") + clip_path = os.path.join(save_dir, "clip") + if not os.path.exists(clip_path): + os.makedirs(clip_path, mode=0o640) + + clip_pt_path = os.path.join(clip_path, f"clip_bs{batch_size}.pt") + if os.path.exists(clip_pt_path): + return + + clip_model = sd_pipeline.text_encoder + + max_position_embeddings = clip_model.config.max_position_embeddings + print(f'max_position_embeddings: {max_position_embeddings}') + dummy_input = torch.ones([batch_size, max_position_embeddings], dtype=torch.int64) + + clip_export = ClipExport(clip_model) + + torch.jit.trace(clip_export, dummy_input).save(clip_pt_path) + + +class UnetExportInit(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward(self, sample, timestep, encoder_hidden_states): + return self.unet_model(sample, timestep, encoder_hidden_states)[0] + + +def export_unet_init(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + if os.path.exists(unet_pt_path): + return + + unet_model = sd_pipeline.unet + clip_model = sd_pipeline.text_encoder + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size = clip_model.config.hidden_size + max_position_embeddings = clip_model.config.max_position_embeddings + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), + ) + + unet = UnetExportInit(unet_model) + unet.eval() + + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + +class UnetExport(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward(self, sample, timestep, encoder_hidden_states, if_skip, inputCache=None): + if if_skip: + print("[Unetexport][forward] skip --------") + return self.unet_model(sample, timestep, encoder_hidden_states, if_skip=if_skip, inputCache=inputCache)[0] + else: + print("[Unetexport][forward] cache --------") + return self.unet_model(sample, timestep, encoder_hidden_states, if_skip=if_skip) + + +def export_unet(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int, if_skip: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_{if_skip}.pt") + if os.path.exists(unet_pt_path): + return + + unet_model = sd_pipeline.unet + clip_model = sd_pipeline.text_encoder + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size = clip_model.config.hidden_size + max_position_embeddings = clip_model.config.max_position_embeddings + + if if_skip: + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, 320, sample_size, sample_size], dtype=torch.float32), + ) + else: + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), + torch.zeros([1], dtype=torch.int64), + ) + + unet = UnetExport(unet_model) + unet.eval() + + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + +class UnetExportFaster(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward(self, sample, timestep, encoder_hidden_states, if_skip, if_faster, inputCache=None, inputFasterCache=None): + if if_skip: + return self.unet_model(sample, timestep, encoder_hidden_states, if_skip=if_skip, if_faster=if_faster, inputCache=inputCache, inputFasterCache=inputFasterCache)[0] + else: + return self.unet_model(sample, timestep, encoder_hidden_states, if_skip=if_skip, if_faster=if_faster) + + +def export_unet_faster(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int, if_skip: int, if_faster: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}_{if_skip}_{if_faster}.pt") + if os.path.exists(unet_pt_path): + return + + unet_model = sd_pipeline.unet + clip_model = sd_pipeline.text_encoder + + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.in_channels + encoder_hidden_size = clip_model.config.hidden_size + max_position_embeddings = clip_model.config.max_position_embeddings + + if if_skip: + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, 320, sample_size, sample_size], dtype=torch.float32), + torch.ones([batch_size, 2*320, sample_size, sample_size], dtype=torch.float32), + ) + else: + dummy_input = ( + torch.ones([batch_size, in_channels, sample_size, sample_size], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.ones([batch_size, max_position_embeddings, encoder_hidden_size], dtype=torch.float32), + torch.zeros([1], dtype=torch.int64), + torch.ones([1], dtype=torch.int64), + ) + + unet = UnetExportFaster(unet_model) + unet.eval() + + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + +class CatExport(torch.nn.Module): + def __init__(self, scale_model_input): + super(CatExport, self).__init__() + self.scale_model_input = scale_model_input + + def forward(self, latents:torch.FloatTensor, t:torch.FloatTensor): + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scale_model_input(latent_model_input, t) + return latent_model_input + + +def export_cat(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int) -> None: + + cat_path = os.path.join(save_dir, "cat") + if not os.path.exists(cat_path): + os.makedirs(cat_path, mode=0o640) + + cat_pt_path = os.path.join(cat_path, "cat.pt") + if os.path.exists(cat_pt_path): + return + + ddim_model = DDIMScheduler.from_config(sd_pipeline.scheduler.config) + + dummy_input = ( + torch.ones([batch_size, 4, 64, 64], dtype=torch.float32), + torch.ones([1], dtype=torch.float32)) + + cat_export = CatExport(scale_model_input=ddim_model.scale_model_input) + cat_export.eval() + torch.jit.trace(cat_export, dummy_input).save(cat_pt_path) + + +class Scheduler(torch.nn.Module): + def __init__(self, num_train_timesteps=1000, num_inference_steps=50, alphas_cumprod=None, + guidance_scale=7.5, alpha_prod_t_prev_cache=None): + super(Scheduler, self).__init__() + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = num_inference_steps + self.alphas_cumprod = alphas_cumprod + self.guidance_scale = guidance_scale + self.alpha_prod_t_prev_cache = alpha_prod_t_prev_cache + + def forward(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, step_index: int): + noise_pred_uncond, noise_pred_text = model_output.chunk(2) + model_output = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alpha_prod_t_prev_cache[step_index] + beta_prod_t = 1 - alpha_prod_t + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + return prev_sample + + +def export_ddim(sd_pipeline: StableDiffusionPipeline, save_dir: str, steps: int, guidance_scale: float, + batch_size: int) -> None: + print("Exporting the ddim...") + ddim_path = os.path.join(save_dir, "ddim") + if not os.path.exists(ddim_path): + os.makedirs(ddim_path, mode=0o640) + + ddim_pt_path = os.path.join(ddim_path, f"ddim{batch_size}.pt") + if os.path.exists(ddim_pt_path): + return + + ddim_model = sd_pipeline.scheduler + + dummy_input = ( + torch.randn([batch_size, 4, 64, 64], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.randn([batch_size//2, 4, 64, 64], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + ) + + scheduler = DDIMScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(steps, device="cpu") + + timesteps = scheduler.timesteps[:steps] + alpha_prod_t_prev_cache = [] + for timestep in timesteps: + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + alpha_prod_t_prev = scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + alpha_prod_t_prev_cache.append(alpha_prod_t_prev) + + new_ddim = Scheduler( + num_train_timesteps=scheduler.config.num_train_timesteps, + num_inference_steps=scheduler.num_inference_steps, + alphas_cumprod=scheduler.alphas_cumprod, + guidance_scale=guidance_scale, + alpha_prod_t_prev_cache=torch.tensor(alpha_prod_t_prev_cache) + ) + + new_ddim.eval() + + torch.jit.trace(new_ddim, dummy_input).save(ddim_pt_path) + + +class SchedulerParallel(torch.nn.Module): + def __init__(self, num_train_timesteps=1000, num_inference_steps=50, alphas_cumprod=None, + guidance_scale=7.5, alpha_prod_t_prev_cache=None): + super(SchedulerParallel, self).__init__() + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = num_inference_steps + self.alphas_cumprod = alphas_cumprod + self.guidance_scale = guidance_scale + self.alpha_prod_t_prev_cache = alpha_prod_t_prev_cache + + def forward(self, noise_pred_uncond: torch.FloatTensor, noise_pred_text: torch.FloatTensor, timestep: int, + sample: torch.FloatTensor, step_index: int): + model_output = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alpha_prod_t_prev_cache[step_index] + beta_prod_t = 1 - alpha_prod_t + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + return prev_sample + + +def export_ddim_parallel(sd_pipeline: StableDiffusionPipeline, save_dir: str, steps: int, + guidance_scale: float, batch_size: int) -> None: + print("Exporting the ddim...") + ddim_path = os.path.join(save_dir, "ddim") + if not os.path.exists(ddim_path): + os.makedirs(ddim_path, mode=0o640) + + ddim_pt_path = os.path.join(ddim_path, f"ddim{batch_size}.pt") + if os.path.exists(ddim_pt_path): + return + + ddim_model = sd_pipeline.scheduler + dummy_input = ( + torch.randn([batch_size, 4, 64, 64], dtype=torch.float32), + torch.randn([batch_size, 4, 64, 64], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + torch.randn([batch_size, 4, 64, 64], dtype=torch.float32), + torch.ones([1], dtype=torch.int64), + ) + + scheduler = DDIMScheduler.from_config(sd_pipeline.scheduler.config) + scheduler.set_timesteps(steps, device="cpu") + + timesteps = scheduler.timesteps[:steps] + alpha_prod_t_prev_cache = [] + for timestep in timesteps: + prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + alpha_prod_t_prev = scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod + alpha_prod_t_prev_cache.append(alpha_prod_t_prev) + + new_ddim = SchedulerParallel( + num_train_timesteps=scheduler.config.num_train_timesteps, + num_inference_steps=scheduler.num_inference_steps, + alphas_cumprod=scheduler.alphas_cumprod, + guidance_scale=guidance_scale, + alpha_prod_t_prev_cache=torch.tensor(alpha_prod_t_prev_cache) + ) + + new_ddim.eval() + + torch.jit.trace(new_ddim, dummy_input).save(ddim_pt_path) + + +class VaeExport(torch.nn.Module): + def __init__(self, vae_model, scaling_factor): + super().__init__() + self.vae_model = vae_model + self.scaling_factor = scaling_factor + + def forward(self, latents): + latents = 1 / self.scaling_factor * latents + image = self.vae_model.decode(latents)[0] + image = (image / 2 + 0.5) + return image.permute(0, 2, 3, 1) + + +def export_vae(sd_pipeline: StableDiffusionPipeline, save_dir: str, batch_size: int) -> None: + print("Exporting the image decoder...") + + vae_path = os.path.join(save_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + + vae_pt_path = os.path.join(vae_path, f"vae_bs{batch_size}.pt") + if os.path.exists(vae_pt_path): + return + + vae_model = sd_pipeline.vae + unet_model = sd_pipeline.unet + + scaling_factor = vae_model.config.scaling_factor + sample_size = unet_model.config.sample_size + in_channels = unet_model.config.out_channels + dummy_input = torch.ones([batch_size, in_channels, sample_size, sample_size]) + vae_export = VaeExport(vae_model,scaling_factor) + torch.jit.trace(vae_export, dummy_input).save(vae_pt_path) + + +def export(model_path: str, save_dir: str, batch_size: int, steps: int, guidance_scale: float, use_cache: bool, use_cache_faster: bool, parallel: bool) -> None: + pipeline = StableDiffusionPipeline.from_pretrained(model_path).to("cpu") + + export_clip(pipeline, save_dir, batch_size) + export_vae(pipeline, save_dir, batch_size) + + if use_cache: + if parallel: + # 双卡, unet_cache + export_unet(pipeline, save_dir, batch_size, 0) + # 双卡, unet_skip + export_unet(pipeline, save_dir, batch_size, 1) + else: + # 单卡, unet_cache + export_unet(pipeline, save_dir, batch_size * 2, 0) + # 单卡, unet_skip + export_unet(pipeline, save_dir, batch_size * 2, 1) + if use_cache_faster: + if parallel: + # 双卡, unet_cache带faster + export_unet_faster(pipeline, save_dir, batch_size, 0, 1) + # 双卡, unet_skip带faster + export_unet_faster(pipeline, save_dir, batch_size, 1, 1) + else: + # 单卡, unet_cache带faster + export_unet_faster(pipeline, save_dir, batch_size * 2, 0, 1) + # 单卡, unet_skip带faster + export_unet_faster(pipeline, save_dir, batch_size * 2, 1, 1) + else: + if parallel: + # 双卡不带unetcache + export_unet_init(pipeline, save_dir, batch_size) + else: + # 单卡不带unetcache + export_unet_init(pipeline, save_dir, batch_size * 2) + + if parallel: + # 双卡 + export_ddim_parallel(pipeline, save_dir, steps, guidance_scale, batch_size) + # export_cat_parallel(pipeline, save_dir, batch_size) + else: + # 单卡 + export_ddim(pipeline, save_dir, steps, guidance_scale, batch_size * 2) + export_cat(pipeline, save_dir, batch_size) + + +def main(): + args = parse_arguments() + export(args.model, args.output_dir, args.batch_size, args.steps, args.guidance_scale, args.use_cache, args.use_cache_faster, args.parallel) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion/prompts.txt b/MindIE/MultiModal/StableDiffusion/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..a375a0bb63931d0d5da6c6d91df1e14f870f47d0 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/prompts.txt @@ -0,0 +1,16 @@ +Beautiful illustration of The ocean. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Islands in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Seaports in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The waves. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Grassland. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Wheat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Hut Tong. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The boat. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Pine trees. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Bamboo. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of The temple. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Cloud in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Sun in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Spring. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Lotus. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper +Beautiful illustration of Snow piles. in a serene landscape, magic realism, narrative realism, beautiful matte painting, heavenly lighting, retrowave, 4 k hd wallpaper \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion/requirements.txt b/MindIE/MultiModal/StableDiffusion/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..6d39028a97778341c84978d9643b966f25895618 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/requirements.txt @@ -0,0 +1,5 @@ +setuptools==57.5.0 +torch==2.1.0 +diffusers==0.26.3 +transformers==4.46.0 +open_clip_torch==2.20.0 \ No newline at end of file diff --git a/MindIE/MultiModal/StableDiffusion/stable_diffusion_attention_patch.py b/MindIE/MultiModal/StableDiffusion/stable_diffusion_attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..59bfa9e822472e8740c8ba5d1666afa4b8abd2ec --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/stable_diffusion_attention_patch.py @@ -0,0 +1,28 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.26.3', f"Expected diffusers version 0.26.3, but got {diffusers_version}" + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion/stable_diffusion_pipeline.py b/MindIE/MultiModal/StableDiffusion/stable_diffusion_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..16d986539bc9fd3493fb072c7914bed9d1522978 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/stable_diffusion_pipeline.py @@ -0,0 +1,1131 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +from typing import Callable, List, Optional, Union +import numpy as np + +import torch +import mindietorch +from mindietorch import _enums +from diffusers import StableDiffusionPipeline +from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMScheduler, SASolverScheduler + +clip_time = 0 +unet_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 +scheduler_time = 0 + + +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + batch_size: int, + num_images_per_prompt: int = 1, + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file) + + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file) + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r", encoding='utf8') as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + + +class AIEStableDiffusionPipeline(StableDiffusionPipeline): + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0 = self.args.device[0] + else: + self.device_0 = args.device + + def compile_aie_model(self): + if self.is_init: + return + args = parse_arguments() + + in_channels = self.unet.config.out_channels + sample_size = self.unet.config.sample_size + encoder_hidden_size = self.text_encoder.config.hidden_size + max_position_embeddings = self.text_encoder.config.max_position_embeddings + size = self.args.batch_size + batch_size = self.args.batch_size * 2 + + if self.args.soc == "Duo": + soc_version = "Ascend310P3" + elif self.args.soc == "A2": + soc_version = "Ascend910B4" + else: + print("unsupport soc_version, please check!") + return + + clip_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{size}_aie_compile.ts") + if os.path.exists(clip_compiled_path): + self.compiled_clip_model = torch.jit.load(clip_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"clip/clip_bs{size}.pt")).eval() + + self.compiled_clip_model = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((self.args.batch_size, + max_position_embeddings), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(self.compiled_clip_model, clip_compiled_path) + + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{size}_aie_compile.ts") + if os.path.exists(vae_compiled_path): + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"vae/vae_bs{size}.pt")).eval() + + self.compiled_vae_model = ( + mindietorch.compile(model, + inputs=[ + mindietorch.Input((self.args.batch_size, in_channels, + sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_vae_model, vae_compiled_path) + + scheduler_compiled_path = os.path.join(self.args.output_dir, f"ddim/ddim{batch_size}_aie_compile.ts") + if os.path.exists(scheduler_compiled_path): + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"ddim/ddim{batch_size}.pt")).eval() + + self.compiled_scheduler = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size // 2, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=False, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(self.compiled_scheduler, scheduler_compiled_path) + + cat_compiled_path = os.path.join(self.args.output_dir, "cat/cat_aie_compile.ts") + if os.path.exists(cat_compiled_path): + self.compiled_cat = torch.jit.load(cat_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, "cat/cat.pt")).eval() + + self.compiled_cat = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size // 2, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(self.compiled_cat, cat_compiled_path) + + if args.use_cache: + unet_cache_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_0.ts") + if os.path.exists(unet_cache_compiled_path): + self.compiled_unet_cache = torch.jit.load(unet_cache_compiled_path).eval() + else: + unet_cache = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_0.pt")).eval() + + self.compiled_unet_cache = ( + mindietorch.compile(unet_cache, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_cache, unet_cache_compiled_path) + + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_1.ts") + if os.path.exists(unet_skip_compiled_path): + self.compiled_unet_skip = torch.jit.load(unet_skip_compiled_path).eval() + else: + unet_skip = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_1.pt")).eval() + + self.compiled_unet_skip = ( + mindietorch.compile(unet_skip, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_skip, unet_skip_compiled_path) + elif args.use_cache_faster: + unet_cache_faster_compiled_path = os.path.join(self.args.output_dir, + f"unet/unet_bs{batch_size}_aie_compile_0_1.ts") + if os.path.exists(unet_cache_faster_compiled_path): + self.compiled_unet_cache_faster = torch.jit.load(unet_cache_faster_compiled_path).eval() + else: + unet_cache_faster = torch.jit.load( + os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_0_1.pt")).eval() + + self.compiled_unet_cache_faster = ( + mindietorch.compile(unet_cache_faster, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_cache_faster, unet_cache_faster_compiled_path) + + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_1_1.ts") + if os.path.exists(unet_skip_compiled_path): + self.compiled_unet_skip = torch.jit.load(unet_skip_compiled_path).eval() + else: + unet_skip = torch.jit.load( + os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_1_1.pt")).eval() + + self.compiled_unet_skip = ( + mindietorch.compile(unet_skip, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + 2 * 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_skip, unet_skip_compiled_path) + else: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile.ts") + if os.path.exists(unet_compiled_path): + self.compiled_unet = torch.jit.load(unet_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}.pt")).eval() + + self.compiled_unet = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet, unet_compiled_path) + + self.is_init = True + + @torch.no_grad() + def ascendie_infer_ddim( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], + None]] = None, + callback_steps: Optional[int] = 1, + skip_steps=None, + flag_cache: int = None, + flag_cache_faster: int = None, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (畏) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + global p1_time, p2_time, p3_time, scheduler_time + start1 = time.time() + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # check compile + if not self.is_init: + self.compile_aie_model() + + # 3. Encode input prompt + text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt) + + text_embeddings_dtype = text_embeddings.dtype + p1_time += time.time() - start1 + start2 = time.time() + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents(batch_size * num_images_per_prompt, + num_channels_latents, height, width, + text_embeddings_dtype, device, + generator, latents).to(f'npu:{self.device_0}') + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + global unet_time + global vae_time + + cache = None + cache_faster = None + skip_flag = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + cache_flag = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + cache_faster_flag = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + + stream = mindietorch.npu.Stream(f'npu:{self.device_0}') + for i, t in enumerate(self.progress_bar(timesteps)): + if i == 50: + break + + # expand the latents if we are doing classifier free guidance + with mindietorch.npu.stream(stream): + latent_model_input = self.compiled_cat(latents, t.to(torch.float32)[None].to(f'npu:{self.device_0}')) + + t_npu = t[None].to(f'npu:{self.device_0}') + text_embeddings_npu = text_embeddings.to(f'npu:{self.device_0}') + + start = time.time() + + if flag_cache: + if skip_steps[i]: + with mindietorch.npu.stream(stream): + noise_pred = self.compiled_unet_skip(latent_model_input, + t_npu, + text_embeddings_npu, + skip_flag, + cache) + else: + with mindietorch.npu.stream(stream): + outputs = self.compiled_unet_cache(latent_model_input, + t_npu, + text_embeddings_npu, + cache_flag) + noise_pred = outputs[0] + cache = outputs[1] + elif flag_cache_faster: + if skip_steps[i]: + with mindietorch.npu.stream(stream): + noise_pred = self.compiled_unet_skip(latent_model_input, + t_npu, + text_embeddings_npu, + skip_flag, + cache_faster_flag, + cache, + cache_faster) + else: + with mindietorch.npu.stream(stream): + outputs = self.compiled_unet_cache_faster(latent_model_input, + t_npu, + text_embeddings_npu, + cache_flag, + cache_faster_flag) + noise_pred = outputs[0] + cache = outputs[1] + cache_faster = outputs[2] + else: + with mindietorch.npu.stream(stream): + noise_pred = self.compiled_unet(latent_model_input, t_npu, text_embeddings_npu) + + unet_time += time.time() - start + # perform guidance + # compute the previous noisy sample x_t -> x_t-1 + start = time.time() + if do_classifier_free_guidance: + x = np.array(i, dtype=np.int64) + y = torch.from_numpy(x).long() + + with mindietorch.npu.stream(stream): + latents = self.compiled_scheduler( + noise_pred, + t_npu, + latents, + y[None].to(f'npu:{self.device_0}')) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + scheduler_time += time.time() - start + stream.synchronize() + + # 8. Post-processing + p2_time += time.time() - start2 + start3 = time.time() + + # run inference + start = time.time() + with mindietorch.npu.stream(stream): + image = self.compiled_vae_model(latents).to('cpu') + vae_time += time.time() - start + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.clamp(0, 1).float().numpy() + + # 9. Run safety checker + has_nsfw_concept = False + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + p3_time += time.time() - start3 + return (image, has_nsfw_concept) + + def ascendie_infer( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], + None]] = None, + callback_steps: Optional[int] = 1, + skip_steps=None, + flag_cache: int = None, + flag_cache_faster: int = None, + **kwargs, + ): + # 0. Default height and width to unet + global p1_time, p2_time, p3_time, scheduler_time + start1 = time.time() + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # check compile + if not self.is_init: + self.compile_aie_model() + + # 3. Encode input prompt + text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt) + + text_embeddings_dtype = text_embeddings.dtype + p1_time += time.time() - start1 + start2 = time.time() + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents(batch_size * num_images_per_prompt, + num_channels_latents, height, width, + text_embeddings_dtype, device, + generator, latents) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + global unet_time + global vae_time + + cache = None + cache_faster = None + skip_flag = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + cache_flag = torch.zeros([1], dtype=torch.long).to(f'npu:{self.device_0}') + cache_faster_flag = torch.ones([1], dtype=torch.long).to(f'npu:{self.device_0}') + + stream = mindietorch.npu.Stream(f'npu:{self.device_0}') + for i, t in enumerate(self.progress_bar(timesteps)): + if i == 50: + break + + # expand the latents if we are doing classifier free guidance + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_model_input_npu = latent_model_input.to(f'npu:{self.device_0}') + t_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') + text_embeddings_npu = text_embeddings.to(f'npu:{self.device_0}') + + start = time.time() + + if flag_cache: + if skip_steps[i]: + noise_pred = self.compiled_unet_skip(latent_model_input_npu, + t_npu, + text_embeddings_npu, + skip_flag, + cache).to('cpu') + else: + outputs = self.compiled_unet_cache(latent_model_input_npu, + t_npu, + text_embeddings_npu, + cache_flag, + ) + noise_pred = outputs[0].to('cpu') + cache = outputs[1] + elif flag_cache_faster: + if skip_steps[i]: + noise_pred = self.compiled_unet_skip(latent_model_input_npu, + t_npu, + text_embeddings_npu, + skip_flag, + cache_faster_flag, + cache, + cache_faster).to('cpu') + else: + outputs = self.compiled_unet_cache_faster(latent_model_input_npu, + t_npu, + text_embeddings_npu, + cache_flag, + cache_faster_flag) + noise_pred = outputs[0].to('cpu') + cache = outputs[1] + cache_faster = outputs[2] + else: + with mindietorch.npu.stream(stream): + noise_pred = self.compiled_unet(latent_model_input_npu, + t_npu, + text_embeddings_npu).to('cpu') + + unet_time += time.time() - start + + # perform guidance + start = time.time() + if do_classifier_free_guidance: + noise_pred, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred + guidance_scale * (noise_pred_text - + noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, + **extra_step_kwargs)[0] + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + scheduler_time += time.time() - start + + # 8. Post-processing + p2_time += time.time() - start2 + start3 = time.time() + + # run inference + start = time.time() + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to('cpu') + vae_time += time.time() - start + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.clamp(0, 1).float().numpy() + + # 9. Run safety checker + has_nsfw_concept = False + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + p3_time += time.time() - start3 + return (image, has_nsfw_concept) + + def _encode_prompt( + self, + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt") + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, + padding="max_length", + return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1:-1]) + print("[warning] The following part of your input was truncated" + " because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}") + + # run inference + self.text_encoder.eval() + global clip_time + start = time.time() + text_embeddings = self.compiled_clip_model(text_input_ids.to(f'npu:{self.device_0}')).to('cpu') + clip_time += time.time() - start + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view( + bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}.") + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`.") + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer(uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt") + + # run inference + start = time.time() + uncond_embeddings = self.compiled_clip_model(uncond_input.input_ids.to(f'npu:{self.device_0}')).to('cpu') + clip_time += time.time() - start + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat( + 1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view( + batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-2-1-base", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "--scheduler", + choices=["DDIM", "Euler", "DPM", "SA-Solver"], + default="DDIM", + help="Type of Sampling methods. Can choose from DDIM, Euler, DPM, SA-Solver", + ) + parser.add_argument( + "--soc", + choices=["Duo", "A2"], + default="A2", + help="soc_version.", + ) + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--use_cache_faster", + action="store_true", + help="Use cache with faster during inference." + ) + parser.add_argument( + "--cache_steps", + type=str, + default="1,2,3,4,5,7,9,10,12,13,14,16,18,19,21,23,24,26,27,29,\ + 30,31,33,34,36,37,39,40,41,43,44,45,47,48,49", + help="Steps to use cache data." + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusionPipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + if args.scheduler == "DDIM": + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "Euler": + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "DPM": + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "SA-Solver": + pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config) + pipe.compile_aie_model() + + skip_steps = [0] * args.steps + + flag_cache = 0 + if args.use_cache: + flag_cache = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + + flag_cache_faster = 0 + if args.use_cache_faster: + flag_cache_faster = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt) + + infer_num = 0 + image_info = [] + current_prompt = None + + mindietorch.set_device(args.device) + + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + start_time = time.time() + + if args.scheduler == "DDIM": + stream = mindietorch.npu.Stream("npu:" + str(args.device)) + with mindietorch.npu.stream(stream): + images = pipe.ascendie_infer_ddim( + prompts, + num_inference_steps=args.steps, + skip_steps=skip_steps, + flag_cache=flag_cache, + flag_cache_faster=flag_cache_faster, + ) + else: + images = pipe.ascendie_infer( + prompts, + num_inference_steps=args.steps, + skip_steps=skip_steps, + flag_cache=flag_cache, + flag_cache_faster=flag_cache_faster, + ) + + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n" + f"clip time: {clip_time / infer_num:.3f}s\n" + f"unet time: {unet_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + f"p3 time: {p3_time / infer_num:.3f}s\n" + f"scheduler time: {scheduler_time / infer_num:.3f}s\n") + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion/stable_diffusion_pipeline_parallel.py b/MindIE/MultiModal/StableDiffusion/stable_diffusion_pipeline_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..5e64cac270f5e5eb97c1230b6a6c4e46fc740f26 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/stable_diffusion_pipeline_parallel.py @@ -0,0 +1,1296 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +from typing import Callable, List, Optional, Union +import numpy as np + +import torch +import mindietorch +from mindietorch import _enums +from diffusers import StableDiffusionPipeline +from diffusers import DPMSolverMultistepScheduler, EulerDiscreteScheduler, DDIMScheduler, SASolverScheduler + +from background_runtime import BackgroundRuntime, RuntimeIOInfo +from background_runtime_cache import BackgroundRuntimeCache, RuntimeIOInfoCache +from background_runtime_cache_faster import BackgroundRuntimeCacheFaster, RuntimeIOInfoCacheFaster + +clip_time = 0 +unet_time = 0 +vae_time = 0 +p1_time = 0 +p2_time = 0 +p3_time = 0 +scheduler_time = 0 + + +class PromptLoader: + def __init__( + self, + prompt_file: str, + prompt_file_type: str, + batch_size: int, + num_images_per_prompt: int = 1, + ): + self.prompts = [] + self.catagories = ['Not_specified'] + self.batch_size = batch_size + self.num_images_per_prompt = num_images_per_prompt + + if prompt_file_type == 'plain': + self.load_prompts_plain(prompt_file) + + elif prompt_file_type == 'parti': + self.load_prompts_parti(prompt_file) + + self.current_id = 0 + self.inner_id = 0 + + def __len__(self): + return len(self.prompts) * self.num_images_per_prompt + + def __iter__(self): + return self + + def __next__(self): + if self.current_id == len(self.prompts): + raise StopIteration + + ret = { + 'prompts': [], + 'catagories': [], + 'save_names': [], + 'n_prompts': self.batch_size, + } + for _ in range(self.batch_size): + if self.current_id == len(self.prompts): + ret['prompts'].append('') + ret['save_names'].append('') + ret['catagories'].append('') + ret['n_prompts'] -= 1 + + else: + prompt, catagory_id = self.prompts[self.current_id] + ret['prompts'].append(prompt) + ret['catagories'].append(self.catagories[catagory_id]) + ret['save_names'].append(f'{self.current_id}_{self.inner_id}') + + self.inner_id += 1 + if self.inner_id == self.num_images_per_prompt: + self.inner_id = 0 + self.current_id += 1 + + return ret + + def load_prompts_plain(self, file_path: str): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r") as f: + for i, line in enumerate(f): + prompt = line.strip() + self.prompts.append((prompt, 0)) + + def load_prompts_parti(self, file_path: str): + with os.fdopen(os.open(file_path, os.O_RDONLY), "r", encoding='utf8') as f: + # Skip the first line + next(f) + tsv_file = csv.reader(f, delimiter="\t") + for i, line in enumerate(tsv_file): + prompt = line[0] + catagory = line[1] + if catagory not in self.catagories: + self.catagories.append(catagory) + + catagory_id = self.catagories.index(catagory) + self.prompts.append((prompt, catagory_id)) + + +class AIEStableDiffusionPipeline(StableDiffusionPipeline): + device_0 = None + device_1 = None + runtime = None + engines = {} + contexts = {} + buffer_bindings = {} + use_parallel_inferencing = False + unet_bg = None + unet_bg_cache = None + unet_bg_cache_faster = None + + def parser_args(self, args): + self.args = args + if isinstance(args.device, list): + self.device_0, self.device_1 = args.device + print(f'Using parallel inferencing on device {self.device_0} and {self.device_1}') + else: + self.device_0 = args.device + self.is_init = False + + def compile_aie_model(self): + if self.is_init: + return + args = parse_arguments() + + in_channels = self.unet.config.out_channels + sample_size = self.unet.config.sample_size + encoder_hidden_size = self.text_encoder.config.hidden_size + max_position_embeddings = self.text_encoder.config.max_position_embeddings + + batch_size = self.args.batch_size + + if self.args.soc == "Duo": + soc_version = "Ascend310P3" + elif self.args.soc == "A2": + soc_version = "Ascend910B4" + else: + print("unsupport soc_version, please check!") + return + + clip_compiled_path = os.path.join(self.args.output_dir, f"clip/clip_bs{batch_size}_aie_compile.ts") + if os.path.exists(clip_compiled_path): + self.compiled_clip_model = torch.jit.load(clip_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"clip/clip_bs{batch_size}.pt")).eval() + + self.compiled_clip_model = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((self.args.batch_size, + max_position_embeddings), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(self.compiled_clip_model, clip_compiled_path) + + vae_compiled_path = os.path.join(self.args.output_dir, f"vae/vae_bs{batch_size}_aie_compile.ts") + if os.path.exists(vae_compiled_path): + self.compiled_vae_model = torch.jit.load(vae_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"vae/vae_bs{batch_size}.pt")).eval() + + self.compiled_vae_model = ( + mindietorch.compile(model, + inputs=[ + mindietorch.Input((self.args.batch_size, in_channels, + sample_size, sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_vae_model, vae_compiled_path) + + scheduler_compiled_path = os.path.join(self.args.output_dir, f"ddim/ddim{batch_size}_aie_compile.ts") + if os.path.exists(scheduler_compiled_path): + self.compiled_scheduler = torch.jit.load(scheduler_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"ddim/ddim{batch_size}.pt")).eval() + + self.compiled_scheduler = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=False, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version=soc_version, + optimization_level=0)) + torch.jit.save(self.compiled_scheduler, scheduler_compiled_path) + + if args.use_cache: + unet_cache_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_0.ts") + if os.path.exists(unet_cache_compiled_path): + self.compiled_unet_cache = torch.jit.load(unet_cache_compiled_path).eval() + else: + unet_cache = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_0.pt")).eval() + + self.compiled_unet_cache = ( + mindietorch.compile(unet_cache, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_cache, unet_cache_compiled_path) + + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_1.ts") + if os.path.exists(unet_skip_compiled_path): + self.compiled_unet_skip = torch.jit.load(unet_skip_compiled_path).eval() + else: + unet_skip = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_1.pt")).eval() + + self.compiled_unet_skip = ( + mindietorch.compile(unet_skip, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_skip, unet_skip_compiled_path) + + runtime_info_cache = RuntimeIOInfoCache( + input_shapes=[ + (batch_size, in_channels, sample_size, sample_size), + (1,), + (batch_size, max_position_embeddings, encoder_hidden_size), + (1,) + ], + input_dtypes=[np.float32, np.int64, np.float32, np.int64], + output_shapes=[(batch_size, in_channels, sample_size, sample_size), + (batch_size, 320, sample_size, sample_size)], + output_dtypes=[np.float32, np.float32] + ) + + if hasattr(self, 'device_1'): + self.unet_bg_cache = BackgroundRuntimeCache.clone(self.device_1, + [unet_cache_compiled_path, unet_skip_compiled_path], + runtime_info_cache) + self.use_parallel_inferencing = True + + elif args.use_cache_faster: + unet_cache_faster_compiled_path = os.path.join(self.args.output_dir, + f"unet/unet_bs{batch_size}_aie_compile_0_1.ts") + if os.path.exists(unet_cache_faster_compiled_path): + self.compiled_unet_cache_faster = torch.jit.load(unet_cache_faster_compiled_path).eval() + else: + unet_cache_faster = torch.jit.load( + os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_0_1.pt")).eval() + + self.compiled_unet_cache_faster = ( + mindietorch.compile(unet_cache_faster, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64)], + allow_tensor_replace_int=True, + require_full_compilation=False, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_cache_faster, unet_cache_faster_compiled_path) + + unet_skip_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile_1_1.ts") + if os.path.exists(unet_skip_compiled_path): + self.compiled_unet_skip = torch.jit.load(unet_skip_compiled_path).eval() + else: + unet_skip = torch.jit.load( + os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_1_1.pt")).eval() + + self.compiled_unet_skip = ( + mindietorch.compile(unet_skip, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size, + 2 * 320, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet_skip, unet_skip_compiled_path) + + runtime_info_cache_faster = RuntimeIOInfoCacheFaster( + input_shapes=[ + (batch_size, in_channels, sample_size, sample_size), + (1,), + (batch_size, max_position_embeddings, encoder_hidden_size), + (1,), + (1,) + ], + input_dtypes=[np.float32, np.int64, np.float32, np.int64, np.int64], + output_shapes=[(batch_size, in_channels, sample_size, sample_size), + (batch_size, 320, sample_size, sample_size), + (batch_size, 2 * 320, sample_size, sample_size)], + output_dtypes=[np.float32, np.float32, np.float32] + ) + + if hasattr(self, 'device_1'): + self.unet_bg_cache_faster = BackgroundRuntimeCacheFaster.clone(self.device_1, + [unet_cache_faster_compiled_path, + unet_skip_compiled_path], + runtime_info_cache_faster) + self.use_parallel_inferencing = True + + else: + unet_compiled_path = os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}_aie_compile.ts") + if os.path.exists(unet_compiled_path): + self.compiled_unet = torch.jit.load(unet_compiled_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, f"unet/unet_bs{batch_size}.pt")).eval() + + self.compiled_unet = ( + mindietorch.compile(model, + inputs=[mindietorch.Input((batch_size, + in_channels, sample_size, + sample_size), + dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,), + dtype=mindietorch.dtype.INT64), + mindietorch.Input((batch_size, + max_position_embeddings, + encoder_hidden_size), + dtype=mindietorch.dtype.FLOAT)], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + soc_version=soc_version, + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + )) + torch.jit.save(self.compiled_unet, unet_compiled_path) + + runtime_info = RuntimeIOInfo( + input_shapes=[ + (batch_size, in_channels, sample_size, sample_size), + (1,), + (batch_size, max_position_embeddings, encoder_hidden_size) + ], + input_dtypes=[np.float32, np.int64, np.float32], + output_shapes=[(batch_size, in_channels, sample_size, sample_size)], + output_dtypes=[np.float32] + ) + if hasattr(self, 'device_1'): + self.unet_bg = BackgroundRuntime.clone(self.device_1, unet_compiled_path, runtime_info) + self.use_parallel_inferencing = True + + self.is_init = True + + @torch.no_grad() + def ascendie_infer_ddim( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], + None]] = None, + callback_steps: Optional[int] = 1, + skip_steps=None, + flag_cache: int = None, + flag_cache_faster: int = None, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (畏) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + global p1_time, p2_time, p3_time, scheduler_time + start1 = time.time() + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # check compile + if not self.is_init: + self.compile_aie_model() + + # 3. Encode input prompt + text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt) + + text_embeddings_dtype = text_embeddings.dtype + p1_time += time.time() - start1 + start2 = time.time() + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents(batch_size * num_images_per_prompt, + num_channels_latents, height, width, + text_embeddings_dtype, device, + generator, latents) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + global unet_time + global vae_time + if self.use_parallel_inferencing and do_classifier_free_guidance: + # Split embeddings + text_embeddings, text_embeddings_2 = text_embeddings.chunk(2) + + cache = None + cache_faster = None + skip_flag = torch.ones([1], dtype=torch.long) + cache_flag = torch.zeros([1], dtype=torch.long) + cache_faster_flag = torch.zeros([1], dtype=torch.long) + + stream = mindietorch.npu.Stream(f'npu:{self.device_0}') + for i, t in enumerate(self.progress_bar(timesteps)): + if i == 50: + break + # expand the latents if we are doing classifier free guidance + if not self.use_parallel_inferencing and do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if self.use_parallel_inferencing and do_classifier_free_guidance: + if flag_cache: + self.unet_bg_cache.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy().astype(np.int64), + text_embeddings_2.numpy(), + skip_flag.numpy(), + # cache_numpy, + ], + skip_steps[i]) + elif flag_cache_faster: + self.unet_bg_cache_faster.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy().astype(np.int64), + text_embeddings_2.numpy(), + skip_flag.numpy(), + cache_faster_flag.numpy() + # cache_numpy, + ], + skip_steps[i]) + else: + self.unet_bg.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy().astype(np.int64), + text_embeddings_2.numpy(), + ]) + + latent_model_input_npu = latent_model_input.to(f'npu:{self.device_0}') + t_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') + text_embeddings_npu = text_embeddings.to(f'npu:{self.device_0}') + + start = time.time() + + if flag_cache: + with mindietorch.npu.stream(stream): + if (skip_steps[i]): + noise_pred = self.compiled_unet_skip(latent_model_input_npu, + t_npu, + text_embeddings_npu, + skip_flag.to(f'npu:{self.device_0}'), + cache) + else: + outputs = self.compiled_unet_cache(latent_model_input_npu, + t_npu, + text_embeddings_npu, + cache_flag.to(f'npu:{self.device_0}'), + ) + noise_pred = outputs[0] + cache = outputs[1] + stream.synchronize() + elif flag_cache_faster: + with mindietorch.npu.stream(stream): + if (skip_steps[i]): + noise_pred = self.compiled_unet_skip(latent_model_input_npu, + t_npu, + text_embeddings_npu, + skip_flag.to(f'npu:{self.device_0}'), + cache_faster_flag.to(f'npu:{self.device_0}'), + cache, + cache_faster) + else: + outputs = self.compiled_unet_cache_faster(latent_model_input_npu, + t_npu, + text_embeddings_npu, + cache_flag.to(f'npu:{self.device_0}'), + cache_faster_flag.to(f'npu:{self.device_0}'), + ) + noise_pred = outputs[0] + cache = outputs[1] + cache_faster = outputs[2] + stream.synchronize() + else: + with mindietorch.npu.stream(stream): + noise_pred = self.compiled_unet(latent_model_input_npu, t_npu, text_embeddings_npu) + stream.synchronize() + + unet_time += time.time() - start + + # perform guidance + # compute the previous noisy sample x_t -> x_t-1 + start = time.time() + if do_classifier_free_guidance: + if self.use_parallel_inferencing: + if flag_cache: + if (skip_steps[i]): + noise_pred_text = torch.from_numpy(self.unet_bg_cache.wait_and_get_outputs()[0]) + else: + out = self.unet_bg_cache.wait_and_get_outputs() + noise_pred_text = torch.from_numpy(out[0]) + elif flag_cache_faster: + if (skip_steps[i]): + noise_pred_text = torch.from_numpy(self.unet_bg_cache_faster.wait_and_get_outputs()[0]) + else: + out = self.unet_bg_cache_faster.wait_and_get_outputs() + noise_pred_text = torch.from_numpy(out[0]) + else: + noise_pred_text = torch.from_numpy(self.unet_bg.wait_and_get_outputs()[0]) + else: + noise_pred, noise_pred_text = noise_pred.chunk(2) + + x = np.array(i, dtype=np.int64) + y = torch.from_numpy(x).long() + + latents = self.compiled_scheduler( + noise_pred.to(f'npu:{self.device_0}'), + noise_pred_text.to(f'npu:{self.device_0}'), + t[None].to(f'npu:{self.device_0}'), + latents.to(f'npu:{self.device_0}'), + y[None].to(f'npu:{self.device_0}')).to('cpu') + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + scheduler_time += time.time() - start + + # 8. Post-processing + p2_time += time.time() - start2 + start3 = time.time() + + # run inference + start = time.time() + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to('cpu') + vae_time += time.time() - start + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.clamp(0, 1).float().numpy() + + # 9. Run safety checker + has_nsfw_concept = False + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + p3_time += time.time() - start3 + return (image, has_nsfw_concept) + + def ascendie_infer( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback: Optional[Callable[[int, int, torch.FloatTensor], + None]] = None, + callback_steps: Optional[int] = 1, + skip_steps=None, + flag_cache: int = None, + flag_cache_faster: int = None, + **kwargs, + ): + # 0. Default height and width to unet + global p1_time, p2_time, p3_time, scheduler_time + start1 = time.time() + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # check compile + if not self.is_init: + self.compile_aie_model() + + # 3. Encode input prompt + text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt) + + text_embeddings_dtype = text_embeddings.dtype + p1_time += time.time() - start1 + start2 = time.time() + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents(batch_size * num_images_per_prompt, + num_channels_latents, height, width, + text_embeddings_dtype, device, + generator, latents) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + global unet_time + global vae_time + if self.use_parallel_inferencing and do_classifier_free_guidance: + # Split embeddings + text_embeddings, text_embeddings_2 = text_embeddings.chunk(2) + + cache = None + cache_faster = None + skip_flag = torch.ones([1], dtype=torch.long) + cache_flag = torch.zeros([1], dtype=torch.long) + cache_faster_flag = torch.zeros([1], dtype=torch.long) + + for i, t in enumerate(self.progress_bar(timesteps)): + if i == 50: + break + + # expand the latents if we are doing classifier free guidance + if not self.use_parallel_inferencing and do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if self.use_parallel_inferencing and do_classifier_free_guidance: + if flag_cache: + self.unet_bg_cache.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy().astype(np.int64), + text_embeddings_2.numpy(), + skip_flag.numpy(), + # cache_numpy, + ], + skip_steps[i]) + elif flag_cache_faster: + self.unet_bg_cache_faster.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy().astype(np.int64), + text_embeddings_2.numpy(), + skip_flag.numpy(), + cache_faster_flag.numpy() + # cache_numpy, + ], + skip_steps[i]) + else: + self.unet_bg.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy().astype(np.int64), + text_embeddings_2.numpy(), + ]) + + latent_model_input_npu = latent_model_input.to(f'npu:{self.device_0}') + t_npu = t.to(torch.int64)[None].to(f'npu:{self.device_0}') + text_embeddings_npu = text_embeddings.to(f'npu:{self.device_0}') + + start = time.time() + + if flag_cache: + if skip_steps[i]: + noise_pred = self.compiled_unet_skip(latent_model_input_npu, + t_npu, + text_embeddings_npu, + skip_flag.to(f'npu:{self.device_0}'), + cache).to('cpu') + else: + outputs = self.compiled_unet_cache(latent_model_input_npu, + t_npu, + text_embeddings_npu, + cache_flag.to(f'npu:{self.device_0}'), + ) + noise_pred = outputs[0].to('cpu') + cache = outputs[1] + elif flag_cache_faster: + if skip_steps[i]: + with mindietorch.npu.stream(stream): + noise_pred = self.compiled_unet_skip(latent_model_input, + t_npu, + text_embeddings_npu, + skip_flag, + cache_faster_flag, + cache, + cache_faster) + else: + with mindietorch.npu.stream(stream): + outputs = self.compiled_unet_cache_faster(latent_model_input, + t_npu, + text_embeddings_npu, + cache_flag, + cache_faster_flag) + noise_pred = outputs[0] + cache = outputs[1] + cache_faster = outputs[2] + else: + noise_pred = self.compiled_unet(latent_model_input_npu, + t_npu, + text_embeddings_npu).to('cpu') + + unet_time += time.time() - start + + # perform guidance + start = time.time() + if do_classifier_free_guidance: + if self.use_parallel_inferencing: + if flag_cache: + if (skip_steps[i]): + noise_pred_text = torch.from_numpy(self.unet_bg_cache.wait_and_get_outputs()[0]) + else: + out = self.unet_bg_cache.wait_and_get_outputs() + noise_pred_text = torch.from_numpy(out[0]) + elif flag_cache_faster: + if (skip_steps[i]): + noise_pred_text = torch.from_numpy(self.unet_bg_cache_faster.wait_and_get_outputs()[0]) + else: + out = self.unet_bg_cache_faster.wait_and_get_outputs() + noise_pred_text = torch.from_numpy(out[0]) + else: + noise_pred_text = torch.from_numpy(self.unet_bg.wait_and_get_outputs()[0]) + else: + noise_pred, noise_pred_text = noise_pred.chunk(2) + + noise_pred = noise_pred + guidance_scale * (noise_pred_text - + noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, + **extra_step_kwargs)[0] + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + scheduler_time += time.time() - start + + # 8. Post-processing + p2_time += time.time() - start2 + start3 = time.time() + + # run inference + start = time.time() + image = self.compiled_vae_model(latents.to(f'npu:{self.device_0}')).to('cpu') + vae_time += time.time() - start + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.clamp(0, 1).float().numpy() + + # 9. Run safety checker + has_nsfw_concept = False + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + p3_time += time.time() - start3 + return (image, has_nsfw_concept) + + def _encode_prompt( + self, + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt") + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, + padding="max_length", + return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1:-1]) + print("[warning] The following part of your input was truncated" + " because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}") + + # run inference + self.text_encoder.eval() + global clip_time + start = time.time() + text_embeddings = self.compiled_clip_model(text_input_ids.to(f'npu:{self.device_0}')).to('cpu') + clip_time += time.time() - start + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view( + bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}.") + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`.") + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer(uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt") + + # run inference + start = time.time() + uncond_embeddings = self.compiled_clip_model(uncond_input.input_ids.to(f'npu:{self.device_0}')).to('cpu') + clip_time += time.time() - start + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat( + 1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view( + batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save compiled models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-diffusion-2-1-base", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--prompt_file", + type=str, + default="./prompts.txt", + help="A text file of prompts for generating images.", + ) + parser.add_argument( + "--prompt_file_type", + choices=["plain", "parti"], + default="plain", + help="Type of prompt file.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--info_file_save_path", + type=str, + default="./image_info.json", + help="Path to save image information file.", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of inference steps.", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=[0, 1], + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "--num_images_per_prompt", + default=1, + type=int, + help="Number of images generated for each prompt.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "--scheduler", + choices=["DDIM", "Euler", "DPM", "SA-Solver"], + default="DDIM", + help="Type of Sampling methods. Can choose from DDIM, Euler, DPM, SA-Solver", + ) + parser.add_argument( + "--soc", + choices=["Duo", "A2"], + default="Duo", + help="soc_version.", + ) + parser.add_argument( + "--use_cache", + action="store_true", + help="Use cache during inference." + ) + parser.add_argument( + "--use_cache_faster", + action="store_true", + help="Use cache with faster during inference." + ) + parser.add_argument( + "--cache_steps", + type=str, + default="1,2,3,4,5,7,9,10,12,13,14,16,18,19,21,23,24,26,27,29,\ + 30,31,33,34,36,37,39,40,41,43,44,45,47,48,49", + help="Steps to use cache data." + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + pipe = AIEStableDiffusionPipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + if args.scheduler == "DDIM": + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "Euler": + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "DPM": + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + if args.scheduler == "SA-Solver": + pipe.scheduler = SASolverScheduler.from_config(pipe.scheduler.config) + + pipe.compile_aie_model() + + mindietorch.set_device(pipe.device_0) + + skip_steps = [0] * args.steps + + flag_cache = 0 + if args.use_cache: + flag_cache = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + + flag_cache_faster = 0 + if args.use_cache_faster: + flag_cache_faster = 1 + for i in args.cache_steps.split(','): + if int(i) >= args.steps: + continue + skip_steps[int(i)] = 1 + + use_time = 0 + prompt_loader = PromptLoader(args.prompt_file, + args.prompt_file_type, + args.batch_size, + args.num_images_per_prompt) + + infer_num = 0 + image_info = [] + current_prompt = None + for i, input_info in enumerate(prompt_loader): + prompts = input_info['prompts'] + catagories = input_info['catagories'] + save_names = input_info['save_names'] + n_prompts = input_info['n_prompts'] + + print(f"[{infer_num + n_prompts}/{len(prompt_loader)}]: {prompts}") + infer_num += args.batch_size + + if i > 4: + start_time = time.time() + + if args.scheduler == "DDIM": + images = pipe.ascendie_infer_ddim( + prompts, + num_inference_steps=args.steps, + skip_steps=skip_steps, + flag_cache=flag_cache, + flag_cache_faster=flag_cache_faster, + ) + else: + images = pipe.ascendie_infer( + prompts, + num_inference_steps=args.steps, + skip_steps=skip_steps, + flag_cache=flag_cache, + flag_cache_faster=flag_cache_faster, + ) + + if i > 4: # do not count the time spent inferring the first 0 to 4 images + use_time += time.time() - start_time + + for j in range(n_prompts): + image_save_path = os.path.join(save_dir, f"{save_names[j]}.png") + image = images[0][j] + image.save(image_save_path) + + if current_prompt != prompts[j]: + current_prompt = prompts[j] + image_info.append({'images': [], 'prompt': current_prompt, 'category': catagories[j]}) + + image_info[-1]['images'].append(image_save_path) + + infer_num = infer_num - 5 # do not count the time spent inferring the first 5 images + print(f"[info] infer number: {infer_num}; use time: {use_time:.3f}s\n" + f"average time: {use_time / infer_num:.3f}s\n" + f"clip time: {clip_time / infer_num:.3f}s\n" + f"unet time: {unet_time / infer_num:.3f}s\n" + f"vae time: {vae_time / infer_num:.3f}s\n" + f"p1 time: {p1_time / infer_num:.3f}s\n" + f"p2 time: {p2_time / infer_num:.3f}s\n" + ) + if hasattr(pipe, 'device_1'): + if (pipe.unet_bg): + pipe.unet_bg.stop() + + if (pipe.unet_bg_cache): + pipe.unet_bg_cache.stop() + + if (pipe.unet_bg_cache_faster): + pipe.unet_bg_cache_faster.stop() + + # Save image information to a json file + if os.path.exists(args.info_file_save_path): + os.remove(args.info_file_save_path) + + with os.fdopen(os.open(args.info_file_save_path, os.O_RDWR | os.O_CREAT, 0o640), "w") as f: + json.dump(image_info, f) + + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableDiffusion/stable_diffusion_unet_patch.py b/MindIE/MultiModal/StableDiffusion/stable_diffusion_unet_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3280950cac948c958b7d17ee06b46e6bb0cb7a --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/stable_diffusion_unet_patch.py @@ -0,0 +1,29 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version == '0.26.3', f"Expected diffusers version 0.26.3, but got {diffusers_version}" + os.system(f'patch -p0 {diffusers_path[0]}/models/unets/unet_2d_condition.py unet_2d_condition.patch') + os.system(f'patch -p0 {diffusers_path[0]}/models/unets/unet_2d_blocks.py unet_2d_blocks.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableDiffusion/unet_2d_blocks.patch b/MindIE/MultiModal/StableDiffusion/unet_2d_blocks.patch new file mode 100644 index 0000000000000000000000000000000000000000..674baeb1f882f47e1e2fd5add6a34a882377d2ce --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/unet_2d_blocks.patch @@ -0,0 +1,69 @@ +--- ./unet_2d_blocks.py 2024-06-24 09:06:04.593004325 +0800 ++++ ./unet_2d_blocks_wz.py 2024-06-24 09:33:26.052027073 +0800 +@@ -1159,6 +1159,7 @@ class CrossAttnDownBlock2D(nn.Module): + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ++ block_number: Optional[int]=None, + additional_residuals: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () +@@ -1210,6 +1211,8 @@ class CrossAttnDownBlock2D(nn.Module): + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) ++ if block_number is not None and len(output_states) == block_number + 1: ++ return hidden_states, output_states + + if self.downsamplers is not None: + for downsampler in self.downsamplers: +@@ -2364,6 +2367,7 @@ class CrossAttnUpBlock2D(nn.Module): + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ++ block_number: Optional[int]=None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( +@@ -2372,8 +2376,12 @@ class CrossAttnUpBlock2D(nn.Module): + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) +- +- for resnet, attn in zip(self.resnets, self.attentions): ++ ++ prev_feature = [] ++ for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): ++ if block_number is not None and i < len(self.resnets) - block_number - 1: ++ continue ++ + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] +@@ -2390,8 +2398,9 @@ class CrossAttnUpBlock2D(nn.Module): + b2=self.b2, + ) + ++ prev_feature.append(hidden_states) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) +- ++ + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): +@@ -2428,12 +2437,12 @@ class CrossAttnUpBlock2D(nn.Module): + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] +- ++ + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) +- +- return hidden_states ++ ++ return hidden_states, prev_feature + + + class UpBlock2D(nn.Module): diff --git a/MindIE/MultiModal/StableDiffusion/unet_2d_condition.patch b/MindIE/MultiModal/StableDiffusion/unet_2d_condition.patch new file mode 100644 index 0000000000000000000000000000000000000000..1e816cec4408082841924bc3ff05c98e33ccdd89 --- /dev/null +++ b/MindIE/MultiModal/StableDiffusion/unet_2d_condition.patch @@ -0,0 +1,242 @@ +--- ./unet_2d_condition.py 2024-06-24 09:06:04.594004325 +0800 ++++ ./unet_2d_condition_wz.py 2024-06-24 09:33:25.980027072 +0800 +@@ -855,6 +855,10 @@ class UNet2DConditionModel(ModelMixin, C + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, ++ if_skip: int = 0, ++ if_faster: int = 0, ++ inputCache: Optional[torch.FloatTensor] = None, ++ inputFasterCache: Optional[torch.FloatTensor] = None, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. +@@ -1110,29 +1114,60 @@ class UNet2DConditionModel(ModelMixin, C + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + +- down_block_res_samples = (sample,) +- for downsample_block in self.down_blocks: +- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: +- # For t2i-adapter CrossAttnDownBlock2D +- additional_residuals = {} +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) +- +- sample, res_samples = downsample_block( +- hidden_states=sample, +- temb=emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- **additional_residuals, +- ) ++ if not if_skip: ++ down_block_res_samples = (sample,) ++ for downsample_block in self.down_blocks: ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) ++ ++ down_block_res_samples += res_samples ++ else: ++ if if_faster: ++ down_block_res_samples = inputFasterCache.chunk(2, dim=1) + else: +- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) +- if is_adapter and len(down_intrablock_additional_residuals) > 0: +- sample += down_intrablock_additional_residuals.pop(0) ++ down_block_res_samples = (sample,) ++ for downsample_block in self.down_blocks: + +- down_block_res_samples += res_samples ++ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: ++ # For t2i-adapter CrossAttnDownBlock2D ++ additional_residuals = {} ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) ++ ++ sample, res_samples = downsample_block( ++ hidden_states=sample, ++ temb=emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ block_number=0, ++ **additional_residuals, ++ ) ++ else: ++ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) ++ if is_adapter and len(down_intrablock_additional_residuals) > 0: ++ sample += down_intrablock_additional_residuals.pop(0) ++ ++ down_block_res_samples += res_samples ++ break + + if is_controlnet: + new_down_block_res_samples = () +@@ -1146,61 +1181,87 @@ class UNet2DConditionModel(ModelMixin, C + down_block_res_samples = new_down_block_res_samples + + # 4. mid +- if self.mid_block is not None: +- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: +- sample = self.mid_block( +- sample, +- emb, +- encoder_hidden_states=encoder_hidden_states, +- attention_mask=attention_mask, +- cross_attention_kwargs=cross_attention_kwargs, +- encoder_attention_mask=encoder_attention_mask, +- ) +- else: +- sample = self.mid_block(sample, emb) +- +- # To support T2I-Adapter-XL +- if ( +- is_adapter +- and len(down_intrablock_additional_residuals) > 0 +- and sample.shape == down_intrablock_additional_residuals[0].shape +- ): +- sample += down_intrablock_additional_residuals.pop(0) ++ if not if_skip: ++ if self.mid_block is not None: ++ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: ++ sample = self.mid_block( ++ sample, ++ emb, ++ encoder_hidden_states=encoder_hidden_states, ++ attention_mask=attention_mask, ++ cross_attention_kwargs=cross_attention_kwargs, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ else: ++ sample = self.mid_block(sample, emb) ++ ++ # To support T2I-Adapter-XL ++ if ( ++ is_adapter ++ and len(down_intrablock_additional_residuals) > 0 ++ and sample.shape == down_intrablock_additional_residuals[0].shape ++ ): ++ sample += down_intrablock_additional_residuals.pop(0) ++ else: ++ sample = inputCache + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up +- for i, upsample_block in enumerate(self.up_blocks): +- is_final_block = i == len(self.up_blocks) - 1 +- +- res_samples = down_block_res_samples[-len(upsample_block.resnets) :] +- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] +- +- # if we have not reached the final block and need to forward the +- # upsample size, we do it here +- if not is_final_block and forward_upsample_size: +- upsample_size = down_block_res_samples[-1].shape[2:] +- +- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: +- sample = upsample_block( +- hidden_states=sample, +- temb=emb, +- res_hidden_states_tuple=res_samples, +- encoder_hidden_states=encoder_hidden_states, +- cross_attention_kwargs=cross_attention_kwargs, +- upsample_size=upsample_size, +- attention_mask=attention_mask, +- encoder_attention_mask=encoder_attention_mask, +- ) +- else: +- sample = upsample_block( +- hidden_states=sample, +- temb=emb, +- res_hidden_states_tuple=res_samples, +- upsample_size=upsample_size, +- scale=lora_scale, +- ) ++ if not if_skip: ++ if if_faster: ++ inputFasterCache = [tmp.clone() for tmp in down_block_res_samples] ++ inputFasterCache = torch.cat(inputFasterCache[0:2], dim=1) ++ ++ for i, upsample_block in enumerate(self.up_blocks): ++ is_final_block = i == len(self.up_blocks) - 1 ++ ++ res_samples = down_block_res_samples[-len(upsample_block.resnets) :] ++ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] ++ ++ # if we have not reached the final block and need to forward the ++ # upsample size, we do it here ++ if not is_final_block and forward_upsample_size: ++ upsample_size = down_block_res_samples[-1].shape[2:] ++ ++ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: ++ sample, record_feature = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ encoder_hidden_states=encoder_hidden_states, ++ cross_attention_kwargs=cross_attention_kwargs, ++ upsample_size=upsample_size, ++ attention_mask=attention_mask, ++ encoder_attention_mask=encoder_attention_mask, ++ ) ++ else: ++ sample = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ upsample_size=upsample_size, ++ scale=lora_scale, ++ ) ++ ++ if (not if_skip) and (i == 3): ++ inputCache = record_feature[-2] ++ else: ++ for i, upsample_block in enumerate(self.up_blocks): ++ if i==3: ++ res_samples = down_block_res_samples[-2:] ++ sample, _ = upsample_block( ++ hidden_states=sample, ++ temb=emb, ++ res_hidden_states_tuple=res_samples, ++ encoder_hidden_states=encoder_hidden_states, ++ cross_attention_kwargs=cross_attention_kwargs, ++ upsample_size=upsample_size, ++ attention_mask=attention_mask, ++ encoder_attention_mask=encoder_attention_mask, ++ block_number=1, ++ ) + + # 6. post-process + if self.conv_norm_out: +@@ -1215,4 +1276,7 @@ class UNet2DConditionModel(ModelMixin, C + if not return_dict: + return (sample,) + +- return UNet2DConditionOutput(sample=sample) ++ if not if_skip: ++ return (sample, inputCache, inputFasterCache) if if_faster else (sample, inputCache) ++ else: ++ return UNet2DConditionOutput(sample=sample) diff --git a/MindIE/MultiModal/StableVideoDiffusion/README.md b/MindIE/MultiModal/StableVideoDiffusion/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d5d5c88f22c5bce34b5fb92a245ae5788a6715e7 --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/README.md @@ -0,0 +1,235 @@ +# stable-video-diffusion模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#section540883920406) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能](#ZH-CN_TOPIC_0000001172201573) + + +# 概述 + + stable-video-diffusion是一种图像到视频的扩散模型,能够在给定任何图像输入的情况下生成与图像相对应的视频。有关稳定扩散函数的更多信息,请查看[Stable Video Diffusion blog](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)。 + +- 参考实现: + ```bash + # StableVideoDiffusion + https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt + ``` + +- 设备支持: +Atlas 800I A2推理设备:支持的卡数为1或2 + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | input | 1 x 512 x 512 x 3 | FLOAT32 | NHWC | + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | -------- | -------- | ------------ | + | output | 1 x 25 x 512 x 512 x 3 | FLOAT32 | NTHWC | + +**注意**:该模型当前仅支持batch size为1的情况。 + +# 推理环境准备 + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 +- + | 配套 | 版本 | 备注 | + | ------------------------------------------------------------ |--------| ------------------------------------------------------------ | + | Python | 3.10.13 | - | + | torch | 2.0.0 | 导出pt模型所需版本 | + | torch | 2.1.0 | 模型编译和推理所需版本 | + + +# 快速上手 + +## 获取源码 + +1. 按照requirements.txt要求的版本安装相关依赖,避免导出模型失败。 + ```bash + pip3 install -r requirements.txt + ``` + +2. 安装mindie包 + + ```bash + # 安装mindie + chmod +x ./Ascend-mindie_xxx.run + ./Ascend-mindie_xxx.run --install + source /usr/local/Ascend/mindie/set_env.sh + ``` + +3. 代码修改 + + 执行命令: + + ```bash + python3 stable_video_diffusion_activations_patch.py + ``` + + ```bash + python3 stable_video_diffusion_attention_patch.py + ``` + + ```bash + python3 stable_video_diffusion_transformer_patch.py + ``` + +## 准备数据集 + +1. 获取原始数据集。 + + 本模型输入图像示例的下载网址为:https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png + 用户自网址自行下载后放置当前路径下,命名为 rocket.png + +## 模型推理 + +1. 模型转换。 + 使用Pytorch导出pt模型,然后使用MindIE推理引擎转换为适配昇腾的模型。 + + 1. 获取权重 + + 可提前下载权重,以避免执行后面步骤时可能会出现下载失败。 + + ```bash + # 需要使用 git-lfs (https://git-lfs.com) + git lfs install + + # StableVideoDiffusion + git clone https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt + ``` + + 2. 导出pt模型 + + 设置模型名称或路径 + ```bash + # 执行时下载权重 + model_base="stabilityai/stable-video-diffusion-img2vid-xt" + + # 使用上一步下载的权重 + model_base="./stable-video-diffusion-img2vid-xt" + ``` + + 执行命令: + + ```bash + # 导出pt模型 + python3 export_ts.py --model ${model_base} --output_dir ./models + # 更换torch版本,执行后续的模型编译和推理 + python3 uninstall torch + python3 install torch==2.1.0 + ``` + + 参数说明: + - --model:模型名称或本地模型目录的路径 + - --output_dir: pt模型输出目录 + + 执行成功后会生成pt模型: + - ./models/image_encoder_embed/image_encoder_embed.pt + - ./models/unet/unet_bs2.pt + - ./models/vae/vae_encode.pt + - ./models/vae/vae_decode.pt + +2. 开始推理验证。 + 1. 开启cpu高性能模式 + ```bash + echo performance |tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor + sysctl -w vm.swappiness=0 + sysctl -w kernel.numa_balancing=0 + ``` + + 2. 安装绑核工具 + ```bash + apt-get update + apt-get install numactl + ``` + 查询卡的NUMA node + ```shell + lspci -vs bus-id + ``` + bus-id可通过npu-smi info获得,查询到NUMA node,在推理命令前加上对应的数字 + + 可通过lscpu获得NUMA node对应的CPU核数 + ```shell + NUMA node0: 0-23 + NUMA node1: 24-47 + NUMA node2: 48-71 + NUMA node3: 72-95 + ``` + 当前查到NUMA node是0,对应0-23,推荐绑定其中单核以获得更好的性能。 + + 3. 执行推理脚本。 + ```bash + # 0.第一次推理需要配置环境变量,使得在静态TLS块中可以分配内存: + find / -name *libGL* # 查找libGLdispatch.so.0文件的路径,记为lib_dir,例如 lib_dir="/lib/aarch64-linux-gnu" + export LD_PRELOAD=${lib_dir}/libGLdispatch.so.0:$LD_PRELOAD + + # 1.若不使用并行推理: + numactl -C 0-23 python3 stable_video_diffusion_pipeline.py \ + --model ${model_base} \ + --img_file ./rocket.png \ + --device 0 \ + --save_dir ./results \ + --num_inference_steps 25 \ + --output_dir ./models + + # 2.若使用并行推理: + numactl -C 0-23 python3 stable_video_diffusion_pipeline_parallel.py \ + --model ${model_base} \ + --img_file ./rocket.png \ + --device 0,1 \ + --save_dir ./results \ + --num_inference_steps 25 \ + --output_dir ./models + ``` + + 参数说明: + - --model:模型名称或本地模型目录的路径。 + - --img_file:输入图像文件。 + - --device:推理设备ID;可用逗号分割传入两个设备ID,此时会使用并行方式进行推理。 + - --save_dir:生成视频的存放目录。 + - --num_inference_steps:生成视频的迭代次数。 + - --output_dir: 编译好的模型路径。 + + 执行完成后在`./results`目录下生成推理视频。并在终端显示推理时间。 + + **注意**:若使用Atlas 800I A2单卡推理,则需要保证单卡的实际可用内存(最大值-无进程时初始值)> 29762MB。否则尝试重启服务器以降低无进程时初始值、更换服务器,或使用双卡并行推理。 + + +# 模型推理性能 + +性能参考下列数据。 + +### StableVideoDiffusion + +| 硬件形态 | 迭代次数 | 平均耗时 | +| :------: |:----:|:----:| +| Atlas 800I A2(8*32G) 单卡 | 25 | 28s | +| Atlas 800I A2(8*32G) 双卡 | 25 | 14.5s | + +**注意**:当前推理pipline中未固定随机种子 + +```bash +# 推理pipline main函数中加入 +generator = torch.Generator().manual_seed(xxx) +# 在ascendie_infer函数中加入参数 +generator=generator +``` \ No newline at end of file diff --git a/MindIE/MultiModal/StableVideoDiffusion/activations.patch b/MindIE/MultiModal/StableVideoDiffusion/activations.patch new file mode 100644 index 0000000000000000000000000000000000000000..a9acb1baebcf5b9eff2e44ca78e57da6252c32e9 --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/activations.patch @@ -0,0 +1,23 @@ +--- activations.py 2024-05-15 08:25:05.724000000 +0000 ++++ activations_new.py 2024-05-15 08:25:05.724000000 +0000 +@@ -90,12 +90,18 @@ + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + self.proj = linear_cls(dim_in, dim_out * 2, bias=bias) ++ ++ def gelu_t(self, gate: torch.Tensor) -> torch.Tensor: ++ import math ++ return gate * 0.5 * (1.0 + torch.erf(gate / math.sqrt(2.0))) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": +- return F.gelu(gate) ++ return self.gelu_t(gate) ++ # return F.gelu(gate) + # mps: gelu is not implemented for float16 +- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) ++ return self.gelu_t(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) ++ # return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states, scale: float = 1.0): + args = () if USE_PEFT_BACKEND else (scale,) diff --git a/MindIE/MultiModal/StableVideoDiffusion/attention_processor.patch b/MindIE/MultiModal/StableVideoDiffusion/attention_processor.patch new file mode 100644 index 0000000000000000000000000000000000000000..ce1d0f592b1bf798c842365844db995719846a45 --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/attention_processor.patch @@ -0,0 +1,152 @@ +--- attention_processor.py 2024-05-21 11:01:29.469113600 +0800 ++++ attention_processor_new.py 2024-05-21 11:01:29.469113600 +0800 +@@ -22,7 +22,7 @@ + from ..utils.import_utils import is_xformers_available + from ..utils.torch_utils import maybe_allow_in_graph + from .lora import LoRACompatibleLinear, LoRALinearLayer +- ++import math + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@@ -204,11 +204,11 @@ + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention +- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- if processor is None: +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # if processor is None: ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor=AttnProcessor() + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( +@@ -366,10 +366,10 @@ + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention +- # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 +- processor = ( +- AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() +- ) ++ # processor = ( ++ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ++ # ) ++ processor = AttnProcessor() + + self.set_processor(processor) + +@@ -999,11 +999,104 @@ + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + ++ def flash_attention_causal_forward(Q, K, V, mask=None, dropout_p=0.0, is_causal=False, scale=None): ++ BLOCK_SIZE = 1024 ++ NEG_INF =-1e10 # -1e10 # -infinity ++ EPSILON = 1e-10 ++ O = torch.zeros_like(Q).float() ++ l = torch.zeros(Q.shape[:-1])[...,None].float() ++ m = (torch.ones(Q.shape[:-1])[...,None] * NEG_INF).float() ++ ++ # O = O.to(device='cuda') ++ # l = l.to(device='cuda') ++ # m = m.to(device='cuda') ++ device_tmp = Q.device ++ O = O.to(device_tmp) ++ l = l.to(device_tmp) ++ m = m.to(device_tmp) ++ ++ Q_BLOCK_SIZE = min(BLOCK_SIZE, Q.shape[-1]) ++ KV_BLOCK_SIZE = BLOCK_SIZE ++ ++ Q_BLOCKS = torch.split(Q.float(), Q_BLOCK_SIZE, dim=2) ++ K_BLOCKS = torch.split(K.float(), KV_BLOCK_SIZE, dim=2) ++ V_BLOCKS = torch.split(V.float(), KV_BLOCK_SIZE, dim=2) ++ ++ Tr = Q_BLOCKS.size[0]#len(Q_BLOCKS) ++ Tc = K_BLOCKS.size[0]#len(K_BLOCKS) ++ ++ O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2)) ++ l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2)) ++ m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2)) ++ ++ # Q_LEN = Q.shape[2] ++ # K_LEN = K.shape[2] ++ ++ # Q_RANGE = torch.arange(Q_LEN)[:,None] + (K_LEN - Q_LEN) ++ # K_RANGE = torch.arange(K_LEN)[None,:] ++ ++ # Q_RANGE = Q_RANGE.to(device='cuda') ++ # K_RANGE = K_RANGE.to(device='cuda') ++ # Q_RANGE = Q_RANGE.to(device_tmp) ++ # K_RANGE = K_RANGE.to(device_tmp) ++ ++ # Q_RANGE_BLOCKS = torch.split(Q_RANGE, Q_BLOCK_SIZE, dim=0) ++ # K_RANGE_BLOCKS = torch.split(K_RANGE, KV_BLOCK_SIZE, dim=1) ++ ++ for j in range(Tc): ++ Kj = K_BLOCKS[j] ++ Vj = V_BLOCKS[j] ++ # K_RANGE_BLOCKSj = K_RANGE_BLOCKS[j] ++ ++ for i in range(Tr): ++ Qi = Q_BLOCKS[i] ++ Oi = O_BLOCKS[i] ++ li = l_BLOCKS[i] ++ mi = m_BLOCKS[i] ++ # Q_RANGE_BLOCKSi = Q_RANGE_BLOCKS[i] ++ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale ++ ++ # scale = 1 / np.sqrt(Q.shape[-1]) ++ Qi_scaled = Qi * scale_factor ++ ++ S_ij = torch.einsum('... i d, ... j d -> ... i j', Qi_scaled, Kj) ++ ++ # print('===============>') ++ ++ # Masking ++ # causal_mask = Q_RANGE_BLOCKSi >= K_RANGE_BLOCKSj ++ # print('NEG_INF:',NEG_INF) ++ # S_ij = torch.where(causal_mask > 0, S_ij, NEG_INF) ++ ++ m_block_ij, _ = torch.max(S_ij, dim=-1, keepdim=True) ++ P_ij = torch.exp(S_ij - m_block_ij) ++ # Masking ++ # P_ij = torch.where(causal_mask > 0, P_ij, 0) ++ ++ l_block_ij = torch.sum(P_ij, dim=-1, keepdim=True) + EPSILON ++ ++ P_ij_Vj = torch.einsum('... i j, ... j d -> ... i d', P_ij, Vj) ++ ++ mi_new = torch.maximum(m_block_ij, mi) ++ li_new = torch.exp(mi - mi_new) * li + torch.exp(m_block_ij - mi_new) * l_block_ij ++ li_new = li_new + EPSILON ++ O_BLOCKS[i] = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj ++ l_BLOCKS[i] = li_new ++ m_BLOCKS[i] = mi_new ++ # if torch.isnan(O_BLOCKS[i]).any(): ++ # print('O_BLOCKS=====>:',O_BLOCKS[i].dtype) ++ ++ O = torch.cat(O_BLOCKS, dim=2) ++ l = torch.cat(l_BLOCKS, dim=2) ++ m = torch.cat(m_BLOCKS, dim=2) ++ # return O, l, m ++ return O.half() ++ + # the output of sdp = (batch, num_heads, seq_len, head_dim) +- # TODO: add support for attn.scale when we move to Torch 2.1 +- hidden_states = F.scaled_dot_product_attention( +- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False +- ) ++ # hidden_states = F.scaled_dot_product_attention( ++ # query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ++ # ) ++ hidden_states = flash_attention_causal_forward(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) + + # linear proj diff --git a/MindIE/MultiModal/StableVideoDiffusion/background_runtime.py b/MindIE/MultiModal/StableVideoDiffusion/background_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..7c348699a2e730c636de22a8dfb0744a0d6d1240 --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/background_runtime.py @@ -0,0 +1,168 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import time +from dataclasses import dataclass +from typing import List + +import numpy as np +import torch +import mindietorch + + +@dataclass +class RuntimeIOInfo: + input_shapes: List[tuple] + input_dtypes: List[type] + output_shapes: List[tuple] + output_dtypes: List[type] + + +class BackgroundRuntime: + def __init__( + self, + device_id: int, + model_path: str, + io_info: RuntimeIOInfo + ): + # Create a pipe for process synchronization + self.sync_pipe, sync_pipe_peer = mp.Pipe(duplex=True) + + # Create shared buffers + input_spaces = self.create_shared_buffers(io_info.input_shapes, + io_info.input_dtypes) + output_spaces = self.create_shared_buffers(io_info.output_shapes, + io_info.output_dtypes) + + # Build numpy arrays on the shared buffers + self.input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + self.output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + mp.set_start_method('fork', force=True) + self.p = mp.Process(target=self.run_infer, + args=[ + sync_pipe_peer, input_spaces, output_spaces, + io_info, device_id, model_path + ]) + self.p.start() + + # Wait until the sub process is ready + self.wait() + + @staticmethod + def create_shared_buffers(shapes: List[tuple], + dtypes: List[type]) -> List[mp.RawArray]: + buffers = [] + for shape, dtype in zip(shapes, dtypes): + size = 1 + for x in shape: + size *= x + + raw_array = mp.RawArray(np.ctypeslib.as_ctypes_type(dtype), size) + buffers.append(raw_array) + + return buffers + + def infer_asyn(self, feeds: List[np.ndarray]) -> None: + for i, _ in enumerate(self.input_arrays): + self.input_arrays[i][:] = feeds[i][:] + + self.sync_pipe.send('') + + def wait(self) -> None: + self.sync_pipe.recv() + + def get_outputs(self) -> List[np.ndarray]: + return self.output_arrays + + def wait_and_get_outputs(self) -> List[np.ndarray]: + self.wait() + return self.get_outputs() + + def infer(self, feeds: List[np.ndarray]) -> List[np.ndarray]: + self.infer_asyn(feeds) + return self.wait_and_get_outputs() + + def stop(self): + # Stop the sub process + self.sync_pipe.send('STOP') + + @staticmethod + def run_infer( + sync_pipe: mp.connection.Connection, + input_spaces: List[np.ndarray], + output_spaces: List[np.ndarray], + io_info: RuntimeIOInfo, + device_id: int, + model_path: str, + ) -> None: + # The sub process function + # Create a runtime + print(f"[info] bg device id: {device_id}") + + # Tell the main function that we are ready + model = torch.jit.load(model_path).eval() + + # Build numpy arrays on the shared buffers + input_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + input_spaces, io_info.input_shapes, io_info.input_dtypes) + ] + + output_arrays = [ + np.frombuffer(b, dtype=t).reshape(s) for (b, s, t) in zip( + output_spaces, io_info.output_shapes, io_info.output_dtypes) + ] + + # Tell the main function that we are ready + sync_pipe.send('') + + stream = mindietorch.npu.Stream(f"npu:{device_id}") + + # Keep looping until recived a 'STOP' + while sync_pipe.recv() != 'STOP': + start = time.time() + sample, timestep, encoder_hidden_states, added_time_ids = [ + torch.Tensor(input_array) for input_array in input_arrays + ] + + sample_npu = sample.to(torch.float32).to(f"npu:{device_id}") + timestep_npu = timestep.to(torch.float32).to(f"npu:{device_id}") + encoder_hidden_states_npu = encoder_hidden_states.to(torch.float32).to(f"npu:{device_id}") + added_time_ids_npu = added_time_ids.to(torch.float32).to(f"npu:{device_id}") + + with mindietorch.npu.stream(stream): + output_npu = model(sample_npu, timestep_npu, encoder_hidden_states_npu, added_time_ids_npu) + stream.synchronize() + + output_cpu = output_npu.to('cpu') + + for i, _ in enumerate(output_arrays): + output = output_cpu.numpy() + output_arrays[i][:] = output[i][:] + + sync_pipe.send('') + + @classmethod + def clone(cls, device_id: int, model_path: str, runtime_info: RuntimeIOInfo) -> 'BackgroundRuntime': + # Get shapes, datatypes from an existed engine, + # then use them to create a BackgroundRuntime + return cls(device_id, model_path, runtime_info) diff --git a/MindIE/MultiModal/StableVideoDiffusion/export_ts.py b/MindIE/MultiModal/StableVideoDiffusion/export_ts.py new file mode 100644 index 0000000000000000000000000000000000000000..47153cc97d69db2e178267ad4f68f5933231fdd6 --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/export_ts.py @@ -0,0 +1,222 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from argparse import Namespace + +import torch +import torch.nn as nn +from diffusers import DDIMScheduler +from diffusers import StableVideoDiffusionPipeline + + +def parse_arguments() -> Namespace: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + type=str, + default="./models", + help="Path of directory to save pt models.", + ) + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-video-diffusion-img2vid-xt", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-vp", + "--num_videos_per_prompt", + type=int, + default=1, + help="num_videos_per_prompt." + ) + parser.add_argument( + "--decode_chunk_size", + type=int, + default=8, + help="decode_chunk_size." + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=25, + help="num_inference_steps." + ) + + return parser.parse_args() + + +class Embedexport(torch.nn.Module): + def __init__(self, embed_model): + super().__init__() + self.embed_model = embed_model + + def forward(self, image): + return self.embed_model(image).image_embeds + + +def export_image_embeddings(svd_pipeline: StableVideoDiffusionPipeline, save_dir: str, batch_size: int) -> None: + print("Exporting the image embedding...") + embed_path = os.path.join(save_dir, "image_encoder_embed") + if not os.path.exists(embed_path): + os.makedirs(embed_path, mode=0o640) + + embed_pt_path = os.path.join(embed_path, "image_encoder_embed.pt") + if os.path.exists(embed_pt_path): + return + + embed_model = svd_pipeline.image_encoder + + dummy_input = torch.ones([batch_size, 3, 224, 224], dtype=torch.float32) + embed_export = Embedexport(embed_model) + + torch.jit.trace(embed_export, dummy_input).save(embed_pt_path) + + +class Unetexport(torch.nn.Module): + def __init__(self, unet_model): + super().__init__() + self.unet_model = unet_model + + def forward(self, sample, timestep, encoder_hidden_states, added_time_ids): + return self.unet_model(sample, timestep, encoder_hidden_states, added_time_ids, False)[0] + + +def export_unet(svd_pipeline: StableVideoDiffusionPipeline, save_dir: str, batch_size: int, num_videos_per_prompt: int) -> None: + print("Exporting the image information creater...") + unet_path = os.path.join(save_dir, "unet") + if not os.path.exists(unet_path): + os.makedirs(unet_path, mode=0o640) + + unet_pt_path = os.path.join(unet_path, f"unet_bs{batch_size}.pt") + if os.path.exists(unet_pt_path): + return + + unet_model = svd_pipeline.unet + + num_frames = 25 + vae_scale_factor = 2 ** (len(svd_pipeline.vae.config.block_out_channels) - 1) + height = 192 + width = 192 + seq_len = 1 + vae_encode_out=1024 + in_channels = unet_model.config.in_channels + + do_classifier_free_guidance = True + + dummy_input = ( + torch.ones([batch_size*num_videos_per_prompt, num_frames, in_channels, height//vae_scale_factor, width//vae_scale_factor], dtype=torch.float32), + torch.ones([1], dtype=torch.float32), + torch.ones([batch_size*num_videos_per_prompt, seq_len, vae_encode_out], dtype=torch.float32), + torch.ones([batch_size*num_videos_per_prompt, 3], dtype=torch.float32), + ) + + unet = Unetexport(unet_model) + unet.eval() + + torch.jit.trace(unet, dummy_input).save(unet_pt_path) + + +class VaeExportDecode(torch.nn.Module): + def __init__(self, vae_model): + super().__init__() + self.vae_model = vae_model + + def forward(self, latents): + num_frames={} + num_frames["num_frames"]=latents.shape[0] + return self.vae_model.decode(latents,**num_frames).sample + + +class VaeExportEncode(torch.nn.Module): + def __init__(self, vae_model): + super().__init__() + self.vae_model = vae_model + + def forward(self, image:torch.Tensor): + return self.vae_model.encode(image).latent_dist.mode() + + +def export_vae(svd_pipeline: StableVideoDiffusionPipeline, save_dir: str, batch_size: int, decode_chunk_size: int) -> None: + print("Exporting the image decoder...") + + vae_path = os.path.join(save_dir, "vae") + if not os.path.exists(vae_path): + os.makedirs(vae_path, mode=0o640) + + vae_pt_path = os.path.join(vae_path, "vae_encode.pt") + vae_pt_path_2 = os.path.join(vae_path, "vae_decode.pt") + if os.path.exists(vae_pt_path) & os.path.exists(vae_pt_path_2): + return + + vae_model = svd_pipeline.vae + unet_model = svd_pipeline.unet + + sample_size = unet_model.config.sample_size + + channels_latents=unet_model.config.in_channels // 2 + vae_scale_factor = 2 ** (len(vae_model.config.block_out_channels) - 1) + height = 192 + width = 192 + height_ld=height // vae_scale_factor + width_ld=width //vae_scale_factor + dummy_input = torch.ones([1, channels_latents, height_ld, width_ld],dtype=torch.float32) + vae_export_decode = VaeExportDecode(vae_model) + trace_model_decode=torch.jit.trace(vae_export_decode, dummy_input) + trace_model_decode.save(vae_pt_path_2) + + print("Exporting the image encoder...") + + dummy_input=torch.ones([1, 3, height, width]) + vae_export_encode = VaeExportEncode(vae_model) + trace_model_encode=torch.jit.trace(vae_export_encode, dummy_input) + trace_model_encode.save(vae_pt_path) + + +def export_to_pt(model_path: str, save_dir: str, batch_size: int, num_inference_steps: int, decode_chunk_size: int, num_videos_per_prompt: int) -> None: + + pipeline = StableVideoDiffusionPipeline.from_pretrained(model_path).to("cpu") + + if not os.path.exists(save_dir): + os.makedirs(save_dir, mode=0o640) + + print(">>>>>>>>>>>>>>>embedding!") + export_image_embeddings(pipeline, save_dir, batch_size) + + print(">>>>>>>>>>>>>>>VAE!") + export_vae(pipeline, save_dir, batch_size, decode_chunk_size) + + print(">>>>>>>>>>>>>>>UNET!") + export_unet(pipeline, save_dir, batch_size * 2, num_videos_per_prompt) + + +def main(): + args = parse_arguments() + export_to_pt(args.model, args.output_dir, args.batch_size, args.num_inference_steps, args.decode_chunk_size, args.num_videos_per_prompt) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableVideoDiffusion/requirements.txt b/MindIE/MultiModal/StableVideoDiffusion/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..66a4c6a3d5d999d7836b45ec54d3137538b44715 --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/requirements.txt @@ -0,0 +1,7 @@ +torch==2.0.0 +torchaudio==2.0.2 +torchvision==0.15.2 +diffusers==0.26.3 +transformers==4.38.2 +open_clip_torch==2.20.0 +opencv-python==4.9.0.80 \ No newline at end of file diff --git a/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_activations_patch.py b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_activations_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..92c36f33954be0eb940869c5ad9b484fc19d4bcf --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_activations_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/activations.py activations.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_attention_patch.py b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_attention_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c1b369bb23389b6abc12c710f63e5b986b836d --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_attention_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/attention_processor.py attention_processor.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_pipeline.py b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..6275dea7e48061f94c816a778d4ea0438a9e7dee --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_pipeline.py @@ -0,0 +1,633 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +import numpy as np +import torch +import mindietorch +from mindietorch import _enums +import pickle +from typing import Callable, Dict, List, Optional, Union + +from diffusers import StableVideoDiffusionPipeline +import diffusers.models.transformer_temporal +from diffusers.utils import load_image, export_to_video +from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing,_compute_padding,_filter2d,_gaussian,_gaussian_blur2d,_append_dims,inspect,tensor2vid,StableVideoDiffusionPipelineOutput +import PIL.Image +import torch +from diffusers.utils.torch_utils import randn_tensor + +image_embed_time = 0 +heightS = 512 +widthS = 512 +num_framesS = 25 +Dshape = False + +print("height:{},width:{},num_frames:{},vae_decode dynamic shape:{}".format(heightS,widthS,num_framesS,Dshape)) + +class AIEStableVideoDiffusionPipeline(StableVideoDiffusionPipeline): + def parser_args(self, args): + self.args = args + self.is_init = False + if isinstance(self.args.device, list): + self.device_0 = self.args.device[0] + else: + self.device_0 = args.device + + def compile_aie_model(self): + if self.is_init: + return + + in_channels = self.unet.config.in_channels + batch_size = self.args.batch_size + num_videos_per_prompt = 1 + height = heightS + width = widthS + num_frames = num_framesS if num_framesS is not None else self.unet.config.num_frames + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + seq_len = 1 + vae_encode_out=1024 + decode_chunk_size=self.args.decode_chunk_size + num_inference_steps=self.args.num_inference_steps + res=num_frames%decode_chunk_size + channels_latents=in_channels // 2 + + image_encoder_embed_path = os.path.join(self.args.output_dir, "image_encoder_embed/image_encoder_embed.ts") + if os.path.exists(image_encoder_embed_path): + self.compiled_image_encoder_embed = torch.jit.load(image_encoder_embed_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, "image_encoder_embed/image_encoder_embed.pt")).eval() + self.compiled_image_encoder_embed = ( + mindietorch.compile( + model, + inputs=[ + mindietorch.Input((batch_size, 3, 224, 224),dtype=mindietorch.dtype.FLOAT) + ], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_image_encoder_embed, image_encoder_embed_path) + + print(">>>>>>>>>>>>>>>image_encoder_embed2ts OK!") + + vae_encode_path = os.path.join(self.args.output_dir, "vae/vae_encode.ts") + if os.path.exists(vae_encode_path): + self.compiled_vae_encode = torch.jit.load(vae_encode_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, "vae/vae_encode.pt")).eval() + self.compiled_vae_encode = ( + mindietorch.compile( + model, + inputs=[ + mindietorch.Input((batch_size, 3, height, width),dtype=mindietorch.dtype.FLOAT) + ], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_vae_encode, vae_encode_path) + + print(">>>>>>>>>>>>>>>vae_encode2ts OK!") + + model = torch.jit.load(os.path.join(self.args.output_dir, "vae/vae_decode.pt")).eval() + if Dshape: + vae_decode_path = os.path.join(self.args.output_dir, "vae/vae_decode.ts") + if os.path.exists(vae_decode_path): + self.compiled_vae_decode = torch.jit.load(vae_decode_path).eval() + else: + max_shape = (decode_chunk_size,channels_latents,height//vae_scale_factor,width//vae_scale_factor) + min_shape = (res,channels_latents,height//vae_scale_factor,width//vae_scale_factor) + inputs_vae = [] + inputs_vae.append([mindietorch.Input(max_shape,dtype=mindietorch.dtype.FLOAT)]) + if res !=0: + inputs_vae.append([mindietorch.Input(min_shape,dtype=mindietorch.dtype.FLOAT)]) + + self.compiled_vae_decode = ( + mindietorch.compile( + model, + inputs=inputs_vae, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_vae_decode, vae_decode_path) + else: + vae_decode_path_8 = os.path.join(self.args.output_dir, "vae/vae_decode8.ts") + vae_decode_path_1 = os.path.join(self.args.output_dir, "vae/vae_decode1.ts") + if os.path.exists(vae_decode_path_8) & os.path.exists(vae_decode_path_1): + self.compiled_vae_decode8 = torch.jit.load(vae_decode_path_8).eval() + self.compiled_vae_decode1 = torch.jit.load(vae_decode_path_1).eval() + else: + max_shape = (decode_chunk_size,channels_latents,height//vae_scale_factor,width//vae_scale_factor) + min_shape = (res,channels_latents,height//vae_scale_factor,width//vae_scale_factor) + inputs_vae = [] + inputs_vae.append([mindietorch.Input(max_shape,dtype=mindietorch.dtype.FLOAT)]) + self.compiled_vae_decode8 = ( + mindietorch.compile( + model, + inputs=inputs_vae, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_vae_decode8, vae_decode_path_8) + + inputs_vae.clear() + inputs_vae.append([mindietorch.Input(min_shape,dtype=mindietorch.dtype.FLOAT)]) + self.compiled_vae_decode1 = ( + mindietorch.compile( + model, + inputs=inputs_vae, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_vae_decode1, vae_decode_path_1) + + print(">>>>>>>>>>>>>>>vae_decode2ts OK!") + + unet_compile_path = os.path.join(self.args.output_dir, "unet/unet_bs2.ts") + if os.path.exists(unet_compile_path): + self.compiled_unet_model = torch.jit.load(unet_compile_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, "unet/unet_bs2.pt")).eval() + + self.compiled_unet_model = ( + mindietorch.compile( + model, + inputs=[ + mindietorch.Input((batch_size*2*num_videos_per_prompt,num_frames,in_channels, height//vae_scale_factor,width//vae_scale_factor),dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,),dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size*2*num_videos_per_prompt,seq_len,vae_encode_out),dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size*2*num_videos_per_prompt,3),dtype=mindietorch.dtype.FLOAT) + ], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_unet_model, unet_compile_path) + + print(">>>>>>>>>>>>>>>unet2ts OK!") + + self.is_init = True + + @torch.no_grad() + def ascendie_infer( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: int = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + ): + self.calmse=torch.nn.MSELoss(reduction='mean') + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + # self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = max_guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width).contiguous() + noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs 创建时间嵌入向量 + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + # 8. Denoising loop + self._num_timesteps = len(timesteps) + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimention + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) + + # predict the noise residual + noise_pred= self.compiled_unet_model( + latent_model_input.to(f'npu:{self.device_0}'), + t[None].to(f'npu:{self.device_0}'), + image_embeddings.to(f'npu:{self.device_0}'), + added_time_ids.to(f'npu:{self.device_0}'), + ).to('cpu') + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size).to('cpu') + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) + + def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + # run inference + global image_embed_time + start =time.time() + + image_embeddings = self.compiled_image_encoder_embed(image.to(device=f'npu:{self.device_0}', dtype=dtype)).to('cpu') + + image_embed_time +=time.time()-start + + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + image_latents = self.compiled_vae_encode(image.to(f'npu:{self.device_0}')).to('cpu') + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + def decode_latents(self, latents, num_frames, decode_chunk_size=14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + if Dshape: + frame = self.compiled_vae_decode(latents[i : i + decode_chunk_size].to(f'npu:{self.device_0}')).to('cpu') + else: + if num_frames_in == decode_chunk_size: + frame = self.compiled_vae_decode8(latents[i : i + decode_chunk_size].to(f'npu:{self.device_0}')).to('cpu') + else: + frame = self.compiled_vae_decode1(latents[i : i + decode_chunk_size].to(f'npu:{self.device_0}')).to('cpu') + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-video-diffusion-img2vid-xt", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--img_file", + type=str, + default="./rocket.png", + help="A png file of prompts for generating vedio.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="Path to save model pt.", + ) + parser.add_argument( + "--fps", + type=int, + default=7, + help="FPS", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=0, + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-vp", + "--num_videos_per_prompt", + type=int, + default=1, + help="num_videos_per_prompt." + ) + parser.add_argument( + "--decode_chunk_size", + type=int, + default=8, + help="decode_chunk_size." + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=25, + help="num_inference_steps." + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + decode_chunk_size=args.decode_chunk_size + num_inference_steps=args.num_inference_steps + + pipe = AIEStableVideoDiffusionPipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + + pipe.compile_aie_model() + mindietorch.set_device(args.device) + + # 加载img及预处理 + image = load_image(args.img_file) + image = image.resize((heightS, widthS)) + + print('warming up ~~~~~') + stream = mindietorch.npu.Stream("npu:" + str(args.device)) + with mindietorch.npu.stream(stream): + frames = pipe.ascendie_infer( + image, + decode_chunk_size=decode_chunk_size, + height= heightS, + width = widthS, + num_inference_steps=num_inference_steps, + num_frames = num_framesS + ).frames[0] + + use_time = 0 + with mindietorch.npu.stream(stream): + start_time = time.time() + frames = pipe.ascendie_infer( + image, + decode_chunk_size=decode_chunk_size, + height= heightS, + width = widthS, + num_inference_steps=num_inference_steps, + num_frames = num_framesS + ).frames[0] + stream.synchronize() + use_time += time.time() - start_time + + print("Stable video diffusion use time:{}. Save dir is {}".format(use_time/1,save_dir)) + import datetime + now=datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + export_to_video(frames, r"{}/rocket_910B4_{}.mp4".format(save_dir,now), fps=args.fps) + + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_pipeline_parallel.py b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_pipeline_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c8a702caba0ecc0e31fd6cebcd2e31dbfa424a --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_pipeline_parallel.py @@ -0,0 +1,682 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +import json +import os +import time +import numpy as np +import torch +import mindietorch +from mindietorch import _enums +import pickle +from typing import Callable, Dict, List, Optional, Union + +from diffusers import StableVideoDiffusionPipeline +import diffusers.models.transformer_temporal +from diffusers.utils import load_image, export_to_video +from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing,_compute_padding,_filter2d,_gaussian,_gaussian_blur2d,_append_dims,inspect,tensor2vid,StableVideoDiffusionPipelineOutput +import PIL.Image +import torch +from diffusers.utils.torch_utils import randn_tensor + +from background_runtime import BackgroundRuntime, RuntimeIOInfo + +image_embed_time = 0 +heightS = 512 +widthS = 512 +num_framesS = 25 +Dshape = False + +print("height:{},width:{},num_frames:{},vae_decode dynamic shape:{}".format(heightS,widthS,num_framesS,Dshape)) + +class AIEStableVideoDiffusionPipeline(StableVideoDiffusionPipeline): + device_0 = None + device_1 = None + runtime = None + use_parallel_inferencing = False + unet_bg = None + + def parser_args(self, args): + self.args = args + if isinstance(args.device, list): + self.device_0, self.device_1 = args.device + print(f'Using parallel inferencing on device {self.device_0} and {self.device_1}') + else: + self.device_0 = args.device + self.is_init = False + + def compile_aie_model(self): + if self.is_init: + return + + in_channels = self.unet.config.in_channels + batch_size = self.args.batch_size + num_videos_per_prompt = 1 + height = heightS + width = widthS + num_frames = num_framesS if num_framesS is not None else self.unet.config.num_frames + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + seq_len = 1 + vae_encode_out=1024 + decode_chunk_size=self.args.decode_chunk_size + num_inference_steps=self.args.num_inference_steps + res=num_frames%decode_chunk_size + channels_latents=in_channels // 2 + + image_encoder_embed_path = os.path.join(self.args.output_dir, "image_encoder_embed/image_encoder_embed.ts") + if os.path.exists(image_encoder_embed_path): + self.compiled_image_encoder_embed = torch.jit.load(image_encoder_embed_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, "image_encoder_embed/image_encoder_embed.pt")).eval() + self.compiled_image_encoder_embed = ( + mindietorch.compile( + model, + inputs=[ + mindietorch.Input((batch_size, 3, 224, 224),dtype=mindietorch.dtype.FLOAT) + ], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_image_encoder_embed, image_encoder_embed_path) + + print(">>>>>>>>>>>>>>>image_encoder_embed2ts OK!") + + vae_encode_path = os.path.join(self.args.output_dir, "vae/vae_encode.ts") + if os.path.exists(vae_encode_path): + self.compiled_vae_encode = torch.jit.load(vae_encode_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, "vae/vae_encode.pt")).eval() + self.compiled_vae_encode = ( + mindietorch.compile( + model, + inputs=[ + mindietorch.Input((batch_size,3, height, width),dtype=mindietorch.dtype.FLOAT) + ], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_vae_encode, vae_encode_path) + + print(">>>>>>>>>>>>>>>vae_encode2ts OK!") + + model = torch.jit.load(os.path.join(self.args.output_dir, "vae/vae_decode.pt")).eval() + if Dshape: + vae_decode_path = os.path.join(self.args.output_dir, "vae/vae_decode.ts") + if os.path.exists(vae_decode_path): + self.compiled_vae_decode = torch.jit.load(vae_decode_path).eval() + else: + max_shape = (decode_chunk_size,channels_latents,height//vae_scale_factor,width//vae_scale_factor) + min_shape = (res,channels_latents,height//vae_scale_factor,width//vae_scale_factor) + inputs_vae = [] + inputs_vae.append([mindietorch.Input(max_shape,dtype=mindietorch.dtype.FLOAT)]) + if res !=0: + inputs_vae.append([mindietorch.Input(min_shape,dtype=mindietorch.dtype.FLOAT)]) + + self.compiled_vae_decode = ( + mindietorch.compile( + model, + inputs=inputs_vae, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_vae_decode, vae_decode_path) + else: + vae_decode_path_8 = os.path.join(self.args.output_dir, "vae/vae_decode8.ts") + vae_decode_path_1 = os.path.join(self.args.output_dir, "vae/vae_decode1.ts") + if os.path.exists(vae_decode_path_8) & os.path.exists(vae_decode_path_1): + self.compiled_vae_decode8 = torch.jit.load(vae_decode_path_8).eval() + self.compiled_vae_decode1 = torch.jit.load(vae_decode_path_1).eval() + else: + max_shape = (decode_chunk_size,channels_latents,height//vae_scale_factor,width//vae_scale_factor) + min_shape = (res,channels_latents,height//vae_scale_factor,width//vae_scale_factor) + inputs_vae = [] + inputs_vae.append([mindietorch.Input(max_shape,dtype=mindietorch.dtype.FLOAT)]) + self.compiled_vae_decode8 = ( + mindietorch.compile( + model, + inputs=inputs_vae, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_vae_decode8, vae_decode_path_8) + + inputs_vae.clear() + inputs_vae.append([mindietorch.Input(min_shape,dtype=mindietorch.dtype.FLOAT)]) + self.compiled_vae_decode1 = ( + mindietorch.compile( + model, + inputs=inputs_vae, + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_vae_decode1, vae_decode_path_1) + + print(">>>>>>>>>>>>>>>vae_decode2ts OK!") + + unet_compile_path = os.path.join(self.args.output_dir, "unet/unet_bs1.ts") + if os.path.exists(unet_compile_path): + self.compiled_unet_model = torch.jit.load(unet_compile_path).eval() + else: + model = torch.jit.load(os.path.join(self.args.output_dir, "unet/unet_bs2.pt")).eval() + + self.compiled_unet_model = ( + mindietorch.compile( + model, + inputs=[ + mindietorch.Input((batch_size*num_videos_per_prompt,num_frames,in_channels, height//vae_scale_factor,width//vae_scale_factor),dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((1,),dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size*num_videos_per_prompt,seq_len,vae_encode_out),dtype=mindietorch.dtype.FLOAT), + mindietorch.Input((batch_size*num_videos_per_prompt,3),dtype=mindietorch.dtype.FLOAT) + ], + allow_tensor_replace_int=True, + require_full_compilation=True, + truncate_long_and_double=True, + min_block_size=1, + soc_version="Ascend910B4", + precision_policy=_enums.PrecisionPolicy.FP16, + optimization_level=0 + ) + ) + torch.jit.save(self.compiled_unet_model, unet_compile_path) + + runtime_info = RuntimeIOInfo( + input_shapes=[ + (batch_size*num_videos_per_prompt,num_frames,in_channels, height//vae_scale_factor,width//vae_scale_factor), + (1,), + (batch_size*num_videos_per_prompt,seq_len,vae_encode_out), + (batch_size*num_videos_per_prompt,3) + ], + input_dtypes=[np.float32, np.float32, np.float32, np.float32], + output_shapes=[(batch_size*num_videos_per_prompt,num_frames,in_channels//2, height//vae_scale_factor,width//vae_scale_factor)], + output_dtypes=[np.float32] + ) + if hasattr(self, 'device_1'): + self.unet_bg = BackgroundRuntime.clone(self.device_1, unet_compile_path, runtime_info) + self.use_parallel_inferencing = True + + print(">>>>>>>>>>>>>>>unet2ts OK!") + + mindietorch.set_device(self.device_0) + self.is_init = True + + @torch.no_grad() + def ascendie_infer( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: int = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + ): + self.calmse=torch.nn.MSELoss(reduction='mean') + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + # self.check_inputs(image, height, width) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = max_guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width).contiguous() + noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs 创建时间嵌入向量 + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, torch.float32) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + # 8. Denoising loop + self._num_timesteps = len(timesteps) + + if self.use_parallel_inferencing and do_classifier_free_guidance: + # Split embeddings + negative_image_embeddings, image_embeddings = image_embeddings.chunk(2) + added_time_ids_1, added_time_ids = added_time_ids.chunk(2) + negative_image_latents, image_latents = image_latents.chunk(2) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + if not self.use_parallel_inferencing and do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimention [batch, num_frames, channels, height, width] + negative_latent_model_input = torch.cat([latent_model_input, negative_image_latents], dim=2) + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) + + # predict the noise residual + if self.use_parallel_inferencing and do_classifier_free_guidance: + self.unet_bg.infer_asyn([ + latent_model_input.numpy(), + t[None].numpy(), + image_embeddings.numpy(), + added_time_ids.numpy(), + ]) + + noise_pred_uncond= self.compiled_unet_model( + negative_latent_model_input.to(f'npu:{self.device_0}'), + t[None].to(f'npu:{self.device_0}'), + negative_image_embeddings.to(f'npu:{self.device_0}'), + added_time_ids_1.to(f'npu:{self.device_0}'), + ).to('cpu') + + # perform guidance + if do_classifier_free_guidance: + if self.use_parallel_inferencing: + noise_pred_cond = torch.from_numpy(self.unet_bg.wait_and_get_outputs()[0]) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size).to('cpu') + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) + + def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + # run inference + global image_embed_time + start =time.time() + + image_embeddings = self.compiled_image_encoder_embed(image.to(device=f'npu:{self.device_0}', dtype=dtype)).to('cpu') + + image_embed_time +=time.time()-start + + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + image_latents = self.compiled_vae_encode(image.to(f'npu:{self.device_0}')).to('cpu') + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + def decode_latents(self, latents, num_frames, decode_chunk_size=14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + if Dshape: + frame = self.compiled_vae_decode(latents[i : i + decode_chunk_size].to(f'npu:{self.device_0}')).to('cpu') + else: + if num_frames_in == decode_chunk_size: + frame = self.compiled_vae_decode8(latents[i : i + decode_chunk_size].to(f'npu:{self.device_0}')).to('cpu') + else: + frame = self.compiled_vae_decode1(latents[i : i + decode_chunk_size].to(f'npu:{self.device_0}')).to('cpu') + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + +def check_device_range_valid(value): + # if contain , split to int list + min_value = 0 + max_value = 255 + if ',' in value: + ilist = [int(v) for v in value.split(',')] + for ivalue in ilist[:2]: + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "{} of device:{} is invalid. valid value range is [{}, {}]" + .format(ivalue, value, min_value, max_value)) + return ilist[:2] + else: + # default as single int value + ivalue = int(value) + if ivalue < min_value or ivalue > max_value: + raise argparse.ArgumentTypeError( + "device:{} is invalid. valid value range is [{}, {}]".format( + ivalue, min_value, max_value)) + return ivalue + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model", + type=str, + default="./stable-video-diffusion-img2vid-xt", + help="Path or name of the pre-trained model.", + ) + parser.add_argument( + "--img_file", + type=str, + default="./rocket.png", + help="A png file of prompts for generating vedio.", + ) + parser.add_argument( + "--save_dir", + type=str, + default="./results", + help="Path to save result images.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="./models", + help="Path to save model pt.", + ) + parser.add_argument( + "--fps", + type=int, + default=7, + help="FPS", + ) + parser.add_argument( + "--device", + type=check_device_range_valid, + default=[0, 1], + help="NPU device id. Give 2 ids to enable parallel inferencing.", + ) + parser.add_argument( + "-bs", + "--batch_size", + type=int, + default=1, + help="Batch size." + ) + parser.add_argument( + "-vp", + "--num_videos_per_prompt", + type=int, + default=1, + help="num_videos_per_prompt." + ) + parser.add_argument( + "--decode_chunk_size", + type=int, + default=8, + help="decode_chunk_size." + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=25, + help="num_inference_steps." + ) + + return parser.parse_args() + + +def main(): + args = parse_arguments() + save_dir = args.save_dir + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + decode_chunk_size=args.decode_chunk_size + num_inference_steps=args.num_inference_steps + + pipe = AIEStableVideoDiffusionPipeline.from_pretrained(args.model).to("cpu") + pipe.parser_args(args) + + pipe.compile_aie_model() + + # 加载img及预处理 + image = load_image(args.img_file) + image = image.resize((heightS, widthS)) + + print('warming up ~~~~~') + stream = mindietorch.npu.Stream("npu:" + str(args.device[0])) + with mindietorch.npu.stream(stream): + frames = pipe.ascendie_infer( + image, + decode_chunk_size=decode_chunk_size, + height= heightS, + width = widthS, + num_inference_steps=num_inference_steps, + num_frames = num_framesS + ).frames[0] + + use_time = 0 + with mindietorch.npu.stream(stream): + start_time = time.time() + frames = pipe.ascendie_infer( + image, + decode_chunk_size=decode_chunk_size, + height= heightS, + width = widthS, + num_inference_steps=num_inference_steps, + num_frames = num_framesS + ).frames[0] + stream.synchronize() + use_time += time.time() - start_time + + print("Stable video diffusion use time:{}. Save dir is {}".format(use_time/1,save_dir)) + import datetime + now=datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + export_to_video(frames, r"{}/rocket_910B4_{}.mp4".format(save_dir,now), fps=args.fps) + + if hasattr(pipe, 'device_1'): + if (pipe.unet_bg): + pipe.unet_bg.stop() + + mindietorch.finalize() + + +if __name__ == "__main__": + main() diff --git a/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_transformer_patch.py b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_transformer_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..b4eb3ca9b310d4bbabbd78cf7f76e9bd6c180696 --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/stable_video_diffusion_transformer_patch.py @@ -0,0 +1,28 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import diffusers + + +def main(): + diffusers_path = diffusers.__path__ + diffusers_version = diffusers.__version__ + + assert diffusers_version is not '0.26.3', "expectation diffusers==0.26.3" + os.system(f'patch -p0 {diffusers_path[0]}/models/transformers/transformer_temporal.py transformer_temporal.patch') + + +if __name__ == '__main__': + main() diff --git a/MindIE/MultiModal/StableVideoDiffusion/transformer_temporal.patch b/MindIE/MultiModal/StableVideoDiffusion/transformer_temporal.patch new file mode 100644 index 0000000000000000000000000000000000000000..89675cbd7e97db294b9372fba70560d7fb9de1f8 --- /dev/null +++ b/MindIE/MultiModal/StableVideoDiffusion/transformer_temporal.patch @@ -0,0 +1,16 @@ +--- transformer_temporal.py 2024-05-15 08:25:05.724000000 +0000 ++++ transformer_temporal_new.py 2024-05-15 08:25:05.724000000 +0000 +@@ -311,9 +311,10 @@ + time_context_first_timestep = time_context[None, :].reshape( + batch_size, num_frames, -1, time_context.shape[-1] + )[:, 0] +- time_context = time_context_first_timestep[None, :].broadcast_to( +- height * width, batch_size, 1, time_context.shape[-1] +- ) ++ # time_context = time_context_first_timestep[None, :].broadcast_to( ++ # height * width, batch_size, 1, time_context.shape[-1] ++ # ) ++ time_context = time_context_first_timestep[None, :].expand([height * width, batch_size, 1, time_context.shape[-1]]) + time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) + + residual = hidden_states