diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/.gitignore b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ffe2a35470f245abe8953a451940be404c773560 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/.gitignore @@ -0,0 +1,4 @@ +*__pycache__/ +.idea +output* +*.so \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..074eceb5b96b59dbb913d737e16ee126aa75eb20 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py @@ -0,0 +1,121 @@ +import argparse +import os +import torch +import torch_npu +import omegaconf +from omegaconf import OmegaConf +from sat.helpers import print_rank0 + + +def add_model_config_args(parser): + """Model arguments""" + + group = parser.add_argument_group("model", "model configuration") + group.add_argument("--base", type=str, nargs="*", help="config for input and saving") + group.add_argument("--device", type=int, default=-1) + group.add_argument("--log-image", type=bool, default=True) + return parser + + +def add_sampling_config_args(parser): + """Sampling configurations""" + + group = parser.add_argument_group("sampling", "Sampling Configurations") + group.add_argument("--output-dir", type=str, default="samples") + group.add_argument("--input-dir", type=str, default=None) + group.add_argument("--input-type", type=str, default="cli") + group.add_argument("--input-file", type=str, default="input.txt") + group.add_argument("--final-size", type=int, default=2048) + group.add_argument("--sdedit", action="store_true") + group.add_argument("--grid-num-rows", type=int, default=1) + group.add_argument("--force-inference", action="store_true") + group.add_argument("--lcm_steps", type=int, default=None) + group.add_argument("--sampling_image_height", type=int, default=480) + group.add_argument("--sampling_image_width", type=int, default=720) + group.add_argument("--sampling-num-frames", type=int, default=32) + group.add_argument("--sampling-fps", type=int, default=8) + group.add_argument("--only-save-latents", type=bool, default=False) + group.add_argument("--only-log-video-latents", type=bool, default=False) + group.add_argument("--latent-channels", type=int, default=32) + group.add_argument("--image2video", action="store_true") + + return parser + + +def add_inference_args(parser): + """Inference arguments.""" + + group = parser.add_argument_group('inference', 'inference configurations') + group.add_argument('--batch-size', type=int, default=4, + help='batch size on a single GPU. batch-size * world_size = total batch_size.') + group.add_argument('--mode', type=str, + default='pretrain', + choices=['pretrain', # from_scratch / load ckpt for continue pretraining. + 'finetune', # finetuning, auto-warmup 100 iters, new exp name. + 'inference' # don't train. + ], + help='what type of task to use, will influence auto-warmup, exp name, iteration') + group.add_argument('--load', type=str, default=None, + help='Path to a directory containing a model checkpoint.') + group.add_argument('--seed', type=int, default=1234, help='random seed') + group.add_argument('--zero-stage', type=int, default=0, choices=[0, 1, 2, 3], + help='deepspeed ZeRO stage. 0 means no ZeRO.') + group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode') + group.add_argument('--bf16', action='store_true', help='Run model in bf16 mode') + + return parser + + +def get_args(args_list=None, parser=None): + """Parse all the args.""" + if parser is None: + parser = argparse.ArgumentParser(description="sat") + else: + assert isinstance(parser, argparse.ArgumentParser) + parser = add_model_config_args(parser) + parser = add_sampling_config_args(parser) + parser = add_inference_args(parser) + + args = parser.parse_args(args_list) + args = process_config_to_args(args) + + args.npu = torch.npu.is_available() + + args.rank = int(os.getenv("RANK", "0")) + args.world_size = int(os.getenv("WORLD_SIZE", "1")) + args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun + + if args.device == -1: + if torch.npu.device_count() == 0: + args.device = "cpu" + else: + args.device = "npu" + + if args.rank == 0: + print_rank0("using world size: {}".format(args.world_size)) + + assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16." + + return args + + +def process_config_to_args(args): + """Fetch args from only --base""" + + configs = [OmegaConf.load(cfg) for cfg in args.base] + config = OmegaConf.merge(*configs) + + args_config = config.pop("args", OmegaConf.create()) + for key in args_config: + if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig): + arg = OmegaConf.to_object(args_config[key]) + else: + arg = args_config[key] + if hasattr(args, key): + setattr(args, key, arg) + + if "model" in config: + model_config = config.pop("model", OmegaConf.create()) + args.model_config = model_config + + return args diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/cogvideox_5b.yaml b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/cogvideox_5b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..73c2ea63eba3d18573e12538d50c70d7d191e476 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/cogvideox_5b.yaml @@ -0,0 +1,139 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 # different from cogvideox_2b_infer.yaml + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 3072 # different from cogvideox_2b_infer.yaml + adm_in_channels: 256 + num_attention_heads: 48 # different from cogvideox_2b_infer.yaml + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml + params: + hidden_size_head: 64 + text_length: 226 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "/home/p30061221/CogVideoX-5B-SAT/5b-cogvideo" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "/home/p30061221/CogVideoX-5B-SAT/5b-cogvideo/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1d882fb1b4cebbed110d00067ad6ae6aabce845 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml @@ -0,0 +1,16 @@ +args: + image2video: False # True for image2video, False for text2video + latent_channels: 16 + mode: inference + load: "/home/p30061221/CogVideoX-5B-SAT/5b-cogvideo/transformer" # This is for Full model without lora adapter + batch_size: 1 + input_type: txt + input_file: configs/test.txt + sampling_image_height: 480 + args.sampling_image_width: 720 + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 + # fp16: True # For CogVideoX-2B + bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V + output_dir: outputs/ + force_inference: True \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/test.txt b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/test.txt new file mode 100644 index 0000000000000000000000000000000000000000..fbb326177040e9b34abcb606731f6cd0e2b62da1 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/test.txt @@ -0,0 +1 @@ +In the haunting backdrop of a warIn the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict. \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py new file mode 100644 index 0000000000000000000000000000000000000000..e9efbdb0db2f6f6d9d75d9a174ce5f34a4c0bc38 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py @@ -0,0 +1,133 @@ +import math +from typing import Dict, List, Tuple, Union + +import torch +from torch import nn + +from sgm.modules import UNCONDITIONAL_CONFIG +from sgm.modules.autoencoding.temporal_ae import VideoDecoder +from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER +from sgm.util import ( + default, + disabled_train, + get_obj_from_str, + instantiate_from_config, +) + + +class SATVideoDiffusionEngine(nn.Module): + def __init__(self, args, **kwargs): + super().__init__() + + model_config = args.model_config + # model args preprocess + network_config = model_config.get("network_config", None) + network_wrapper = model_config.get("network_wrapper", None) + denoiser_config = model_config.get("denoiser_config", None) + sampler_config = model_config.get("sampler_config", None) + conditioner_config = model_config.get("conditioner_config", None) + first_stage_config = model_config.get("first_stage_config", None) + scale_factor = model_config.get("scale_factor", 1.0) + latent_input = model_config.get("latent_input", False) + disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False) + compile_model = model_config.get("compile_model", False) + en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None) + + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + if args.fp16: + dtype = torch.float16 + dtype_str = "fp16" + elif args.bf16: + dtype = torch.bfloat16 + dtype_str = "bf16" + else: + dtype = torch.float32 + dtype_str = "fp32" + self.dtype = dtype + self.dtype_str = dtype_str + + network_config["params"]["dtype"] = dtype_str + model = instantiate_from_config(network_config) + self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( + model, compile_model=compile_model, dtype=dtype + ) + self.denoiser = instantiate_from_config(denoiser_config) + self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None + self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG)) + self._init_first_stage(first_stage_config) + self.latent_input = latent_input + self.scale_factor = scale_factor + self.disable_first_stage_autocast = disable_first_stage_autocast + self.device = args.device + + def _init_first_stage(self, config): + model = instantiate_from_config(config).eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + self.first_stage_model = model + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) + n_rounds = math.ceil(z.shape[0] / n_samples) + all_out = [] + with torch.autocast("npu", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + if isinstance(self.first_stage_model.decoder, VideoDecoder): + kwargs = {"timesteps": len(z[n * n_samples: (n + 1) * n_samples])} + else: + kwargs = {} + out = self.first_stage_model.decode(z[n * n_samples: (n + 1) * n_samples], **kwargs) + all_out.append(out) + out = torch.cat(all_out, dim=0) + return out + + @torch.no_grad() + def encode_first_stage(self, x, batch): + frame = x.shape[2] + + if frame > 1 and self.latent_input: + x = x.permute(0, 2, 1, 3, 4).contiguous() + return x * self.scale_factor # already encoded + + n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) + n_rounds = math.ceil(x.shape[0] / n_samples) + all_out = [] + with torch.autocast("npu", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + out = self.first_stage_model.encode(x[n * n_samples: (n + 1) * n_samples]) + all_out.append(out) + z = torch.cat(all_out, dim=0) + z = self.scale_factor * z + return z + + @torch.no_grad() + def sample( + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + prefix=None, + concat_images=None, + **kwargs, + ): + randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) + if hasattr(self, "seeded_noise"): + randn = self.seeded_noise(randn) + + if prefix is not None: + randn = torch.cat([prefix, randn[:, prefix.shape[1]:]], dim=1) + + scale = None + scale_emb = None + + denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser( + self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs + ) + + samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb) + samples = samples.to(self.dtype) + return samples diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py new file mode 100644 index 0000000000000000000000000000000000000000..494a3786afc332d43300dd6680d95ffff078c1e0 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py @@ -0,0 +1,923 @@ +from functools import partial +from einops import rearrange, repeat +import numpy as np +import math + +import torch +import torch_npu +from torch import nn +import torch.nn.functional as F + +torch.ops.load_library("./sgm/modules/libPTAExtensionOPS.so") + +from sat.model.base_model import BaseModel, non_conflict +from sat.model.mixins import BaseMixin +from sat.mpu.layers import ColumnParallelLinear +from sgm.util import instantiate_from_config + +from sgm.modules.diffusionmodules.openaimodel import Timestep +from sgm.modules.diffusionmodules.util import ( + linear, + timestep_embedding, +) +from sat.ops.layernorm import LayerNorm, RMSNorm +from utils.comm import all_to_all_bsh +from utils.parallel_mgr import get_sp_group, get_sequence_parallel_world_size, get_sequence_parallel_rank + +SP_PAD = 0 + +class ImagePatchEmbeddingMixin(BaseMixin): + def __init__( + self, + in_channels, + hidden_size, + patch_size, + bias=True, + text_hidden_size=None, + ): + super().__init__() + self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) + if text_hidden_size is not None: + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + else: + self.text_proj = None + + def word_embedding_forward(self, input_ids, **kwargs): + # now is 3d patch + images = kwargs["images"] # (b,t,c,h,w) + B, T = images.shape[:2] + emb = images.view(-1, *images.shape[2:]) + emb = self.proj(emb) # ((b t),d,h/2,w/2) + emb = emb.view(B, T, *emb.shape[1:]) + emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d) + emb = rearrange(emb, "b t n d -> b (t n) d") + + if self.text_proj is not None: + text_emb = self.text_proj(kwargs["encoder_outputs"]) + emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d) + + emb = emb.contiguous() + return emb # (b,n_t+t*n_i,d) + + def reinit(self, parent_model=None): + w = self.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.proj.bias, 0) + del self.transformer.word_embeddings + + +def attention_fn_fa_score(query_layer, key_layer, value_layer, attention_mask, + attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): + # expand head dim to query dim, if necessary + # only useful for multi-query attention + num_query_heads = query_layer.shape[1] # [b, np, s, hn] + attn_output = torch_npu.npu_fusion_attention(query_layer, key_layer, value_layer, + input_layout="BNSD", + scale=1 / math.sqrt(query_layer.shape[-1]), + head_num=num_query_heads + )[0] + return attn_output + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_height, + grid_width, + t_size, + cls_token=False, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, +): + """ + grid_size: int of the grid height and width + t_size: int of the temporal size + return: + pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + assert embed_dim % 4 == 0 + embed_dim_spatial = embed_dim // 4 * 3 + embed_dim_temporal = embed_dim // 4 + + # spatial + grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation + grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_height, grid_width]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # temporal + grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # concate: [T, H, W] order + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4] + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) + # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] + + return pos_embed # [T, H*W, D] + + +def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_height, dtype=np.float32) + grid_w = np.arange(grid_width, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_height, grid_width]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class Basic2DPositionEmbeddingMixin(BaseMixin): + def __init__(self, height, width, compressed_num_frames, hidden_size, text_length=0): + super().__init__() + self.height = height + self.width = width + self.spatial_length = height * width + self.pos_embedding = nn.Parameter( + torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False + ) + + def position_embedding_forward(self, position_ids, **kwargs): + return self.pos_embedding + + def reinit(self, parent_model=None): + del self.transformer.position_embeddings + pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width) + self.pos_embedding.data[:, -self.spatial_length:].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + +class Basic3DPositionEmbeddingMixin(BaseMixin): + def __init__( + self, + height, + width, + compressed_num_frames, + hidden_size, + text_length=0, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, + ): + super().__init__() + self.height = height + self.width = width + self.text_length = text_length + self.compressed_num_frames = compressed_num_frames + self.spatial_length = height * width + self.num_patches = height * width * compressed_num_frames + self.pos_embedding = nn.Parameter( + torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False + ) + self.height_interpolation = height_interpolation + self.width_interpolation = width_interpolation + self.time_interpolation = time_interpolation + + def position_embedding_forward(self, position_ids, **kwargs): + if kwargs["images"].shape[1] == 1: + return self.pos_embedding[:, : self.text_length + self.spatial_length] + + return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]] + + def reinit(self, parent_model=None): + del self.transformer.position_embeddings + pos_embed = get_3d_sincos_pos_embed( + self.pos_embedding.shape[-1], + self.height, + self.width, + self.compressed_num_frames, + height_interpolation=self.height_interpolation, + width_interpolation=self.width_interpolation, + time_interpolation=self.time_interpolation, + ) + pos_embed = torch.from_numpy(pos_embed).float() + pos_embed = rearrange(pos_embed, "t n d -> (t n) d") + self.pos_embedding.data[:, -self.num_patches:].copy_(pos_embed) + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +class Rotary3DPositionEmbeddingMixin(BaseMixin): + def __init__( + self, + height, + width, + compressed_num_frames, + hidden_size, + hidden_size_head, + text_length, + theta=10000, + rot_v=False, + learnable_pos_embed=False, + ): + super().__init__() + self.rot_v = rot_v + + dim_t = hidden_size_head // 4 + dim_h = hidden_size_head // 8 * 3 + dim_w = hidden_size_head // 8 * 3 + + freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) + freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) + freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) + + grid_t = torch.arange(compressed_num_frames, dtype=torch.float32) + grid_h = torch.arange(height, dtype=torch.float32) + grid_w = torch.arange(width, dtype=torch.float32) + + freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) + freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) + freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) + + freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + freqs = rearrange(freqs, "t h w d -> (t h w) d") + + freqs = freqs.contiguous() + freqs_sin = freqs.sin() + freqs_cos = freqs.cos() + self.register_buffer("freqs_sin", freqs_sin) + self.register_buffer("freqs_cos", freqs_cos) + + self.text_length = text_length + if learnable_pos_embed: + num_patches = height * width * compressed_num_frames + text_length + self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True) + else: + self.pos_embedding = None + + def rotary(self, t, **kwargs): + seq_len = t.shape[2] + freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) + freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) + return torch.ops.mindie.rope_mindie_sd(t, freqs_cos, freqs_sin, mode=1) + + def position_embedding_forward(self, position_ids, **kwargs): + if self.pos_embedding is not None: + return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]] + else: + return None + + @non_conflict + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + old_impl=attention_fn_fa_score, + **kwargs, + ): + if get_sequence_parallel_world_size() > 1: + batch_size = query_layer.shape[0] + num_heads = query_layer.shape[1] + head_dim = query_layer.shape[-1] + h_size = num_heads * head_dim + sp_size = get_sequence_parallel_world_size() + h_size_sp = h_size // sp_size + query_layer = rearrange(query_layer, "b n s d -> (b s) n d") + key_layer = rearrange(key_layer, "b n s d -> (b s) n d") + value_layer = rearrange(value_layer, "b n s d -> (b s) n d") + + _, query_layer = all_to_all_bsh(query_layer, scatter_dim=1, gather_dim=0) + _, key_layer = all_to_all_bsh(key_layer, scatter_dim=1, gather_dim=0) + _, value_layer = all_to_all_bsh(value_layer, scatter_dim=1, gather_dim=0) + + query_layer = query_layer.view(batch_size, -1, h_size_sp) + key_layer = key_layer.view(batch_size, -1, h_size_sp) + value_layer = value_layer.view(batch_size, -1, h_size_sp) + query_layer = torch.chunk(query_layer, sp_size, dim=1) + text_layer = query_layer[0][:, :self.text_length] + stack = [text_layer] + for image_slice in query_layer: + stack.append(image_slice[:, self.text_length:]) + query_layer = torch.cat(stack, dim=1) + + key_layer = torch.chunk(key_layer, sp_size, dim=1) + text_layer = key_layer[0][:, :self.text_length] + stack = [text_layer] + for image_slice in key_layer: + stack.append(image_slice[:, self.text_length:]) + key_layer = torch.cat(stack, dim=1) + + value_layer = torch.chunk(value_layer, sp_size, dim=1) + text_layer = value_layer[0][:, :self.text_length] + stack = [text_layer] + for image_slice in value_layer: + stack.append(image_slice[:, self.text_length:]) + value_layer = torch.cat(stack, dim=1) + + if SP_PAD > 0: + query_layer = query_layer[:, :(-SP_PAD)] + key_layer = key_layer[:, :(-SP_PAD)] + value_layer = value_layer[:, :(-SP_PAD)] + + query_layer = rearrange(query_layer, "b s (n d) -> b n s d", n=num_heads // sp_size, d=head_dim) + key_layer = rearrange(key_layer, "b s (n d) -> b n s d", n=num_heads // sp_size, d=head_dim) + value_layer = rearrange(value_layer, "b s (n d) -> b n s d", n=num_heads // sp_size, d=head_dim) + + query_layer[:, :, self.text_length:] = self.rotary(query_layer[:, :, self.text_length:]) + key_layer[:, :, self.text_length:] = self.rotary(key_layer[:, :, self.text_length:]) + if self.rot_v: + value_layer[:, :, self.text_length:] = self.rotary(value_layer[:, :, self.text_length:]) + attention_output = attention_fn_fa_score( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + if SP_PAD > 0: + padding_shape = (*attention_output.shape[:2], SP_PAD, attention_output.shape[3]) + padding = torch.zeros(padding_shape, dtype=attention_output.dtype, device=attention_output.device) + attention_output = torch.cat([attention_output, padding], dim=2) + text_layer = attention_output[:, :, :self.text_length] + image_layer = attention_output[:, :, self.text_length:] + image_layer_ls = torch.chunk(image_layer, sp_size, dim=2) + stack = [] + for i in range(sp_size): + stack.append(text_layer.clone()) + stack.append(image_layer_ls[i]) + attention_output = torch.cat(stack, dim=2) + attention_output = rearrange(attention_output, "b n s d -> (b s) n d") + attention_output = attention_output.view(-1, num_heads // sp_size, head_dim).contiguous() + _, attention_output = all_to_all_bsh(attention_output, scatter_dim=0, gather_dim=1) + attention_output = attention_output.view(batch_size, -1, h_size) + attention_output = rearrange(attention_output, "b s (n d) -> b n s d", n=num_heads, d=head_dim) + else: + query_layer[:, :, self.text_length:] = self.rotary(query_layer[:, :, self.text_length:]) + key_layer[:, :, self.text_length:] = self.rotary(key_layer[:, :, self.text_length:]) + if self.rot_v: + value_layer[:, :, self.text_length:] = self.rotary(value_layer[:, :, self.text_length:]) + + attention_output = attention_fn_fa_score( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + return attention_output + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): + """ + x: (N, T/2 * S, patch_size**3 * C) + imgs: (N, T, H, W, C) + """ + if rope_position_ids is not None: + assert NotImplementedError + # do pix2struct unpatchify + L = x.shape[1] + x = x.reshape(shape=(x.shape[0], L, p, p, c)) + x = torch.einsum("nlpqc->ncplq", x) + imgs = x.reshape(shape=(x.shape[0], c, p, L * p)) + else: + b = x.shape[0] + imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p) + + return imgs + + +class FinalLayerMixin(BaseMixin): + def __init__( + self, + hidden_size, + time_embed_dim, + patch_size, + out_channels, + latent_width, + latent_height, + elementwise_affine, + ): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + self.out_channels = out_channels + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) + + self.spatial_length = latent_width * latent_height // patch_size ** 2 + self.latent_width = latent_width + self.latent_height = latent_height + + def final_forward(self, logits, **kwargs): + x, emb = logits[:, kwargs["text_length"]:, :], kwargs["emb"] # x:(b,(t n),d) + shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + return unpatchify( + x, + c=self.out_channels, + p=self.patch_size, + w=self.latent_width // self.patch_size, + h=self.latent_height // self.patch_size, + rope_position_ids=kwargs.get("rope_position_ids", None), + **kwargs, + ) + + def reinit(self, parent_model=None): + nn.init.xavier_uniform_(self.linear.weight) + nn.init.constant_(self.linear.bias, 0) + + +class SwiGLUMixin(BaseMixin): + def __init__(self, num_layers, in_features, hidden_features, bias=False): + super().__init__() + self.w2 = nn.ModuleList( + [ + ColumnParallelLinear( + in_features, + hidden_features, + gather_output=False, + bias=bias, + module=self, + name="dense_h_to_4h_gate", + ) + for i in range(num_layers) + ] + ) + + def mlp_forward(self, hidden_states, **kw_args): + x = hidden_states + origin = self.transformer.layers[kw_args["layer_id"]].mlp + x1 = origin.dense_h_to_4h(x) + x2 = self.w2[kw_args["layer_id"]](x) + hidden = origin.activation_func(x2) * x1 + x = origin.dense_4h_to_h(hidden) + return x + + +class AdaLNMixin(BaseMixin): + def __init__( + self, + width, + height, + hidden_size, + num_layers, + time_embed_dim, + compressed_num_frames, + qk_ln=True, + hidden_size_head=None, + elementwise_affine=True, + ): + super().__init__() + self.num_layers = num_layers + self.width = width + self.height = height + self.compressed_num_frames = compressed_num_frames + + self.adaLN_modulations = nn.ModuleList( + [nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)] + ) + + self.qk_ln = qk_ln + if qk_ln: + self.query_layernorm_list = nn.ModuleList( + [ + LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) + for _ in range(num_layers) + ] + ) + self.key_layernorm_list = nn.ModuleList( + [ + LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) + for _ in range(num_layers) + ] + ) + + def layer_forward( + self, + hidden_states, + mask, + *args, + **kwargs, + ): + text_length = kwargs["text_length"] + # hidden_states (b,(n_t+t*n_i),d) + text_hidden_states = hidden_states[:, :text_length] # (b,n,d) + img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d) + + layer = self.transformer.layers[kwargs["layer_id"]] + adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]] + + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + text_shift_msa, + text_scale_msa, + text_gate_msa, + text_shift_mlp, + text_scale_mlp, + text_gate_mlp, + ) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1) + gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = ( + gate_msa.unsqueeze(1), + gate_mlp.unsqueeze(1), + text_gate_msa.unsqueeze(1), + text_gate_mlp.unsqueeze(1), + ) + + # self full attention (b,(t n),d) + if get_sequence_parallel_world_size() > 1: + global SP_PAD + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + img_length = img_hidden_states.shape[1] + SP_PAD = math.ceil(img_length / sp_world_size) * sp_world_size - img_length + if SP_PAD > 0: + padding_shape = (img_hidden_states.shape[0], SP_PAD, *img_hidden_states.shape[2:]) + padding = torch.zeros(padding_shape, dtype=img_hidden_states.dtype, device=img_hidden_states.device) + img_hidden_states = torch.cat([img_hidden_states, padding], dim=1) + img_hidden_states = torch.chunk(img_hidden_states, sp_world_size, dim=1)[sp_rank] + + img_attention_input = layer.input_layernorm(img_hidden_states) + text_attention_input = layer.input_layernorm(text_hidden_states) + img_attention_input = modulate(img_attention_input, shift_msa, scale_msa) + text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa) + + attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d) + attention_output = layer.attention(attention_input, mask, **kwargs) + text_attention_output = attention_output[:, :text_length] # (b,n,d) + img_attention_output = attention_output[:, text_length:] # (b,(t n),d) + if self.transformer.layernorm_order == "sandwich": + text_attention_output = layer.third_layernorm(text_attention_output) + img_attention_output = layer.third_layernorm(img_attention_output) + img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d) + text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d) + + # mlp (b,(t n),d) + img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d) + text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d) + img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp) + text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp) + mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d + mlp_output = layer.mlp(mlp_input, **kwargs) + img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d) + text_mlp_output = mlp_output[:, :text_length] # language (b,n,d) + if self.transformer.layernorm_order == "sandwich": + text_mlp_output = layer.fourth_layernorm(text_mlp_output) + img_mlp_output = layer.fourth_layernorm(img_mlp_output) + + img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d) + text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d) + + hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d) + if get_sequence_parallel_world_size() > 1: + sp_world_size = get_sequence_parallel_world_size() + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + text_hidden_states = hidden_states[:, :text_length] + stack = [text_hidden_states] + hidden_states_ls = torch.chunk(hidden_states, sp_world_size, dim=1) + for hidden_states_slice in hidden_states_ls: + stack.append(hidden_states_slice[:, text_length:]) + hidden_states = torch.cat(stack, dim=1) + if SP_PAD > 0: + hidden_states = hidden_states[:, :(-SP_PAD)] + return hidden_states + + def reinit(self, parent_model=None): + for layer in self.adaLN_modulations: + nn.init.constant_(layer[-1].weight, 0) + nn.init.constant_(layer[-1].bias, 0) + + @non_conflict + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + old_impl=attention_fn_fa_score, + **kwargs, + ): + if self.qk_ln: + query_layernorm = self.query_layernorm_list[kwargs["layer_id"]] + key_layernorm = self.key_layernorm_list[kwargs["layer_id"]] + query_layer = query_layernorm(query_layer) + key_layer = key_layernorm(key_layer) + + return old_impl( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + + +str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + +class DiffusionTransformer(BaseModel): + def __init__( + self, + transformer_args, + num_frames, + time_compressed_rate, + latent_width, + latent_height, + patch_size, + in_channels, + out_channels, + hidden_size, + num_layers, + num_attention_heads, + elementwise_affine, + time_embed_dim=None, + num_classes=None, + modules={}, + input_time="adaln", + adm_in_channels=None, + parallel_output=True, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, + use_SwiGLU=False, + use_RMSNorm=False, + zero_init_y_embed=False, + **kwargs, + ): + self.latent_width = latent_width + self.latent_height = latent_height + self.patch_size = patch_size + self.num_frames = num_frames + self.time_compressed_rate = time_compressed_rate + self.spatial_length = latent_width * latent_height // patch_size ** 2 + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_size = hidden_size + self.model_channels = hidden_size + self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size + self.num_classes = num_classes + self.adm_in_channels = adm_in_channels + self.input_time = input_time + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.is_decoder = transformer_args.is_decoder + self.elementwise_affine = elementwise_affine + self.height_interpolation = height_interpolation + self.width_interpolation = width_interpolation + self.time_interpolation = time_interpolation + self.inner_hidden_size = hidden_size * 4 + self.zero_init_y_embed = zero_init_y_embed + try: + self.dtype = str_to_dtype[kwargs.pop("dtype")] + except: + self.dtype = torch.float32 + + if use_SwiGLU: + kwargs["activation_func"] = F.silu + elif "activation_func" not in kwargs: + approx_gelu = nn.GELU(approximate="tanh") + kwargs["activation_func"] = approx_gelu + + if use_RMSNorm: + kwargs["layernorm"] = RMSNorm + else: + kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6) + + transformer_args.num_layers = num_layers + transformer_args.hidden_size = hidden_size + transformer_args.num_attention_heads = num_attention_heads + transformer_args.parallel_output = parallel_output + super().__init__(args=transformer_args, transformer=None, **kwargs) + + module_configs = modules + self._build_modules(module_configs) + + if use_SwiGLU: + self.add_mixin( + "swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True + ) + + def _build_modules(self, module_configs): + model_channels = self.hidden_size + # time_embed_dim = model_channels * 4 + time_embed_dim = self.time_embed_dim + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + elif self.num_classes == "sequential": + assert self.adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(self.adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + if self.zero_init_y_embed: + nn.init.constant_(self.label_emb[0][2].weight, 0) + nn.init.constant_(self.label_emb[0][2].bias, 0) + else: + raise ValueError() + + pos_embed_config = module_configs["pos_embed_config"] + self.add_mixin( + "pos_embed", + instantiate_from_config( + pos_embed_config, + height=self.latent_height // self.patch_size, + width=self.latent_width // self.patch_size, + compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, + hidden_size=self.hidden_size, + ), + reinit=True, + ) + + patch_embed_config = module_configs["patch_embed_config"] + self.add_mixin( + "patch_embed", + instantiate_from_config( + patch_embed_config, + patch_size=self.patch_size, + hidden_size=self.hidden_size, + in_channels=self.in_channels, + ), + reinit=True, + ) + if self.input_time == "adaln": + adaln_layer_config = module_configs["adaln_layer_config"] + self.add_mixin( + "adaln_layer", + instantiate_from_config( + adaln_layer_config, + height=self.latent_height // self.patch_size, + width=self.latent_width // self.patch_size, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, + hidden_size_head=self.hidden_size // self.num_attention_heads, + time_embed_dim=self.time_embed_dim, + elementwise_affine=self.elementwise_affine, + ), + ) + else: + raise NotImplementedError + + final_layer_config = module_configs["final_layer_config"] + self.add_mixin( + "final_layer", + instantiate_from_config( + final_layer_config, + hidden_size=self.hidden_size, + patch_size=self.patch_size, + out_channels=self.out_channels, + time_embed_dim=self.time_embed_dim, + latent_width=self.latent_width, + latent_height=self.latent_height, + elementwise_affine=self.elementwise_affine, + ), + reinit=True, + ) + + if "lora_config" in module_configs: + lora_config = module_configs["lora_config"] + self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) + + return + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + b, t, d, h, w = x.shape + if x.dtype != self.dtype: + x = x.to(self.dtype) + + # This is not use in inference + if "concat_images" in kwargs and kwargs["concat_images"] is not None: + if kwargs["concat_images"].shape[0] != x.shape[0]: + concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1) + else: + concat_images = kwargs["concat_images"] + x = torch.cat([x, concat_images], dim=2) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + # assert y.shape[0] == x.shape[0] + assert x.shape[0] % y.shape[0] == 0 + y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) + emb = emb + self.label_emb(y) + + kwargs["seq_length"] = t * h * w // (self.patch_size ** 2) + kwargs["images"] = x + kwargs["emb"] = emb + kwargs["encoder_outputs"] = context + kwargs["text_length"] = context.shape[1] + + kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) + output = super().forward(**kwargs)[0] + return output diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh new file mode 100755 index 0000000000000000000000000000000000000000..6683119da2806f1a2f0cdb44831b7e84aa3cb918 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh @@ -0,0 +1,7 @@ +#! /bin/bash +export TASK_QUEUE_ENABLE=2 +export HCCL_OP_EXPANSION_MODE="AIV" +torchrun --nnodes=1 --nproc_per_node 8 --master_port 29503 \ + sample_video.py \ + --base configs/cogvideox_5b.yaml configs/inference.yaml \ + --seed 42 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py new file mode 100644 index 0000000000000000000000000000000000000000..79d3211ed3939c4fb11c363f14d4df811f7506fa --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py @@ -0,0 +1,303 @@ +import os +import math +import time +import argparse +from arguments import get_args +from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_state, \ + get_sequence_parallel_rank +from sat.arguments import set_random_seed + +from typing import List, Union +from tqdm import tqdm +from omegaconf import ListConfig +import imageio + +import torch +import torch_npu +import torch.distributed as dist +import numpy as np +from einops import rearrange +import torchvision.transforms as TT + +from sat.model.base_model import get_model +from sat.training.model_io import load_checkpoint + +from diffusion_video import SATVideoDiffusionEngine + +from torchvision.transforms.functional import center_crop, resize +from torchvision.transforms import InterpolationMode +from PIL import Image + + +def read_from_cli(): + cnt = 0 + try: + while True: + x = input("Please input English text (Ctrl-D quit): ") + yield x.strip(), cnt + cnt += 1 + except EOFError as e: + pass + + +def read_from_file(p): + with open(p, "r") as fin: + cnt = -1 + for l in fin: + cnt += 1 + yield l.strip(), cnt + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + +def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="npu"): + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() + batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + else: + batch[key] = value_dict[key] + + if T is not None: + batch["num_video_frames"] = T + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None): + os.makedirs(save_path, exist_ok=True) + + for i, vid in enumerate(video_batch): + gif_frames = [] + for frame in vid: + frame = rearrange(frame, "c h w -> h w c") + frame = (255.0 * frame).cpu().numpy().astype(np.uint8) + gif_frames.append(frame) + now_save_path = os.path.join(save_path, f"{i:06d}.mp4") + with imageio.get_writer(now_save_path, fps=fps) as writer: + for frame in gif_frames: + writer.append_data(frame) + + +def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + +def sampling_main(args, model_cls): + torch.npu.config.allow_internal_format = False + if args.world_size > 1: + sp_degree = args.world_size // 2 + parallel_config = ParallelConfig(sp_degree=sp_degree, use_cfg_parallel=True, world_size=args.world_size) + init_parallel_env(parallel_config) + else: + parallel_config = ParallelConfig(sp_degree=1, use_cfg_parallel=False, world_size=args.world_size) + init_parallel_env(parallel_config) + + rank = get_sequence_parallel_rank() if get_sequence_parallel_state() else 0 + + if isinstance(model_cls, type): + model = get_model(args, model_cls) + else: + model = model_cls + + load_checkpoint(model, args) + model.eval() + + if args.input_type == "cli": + data_iter = read_from_cli() + elif args.input_type == "txt": + data_iter = read_from_file(args.input_file) + else: + raise NotImplementedError + + image_size = [args.sampling_image_height, args.sampling_image_width] + + if args.image2video: + chained_trainsforms = [] + chained_trainsforms.append(TT.ToTensor()) + transform = TT.Compose(chained_trainsforms) + + sample_func = model.sample + T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, args.sampling_fps + num_samples = [1] + force_uc_zero_embeddings = ["txt"] + with torch.no_grad(): + for text, cnt in tqdm(data_iter): + set_random_seed(args.seed) + start = time.time() + torch.npu.synchronize() + if args.image2video: + text, image_path = text.split("@@") + assert os.path.exists(image_path), image_path + image = Image.open(image_path).convert("RGB") + image = transform(image).unsqueeze(0).to("npu") + image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0) + image = image * 2.0 - 1.0 + image = image.unsqueeze(2).to(torch.bfloat16) + image = model.encode_first_stage(image, None) + image = image.permute(0, 2, 1, 3, 4).contiguous() + pad_shape = (image.shape[0], T - 1, C, H // F, W // F) + image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) + else: + image = None + + value_dict = { + "prompt": text, + "negative_prompt": "", + "num_frames": torch.tensor(T).unsqueeze(0), + } + + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("npu"), (c, uc)) + + if args.image2video and image is not None: + c["concat"] = image + uc["concat"] = image + + first_stage_model = model.first_stage_model + for index in range(args.batch_size): + # count = 5 + # experimental_config = torch_npu.profiler._ExperimentalConfig( + # export_type=torch_npu.profiler.ExportType.Text, + # aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + # profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + # l2_cache=False, + # data_simplification=False + # ) + # with torch_npu.profiler.profile( + # activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + # # 默认CPU、NPU均打开 + # with_stack=True, # 采集torch op的函数调用栈的开关,默认关闭 + # record_shapes=True, # 采集torch op的input shape和input type的开关,默认关闭 + # profile_memory=True, # 采集memory相关数据的开关,默认关闭 + # schedule=torch_npu.profiler.schedule(wait=2, warmup=2, active=1, repeat=1), + # experimental_config=experimental_config, + # on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prof_unet") + # # 导出tensorboard可呈现的数据形式 + # ) as prof: + # for i in range(count): + # samples_z = sample_func( + # c, + # uc=uc, + # batch_size=1, + # shape=(T, C, H // F, W // F), + # ) + # prof.step() + samples_z = sample_func( + c, + uc=uc, + batch_size=1, + shape=(T, C, H // F, W // F), + ) + samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() + world_size = dist.get_world_size() + if world_size <= 2: + # VAE Tiling to save memory + latent = 1.0 / model.scale_factor * samples_z + # Decode latent serial to save GPU memory + recons = [] + loop_num = (T - 1) // 2 + for i in range(loop_num): + if i == 0: + start_frame, end_frame = 0, 3 + else: + start_frame, end_frame = i * 2 + 1, i * 2 + 3 + if i == loop_num - 1: + clear_fake_cp_cache = True + else: + clear_fake_cp_cache = False + with torch.no_grad(): + recon = first_stage_model.decode( + latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache + ) + recons.append(recon) + recon = torch.cat(recons, dim=2).to(torch.float32) + samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() + else: + # VAE Parallel + decode_pad = math.ceil((samples_z.shape[2] - 1) / world_size) * world_size - samples_z.shape[2] + 1 + padding_shape = (*samples_z.shape[:2], decode_pad, *samples_z.shape[3:]) + padding = torch.zeros(padding_shape, dtype=samples_z.dtype, device=samples_z.device) + samples_z = torch.cat([samples_z, padding], dim=2) + samples_x = model.decode_first_stage(samples_z).to(torch.float32) + if decode_pad > 0: + samples_x = samples_x[:, :, :args.model_config.network_config.params.num_frames] + samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous() + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() + + save_path = os.path.join( + args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + ) + if rank == 0: + save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) + torch.npu.empty_cache() + torch.npu.synchronize() + end = time.time() + print(f"use time: {end - start}s") + finalize_parallel_env() + + +if __name__ == "__main__": + py_parser = argparse.ArgumentParser(add_help=False) + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.model_config.first_stage_config.params.cp_size = 1 + args.model_config.network_config.params.transformer_args.model_parallel_size = 1 + args.model_config.network_config.params.transformer_args.checkpoint_activations = False + + sampling_main(args, model_cls=SATVideoDiffusionEngine) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f8ebd92251f02aafb7ac967d843e4bd307b82c2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/__init__.py @@ -0,0 +1,3 @@ +from .util import get_configs_path, instantiate_from_config + +__version__ = "0.1.0" diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0db1d7716a6e48f77b86a4b59c9289d6fb76b50b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/__init__.py @@ -0,0 +1,6 @@ +from .encoders.modules import GeneralConditioner + +UNCONDITIONAL_CONFIG = { + "target": "sgm.modules.GeneralConditioner", + "params": {"emb_models": []}, +} diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..46ce4753c37e03db8cfdcdabe3788141716bd0e6 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py @@ -0,0 +1,396 @@ +import math +from inspect import isfunction + +import torch +import torch_npu +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + backend=None, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.backend = backend + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + h = self.heads + + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + n_cp = x.shape[0] // n_times_crossframe_attn_in_self + k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + out = torch_npu.npu_fusion_attention(q, k, v, atten_mask=mask, + input_layout="BNSD", + scale=1 / math.sqrt(q.shape[-1]), + head_num=h)[0] + + del q, k, v + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attn_mode="softmax", + sdp_backend=None, + ): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + backend=sdp_backend, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + backend=sdp_backend, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): + kwargs = {"x": x} + + if context is not None: + kwargs.update({"context": context}) + + if additional_tokens is not None: + kwargs.update({"additional_tokens": additional_tokens}) + + if n_times_crossframe_attn_in_self: + kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}) + + # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, + ) + + x + ) + x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerSingleLayerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + attn_mode="softmax", + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context) + x + x = self.ff(self.norm2(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + attn_type="softmax", + use_checkpoint=True, + # sdp_backend=SDPBackend.FLASH_ATTENTION + sdp_backend=None, + ): + super().__init__() + print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") + from omegaconf import ListConfig + + if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): + context_dim = [context_dim] + if exists(context_dim) and isinstance(context_dim, list): + if depth != len(context_dim): + print( + f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " + f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." + ) + # depth does not match context dims. + assert all( + map(lambda x: x == context_dim[0], context_dim) + ), "need homogenous context_dim to match depth automatically" + context_dim = depth * [context_dim[0]] + elif context_dim is None: + context_dim = [None] * depth + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + attn_mode=attn_type, + checkpoint=use_checkpoint, + sdp_backend=sdp_backend, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + if i > 0 and len(context) == 1: + i = 0 # use same context for each block + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ae0ebfc2a72ab589c39ce16010e65d4a942c76 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py @@ -0,0 +1,243 @@ +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +from sgm.modules.diffusionmodules.model import ( + AttnBlock, + Decoder, + ResnetBlock, +) +from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding +from sgm.modules.video_attention import VideoTransformerBlock +from sgm.util import partialclass + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + if timesteps is None: + timesteps = self.timesteps + + b, c, h, w = x.shape + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps, skip_video=False): + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class VideoBlock(AttnBlock): + def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = VideoTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + attn_mode="softmax", + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + torch.nn.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + torch.nn.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps, skip_video=False): + if skip_video: + return super().forward(x) + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + if attn_type == "vanilla": + assert attn_kwargs is None + return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy) + else: + return NotImplementedError() + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight + + def _make_attn(self) -> Callable: + if self.time_mode not in ["conv-only", "only-last-conv"]: + return partialclass( + make_time_attn, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_attn() + + def _make_conv(self) -> Callable: + if self.time_mode != "attn-only": + return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + else: + return Conv2DWrapper + + def _make_resblock(self) -> Callable: + if self.time_mode not in ["attn-only", "only-last-conv"]: + return partialclass( + VideoResBlock, + video_kernel_size=self.video_kernel_size, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_resblock() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fccebf954f5760fa559b17755e743c41daa1a824 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/__init__.py @@ -0,0 +1,6 @@ +from .denoiser import Denoiser +from .discretizer import Discretization +from .model import Decoder, Encoder, Model +from .openaimodel import UNetModel +from .sampling import BaseDiffusionSampler +from .wrappers import OpenAIWrapper diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..3b192ed410ec14c39c095a99e992ccdaf8de2c8f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py @@ -0,0 +1,72 @@ +from typing import Dict + +import torch +import torch.nn as nn + +from ...util import append_dims, instantiate_from_config + + +class Denoiser(nn.Module): + def __init__(self, weighting_config, scaling_config): + super().__init__() + + self.weighting = instantiate_from_config(weighting_config) + self.scaling = instantiate_from_config(scaling_config) + + def possibly_quantize_sigma(self, sigma): + return sigma + + def possibly_quantize_c_noise(self, c_noise): + return c_noise + + def w(self, sigma): + return self.weighting(sigma) + + def forward( + self, + network: nn.Module, + input: torch.Tensor, + sigma: torch.Tensor, + cond: Dict, + **additional_model_inputs, + ) -> torch.Tensor: + sigma = self.possibly_quantize_sigma(sigma) + sigma_shape = sigma.shape + sigma = append_dims(sigma, input.ndim) + c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs) + c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) + return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip + + +class DiscreteDenoiser(Denoiser): + def __init__( + self, + weighting_config, + scaling_config, + num_idx, + discretization_config, + do_append_zero=False, + quantize_c_noise=True, + flip=True, + ): + super().__init__(weighting_config, scaling_config) + sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + self.sigmas = sigmas + # self.register_buffer("sigmas", sigmas) + self.quantize_c_noise = quantize_c_noise + + def sigma_to_idx(self, sigma): + dists = sigma - self.sigmas.to(sigma.device)[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def idx_to_sigma(self, idx): + return self.sigmas.to(idx.device)[idx] + + def possibly_quantize_sigma(self, sigma): + return self.idx_to_sigma(self.sigma_to_idx(sigma)) + + def possibly_quantize_c_noise(self, c_noise): + if self.quantize_c_noise: + return self.sigma_to_idx(c_noise) + else: + return c_noise diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..7f803db1e0dab1317fd3caf7b8255b056230f3e5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch + + +class DenoiserScaling(ABC): + @abstractmethod + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + pass + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class EpsScaling: + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = torch.ones_like(sigma, device=sigma.device) + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScaling: + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma ** 2 + 1.0) + c_out = -sigma / (sigma ** 2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma ** 2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScalingWithEDMcNoise(DenoiserScaling): + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma ** 2 + 1.0) + c_out = -sigma / (sigma ** 2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma ** 2 + 1.0) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class VideoScaling: # similar to VScaling + def __call__( + self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = alphas_cumprod_sqrt + c_out = -((1 - alphas_cumprod_sqrt ** 2) ** 0.5) + c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device) + c_noise = additional_model_inputs["idx"].clone() + return c_skip, c_out, c_in, c_noise diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..18285bb095d97e547ec54e01e789a05e0e144d44 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py @@ -0,0 +1,24 @@ +import torch + + +class UnitWeighting: + def __call__(self, sigma): + return torch.ones_like(sigma, device=sigma.device) + + +class EDMWeighting: + def __init__(self, sigma_data=0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma): + return (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + + +class VWeighting(EDMWeighting): + def __init__(self): + super().__init__(sigma_data=1.0) + + +class EpsWeighting: + def __call__(self, sigma): + return sigma ** -2.0 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a579eab64d0a2922a56dda12bd29ae67441e3ce4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py @@ -0,0 +1,127 @@ +from abc import abstractmethod +from functools import partial + +import numpy as np +import torch + +from ...modules.diffusionmodules.util import make_beta_schedule +from ...util import append_zero + + +def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray: + return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] + + +class Discretization: + def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False): + if return_idx: + sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx) + else: + sigmas = self.get_sigmas(n, device=device, return_idx=return_idx) + sigmas = append_zero(sigmas) if do_append_zero else sigmas + if return_idx: + return sigmas if not flip else torch.flip(sigmas, (0,)), idx + else: + return sigmas if not flip else torch.flip(sigmas, (0,)) + + @abstractmethod + def get_sigmas(self, n, device): + pass + + +class EDMDiscretization(Discretization): + def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def get_sigmas(self, n, device="cpu"): + ramp = torch.linspace(0, 1, n, device=device) + min_inv_rho = self.sigma_min ** (1 / self.rho) + max_inv_rho = self.sigma_max ** (1 / self.rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho + return sigmas + + +class LegacyDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + ): + super().__init__() + self.num_timesteps = num_timesteps + betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + def get_sigmas(self, n, device="cpu"): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029 + + +class ZeroSNRDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale) + keep_start=False, + post_shift=False, + ): + super().__init__() + if keep_start and not post_shift: + linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start) + self.num_timesteps = num_timesteps + betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + # SNR shift + if not post_shift: + self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod) + + self.post_shift = post_shift + self.shift_scale = shift_scale + + def get_sigmas(self, n, device="cpu", return_idx=False): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + alphas_cumprod = to_torch(alphas_cumprod) + alphas_cumprod_sqrt = alphas_cumprod.sqrt() + alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone() + alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone() + + alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T + alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) + + if self.post_shift: + alphas_cumprod_sqrt = ( + alphas_cumprod_sqrt ** 2 / ( + self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt ** 2) + ) ** 0.5 + + if return_idx: + return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps + else: + return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py new file mode 100644 index 0000000000000000000000000000000000000000..71bf91b07299bb494836084df5e2f9dc87220bb3 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py @@ -0,0 +1,124 @@ +import logging +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Union +from functools import partial +import math + +import torch +from einops import rearrange, repeat + +from ...util import append_dims, default, instantiate_from_config +from utils.parallel_mgr import get_classifier_free_guidance_world_size, get_classifier_free_guidance_rank, \ + get_sequence_parallel_rank, get_cfg_group + + +class Guider(ABC): + @abstractmethod + def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: + pass + + def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]: + pass + + +class VanillaCFG: + """ + implements parallelized CFG + """ + + def __init__(self, scale, dyn_thresh_config=None): + self.scale = scale + scale_schedule = lambda scale, sigma: scale # independent of step + self.scale_schedule = partial(scale_schedule, scale) + self.dyn_thresh = instantiate_from_config( + default( + dyn_thresh_config, + {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, + ) + ) + + def __call__(self, x, sigma, scale=None): + if get_classifier_free_guidance_world_size() == 1: + x_u, x_c = x.chunk(2) + elif get_classifier_free_guidance_world_size() == 2: + x_u, x_c = get_cfg_group().all_gather( + x, separate_tensors=True + ) + else: + raise ValueError("Invalid classifier free guidance world size") + scale_value = default(scale, self.scale_schedule(sigma)) + x_pred = self.dyn_thresh(x_u, x_c, scale_value) + return x_pred + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + if get_classifier_free_guidance_world_size() == 1: + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + x_out = torch.cat([x] * 2) + s_out = torch.cat([s] * 2) + else: + x_out = x + s_out = s + if get_classifier_free_guidance_rank() == 0: + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = uc[k] + else: + assert c[k] == uc[k] + c_out[k] = c[k] + elif get_classifier_free_guidance_rank() == 1: + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = c[k] + else: + assert c[k] == uc[k] + c_out[k] = c[k] + else: + raise ValueError("Invalid classifier free guidance rank") + return x_out, s_out, c_out + + +class DynamicCFG(VanillaCFG): + def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): + super().__init__(scale, dyn_thresh_config) + scale_schedule = ( + lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 + ) + self.scale_schedule = partial(scale_schedule, scale) + self.dyn_thresh = instantiate_from_config( + default( + dyn_thresh_config, + {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, + ) + ) + + def __call__(self, x, sigma, step_index, scale=None): + if get_classifier_free_guidance_world_size() == 1: + x_u, x_c = x.chunk(2) + elif get_classifier_free_guidance_world_size() == 2: + x_u, x_c = get_cfg_group().all_gather( + x, separate_tensors=True + ) + else: + raise ValueError("Invalid classifier free guidance world size") + scale_value = self.scale_schedule(sigma, step_index.item()) + x_pred = self.dyn_thresh(x_u, x_c, scale_value) + return x_pred + + +class IdentityGuider: + def __call__(self, x, sigma): + return x + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + c_out[k] = c[k] + + return x, s, c_out diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6371587da9b40450173dc7e6c331d1e5aa73e4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py @@ -0,0 +1,362 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Set, Type + +import torch +import torch.nn.functional as F +from torch import nn + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + self.out_features = out_features + self.in_features = in_features + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRAConv2dLayer(nn.Module): + def __init__( + self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + ): + super().__init__() + + self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + # according to the official kohya_ss trainer kernel_size are always fixed for the up layer + # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 + self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) + + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRACompatibleConv(nn.Conv2d): + """ + A convolutional layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + self.scale = scale + + def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale=1.0): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) + fusion = fusion.reshape((w_orig.shape)) + fused_weight = w_orig + (lora_scale * fusion) + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.data.dtype, fused_weight.data.device + + self.w_up = self.w_up.to(device=device).float() + self.w_down = self.w_down.to(device).float() + + fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) + fusion = fusion.reshape((fused_weight.shape)) + unfused_weight = fused_weight.float() - (self._lora_scale * fusion) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states, scale: float = None): + if scale is None: + scale = self.scale + if self.lora_layer is None: + # make sure to the functional Conv2D function as otherwise torch.compile's graph will break + # see: https://github.com/huggingface/diffusers/pull/4315 + return F.conv2d( + hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + else: + return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + + +class LoRACompatibleLinear(nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + self.scale = scale + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale=1.0): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states, scale: float = None): + if scale is None: + scale = self.scale + if self.lora_layer is None: + out = super().forward(hidden_states) + return out + else: + out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + return out + + +def _find_children( + model, + search_class: List[Type[nn.Module]] = [nn.Linear], +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if any([isinstance(module, _class) for _class in search_class]): + yield parent, name, module + + +def _find_modules_v2( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoRACompatibleLinear, + LoRACompatibleConv, + LoRALinearLayer, + LoRAConv2dLayer, + ], +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class) + else: + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + flag = False + while path: + try: + parent = parent.get_submodule(path.pop(0)) + except: + flag = True + break + if flag: + continue + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]): + continue + # Otherwise, yield it + yield parent, name, module + + +_find_modules = _find_modules_v2 + + +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = None, + rank: int = 4, + scale: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + lora_layer = LoRALinearLayer( + in_features=_child_module.in_features, + out_features=_child_module.out_features, + rank=rank, + ) + _tmp = ( + LoRACompatibleLinear( + _child_module.in_features, + _child_module.out_features, + lora_layer=lora_layer, + scale=scale, + ) + .to(weight.dtype) + .to(weight.device) + ) + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + lora_layer = LoRAConv2dLayer( + in_features=_child_module.in_channels, + out_features=_child_module.out_channels, + rank=rank, + kernel_size=_child_module.kernel_size, + stride=_child_module.stride, + padding=_child_module.padding, + ) + _tmp = ( + LoRACompatibleConv( + _child_module.in_channels, + _child_module.out_channels, + kernel_size=_child_module.kernel_size, + stride=_child_module.stride, + padding=_child_module.padding, + lora_layer=lora_layer, + scale=scale, + ) + .to(weight.dtype) + .to(weight.device) + ) + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + else: + continue + + _module._modules[name] = _tmp + # print('injecting lora layer to', _module, name) + + return + + +def update_lora_scale( + model: nn.Module, + target_module: Set[str] = None, + scale: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_module, search_class=[LoRACompatibleLinear, LoRACompatibleConv] + ): + _child_module.scale = scale + + return diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd79d2a5e4e445b77482122e84b5f9f7b6188ff --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py @@ -0,0 +1,602 @@ +import math +from typing import Callable + +import numpy as np +import torch +import torch_npu +import torch.nn as nn +from einops import rearrange + +from ..attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + h_ = torch_npu.npu_fusion_attention(q, k, v, atten_mask=None, + input_layout="BNSD", + scale=1 / math.sqrt(c), + head_num=1)[0] + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + make_attn_cls = self._make_attn() + make_resblock_cls = self._make_resblock() + make_conv_cls = self._make_conv() + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) + self.mid.block_2 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + make_resblock_cls( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn_cls(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = make_conv_cls(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def _make_attn(self) -> Callable: + return make_attn + + def _make_resblock(self) -> Callable: + return ResnetBlock + + def _make_conv(self) -> Callable: + return torch.nn.Conv2d + + def get_last_layer(self, **kwargs): + return self.conv_out.weight + + def forward(self, z, **kwargs): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, **kwargs) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, **kwargs) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h, **kwargs) + if self.tanh_out: + h = torch.tanh(h) + return h diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..10a7a0da30362daadd9842532d454e5f52a14e16 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1248 @@ +import os +import math +from abc import abstractmethod +from functools import partial +from typing import Iterable, Optional + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...modules.attention import SpatialTransformer +from ...modules.diffusionmodules.util import ( + avg_pool_nd, + checkpoint, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) +from ...modules.diffusionmodules.lora import inject_trainable_lora_extended, update_lora_scale +from ...modules.video_attention import SpatialVideoTransformer +from ...util import default, exists + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + context: Optional[th.Tensor] = None, + image_only_indicator: Optional[th.Tensor] = None, + time_context: Optional[int] = None, + num_video_frames: Optional[int] = None, + ): + from ...modules.diffusionmodules.video_model import VideoResBlock + + for layer in self: + module = layer + + if isinstance(module, TimestepBlock) and not isinstance(module, VideoResBlock): + x = layer(x, emb) + elif isinstance(module, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(module, SpatialVideoTransformer): + x = layer( + x, + context, + time_context, + num_video_frames, + image_only_indicator, + ) + elif isinstance(module, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.third_up = third_up + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + t_factor = 1 if not self.third_up else 2 + x = F.interpolate( + x, + (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode="nearest", + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) + if use_conv: + print(f"Building a Downsample layer with {dims} dims.") + print( + f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " + f"kernel-size: 3, stride: {stride}, padding: {padding}" + ) + if dims == 3: + print(f" --> Downsampling third axis (time): {third_down}") + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, Iterable): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.skip_t_emb = skip_t_emb + self.emb_out_channels = 2 * self.out_channels if use_scale_shift_norm else self.out_channels + if self.skip_t_emb: + print(f"Skipping timestep embedding in {self.__class__.__name__}") + assert not self.use_scale_shift_norm + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + self.emb_out_channels, + ), + ) + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd( + dims, + self.out_channels, + self.out_channels, + kernel_size, + padding=padding, + ) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + if self.skip_t_emb: + emb_out = th.zeros_like(h) + else: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, **kwargs): + # TODO add crossframe attention and use mixed checkpoint + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +str_to_dtype = {"fp32": th.float32, "fp16": th.float16, "bf16": th.bfloat16} + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + spatial_transformer_attn_type="softmax", + adm_in_channels=None, + use_fairscale_checkpoint=False, + offload_to_cpu=False, + transformer_depth_middle=None, + dtype="fp32", + lora_init=False, + lora_rank=4, + lora_scale=1.0, + lora_weight_path=None, + ): + super().__init__() + from omegaconf.listconfig import ListConfig + + self.dtype = str_to_dtype[dtype] + + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert num_heads != -1, "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + elif isinstance(transformer_depth, ListConfig): + transformer_depth = list(transformer_depth) + transformer_depth_middle = default(transformer_depth_middle, transformer_depth[-1]) + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + # self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + if use_fp16: + print("WARNING: use_fp16 was dropped and has no effect anymore.") + # self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + assert use_fairscale_checkpoint != use_checkpoint or not (use_checkpoint or use_fairscale_checkpoint) + + self.use_fairscale_checkpoint = False + checkpoint_wrapper_fn = ( + partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu) + if self.use_fairscale_checkpoint + else lambda x: x + ) + + time_embed_dim = model_channels * 4 + self.time_embed = checkpoint_wrapper_fn( + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = checkpoint_wrapper_fn( + nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + ) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ), + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + ) + if self.predict_codebook_ids: + self.id_predictor = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + ) + + if lora_init: + self._init_lora(lora_rank, lora_scale, lora_weight_path) + + def _init_lora(self, rank, scale, ckpt_dir=None): + inject_trainable_lora_extended(self, target_replace_module=None, rank=rank, scale=scale) + + if ckpt_dir is not None: + with open(os.path.join(ckpt_dir, "latest")) as latest_file: + latest = latest_file.read().strip() + ckpt_path = os.path.join(ckpt_dir, latest, "mp_rank_00_model_states.pt") + print(f"loading lora from {ckpt_path}") + sd = th.load(ckpt_path)["module"] + sd = { + key[len("model.diffusion_model"):]: sd[key] for key in sd if key.startswith("model.diffusion_model") + } + self.load_state_dict(sd, strict=False) + + def _update_scale(self, scale): + update_lora_scale(self, scale) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # h = x.type(self.dtype) + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + assert False, "not supported anymore. what the f*** are you doing?" + else: + return self.out(h) + + +class NoTimeUNetModel(UNetModel): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + timesteps = th.zeros_like(timesteps) + return super().forward(x, timesteps, context, y, **kwargs) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + # h = x.type(self.dtype) + h = x + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + + +if __name__ == "__main__": + class Dummy(nn.Module): + def __init__(self, in_channels=3, model_channels=64): + super().__init__() + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(2, in_channels, model_channels, 3, padding=1))] + ) + + + model = UNetModel( + use_checkpoint=True, + image_size=64, + in_channels=4, + out_channels=4, + model_channels=128, + attention_resolutions=[4, 2], + num_res_blocks=2, + channel_mult=[1, 2, 4], + num_head_channels=64, + use_spatial_transformer=False, + use_linear_in_transformer=True, + transformer_depth=1, + legacy=False, + ).npu() + x = th.randn(11, 4, 64, 64).npu() + t = th.randint(low=0, high=10, size=(11,), device="npu") + o = model(x, t) + print("done.") diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..db0f6acf8b62fbd16f0e4d9db88e37de33b6a3ab --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py @@ -0,0 +1,843 @@ +from typing import Dict, Union + +import torch +from omegaconf import ListConfig, OmegaConf +from tqdm import tqdm +from mindiesd.pipeline.sampling_optm import SamplingOptm +from mindiesd.pipeline.sampling_optm import AdaStep + +from ...modules.diffusionmodules.sampling_utils import ( + get_ancestral_step, + linear_multistep_coeff, + to_d, + to_neg_log_sigma, + to_sigma, +) +from ...util import append_dims, default, instantiate_from_config + +from .guiders import DynamicCFG +from utils.parallel_mgr import get_classifier_free_guidance_world_size + +DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} +SAMPLING_OPTM = True + +class BaseDiffusionSampler: + def __init__( + self, + discretization_config: Union[Dict, ListConfig, OmegaConf], + num_steps: Union[int, None] = None, + guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, + verbose: bool = False, + device: str = "npu", + ): + self.num_steps = num_steps + self.discretization = instantiate_from_config(discretization_config) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) + self.verbose = verbose + self.device = device + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device) + uc = default(uc, cond) + + x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) + num_sigmas = len(sigmas) + + s_in = x.new_ones([x.shape[0]]).float() + + return x, s_in, sigmas, num_sigmas, cond, uc + + def denoise(self, x, denoiser, sigma, cond, uc): + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) + denoised = self.guider(denoised, sigma) + return denoised + + def get_sigma_gen(self, num_sigmas): + sigma_generator = range(num_sigmas - 1) + if self.verbose: + print("#" * 30, " Sampling setting ", "#" * 30) + print(f"Sampler: {self.__class__.__name__}") + print(f"Discretization: {self.discretization.__class__.__name__}") + print(f"Guider: {self.guider.__class__.__name__}") + sigma_generator = tqdm( + sigma_generator, + total=num_sigmas, + desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", + ) + return sigma_generator + + +class SingleStepDiffusionSampler(BaseDiffusionSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): + raise NotImplementedError + + def euler_step(self, x, d, dt): + return x + dt * d + + +class EDMSampler(SingleStepDiffusionSampler): + def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.s_churn = s_churn + self.s_tmin = s_tmin + self.s_tmax = s_tmax + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + sigma_hat = sigma * (gamma + 1.0) + if gamma > 0: + eps = torch.randn_like(x) * self.s_noise + x = x + eps * append_dims(sigma_hat ** 2 - sigma ** 2, x.ndim) ** 0.5 + + denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + d = to_d(x, sigma_hat, denoised) + dt = append_dims(next_sigma - sigma_hat, x.ndim) + + euler_step = self.euler_step(x, d, dt) + x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class DDIMSampler(SingleStepDiffusionSampler): + def __init__(self, s_noise=0.1, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + d = to_d(x, sigma, denoised) + dt = append_dims(next_sigma * (1 - s_noise ** 2) ** 0.5 - sigma, x.ndim) + + euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) + + x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + self.s_noise, + ) + + return x + + +class AncestralSampler(SingleStepDiffusionSampler): + def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta = eta + self.s_noise = s_noise + self.noise_sampler = lambda x: torch.randn_like(x) + + def ancestral_euler_step(self, x, denoised, sigma, sigma_down): + d = to_d(x, sigma, denoised) + dt = append_dims(sigma_down - sigma, x.ndim) + + return self.euler_step(x, d, dt) + + def ancestral_step(self, x, sigma, next_sigma, sigma_up): + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, + x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), + x, + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) + + return x + + +class LinearMultistepSampler(BaseDiffusionSampler): + def __init__( + self, + order=4, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.order = order + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + ds = [] + sigmas_cpu = sigmas.detach().cpu().numpy() + for i in self.get_sigma_gen(num_sigmas): + sigma = s_in * sigmas[i] + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) + denoised = self.guider(denoised, sigma) + d = to_d(x, sigma, denoised) + ds.append(d) + if len(ds) > self.order: + ds.pop(0) + cur_order = min(i + 1, self.order) + coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + + return x + + +class EulerEDMSampler(EDMSampler): + def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + return euler_step + + +class HeunEDMSampler(EDMSampler): + def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + if torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 + return euler_step + else: + denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) + d_new = to_d(euler_step, next_sigma, denoised) + d_prime = (d + d_new) / 2.0 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) + return x + + +class EulerAncestralSampler(AncestralSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + + return x + + +class DPMPP2SAncestralSampler(AncestralSampler): + def get_variables(self, sigma, sigma_down): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] + h = t_next - t + s = t + 0.5 * h + return h, s, t, t_next + + def get_mult(self, h, s, t, t_next): + mult1 = to_sigma(s) / to_sigma(t) + mult2 = (-0.5 * h).expm1() + mult3 = to_sigma(t_next) / to_sigma(t) + mult4 = (-h).expm1() + + return mult1, mult2, mult3, mult4 + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + + if torch.sum(sigma_down) < 1e-14: + # Save a network evaluation if all noise levels are 0 + x = x_euler + else: + h, s, t, t_next = self.get_variables(sigma, sigma_down) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] + + x2 = mult[0] * x - mult[1] * denoised + denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) + x_dpmpp2s = mult[2] * x - mult[3] * denoised2 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) + + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + return x + + +class DPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) + mult2 = (-h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + # apply correction if noise level is not 0 and not first step + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x + + +class SDEDPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() + mult2 = (-2 * h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] + mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + + x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) + + # apply correction if noise level is not 0 and not first step + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x + + +class SdeditEDMSampler(EulerEDMSampler): + def __init__(self, edit_ratio=0.5, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.edit_ratio = edit_ratio + + def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None): + randn_unit = randn.clone() + randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps) + + if num_steps is None: + num_steps = self.num_steps + if edit_ratio is None: + edit_ratio = self.edit_ratio + x = None + + for i in self.get_sigma_gen(num_sigmas): + if i / num_steps < edit_ratio: + continue + if x is None: + x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape)) + + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class VideoDDIMSampler(BaseDiffusionSampler): + def __init__(self, fixed_frames=0, sdedit=False, **kwargs): + super().__init__(**kwargs) + self.fixed_frames = fixed_frames + self.sdedit = sdedit + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + alpha_cumprod_sqrt, timesteps = self.discretization( + self.num_steps if num_steps is None else num_steps, + device=self.device, + return_idx=True, + do_append_zero=False, + ) + alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) + timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]) + + uc = default(uc, cond) + + num_sigmas = len(alpha_cumprod_sqrt) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps + + def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None): + additional_model_inputs = {} + + if isinstance(scale, torch.Tensor) == False and scale == 1: + additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep + if scale_emb is not None: + additional_model_inputs["scale_emb"] = scale_emb + denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) + else: + additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * (2 // get_classifier_free_guidance_world_size())) + denoised = denoiser( + *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs + ).to(torch.float32) + if isinstance(self.guider, DynamicCFG): + denoised = self.guider( + denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep, scale=scale + ) + else: + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, scale=scale) + return denoised + + def sampler_step( + self, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ): + denoised = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb + ).to(torch.float32) + + a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 + b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t + + x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) + + return x + + +class VPSDEDPMPP2MSampler(VideoDDIMSampler): + def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): + alpha_cumprod = alpha_cumprod_sqrt ** 2 + lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 + lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + h = lamb_next - lamb + + if previous_alpha_cumprod_sqrt is not None: + previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 + lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): + mult1 = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 * (-h).exp() + mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt + + if previous_alpha_cumprod_sqrt is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def pred_noise( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ): + noise = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb + ).to(torch.float32) + return noise + + def schedule_step( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + x, + noise=None, + idx=None, + ): + if idx == 1: + return noise, noise + + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + ] + mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + denoised = noise + x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) + if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) + + x = x_advanced + + return x, denoised + + def sampler_step( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ): + denoised = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb + ).to(torch.float32) + if idx == 1: + return denoised, denoised + + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + ] + mult_noise = append_dims((1 - next_alpha_cumprod_sqrt ** 2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + + x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) + if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) + + x = x_advanced + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + if self.fixed_frames > 0: + prefix_frames = x[:, : self.fixed_frames] + old_denoised = None + skip_strategy = AdaStep(skip_thr=0.004, max_skip_steps=1, decay_ratio=0.99, device="npu") + for i in self.get_sigma_gen(num_sigmas): + if self.fixed_frames > 0: + if self.sdedit: + rd = torch.randn_like(prefix_frames) + noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( + s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) + ) + x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1) + else: + x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) + if SAMPLING_OPTM: + noise = skip_strategy( + self.pred_noise, + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) + x, old_denoised = self.schedule_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + x, + noise=noise, + idx=self.num_steps - i, + ) + skip_strategy.update_strategy(x) + else: + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) + + if self.fixed_frames > 0: + x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) + + return x + + +class VPODEDPMPP2MSampler(VideoDDIMSampler): + def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): + alpha_cumprod = alpha_cumprod_sqrt ** 2 + lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 + lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + h = lamb_next - lamb + + if previous_alpha_cumprod_sqrt is not None: + previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 + lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): + mult1 = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 + mult2 = (-h).expm1() * next_alpha_cumprod_sqrt + + if previous_alpha_cumprod_sqrt is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + ): + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) + if idx == 1: + return denoised, denoised + + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + ] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + x = x_advanced + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + ) + + return x diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1b10674199f7ea12261ff02f2d97ce6dff31a5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py @@ -0,0 +1,155 @@ +import torch +from scipy import integrate + +from ...util import append_dims +from einops import rearrange + + +class NoDynamicThresholding: + def __call__(self, uncond, cond, scale): + scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale + return uncond + scale * (cond - uncond) + + +class StaticThresholding: + def __call__(self, uncond, cond, scale): + result = uncond + scale * (cond - uncond) + result = torch.clamp(result, min=-1.0, max=1.0) + return result + + +def dynamic_threshold(x, p=0.95): + N, T, C, H, W = x.shape + x = rearrange(x, "n t c h w -> n c (t h w)") + l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True) + s = torch.maximum(-l, r) + threshold_mask = (s > 1).expand(-1, -1, H * W * T) + if threshold_mask.any(): + x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x) + x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W) + return x + + +def dynamic_thresholding2(x0): + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) + x0 = torch.clamp(x0, -s, s) # / s + return x0.to(origin_dtype) + + +def latent_dynamic_thresholding(x0): + p = 0.9995 + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0), p, dim=2) + s = append_dims(s, x0.dim()) + x0 = torch.clamp(x0, -s, s) / s + return x0.to(origin_dtype) + + +def dynamic_thresholding3(x0): + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) + x0 = torch.clamp(x0, -s, s) # / s + return x0.to(origin_dtype) + + +class DynamicThresholding: + def __call__(self, uncond, cond, scale): + mean = uncond.mean() + std = uncond.std() + result = uncond + scale * (cond - uncond) + result_mean, result_std = result.mean(), result.std() + result = (result - result_mean) / result_std * std + # result = dynamic_thresholding3(result) + return result + + +class DynamicThresholdingV1: + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def __call__(self, uncond, cond, scale): + result = uncond + scale * (cond - uncond) + unscaled_result = result / self.scale_factor + B, T, C, H, W = unscaled_result.shape + flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)") + means = flattened.mean(dim=2).unsqueeze(2) + recentered = flattened - means + magnitudes = recentered.abs().max() + normalized = recentered / magnitudes + thresholded = latent_dynamic_thresholding(normalized) + denormalized = thresholded * magnitudes + uncentered = denormalized + means + unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W) + scaled_result = unflattened * self.scale_factor + return scaled_result + + +class DynamicThresholdingV2: + def __call__(self, uncond, cond, scale): + B, T, C, H, W = uncond.shape + diff = cond - uncond + mim_target = uncond + diff * 4.0 + cfg_target = uncond + diff * 8.0 + + mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)") + cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)") + mim_means = mim_flattened.mean(dim=2).unsqueeze(2) + cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2) + mim_centered = mim_flattened - mim_means + cfg_centered = cfg_flattened - cfg_means + + mim_scaleref = mim_centered.std(dim=2).unsqueeze(2) + cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2) + + cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref + + result = cfg_renormalized + cfg_means + unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W) + + return unflattened + + +def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): + if order - 1 > i: + raise ValueError(f"Order {order} too high for step {i}") + + def fn(tau): + prod = 1.0 + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + + return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): + if not eta: + return sigma_to, 0.0 + sigma_up = torch.minimum( + sigma_to, + eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5, + ) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +def to_d(x, sigma, denoised): + return (x - denoised) / append_dims(sigma, x.ndim) + + +def to_neg_log_sigma(sigma): + return sigma.log().neg() + + +def to_sigma(neg_log_sigma): + return neg_log_sigma.neg().exp() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..009de3cc035cce9cded059648756d265b32f2fad --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py @@ -0,0 +1,67 @@ +import torch +import torch.distributed + +from ...util import default, instantiate_from_config + + +class EDMSampling: + def __init__(self, p_mean=-1.2, p_std=1.2): + self.p_mean = p_mean + self.p_std = p_std + + def __call__(self, n_samples, rand=None): + log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) + return log_sigma.exp() + + +class DiscreteSampling: + def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): + self.num_idx = num_idx + self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + self.uniform_sampling = uniform_sampling + if self.uniform_sampling: + self.group_num = 1 + self.group_width = 1 + self.sigma_interval = self.num_idx + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None, return_idx=False): + if self.uniform_sampling: + group_index = 0 + idx = default( + rand, + torch.randint( + group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,) + ), + ) + else: + idx = default( + rand, + torch.randint(0, self.num_idx, (n_samples,)), + ) + if return_idx: + return self.idx_to_sigma(idx), idx + else: + return self.idx_to_sigma(idx) + + +class PartialDiscreteSampling: + def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): + self.total_num_idx = total_num_idx + self.partial_num_idx = partial_num_idx + self.sigmas = instantiate_from_config(discretization_config)( + total_num_idx, do_append_zero=do_append_zero, flip=flip + ) + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None): + idx = default( + rand, + # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)), + torch.randint(0, self.partial_num_idx, (n_samples,)), + ) + return self.idx_to_sigma(idx) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..82a9dff8f25d3f6157bf47bab405932ad425bed7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py @@ -0,0 +1,317 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +def make_beta_schedule( + schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, +): + if schedule == "linear": + betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + return betas.numpy() + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def mixed_checkpoint(func, inputs: dict, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function + borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that + it also works with non-tensor inputs + :param func: the function to evaluate. + :param inputs: the argument dictionary to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] + tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)] + non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)] + non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)] + args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) + return MixedCheckpointFunction.apply( + func, + len(tensor_inputs), + len(non_tensor_inputs), + tensor_keys, + non_tensor_keys, + *args, + ) + else: + return func(**inputs) + + +class MixedCheckpointFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + run_function, + length_tensors, + length_non_tensors, + tensor_keys, + non_tensor_keys, + *args, + ): + ctx.end_tensors = length_tensors + ctx.end_non_tensors = length_tensors + length_non_tensors + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors + + ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))} + ctx.input_non_tensors = { + key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors: ctx.end_non_tensors])) + } + ctx.run_function = run_function + ctx.input_params = list(args[ctx.end_non_tensors:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} + ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors} + + with torch.enable_grad(), torch.npu.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors} + # shallow_copies.update(additional_args) + output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) + input_grads = torch.autograd.grad( + output_tensors, + list(ctx.input_tensors.values()) + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return ( + (None, None, None, None, None) + + input_grads[: ctx.end_tensors] + + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + + input_grads[ctx.end_tensors:] + ) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), torch.npu.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=torch.float32): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding.to(dtype) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + if self.merge_strategy == "fixed": + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + else: + raise NotImplementedError + return alpha + + def forward( + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + x = alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + return x diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/wrappers.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b78ffd502fc67238752fbd5de7ee7c6661ce05 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/wrappers.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from packaging import version + +OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" + + +class IdentityWrapper(nn.Module): + def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32): + super().__init__() + compile = ( + torch.compile + if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model + else lambda x: x + ) + self.diffusion_model = compile(diffusion_model) + self.dtype = dtype + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + + +class OpenAIWrapper(IdentityWrapper): + def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor: + for key in c: + c[key] = c[key].to(self.dtype) + + if x.dim() == 4: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) + elif x.dim() == 5: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2) + else: + raise ValueError("Input tensor must be 4D or 5D") + + return self.diffusion_model( + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, + ) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b4fd8ccec1970cc66005a5ef545cd9d051f73ba3 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py @@ -0,0 +1,266 @@ +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from omegaconf import ListConfig +from transformers import ( + T5EncoderModel, + T5Tokenizer, +) + +from ...util import ( + disabled_train, + expand_dims_like, + instantiate_from_config, +) + + +class AbstractEmbModel(nn.Module): + def __init__(self): + super().__init__() + self._is_trainable = None + self._ucg_rate = None + self._input_key = None + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def ucg_rate(self) -> Union[float, torch.Tensor]: + return self._ucg_rate + + @property + def input_key(self) -> str: + return self._input_key + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @ucg_rate.setter + def ucg_rate(self, value: Union[float, torch.Tensor]): + self._ucg_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @ucg_rate.deleter + def ucg_rate(self): + del self._ucg_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + +class GeneralConditioner(nn.Module): + OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} + KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} + + def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]): + super().__init__() + embedders = [] + for n, embconfig in enumerate(emb_models): + embedder = instantiate_from_config(embconfig) + assert isinstance( + embedder, AbstractEmbModel + ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" + embedder.is_trainable = False + embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) + embedder.train = disabled_train + for param in embedder.parameters(): + param.requires_grad = False + embedder.eval() + + if "input_key" in embconfig: + embedder.input_key = embconfig["input_key"] + elif "input_keys" in embconfig: + embedder.input_keys = embconfig["input_keys"] + else: + raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") + + embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) + if embedder.legacy_ucg_val is not None: + embedder.ucg_prng = np.random.RandomState() + + embedders.append(embedder) + self.embedders = nn.ModuleList(embedders) + + if len(cor_embs) > 0: + assert len(cor_p) == 2 ** len(cor_embs) + self.cor_embs = cor_embs + self.cor_p = cor_p + + def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: + assert embedder.legacy_ucg_val is not None + p = embedder.ucg_rate + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if embedder.ucg_prng.choice(2, p=[1 - p, p]): + batch[embedder.input_key][i] = val + return batch + + def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict: + assert embedder.legacy_ucg_val is not None + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if cond_or_not[i]: + batch[embedder.input_key][i] = val + return batch + + def get_single_embedding( + self, + embedder, + batch, + output, + cond_or_not: Optional[np.ndarray] = None, + force_zero_embeddings: Optional[List] = None, + ): + with torch.no_grad(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + if embedder.legacy_ucg_val is not None: + if cond_or_not is None: + batch = self.possibly_get_ucg_val(embedder, batch) + else: + batch = self.surely_get_ucg_val(embedder, batch, cond_or_not) + emb_out = embedder(batch[embedder.input_key]) + elif hasattr(embedder, "input_keys"): + emb_out = embedder(*[batch[k] for k in embedder.input_keys]) + assert isinstance( + emb_out, (torch.Tensor, list, tuple) + ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" + if not isinstance(emb_out, (list, tuple)): + emb_out = [emb_out] + for emb in emb_out: + out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: + if cond_or_not is None: + emb = ( + expand_dims_like( + torch.bernoulli( + (1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), + emb, + ) + * emb + ) + else: + emb = ( + expand_dims_like( + torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device), + emb, + ) + * emb + ) + if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings: + emb = torch.zeros_like(emb) + if out_key in output: + output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key]) + else: + output[out_key] = emb + return output + + def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict: + output = dict() + if force_zero_embeddings is None: + force_zero_embeddings = [] + + if len(self.cor_embs) > 0: + batch_size = len(batch[list(batch.keys())[0]]) + rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p) + for emb_idx in self.cor_embs: + cond_or_not = rand_idx % 2 + rand_idx //= 2 + output = self.get_single_embedding( + self.embedders[emb_idx], + batch, + output=output, + cond_or_not=cond_or_not, + force_zero_embeddings=force_zero_embeddings, + ) + + for i, embedder in enumerate(self.embedders): + if i in self.cor_embs: + continue + output = self.get_single_embedding( + embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings + ) + return output + + def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + ucg_rates = list() + for embedder in self.embedders: + ucg_rates.append(embedder.ucg_rate) + embedder.ucg_rate = 0.0 + cor_embs = self.cor_embs + cor_p = self.cor_p + self.cor_embs = [] + self.cor_p = [] + + c = self(batch_c) + uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) + + for embedder, rate in zip(self.embedders, ucg_rates): + embedder.ucg_rate = rate + self.cor_embs = cor_embs + self.cor_p = cor_p + + return c, uc + + +class FrozenT5Embedder(AbstractEmbModel): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, + model_dir="google/t5-v1_1-xxl", + device="npu", + max_length=77, + freeze=True, + cache_dir=None, + ): + super().__init__() + if model_dir != "google/t5-v1_1-xxl": + self.tokenizer = T5Tokenizer.from_pretrained(model_dir) + self.transformer = T5EncoderModel.from_pretrained(model_dir) + else: + self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir) + self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + # @autocast + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("npu", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d742ffffda523580d5564a879daf8fc79381d8c2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py @@ -0,0 +1,285 @@ +from einops import rearrange, repeat +from typing import Optional +from ..modules.attention import * +from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding + + +class TimeMixSequential(nn.Sequential): + def forward(self, x, context=None, timesteps=None): + for layer in self: + x = layer(x, context, timesteps) + + return x + + +class VideoTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + timesteps=None, + ff_in=False, + inner_dim=None, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + switch_temporal_ca_to_sa=False, + ): + super().__init__() + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + assert int(n_heads * d_head) == inner_dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff) + + self.timesteps = timesteps + self.disable_self_attn = disable_self_attn + if self.disable_self_attn: + self.attn1 = CrossAttention( + query_dim=inner_dim, + heads=n_heads, + dim_head=d_head, + context_dim=context_dim, + dropout=dropout, + ) # is a cross-attention + else: + self.attn1 = CrossAttention( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + self.norm2 = nn.LayerNorm(inner_dim) + if switch_temporal_ca_to_sa: + self.attn2 = CrossAttention( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + else: + self.attn2 = CrossAttention( + query_dim=inner_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + + self.norm1 = nn.LayerNorm(inner_dim) + self.norm3 = nn.LayerNorm(inner_dim) + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa + + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor: + if self.checkpoint: + return checkpoint(self._forward, x, context, timesteps) + else: + return self._forward(x, context, timesteps=timesteps) + + def _forward(self, x, context=None, timesteps=None): + assert self.timesteps or timesteps + assert not (self.timesteps and timesteps) or self.timesteps == timesteps + timesteps = self.timesteps or timesteps + B, S, C = x.shape + x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) + + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + + if self.disable_self_attn: + x = self.attn1(self.norm1(x), context=context) + x + else: + x = self.attn1(self.norm1(x)) + x + + if self.attn2 is not None: + if self.switch_temporal_ca_to_sa: + x = self.attn2(self.norm2(x)) + x + else: + x = self.attn2(self.norm2(x), context=context) + x + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + + x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps) + return x + + def get_last_layer(self): + return self.ff.net[-1].weight + + +str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + dtype="fp32", + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + attn_type=attn_mode, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + VideoTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + attn_mode=attn_mode, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + + self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy) + self.dtype = str_to_dtype[dtype] + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}" + + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding( + num_frames, + self.in_channels, + repeat_only=False, + max_period=self.max_time_embed_period, + dtype=self.dtype, + ) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)): + x = block( + x, + context=spatial_context, + ) + + x_mix = x + x_mix = x_mix + emb + + x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) + x = self.time_mixer( + x_spatial=x, + x_temporal=x_mix, + image_only_indicator=image_only_indicator, + ) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..30760bd3e7e8fa34afbbb89d035f690201b0643c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py @@ -0,0 +1,318 @@ +import functools +import importlib +import os +from functools import partial +from inspect import isfunction + +import fsspec +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file as load_safetensors + + +class SafeConv3d(torch.nn.Conv3d): + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024 ** 3 + if memory_count > 2: + # print(f"WARNING: Conv3d with {memory_count:.2f}GB") + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super(SafeConv3d, self).forward(input) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def get_string_from_tuple(s): + try: + # Check if the string starts and ends with parentheses + if s[0] == "(" and s[-1] == ")": + # Convert the string to a tuple + t = eval(s) + # Check if the type of t is tuple + if type(t) == tuple: + return t[0] + else: + pass + except: + pass + return s + + +def is_power_of_two(n): + """ + chat.openai.com/chat + Return True if n is a power of 2, otherwise return False. + + The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. + The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. + If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. + Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. + + """ + if n <= 0: + return False + return (n & (n - 1)) == 0 + + +def autocast(f, enabled=True): + def do_autocast(*args, **kwargs): + with torch.npu.amp.autocast( + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def load_partial_from_config(config): + return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + if isinstance(xc[bi], list): + text_seq = xc[bi][0] + else: + text_seq = xc[bi] + lines = "\n".join(text_seq[start: start + nc] for start in range(0, len(text_seq), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def isheatmap(x): + if not isinstance(x, torch.Tensor): + return False + + return x.ndim == 2 + + +def isneighbors(x): + if not isinstance(x, torch.Tensor): + return False + return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) + + +def exists(x): + return x is not None + + +def expand_dims_like(x, y): + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config, **extra_kwargs): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict()), **extra_kwargs) + + +def get_obj_from_str(string, reload=False, invalidate_cache=True): + module, cls = string.rsplit(".", 1) + if invalidate_cache: + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def load_model_from_config(config, ckpt, verbose=True, freeze=True): + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + model = instantiate_from_config(config.model) + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + + model.eval() + return model + + +def get_configs_path() -> str: + """ + Get the `configs` directory. + For a working copy, this is the one in the root of the repository, + but for an installed copy, it's in the `sgm` package (see pyproject.toml). + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "configs"), + os.path.join(this_dir, "..", "configs"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM configs in {candidates}") + + +def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): + """ + Will return the result of a recursive get attribute call. + E.g.: + a.b.c + = getattr(getattr(a, "b"), "c") + = get_nested_attribute(a, "b.c") + If any part of the attribute call is an integer x with current obj a, will + try to call a[x] instead of a.x first. + """ + attributes = attribute_path.split(".") + if depth is not None and depth > 0: + attributes = attributes[:depth] + assert len(attributes) > 0, "At least one attribute should be selected" + current_attribute = obj + current_key = None + for level, attribute in enumerate(attributes): + current_key = ".".join(attributes[: level + 1]) + try: + id_ = int(attribute) + current_attribute = current_attribute[id_] + except ValueError: + current_attribute = getattr(current_attribute, attribute) + + return (current_attribute, current_key) if return_key else current_attribute + + +from math import sqrt + + +class SeededNoise: + def __init__(self, seeds, weights): + self.seeds = seeds + self.weights = weights + weight_square_sum = 0 + for weight in weights: + weight_square_sum += weight ** 2 + self.weight_square_sum_sqrt = sqrt(weight_square_sum) + self.cnt = 0 + + def __call__(self, x): + self.cnt += 1 + randn_combined = torch.zeros_like(x) + for seed, weight in zip(self.seeds, self.weights): + randn = np.random.RandomState(seed + self.cnt).randn(*x.shape) + randn = torch.from_numpy(randn, dtype=x.dtype, device=x.device) + randn_combined += randn * weight + randn_combined /= self.weight_square_sum_sqrt + return randn_combined diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..6d3f4973d811ad6f123a6f3a9f0dd3c6e4a6c1e9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# coding=utf-8 + +import torch +import torch.distributed as dist + +from .parallel_mgr import get_sequence_parallel_world_size, get_sp_group + + +def _all_to_all_func(input_, world_size, process_group, scatter_dim=2, gather_dim=1): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=process_group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def split_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + if world_size == 1: + return input_ + + if pad > 0: + pad_size = list(input_.shape) + pad_size[dim] = pad + input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim) + + dim_size = input_.size(dim) + if dim_size % world_size != 0: + raise ValueError( + f"The th{dim} dimensions of input_:{input_.size()} is not divisible by world_size:{world_size}.") + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + output = tensor_list[rank].contiguous() + return output + + +def gather_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # all gather + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + # concat + output = torch.cat(tensor_list, dim=dim) + + if pad > 0: + output = output.narrow(dim, 0, output.size(dim) - pad) + + return output + + +# ====== +# Pad +# ====== + +SPTIAL_PAD = 0 +TEMPORAL_PAD = 0 + + +def set_spatial_pad(dim_size: int): + sp_size = get_sequence_parallel_world_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global SPTIAL_PAD + SPTIAL_PAD = pad + + +def get_spatial_pad() -> int: + return SPTIAL_PAD + + +def set_temporal_pad(dim_size: int): + sp_size = get_sequence_parallel_world_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global TEMPORAL_PAD + TEMPORAL_PAD = pad + + +def get_temporal_pad() -> int: + return TEMPORAL_PAD + + +def all_to_all_with_pad( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + **kwargs +): + scatter_dim = kwargs.get("scatter_dim", 2) + gather_dim = kwargs.get("gather_dim", 1) + scatter_pad = kwargs.get("scatter_pad", 0) + gather_pad = kwargs.get("gather_pad", 0) + + if scatter_pad > 0: + pad_shape = list(input_.shape) + pad_shape[scatter_dim] = scatter_pad + pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) + input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) + + world_size = dist.get_world_size(process_group) + if input_.shape[scatter_dim] % world_size != 0: + raise ValueError( + f"The scatter_dim:{scatter_dim} of input_:{input_.shape} is not divisible by world_size:{world_size}.") + + input_ = _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) + + if gather_pad > 0: + input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) + + return input_ + + +def all_to_all( + tensor: torch.Tensor, + world_size: int, + scatter_dim: int, + gather_dim: int, + process_group: dist.ProcessGroup = None, +): + if process_group is None: + process_group = dist.group.WORLD + return _all_to_all_func(tensor, world_size, process_group, scatter_dim, gather_dim) + + +def all_to_all_bsh( + input_: torch.Tensor, + scatter_dim: int = 1, + gather_dim: int = 0, + async_op: bool = False, +): + return single_all_to_all(input_, scatter_dim, gather_dim, async_op) + + +def single_all_to_all( + input_: torch.Tensor, + scatter_dim: int, + gather_dim: int, + async_op: bool = False, +): + sp_size = get_sequence_parallel_world_size() + inp_shape = list(input_.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size + if scatter_dim < 1: + input_t = input_.reshape( + [sp_size, inp_shape[scatter_dim]] + \ + inp_shape[scatter_dim + 1:] + ) + else: + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + input_t = input_.reshape( + [-1, sp_size, inp_shape[scatter_dim]] + \ + inp_shape[scatter_dim + 1:] + ).transpose(0, 1).contiguous() + + output = torch.empty_like(input_t) + + req = dist.all_to_all_single(output, input_t, group=get_sp_group().device_group, async_op=async_op) + # if scattering the seq-dim, transpose the heads back to the original dimension + if scatter_dim < 1: + output = output.transpose(0, 1).contiguous() + + return req, output.reshape( + inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:]) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..afb14d1e2274e205fa36e5a2c042586eba94c8e2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py @@ -0,0 +1,631 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch_npu +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, ( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "npu:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and npu graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if torch.npu.is_available(): + self.device = torch.device(f"npu:{local_rank}") + else: + self.device = torch.device("cpu") + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank." + ) + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert ( + src != self.rank + ), "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + self.ulysses_world_size = self.ring_world_size = 1 + self.ulysses_rank = self.ring_rank = 0 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py new file mode 100644 index 0000000000000000000000000000000000000000..41047606b6cc3de1d16653ae80752a6ced2cf778 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python +# coding=utf-8 + +import os +from typing import List, Optional +from dataclasses import dataclass +import torch.distributed as dist +import torch_npu +from .utils import RankGenerator +from .group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + +_WORLD: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None +_VAE_PARALLEL_GROUP = None +_VAE_PARALLEL_SIZE = None + + +def is_vae_parallel_initialized(): + global _VAE_PARALLEL_GROUP + if _VAE_PARALLEL_GROUP is None: + return False + else: + return True + + +def initialize_vae_parallel(parallel_size): + global _VAE_PARALLEL_GROUP + global _VAE_PARALLEL_SIZE + + assert _VAE_PARALLEL_GROUP is None, "vae parallel group is already initialized" + _VAE_PARALLEL_SIZE = parallel_size + + ranks = range(_VAE_PARALLEL_SIZE) + group = dist.new_group(ranks) + _VAE_PARALLEL_GROUP = group + + +def get_vae_parallel_group(): + if _VAE_PARALLEL_GROUP is None: + return 0 + return _VAE_PARALLEL_GROUP + + +def get_vae_parallel_world_size(): + if not is_vae_parallel_initialized(): + return 1 + return _VAE_PARALLEL_SIZE + + +def get_vae_parallel_rank(): + if not is_vae_parallel_initialized(): + return 0 + rank = dist.get_rank() + vae_rank = rank % _VAE_PARALLEL_SIZE + return vae_rank + + +def get_vae_parallel_group_rank(): + if not is_vae_parallel_initialized(): + return 0 + rank = dist.get_rank() + vae_group_rank = rank // _VAE_PARALLEL_SIZE + return vae_group_rank + + +@dataclass +class ParallelConfig: + sp_degree: int = 1 + use_cfg_parallel: bool = False + world_size: int = 1 + + def __post_init__(self): + if self.use_cfg_parallel: + self.cfg_degree = 2 + else: + self.cfg_degree = 1 + assert self.sp_degree * self.cfg_degree <= self.world_size, ( + "sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + assert ( + self.world_size % (self.sp_degree * self.cfg_degree) == 0 + ), "world_size must be divisible by sp_degree * cfg_degree" + + +# * QUERY +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + assert _SP is not None, "pipeline model parallel group is not initialized" + return _SP + + +def get_sequence_parallel_state(): + """Return state for the sequence parallel group.""" + return _SP is not None + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 1 + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 0 + return get_sp_group().rank_in_group + + +# CFG +def get_cfg_group() -> GroupCoordinator: + assert ( + _CFG is not None + ), "classifier_free_guidance parallel group is not initialized" + return _CFG + + +def get_cfg_state(): + """Return state for the sequence parallel group.""" + return _CFG is not None + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 1 + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 0 + return get_cfg_group().rank_in_group + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not dist.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + dist.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = int(os.getenv('RANK', 0)) + torch_npu.npu.set_device(local_rank) + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(dist.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert ( + _WORLD.world_size == dist.get_world_size() + ), "world group already initialized with a different world size" + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _CFG is not None + and _SP is not None + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + assert parallel_mode in [ + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. + backend: distributed backend of pytorch collective comm. + """ + # Get world size and rank. Ensure some consistencies. + assert dist.is_initialized() + world_size: int = dist.get_world_size() + backend = backend or dist.get_backend(get_world_group().device_group) + + if ( + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree + ): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"sequence_parallel_degree ({sequence_parallel_degree}) x" + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x" + ) + + rank_generator: RankGenerator = RankGenerator( + sequence_parallel_degree, + classifier_free_guidance_degree, + "sp-cfg", + ) + + global _CFG + assert _CFG is None, "classifier_free_guidance group is already initialized" + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + + global _SP + assert _SP is None, "sequence parallel group is already initialized" + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ) + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if dist.is_initialized(): + dist.destroy_process_group() + + +def init_parallel_env(parallel_config: ParallelConfig): + if not model_parallel_is_initialized(): + logger.warning("Model parallel is not initialized, initializing...") + if not dist.is_initialized(): + init_distributed_environment() + initialize_model_parallel( + classifier_free_guidance_degree=parallel_config.cfg_degree, + sequence_parallel_degree=parallel_config.sp_degree, + ) + + +def finalize_parallel_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py new file mode 100644 index 0000000000000000000000000000000000000000..d54a4137f73ad539503c58fed65fcc414ec7196b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from functools import reduce + +import torch + + +class AdaStep: + """ + The Adastep class is designed to optimize the sampling process in a diffusion model, + """ + + def __init__(self, skip_thr=0.015, max_skip_steps=4, decay_ratio=0.99, + device="npu", forward_value=None, step_value=None): + """ + Args: + skip_thr (float): The threshold for determining whether to skip a step based on the change in latent variables. + Recommended values are between 0.01 and 0.015. Default is 0.015. + max_skip_steps (int): The maximum number of consecutive steps that can be skipped. + Recommended values are between 3 and 4. Default is 4. + decay_ratio (float): The decay ratio for the skip threshold, which is used to dynamically adjust + the threshold over time. Recommended values are between 0.95 and 0.99. Default is 0.99. + device (str): The device on which the computations will be performed. Default is "npu". + """ + + # recommand 0.01(skip around 35 steps) ~ 0.015(skip around 50 steps) + self.skip_thr = skip_thr + # recommand 3 ~ 4 + self.max_skip_steps = max_skip_steps + # recommand 0.95 ~ 0.99 + self.decay_ratio = decay_ratio + self.device = device + self.if_skip = self.max_skip_steps > 0 + self.reset_status() + + self.forwardretype = forward_value + self.stepretype = step_value + + def __call__(self, transformer, *model_args, **model_kwargs): + """ + Args: + transformer (Module): The Module that works as the DiT and returns the noise prediction. + model_args (tuple): The arguments to be passed to the transformer Module forword. + model_kwargs (dict): The keyword arguments to be passed to the transformer Module forword. + Returns: + The noise prediction from the transformer function. + """ + if self.if_skip and torch.all(self.skip_vote): + return self._return_output(self.skip_noise_pred, self.forwardretype) + + noise_pred = transformer(*model_args, **model_kwargs) + if not self.forwardretype: + if isinstance(noise_pred, tuple): + self.forwardretype = tuple + elif isinstance(noise_pred, torch.Tensor): + self.forwardretype = torch.Tensor + else: + raise (ValueError, "transformer need return a tuple which the first element is the result" + "or a Tensor. Otherwith please input the `forward_value`") + self.skip_noise_pred = self._get_input(noise_pred, self.forwardretype) + return noise_pred + + @staticmethod + def _get_input(input_value, inp_type): + if isinstance(inp_type, type): + if inp_type is tuple: + return input_value[0] + else: + return input_value + else: + return input_value[inp_type] + + @staticmethod + def _return_output(output, outptype): + if isinstance(outptype, type): + if outptype is tuple: + return (output,) + else: + return output + elif isinstance(outptype, str): + return {outptype: output} + else: + return (0,) * outptype + (output,) + + def set_param(self, skip_thr=None, max_skip_steps=None, decay_ratio=None, device=None): + """ + To set the parameters of the AdaStep class. + """ + self.skip_thr = skip_thr or self.skip_thr + self.max_skip_steps = max_skip_steps or self.max_skip_steps + self.decay_ratio = decay_ratio or self.decay_ratio + if device: + self.device = device + self.skip_vote.to(self.device) + self.if_skip = self.max_skip_steps > 0 + + def reset_status(self): + """ + Reset the status of the AdaStep class. + """ + self.skip_mask = [False] + self.skip_latents_diff = [] + self.skip_noise_pred = None + self.skip_prev_latents = 0 + self.skip_vote = torch.tensor([False], dtype=torch.bool, device=self.device) + + def update_strategy(self, latents, sequence_parallel=False, sp_group=None): + """ + Update the strategy for skipping steps based on the change in latents. + """ + if not self.stepretype: + if isinstance(latents, tuple): + self.stepretype = tuple + elif isinstance(latents, torch.Tensor): + self.stepretype = torch.Tensor + else: + raise (ValueError, "step need return a tuple which the first element is the result" + "or a Tensor. Otherwith please input the `step_value`") + if self.if_skip: + latents = self._get_input(latents, self.stepretype) + diff = latents - self.skip_prev_latents + self.skip_latents_diff.append(diff.abs().mean()) + if len(self.skip_latents_diff) >= 3: + self.skip_mask.append(self._estimate()) + + self.skip_prev_latents = latents + + mask_value = self.skip_mask[-1] + mask_value = torch.tensor([mask_value], dtype=torch.bool, device=self.device) + if sequence_parallel: + skip_vote = torch.zeros(torch.distributed.get_world_size(sp_group), + dtype=torch.bool, device=self.device) + torch.distributed.all_gather_into_tensor(skip_vote, mask_value, group=sp_group) + else: + skip_vote = mask_value + self.skip_vote = skip_vote + + def _estimate(self): + # `self.skip_latents_diff[-1]` refers to the most recent difference in latent variables. + cur_diff = self.skip_latents_diff[-1] + # `self.skip_latents_diff[-2]` refers to the second most recent difference in latent variables. + prev_diff = self.skip_latents_diff[-2] + # `self.skip_latents_diff[-3]` refers to the third most recent difference in latent variables. + prev_prev_diff = self.skip_latents_diff[-3] + + self.skip_thr = self.skip_thr * self.decay_ratio + + if len(self.skip_mask) >= self.max_skip_steps and \ + all(self.skip_mask[-self.max_skip_steps:]): + return False + + if abs((cur_diff + prev_prev_diff) / 2 - prev_diff) <= prev_diff * self.skip_thr: + return True + return False + + +class SamplingOptm: + def __init__(self, pipe, dit_forward="transformer.forward", scheduler_step="scheduler.step", + forward_value=None, step_value=None, parallel=False, group=None, config=None): + self.parallel = parallel + self.group = group + self.skip = False + if config and config["method"] == "AdaStep": + self.skip = True + ditforward_lst = dit_forward.split(".") + schedulerstep_lst = scheduler_step.split(".") + self.pipe = pipe + + self.ori_forward = reduce(getattr, ditforward_lst, self.pipe) # getattr(self.pipe, )dit_forward.split(".") + self.forward = ditforward_lst.pop() + self.ori_dit = reduce(getattr, ditforward_lst, self.pipe) + + self.ori_step = reduce(getattr, schedulerstep_lst, self.pipe) + self.step = schedulerstep_lst.pop() + self.ori_scheduler = reduce(getattr, schedulerstep_lst, self.pipe) + + shik_thr = config.get("skip_thr", 0.015) + max_skip_steps = config.get("max_skip_steps", 4) + decay_ratio = config.get("decay_ratio", 0.99) + self.skip_strategy = AdaStep(skip_thr=shik_thr, max_skip_steps=max_skip_steps, decay_ratio=decay_ratio) + + def __enter__(self): + if self.skip: + self._sub_forward() + self._sub_step() + + def __exit__(self, t, v, trace): + if self.skip: + self._revert_forward() + self._revert_step() + + def _sub_forward(self): + @functools.wraps(self.ori_forward) + def _optm_forward(*args, **kwargs): + noise_pred = self.skip_strategy(self.ori_forward, *args, **kwargs) + return noise_pred + + setattr(self.ori_dit, self.forward, _optm_forward) + + def _sub_step(self): + @functools.wraps(self.ori_step) + def _optm_step(*args, **kwargs): + latents = self.ori_step(*args, **kwargs) + self.skip_strategy.update_strategy(latents, self.parallel, self.group) + return latents + + setattr(self.ori_scheduler, self.step, _optm_step) + + def _revert_forward(self): + setattr(self.ori_dit, self.forward, self.ori_forward) + + def _revert_step(self): + setattr(self.ori_scheduler, self.step, self.ori_step) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e638880fea8203f0f46272141a0e08bef97343 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py @@ -0,0 +1,147 @@ +from typing import List + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.sp = sp + self.cfg = cfg + self.rank_offset = rank_offset + self.world_size = sp * cfg + + self.name_to_size = { + "sp": self.sp, + "cfg": self.cfg, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b5fc8a5932ea4197cb47c7b8c3eb08206e03fc --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py @@ -0,0 +1,420 @@ +import logging +import math +import re +from abc import abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytorch_lightning as pl +import torch +import torch.distributed as dist +from packaging import version + +from vae_modules.ema import LitEma + +from utils.parallel_mgr import ( + is_vae_parallel_initialized, + initialize_vae_parallel, + get_vae_parallel_group, + get_vae_parallel_group_rank, +) + +from sgm.util import ( + instantiate_from_config, + get_obj_from_str, + default, +) +from vae_modules.cp_enc_dec import _conv_split, _conv_gather + +logpy = logging.getLogger(__name__) + + +class AbstractAutoencoder(pl.LightningModule): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None + if monitor is not None: + self.monitor = monitor + + if self.use_ema: + self.model_ema = LitEma(self, decay=ema_decay) + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if version.parse(torch.__version__) >= version.parse("2.0.0"): + self.automatic_optimization = False + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + self.init_from_ckpt(ckpt) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) + print("Missing keys: ", missing_keys) + print("Unexpected keys: ", unexpected_keys) + print(f"Restored from {path}") + + @abstractmethod + def get_input(self, batch) -> Any: + raise NotImplementedError() + + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + logpy.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + logpy.info(f"{context}: Restored training weights") + + @abstractmethod + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") + + @abstractmethod + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict())) + + def configure_optimizers(self) -> Any: + raise NotImplementedError() + + +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ + + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + regularizer_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + trainable_ae_params: Optional[List[List[str]]] = None, + ae_optimizer_args: Optional[List[dict]] = None, + trainable_disc_params: Optional[List[List[str]]] = None, + disc_optimizer_args: Optional[List[dict]] = None, + disc_start_iter: int = 0, + diff_boost_factor: float = 3.0, + ckpt_engine: Union[None, str, dict] = None, + ckpt_path: Optional[str] = None, + additional_decode_keys: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.automatic_optimization = False # pytorch lightning + self.encoder = instantiate_from_config(encoder_config) + self.decoder = instantiate_from_config(decoder_config) + self.regularization = instantiate_from_config(regularizer_config) + self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"}) + self.diff_boost_factor = diff_boost_factor + self.disc_start_iter = disc_start_iter + + if ckpt_path is not None: + assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" + logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + self.additional_decode_keys = set(default(additional_decode_keys, [])) + + def get_input(self, batch: Dict) -> torch.Tensor: + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in channels-first + # format (e.g., bchw instead if bhwc) + return batch[self.input_key] + + def get_autoencoder_params(self) -> list: + params = [] + params = params + list(self.encoder.parameters()) + params = params + list(self.decoder.parameters()) + return params + + def get_discriminator_params(self) -> list: + params = [] + return params + + def get_last_layer(self): + return self.decoder.get_last_layer() + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) + return x + + def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log + + def get_param_groups( + self, parameter_names: List[List[str]], optimizer_args: List[dict] + ) -> Tuple[List[Dict[str, Any]], int]: + groups = [] + num_params = 0 + for names, args in zip(parameter_names, optimizer_args): + params = [] + for pattern_ in names: + pattern_params = [] + pattern = re.compile(pattern_) + for p_name, param in self.named_parameters(): + if re.match(pattern, p_name): + pattern_params.append(param) + num_params += param.numel() + if len(pattern_params) == 0: + logpy.warn(f"Did not find parameters for pattern {pattern_}") + params.extend(pattern_params) + groups.append({"params": params, **args}) + return groups, num_params + + +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + ckpt_path = kwargs.pop("ckpt_path", None) + ckpt_engine = kwargs.pop("ckpt_engine", None) + super().__init__( + encoder_config={ + "target": "sgm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "sgm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params + + def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs: (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) + + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs: (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class IdentityFirstStage(AbstractAutoencoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_input(self, x: Any) -> Any: + return x + + def encode(self, x: Any, *args, **kwargs) -> Any: + return x + + def decode(self, x: Any, *args, **kwargs) -> Any: + return x + + +class VideoAutoencodingEngine(AutoencodingEngine): + def __init__( + self, + ckpt_path: Union[None, str] = None, + ignore_keys: Union[Tuple, list] = (), + image_video_weights=[1, 1], + only_train_decoder=False, + context_parallel_size=0, + **kwargs, + ): + super().__init__(**kwargs) + self.context_parallel_size = dist.get_world_size() + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + return self.log_images(batch, additional_log_kwargs, **kwargs) + + def get_input(self, batch: dict) -> torch.Tensor: + if self.context_parallel_size > 0: + if not is_vae_parallel_initialized(): + initialize_vae_parallel(self.context_parallel_size) + + batch = batch[self.input_key] + + global_src_rank = get_vae_parallel_group_rank() * self.context_parallel_size + dist.broadcast(batch, src=global_src_rank, group=get_vae_parallel_group()) + + batch = _conv_split(batch, dim=2, kernel_size=1) + return batch + + return batch[self.input_key] + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + self.init_from_ckpt(ckpt) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) + print("Missing keys: ", missing_keys) + print("Unexpected keys: ", unexpected_keys) + print(f"Restored from {path}") + + +class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): + def __init__( + self, + cp_size=0, + *args, + **kwargs, + ): + self.cp_size = dist.get_world_size() + return super().__init__(*args, **kwargs) + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + input_cp: bool = False, + output_cp: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if return_reg_log: + z, reg_log = super().encode(x, return_reg_log, unregularized) + else: + z = super().encode(x, return_reg_log, unregularized) + + if return_reg_log: + return z, reg_log + return z + + def decode( + self, + z: torch.Tensor, + input_cp: bool = False, + output_cp: bool = False, + split_kernel_size: int = 1, + use_cp: bool = True, + **kwargs, + ): + if self.cp_size <= 2: + use_cp = False + if self.cp_size > 0 and use_cp and not input_cp: + if not is_vae_parallel_initialized(): + initialize_vae_parallel(self.cp_size) + + global_src_rank = get_vae_parallel_group_rank() * self.cp_size + dist.broadcast(z, src=global_src_rank, group=get_vae_parallel_group()) + z = _conv_split(z, dim=2, kernel_size=1) + + x = super().decode(z, use_cp=use_cp, **kwargs) + + if self.cp_size > 0 and use_cp and not output_cp: + x = _conv_gather(x, dim=2, kernel_size=1) + + return x + + def forward( + self, + x: torch.Tensor, + input_cp: bool = False, + latent_cp: bool = False, + output_cp: bool = False, + **additional_decode_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp) + dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs) + return z, dec, reg_log diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py new file mode 100644 index 0000000000000000000000000000000000000000..3aca7e7e8f86d32a75f3b7ea621c16303596b947 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py @@ -0,0 +1,1005 @@ +import math +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from beartype.typing import Union, Tuple, Optional, List +from einops import rearrange + +from utils.parallel_mgr import ( + get_vae_parallel_group, + get_vae_parallel_rank, + get_vae_parallel_world_size, + get_vae_parallel_group_rank, +) + +from vae_modules.utils import SafeConv3d as Conv3d + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def exists(v): + return v is not None + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def leaky_relu(p=0.1): + return nn.LeakyReLU(p) + + +def _split(input_, dim): + cp_world_size = get_vae_parallel_world_size() + + if cp_world_size == 1: + return input_ + + cp_rank = get_vae_parallel_rank() + + inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() + input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() + dim_size = input_.size()[dim] // cp_world_size + + input_list = torch.split(input_, dim_size, dim=dim) + output = input_list[cp_rank] + + if cp_rank == 0: + output = torch.cat([inpu_first_frame_, output], dim=dim) + output = output.contiguous() + + return output + + +def _gather(input_, dim): + cp_world_size = get_vae_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + group = get_vae_parallel_group() + cp_rank = get_vae_parallel_rank() + + input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() + if cp_rank == 0: + input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() + + tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] + + if cp_rank == 0: + input_ = torch.cat([input_first_frame_, input_], dim=dim) + + tensor_list[cp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +def _conv_split(input_, dim, kernel_size): + cp_world_size = get_vae_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + cp_rank = get_vae_parallel_rank() + + dim_size = (input_.size()[dim] - kernel_size) // cp_world_size + + if cp_rank == 0: + output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) + else: + output = input_.transpose(dim, 0)[ + cp_rank * dim_size + kernel_size: (cp_rank + 1) * dim_size + kernel_size + ].transpose(dim, 0) + output = output.contiguous() + + return output + + +def _conv_gather(input_, dim, kernel_size): + cp_world_size = get_vae_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + group = get_vae_parallel_group() + cp_rank = get_vae_parallel_rank() + + # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) + + input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() + if cp_rank == 0: + input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() + else: + input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0):].transpose(0, dim).contiguous() + + tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] + if cp_rank == 0: + input_ = torch.cat([input_first_kernel_, input_], dim=dim) + + tensor_list[cp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) + + return output + + +def _pass_from_previous_rank(input_, dim, kernel_size): + # Bypass the function if kernel size is 1 + if kernel_size == 1: + return input_ + + group = get_vae_parallel_group() + cp_rank = get_vae_parallel_rank() + cp_group_rank = get_vae_parallel_group_rank() + cp_world_size = get_vae_parallel_world_size() + + # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + global_rank = torch.distributed.get_rank() + global_world_size = torch.distributed.get_world_size() + + input_ = input_.transpose(0, dim) + + # pass from last rank + send_rank = global_rank + 1 + recv_rank = global_rank - 1 + if send_rank % cp_world_size == 0: + send_rank -= cp_world_size + if recv_rank % cp_world_size == cp_world_size - 1: + recv_rank += cp_world_size + + if cp_rank < cp_world_size - 1: + req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + if cp_rank > 0: + recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() + req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + + if cp_rank == 0: + input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) + else: + req_recv.wait() + input_ = torch.cat([recv_buffer, input_], dim=0) + + input_ = input_.transpose(0, dim).contiguous() + + # print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + return input_ + + +def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None): + # Bypass the function if kernel size is 1 + if kernel_size == 1: + return input_ + + group = get_vae_parallel_group() + cp_rank = get_vae_parallel_rank() + cp_group_rank = get_vae_parallel_group_rank() + cp_world_size = get_vae_parallel_world_size() + + # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + global_rank = torch.distributed.get_rank() + global_world_size = torch.distributed.get_world_size() + + input_ = input_.transpose(0, dim) + + # pass from last rank + send_rank = global_rank + 1 + recv_rank = global_rank - 1 + if send_rank % cp_world_size == 0: + send_rank -= cp_world_size + if recv_rank % cp_world_size == cp_world_size - 1: + recv_rank += cp_world_size + + # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() + # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group) + # req_recv.wait() + recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() + if cp_rank < cp_world_size - 1: + req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + if cp_rank > 0: + req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + + if cp_rank == 0: + if cache_padding is not None: + input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0) + else: + input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) + else: + req_recv.wait() + input_ = torch.cat([recv_buffer, input_], dim=0) + + input_ = input_.transpose(0, dim).contiguous() + return input_ + + +def _drop_from_previous_rank(input_, dim, kernel_size): + input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(0, dim) + return input_ + + +class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _conv_split(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _conv_gather(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _ConvolutionPassFromPreviousRank(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _pass_from_previous_rank(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size, cache_padding): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding) + + @staticmethod + def backward(ctx, grad_output): + return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None, None + + +def conv_scatter_to_context_parallel_region(input_, dim, kernel_size): + return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size) + + +def conv_gather_from_context_parallel_region(input_, dim, kernel_size): + return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size) + + +def conv_pass_from_last_rank(input_, dim, kernel_size): + return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size) + + +def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding): + return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding) + + +class ContextParallelCausalConv3d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + time_pad = time_kernel_size - 1 + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_kernel_size = time_kernel_size + self.temporal_dim = 2 + + stride = (stride, stride, stride) + dilation = (1, 1, 1) + self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + self.cache_padding = None + + def forward(self, input_, clear_cache=True): + # if input_.shape[2] == 1: # handle image + # # first frame padding + # input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2) + # else: + # input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size) + + # padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + # input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0) + + # output_parallel = self.conv(input_parallel) + # output = output_parallel + # return output + if input_.shape[2] == 1: # handle image + # first frame padding + input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2) + else: + input_parallel = fake_cp_pass_from_previous_rank( + input_, self.temporal_dim, self.time_kernel_size, self.cache_padding + ) + + del self.cache_padding + self.cache_padding = None + if not clear_cache: + cp_rank, cp_world_size = get_vae_parallel_rank(), get_vae_parallel_world_size() + global_rank = torch.distributed.get_rank() + if cp_world_size == 1: + self.cache_padding = ( + input_parallel[:, :, -self.time_kernel_size + 1:].contiguous().detach().clone().cpu() + ) + else: + if cp_rank == cp_world_size - 1: + torch.distributed.isend( + input_parallel[:, :, -self.time_kernel_size + 1:].contiguous(), + global_rank + 1 - cp_world_size, + group=get_vae_parallel_group(), + ) + if cp_rank == 0: + recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1:]).contiguous() + torch.distributed.recv( + recv_buffer, global_rank - 1 + cp_world_size, group=get_vae_parallel_group() + ) + self.cache_padding = recv_buffer.contiguous().detach().clone().cpu() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) + + output_parallel = self.conv(input_parallel) + output = output_parallel + return output + + +class ContextParallelGroupNorm(torch.nn.GroupNorm): + def forward(self, input_): + gather_flag = input_.shape[2] > 1 + if gather_flag: + input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1) + output = super().forward(input_) + if gather_flag: + output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1) + return output + + +def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D + if gather: + return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + else: + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialNorm3D(nn.Module): + def __init__( + self, + f_channels, + zq_channels, + freeze_norm_layer=False, + add_conv=False, + pad_mode="constant", + gather=False, + **norm_layer_params, + ): + super().__init__() + if gather: + self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params) + else: + self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params) + # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + + self.add_conv = add_conv + if add_conv: + self.conv = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=zq_channels, + kernel_size=3, + ) + + self.conv_y = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=f_channels, + kernel_size=1, + ) + self.conv_b = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=f_channels, + kernel_size=1, + ) + + def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True): + if f.shape[2] > 1 and get_vae_parallel_rank() == 0 and fake_cp: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] + zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") + + zq_rest_splits = torch.split(zq_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits + ] + + zq_rest = torch.cat(interpolated_splits, dim=1) + zq = torch.cat([zq_first, zq_rest], dim=2) + else: + f_size = f.shape[-3:] + + zq_splits = torch.split(zq, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits + ] + zq = torch.cat(interpolated_splits, dim=1) + + if self.add_conv: + zq = self.conv(zq, clear_cache=clear_fake_cp_cache) + + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +def Normalize3D( + in_channels, + zq_ch, + add_conv, + gather=False, +): + return SpatialNorm3D( + in_channels, + zq_ch, + gather=gather, + freeze_norm_layer=False, + add_conv=add_conv, + num_groups=32, + eps=1e-6, + affine=True, + ) + + +class Upsample3D(nn.Module): + def __init__( + self, + in_channels, + with_conv, + compress_time=False, + ): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.compress_time = compress_time + + def forward(self, x, fake_cp=True): + if self.compress_time and x.shape[2] > 1: + if get_vae_parallel_rank() == 0 and fake_cp: + # split first frame + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") + + splits = torch.split(x_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x_rest = torch.cat(interpolated_splits, dim=1) + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + else: + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) + + else: + # only interpolate 2D + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.with_conv: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + +class DownSample3D(nn.Module): + def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None): + super().__init__() + self.with_conv = with_conv + if out_channels is None: + out_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.compress_time = compress_time + + def forward(self, x, fake_cp=True): + if self.compress_time and x.shape[2] > 1: + h, w = x.shape[-2:] + x = rearrange(x, "b c t h w -> (b h w) c t") + + if get_vae_parallel_rank() == 0 and fake_cp: + # split first frame + x_first, x_rest = x[..., 0], x[..., 1:] + + if x_rest.shape[-1] > 0: + splits = torch.split(x_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + ] + x_rest = torch.cat(interpolated_splits, dim=1) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + else: + # x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + else: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + +class ContextParallelResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + zq_ch=None, + add_conv=False, + gather_norm=False, + normalization=Normalize, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = normalization( + in_channels, + zq_ch=zq_ch, + add_conv=add_conv, + gather=gather_norm, + ) + + self.conv1 = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=out_channels, + kernel_size=3, + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = normalization( + out_channels, + zq_ch=zq_ch, + add_conv=add_conv, + gather=gather_norm, + ) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = ContextParallelCausalConv3d( + chan_in=out_channels, + chan_out=out_channels, + kernel_size=3, + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=out_channels, + kernel_size=3, + ) + else: + self.nin_shortcut = Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True): + h = x + + if zq is not None: + h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp) + else: + h = self.norm1(h) + + h = nonlinearity(h) + h = self.conv1(h, clear_cache=clear_fake_cp_cache) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp) + else: + h = self.norm2(h) + + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h, clear_cache=clear_fake_cp_cache) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x, clear_cache=clear_fake_cp_cache) + else: + x = self.nin_shortcut(x) + + return x + h + + +class ContextParallelEncoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + self.conv_in = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=self.ch, + kernel_size=3, + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + temb_channels=self.temb_ch, + gather_norm=gather_norm, + ) + ) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level < self.temporal_compress_level: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) + else: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + gather_norm=gather_norm, + ) + + self.mid.block_2 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + gather_norm=gather_norm, + ) + + # end + self.norm_out = Normalize(block_in, gather=gather_norm) + + self.conv_out = ContextParallelCausalConv3d( + chan_in=block_in, + chan_out=2 * z_channels if double_z else z_channels, + kernel_size=3, + ) + + def forward(self, x, use_cp=True): + global _USE_CP + if torch.distributed.get_world_size() <= 2: + use_cp = False + _USE_CP = use_cp + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + + return h + + +class ContextParallelDecoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if zq_ch is None: + zq_ch = z_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + self.conv_in = ContextParallelCausalConv3d( + chan_in=z_channels, + chan_out=block_in, + kernel_size=3, + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + + self.mid.block_2 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + ) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) + else: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) + self.up.insert(0, up) + + self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) + + self.conv_out = ContextParallelCausalConv3d( + chan_in=block_in, + chan_out=out_ch, + kernel_size=3, + ) + + def forward(self, z, clear_fake_cp_cache=True, use_cp=True): + global _USE_CP + if torch.distributed.get_world_size() <= 2: + use_cp = False + _USE_CP = use_cp + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + t = z.shape[2] + # z to block_in + + zq = z + h = self.conv_in(z, clear_cache=clear_fake_cp_cache) + + # middle + h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) + h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h, fake_cp=use_cp) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) + h = nonlinearity(h) + h = self.conv_out(h, clear_cache=clear_fake_cp_cache) + + return h + + def get_last_layer(self): + return self.conv_out.conv.weight diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/ema.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1f7606c2c9b68ebd2302215a9e08f9f31ed8ab --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/regularizers.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/regularizers.py new file mode 100644 index 0000000000000000000000000000000000000000..205bd4a9415f0eb350b04508545a25362c6d0449 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/regularizers.py @@ -0,0 +1,108 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + # x = self.mean + self.std * torch.randn(self.mean.shape).to( + # device=self.parameters.device + # ) + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +class AbstractRegularizer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + raise NotImplementedError() + + @abstractmethod + def get_trainable_parameters(self) -> Any: + raise NotImplementedError() + + +class IdentityRegularizer(AbstractRegularizer): + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, dict() + + def get_trainable_parameters(self) -> Any: + yield from () + + +def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + + +class DiagonalGaussianRegularizer(AbstractRegularizer): + def __init__(self, sample: bool = True): + super().__init__() + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..26402ac9facb0710e9c30e3cc113191f35a7b761 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py @@ -0,0 +1,23 @@ +import torch + + +class SafeConv3d(torch.nn.Conv3d): + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024 ** 3 + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super(SafeConv3d, self).forward(input)