From b9e82a90ce22d34dfa6c17a75379562e5b5411a9 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 23 Jan 2025 00:48:11 +0000 Subject: [PATCH 1/9] initial release --- .../foundation/cogvideox-sat/.gitignore | 4 + .../foundation/cogvideox-sat/arguments.py | 122 ++ .../cogvideox-sat/configs/cogvideox_5b.yaml | 139 ++ .../cogvideox-sat/configs/inference.yaml | 16 + .../foundation/cogvideox-sat/configs/test.txt | 1 + .../cogvideox-sat/diffusion_video.py | 140 ++ .../cogvideox-sat/dit_video_concat.py | 816 +++++++++++ .../foundation/cogvideox-sat/inference.sh | 7 + .../foundation/cogvideox-sat/sample_video.py | 286 ++++ .../foundation/cogvideox-sat/sgm/__init__.py | 3 + .../cogvideox-sat/sgm/modules/__init__.py | 6 + .../cogvideox-sat/sgm/modules/attention.py | 397 ++++++ .../sgm/modules/autoencoding/__init__.py | 0 .../sgm/modules/autoencoding/temporal_ae.py | 243 ++++ .../sgm/modules/diffusionmodules/__init__.py | 6 + .../sgm/modules/diffusionmodules/denoiser.py | 72 + .../diffusionmodules/denoiser_scaling.py | 60 + .../diffusionmodules/denoiser_weighting.py | 24 + .../modules/diffusionmodules/discretizer.py | 126 ++ .../sgm/modules/diffusionmodules/guiders.py | 87 ++ .../sgm/modules/diffusionmodules/lora.py | 362 +++++ .../sgm/modules/diffusionmodules/model.py | 653 +++++++++ .../modules/diffusionmodules/openaimodel.py | 1248 +++++++++++++++++ .../sgm/modules/diffusionmodules/sampling.py | 763 ++++++++++ .../diffusionmodules/sampling_utils.py | 155 ++ .../diffusionmodules/sigma_sampling.py | 80 ++ .../sgm/modules/diffusionmodules/util.py | 328 +++++ .../sgm/modules/diffusionmodules/wrappers.py | 41 + .../sgm/modules/encoders/__init__.py | 0 .../sgm/modules/encoders/modules.py | 275 ++++ .../sgm/modules/video_attention.py | 285 ++++ .../foundation/cogvideox-sat/sgm/util.py | 318 +++++ .../cogvideox-sat/utils/cache_mgr.py | 174 +++ .../foundation/cogvideox-sat/utils/comm.py | 182 +++ .../cogvideox-sat/utils/group_coordinator.py | 631 +++++++++ .../cogvideox-sat/utils/parallel_mgr.py | 332 +++++ .../cogvideox-sat/utils/sampling_optm.py | 226 +++ .../foundation/cogvideox-sat/utils/utils.py | 147 ++ .../cogvideox-sat/vae_modules/autoencoder.py | 435 ++++++ .../cogvideox-sat/vae_modules/cp_enc_dec.py | 983 +++++++++++++ .../cogvideox-sat/vae_modules/ema.py | 82 ++ .../cogvideox-sat/vae_modules/regularizers.py | 108 ++ .../cogvideox-sat/vae_modules/utils.py | 23 + 43 files changed, 10386 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/.gitignore create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/cogvideox_5b.yaml create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/test.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py create mode 100755 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/wrappers.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/cache_mgr.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/ema.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/regularizers.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/.gitignore b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/.gitignore new file mode 100644 index 0000000000..ffe2a35470 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/.gitignore @@ -0,0 +1,4 @@ +*__pycache__/ +.idea +output* +*.so \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py new file mode 100644 index 0000000000..1844991a54 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py @@ -0,0 +1,122 @@ +import argparse +import os +import torch +import torch_npu +import json +import warnings +import omegaconf +from omegaconf import OmegaConf +from sat.helpers import print_rank0 + + +def add_model_config_args(parser): + """Model arguments""" + + group = parser.add_argument_group("model", "model configuration") + group.add_argument("--base", type=str, nargs="*", help="config for input and saving") + group.add_argument("--device", type=int, default=-1) + group.add_argument("--log-image", type=bool, default=True) + return parser + + +def add_sampling_config_args(parser): + """Sampling configurations""" + + group = parser.add_argument_group("sampling", "Sampling Configurations") + group.add_argument("--output-dir", type=str, default="samples") + group.add_argument("--input-dir", type=str, default=None) + group.add_argument("--input-type", type=str, default="cli") + group.add_argument("--input-file", type=str, default="input.txt") + group.add_argument("--final-size", type=int, default=2048) + group.add_argument("--sdedit", action="store_true") + group.add_argument("--grid-num-rows", type=int, default=1) + group.add_argument("--force-inference", action="store_true") + group.add_argument("--lcm_steps", type=int, default=None) + group.add_argument("--sampling_image_height", type=int, default=480) + group.add_argument("--sampling_image_width", type=int, default=720) + group.add_argument("--sampling-num-frames", type=int, default=32) + group.add_argument("--sampling-fps", type=int, default=8) + group.add_argument("--only-save-latents", type=bool, default=False) + group.add_argument("--only-log-video-latents", type=bool, default=False) + group.add_argument("--latent-channels", type=int, default=32) + group.add_argument("--image2video", action="store_true") + + return parser + + +def add_inference_args(parser): + """Inference arguments.""" + + group = parser.add_argument_group('inference', 'inference configurations') + group.add_argument('--batch-size', type=int, default=4, + help='batch size on a single GPU. batch-size * world_size = total batch_size.') + group.add_argument('--mode', type=str, + default='pretrain', + choices=['pretrain', # from_scratch / load ckpt for continue pretraining. + 'finetune', # finetuning, auto-warmup 100 iters, new exp name. + 'inference' # don't train. + ], + help='what type of task to use, will influence auto-warmup, exp name, iteration') + group.add_argument('--load', type=str, default=None, + help='Path to a directory containing a model checkpoint.') + group.add_argument('--seed', type=int, default=1234, help='random seed') + group.add_argument('--zero-stage', type=int, default=0, choices=[0, 1, 2, 3], + help='deepspeed ZeRO stage. 0 means no ZeRO.') + group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode') + group.add_argument('--bf16', action='store_true', help='Run model in bf16 mode') + + return parser + + +def get_args(args_list=None, parser=None): + """Parse all the args.""" + if parser is None: + parser = argparse.ArgumentParser(description="sat") + else: + assert isinstance(parser, argparse.ArgumentParser) + parser = add_model_config_args(parser) + parser = add_sampling_config_args(parser) + parser = add_inference_args(parser) + + args = parser.parse_args(args_list) + args = process_config_to_args(args) + + args.npu = torch.npu.is_available() + + args.rank = int(os.getenv("RANK", "0")) + args.world_size = int(os.getenv("WORLD_SIZE", "1")) + args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun + + if args.device == -1: + if torch.npu.device_count() == 0: + args.device = "cpu" + else: + args.device = "npu" + + if args.rank == 0: + print_rank0("using world size: {}".format(args.world_size)) + + assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16." + + return args + +def process_config_to_args(args): + """Fetch args from only --base""" + + configs = [OmegaConf.load(cfg) for cfg in args.base] + config = OmegaConf.merge(*configs) + + args_config = config.pop("args", OmegaConf.create()) + for key in args_config: + if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig): + arg = OmegaConf.to_object(args_config[key]) + else: + arg = args_config[key] + if hasattr(args, key): + setattr(args, key, arg) + + if "model" in config: + model_config = config.pop("model", OmegaConf.create()) + args.model_config = model_config + + return args diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/cogvideox_5b.yaml b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/cogvideox_5b.yaml new file mode 100644 index 0000000000..73c2ea63eb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/cogvideox_5b.yaml @@ -0,0 +1,139 @@ +model: + scale_factor: 0.7 + disable_first_stage_autocast: true + log_keys: + - txt + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + quantize_c_noise: False + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + network_config: + target: dit_video_concat.DiffusionTransformer + params: + time_embed_dim: 512 + elementwise_affine: True + num_frames: 49 + time_compressed_rate: 4 + latent_width: 90 + latent_height: 60 + num_layers: 42 # different from cogvideox_2b_infer.yaml + patch_size: 2 + in_channels: 16 + out_channels: 16 + hidden_size: 3072 # different from cogvideox_2b_infer.yaml + adm_in_channels: 256 + num_attention_heads: 48 # different from cogvideox_2b_infer.yaml + + transformer_args: + checkpoint_activations: True + vocab_size: 1 + max_sequence_length: 64 + layernorm_order: pre + skip_init: false + model_parallel_size: 1 + is_decoder: false + + modules: + pos_embed_config: + target: dit_video_concat.Rotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml + params: + hidden_size_head: 64 + text_length: 226 + + patch_embed_config: + target: dit_video_concat.ImagePatchEmbeddingMixin + params: + text_hidden_size: 4096 + + adaln_layer_config: + target: dit_video_concat.AdaLNMixin + params: + qk_ln: True + + final_layer_config: + target: dit_video_concat.FinalLayerMixin + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + input_key: txt + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.FrozenT5Embedder + params: + model_dir: "/home/p30061221/CogVideoX-5B-SAT/5b-cogvideo" + max_length: 226 + + first_stage_config: + target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper + params: + cp_size: 1 + ckpt_path: "/home/p30061221/CogVideoX-5B-SAT/5b-cogvideo/vae/3d-vae.pt" + ignore_keys: [ 'loss' ] + + loss_config: + target: torch.nn.Identity + + regularizer_config: + target: vae_modules.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: vae_modules.cp_enc_dec.ContextParallelEncoder3D + params: + double_z: true + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: True + + decoder_config: + target: vae_modules.cp_enc_dec.ContextParallelDecoder3D + params: + double_z: True + z_channels: 16 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 2, 4 ] + attn_resolutions: [ ] + num_res_blocks: 3 + dropout: 0.0 + gather_norm: False + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler + params: + num_steps: 50 + verbose: True + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization + params: + shift_scale: 1.0 # different from cogvideox_2b_infer.yaml + + guider_config: + target: sgm.modules.diffusionmodules.guiders.DynamicCFG + params: + scale: 6 + exp: 5 + num_steps: 50 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml new file mode 100644 index 0000000000..f688264012 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml @@ -0,0 +1,16 @@ +args: + image2video: False # True for image2video, False for text2video + latent_channels: 16 + mode: inference + load: "/home/p30061221/CogVideoX-5B-SAT/5b-cogvideo/transformer" # This is for Full model without lora adapter + batch_size: 1 + input_type: txt + input_file: configs/test.txt + sampling_image_height: 480 + args.sampling_image_width: 720 + sampling_num_frames: 13 # Must be 13, 11 or 9 + sampling_fps: 8 + fp16: True # For CogVideoX-2B + # bf16: False # For CogVideoX-5B and CoGVideoX-5B-I2V + output_dir: outputs/ + force_inference: True \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/test.txt b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/test.txt new file mode 100644 index 0000000000..fbb3261770 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/test.txt @@ -0,0 +1 @@ +In the haunting backdrop of a warIn the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict. \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py new file mode 100644 index 0000000000..989feb8d0e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py @@ -0,0 +1,140 @@ +import random + +import math +from typing import Any, Dict, List, Tuple, Union +from omegaconf import ListConfig +import torch.nn.functional as F + +from sat.helpers import print_rank0 +import torch +from torch import nn + +from sgm.modules import UNCONDITIONAL_CONFIG +from sgm.modules.autoencoding.temporal_ae import VideoDecoder +from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER +from sgm.util import ( + default, + disabled_train, + get_obj_from_str, + instantiate_from_config, + log_txt_as_img, +) +import gc + + +class SATVideoDiffusionEngine(nn.Module): + def __init__(self, args, **kwargs): + super().__init__() + + model_config = args.model_config + # model args preprocess + network_config = model_config.get("network_config", None) + network_wrapper = model_config.get("network_wrapper", None) + denoiser_config = model_config.get("denoiser_config", None) + sampler_config = model_config.get("sampler_config", None) + conditioner_config = model_config.get("conditioner_config", None) + first_stage_config = model_config.get("first_stage_config", None) + scale_factor = model_config.get("scale_factor", 1.0) + latent_input = model_config.get("latent_input", False) + disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False) + compile_model = model_config.get("compile_model", False) + en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None) + + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + if args.fp16: + dtype = torch.float16 + dtype_str = "fp16" + elif args.bf16: + dtype = torch.bfloat16 + dtype_str = "bf16" + else: + dtype = torch.float32 + dtype_str = "fp32" + self.dtype = dtype + self.dtype_str = dtype_str + + network_config["params"]["dtype"] = dtype_str + model = instantiate_from_config(network_config) + self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( + model, compile_model=compile_model, dtype=dtype + ) + self.denoiser = instantiate_from_config(denoiser_config) + self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None + self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG)) + self._init_first_stage(first_stage_config) + self.latent_input = latent_input + self.scale_factor = scale_factor + self.disable_first_stage_autocast = disable_first_stage_autocast + self.device = args.device + + def _init_first_stage(self, config): + model = instantiate_from_config(config).eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + self.first_stage_model = model + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) + n_rounds = math.ceil(z.shape[0] / n_samples) + all_out = [] + with torch.autocast("npu", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + if isinstance(self.first_stage_model.decoder, VideoDecoder): + kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} + else: + kwargs = {} + out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs) + all_out.append(out) + out = torch.cat(all_out, dim=0) + return out + + @torch.no_grad() + def encode_first_stage(self, x, batch): + frame = x.shape[2] + + if frame > 1 and self.latent_input: + x = x.permute(0, 2, 1, 3, 4).contiguous() + return x * self.scale_factor # already encoded + + n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) + n_rounds = math.ceil(x.shape[0] / n_samples) + all_out = [] + with torch.autocast("npu", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples]) + all_out.append(out) + z = torch.cat(all_out, dim=0) + z = self.scale_factor * z + return z + + @torch.no_grad() + def sample( + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + prefix=None, + concat_images=None, + **kwargs, + ): + randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) + if hasattr(self, "seeded_noise"): + randn = self.seeded_noise(randn) + + if prefix is not None: + randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1) + + scale = None + scale_emb = None + + denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser( + self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs + ) + + samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb) + samples = samples.to(self.dtype) + return samples diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py new file mode 100644 index 0000000000..6c2ac4e68c --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py @@ -0,0 +1,816 @@ +from functools import partial +from einops import rearrange, repeat +import numpy as np +import math + +import torch +import torch_npu +from torch import nn +import torch.nn.functional as F +torch.ops.load_library("./sgm/modules/libPTAExtensionOPS.so") + +from sat.model.base_model import BaseModel, non_conflict +from sat.model.mixins import BaseMixin +from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default +from sat.mpu.layers import ColumnParallelLinear +from sgm.util import instantiate_from_config + +from sgm.modules.diffusionmodules.openaimodel import Timestep +from sgm.modules.diffusionmodules.util import ( + linear, + timestep_embedding, +) +from sat.ops.layernorm import LayerNorm, RMSNorm + + +class ImagePatchEmbeddingMixin(BaseMixin): + def __init__( + self, + in_channels, + hidden_size, + patch_size, + bias=True, + text_hidden_size=None, + ): + super().__init__() + self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) + if text_hidden_size is not None: + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + else: + self.text_proj = None + + def word_embedding_forward(self, input_ids, **kwargs): + # now is 3d patch + images = kwargs["images"] # (b,t,c,h,w) + B, T = images.shape[:2] + emb = images.view(-1, *images.shape[2:]) + emb = self.proj(emb) # ((b t),d,h/2,w/2) + emb = emb.view(B, T, *emb.shape[1:]) + emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d) + emb = rearrange(emb, "b t n d -> b (t n) d") + + if self.text_proj is not None: + text_emb = self.text_proj(kwargs["encoder_outputs"]) + emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d) + + emb = emb.contiguous() + return emb # (b,n_t+t*n_i,d) + + def reinit(self, parent_model=None): + w = self.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.proj.bias, 0) + del self.transformer.word_embeddings + +def attention_fn_fa_score(query_layer, key_layer, value_layer, attention_mask, + attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): + # expand head dim to query dim, if necessary + # only useful for multi-query attention + num_query_heads = query_layer.shape[1] # [b, np, s, hn] + attn_output = torch_npu.npu_fusion_attention(query_layer, key_layer, value_layer, + input_layout="BNSD", + scale=1 / math.sqrt(query_layer.shape[-1]), + head_num=num_query_heads + )[0] + return attn_output + +def get_3d_sincos_pos_embed( + embed_dim, + grid_height, + grid_width, + t_size, + cls_token=False, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, +): + """ + grid_size: int of the grid height and width + t_size: int of the temporal size + return: + pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + assert embed_dim % 4 == 0 + embed_dim_spatial = embed_dim // 4 * 3 + embed_dim_temporal = embed_dim // 4 + + # spatial + grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation + grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_height, grid_width]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) + + # temporal + grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) + + # concate: [T, H, W] order + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4] + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3] + + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) + # pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D] + + return pos_embed # [T, H*W, D] + + +def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_height, dtype=np.float32) + grid_w = np.arange(grid_width, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_height, grid_width]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class Basic2DPositionEmbeddingMixin(BaseMixin): + def __init__(self, height, width, compressed_num_frames, hidden_size, text_length=0): + super().__init__() + self.height = height + self.width = width + self.spatial_length = height * width + self.pos_embedding = nn.Parameter( + torch.zeros(1, int(text_length + self.spatial_length), int(hidden_size)), requires_grad=False + ) + + def position_embedding_forward(self, position_ids, **kwargs): + return self.pos_embedding + + def reinit(self, parent_model=None): + del self.transformer.position_embeddings + pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width) + self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + +class Basic3DPositionEmbeddingMixin(BaseMixin): + def __init__( + self, + height, + width, + compressed_num_frames, + hidden_size, + text_length=0, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, + ): + super().__init__() + self.height = height + self.width = width + self.text_length = text_length + self.compressed_num_frames = compressed_num_frames + self.spatial_length = height * width + self.num_patches = height * width * compressed_num_frames + self.pos_embedding = nn.Parameter( + torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False + ) + self.height_interpolation = height_interpolation + self.width_interpolation = width_interpolation + self.time_interpolation = time_interpolation + + def position_embedding_forward(self, position_ids, **kwargs): + if kwargs["images"].shape[1] == 1: + return self.pos_embedding[:, : self.text_length + self.spatial_length] + + return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]] + + def reinit(self, parent_model=None): + del self.transformer.position_embeddings + pos_embed = get_3d_sincos_pos_embed( + self.pos_embedding.shape[-1], + self.height, + self.width, + self.compressed_num_frames, + height_interpolation=self.height_interpolation, + width_interpolation=self.width_interpolation, + time_interpolation=self.time_interpolation, + ) + pos_embed = torch.from_numpy(pos_embed).float() + pos_embed = rearrange(pos_embed, "t n d -> (t n) d") + self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed) + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +class Rotary3DPositionEmbeddingMixin(BaseMixin): + def __init__( + self, + height, + width, + compressed_num_frames, + hidden_size, + hidden_size_head, + text_length, + theta=10000, + rot_v=False, + learnable_pos_embed=False, + ): + super().__init__() + self.rot_v = rot_v + + dim_t = hidden_size_head // 4 + dim_h = hidden_size_head // 8 * 3 + dim_w = hidden_size_head // 8 * 3 + + freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t)) + freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h)) + freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w)) + + grid_t = torch.arange(compressed_num_frames, dtype=torch.float32) + grid_h = torch.arange(height, dtype=torch.float32) + grid_w = torch.arange(width, dtype=torch.float32) + + freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t) + freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h) + freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w) + + freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + freqs = rearrange(freqs, "t h w d -> (t h w) d") + + freqs = freqs.contiguous() + freqs_sin = freqs.sin() + freqs_cos = freqs.cos() + self.register_buffer("freqs_sin", freqs_sin) + self.register_buffer("freqs_cos", freqs_cos) + + self.text_length = text_length + if learnable_pos_embed: + num_patches = height * width * compressed_num_frames + text_length + self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True) + else: + self.pos_embedding = None + + def rotary(self, t, **kwargs): + seq_len = t.shape[2] + freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0) + freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0) + return torch.ops.mindie.rope_mindie_sd(t, freqs_cos, freqs_sin, mode=1) + + def position_embedding_forward(self, position_ids, **kwargs): + if self.pos_embedding is not None: + return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]] + else: + return None + + @non_conflict + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + old_impl=attention_fn_fa_score, + **kwargs, + ): + attention_fn_default = HOOKS_DEFAULT["attention_fn"] + + query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :]) + key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :]) + if self.rot_v: + value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :]) + + return attention_fn_fa_score( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): + """ + x: (N, T/2 * S, patch_size**3 * C) + imgs: (N, T, H, W, C) + """ + if rope_position_ids is not None: + assert NotImplementedError + # do pix2struct unpatchify + L = x.shape[1] + x = x.reshape(shape=(x.shape[0], L, p, p, c)) + x = torch.einsum("nlpqc->ncplq", x) + imgs = x.reshape(shape=(x.shape[0], c, p, L * p)) + else: + b = x.shape[0] + imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p) + + return imgs + + +class FinalLayerMixin(BaseMixin): + def __init__( + self, + hidden_size, + time_embed_dim, + patch_size, + out_channels, + latent_width, + latent_height, + elementwise_affine, + ): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + self.out_channels = out_channels + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) + + self.spatial_length = latent_width * latent_height // patch_size**2 + self.latent_width = latent_width + self.latent_height = latent_height + + def final_forward(self, logits, **kwargs): + x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d) + shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + + return unpatchify( + x, + c=self.out_channels, + p=self.patch_size, + w=self.latent_width // self.patch_size, + h=self.latent_height // self.patch_size, + rope_position_ids=kwargs.get("rope_position_ids", None), + **kwargs, + ) + + def reinit(self, parent_model=None): + nn.init.xavier_uniform_(self.linear.weight) + nn.init.constant_(self.linear.bias, 0) + + +class SwiGLUMixin(BaseMixin): + def __init__(self, num_layers, in_features, hidden_features, bias=False): + super().__init__() + self.w2 = nn.ModuleList( + [ + ColumnParallelLinear( + in_features, + hidden_features, + gather_output=False, + bias=bias, + module=self, + name="dense_h_to_4h_gate", + ) + for i in range(num_layers) + ] + ) + + def mlp_forward(self, hidden_states, **kw_args): + x = hidden_states + origin = self.transformer.layers[kw_args["layer_id"]].mlp + x1 = origin.dense_h_to_4h(x) + x2 = self.w2[kw_args["layer_id"]](x) + hidden = origin.activation_func(x2) * x1 + x = origin.dense_4h_to_h(hidden) + return x + + +class AdaLNMixin(BaseMixin): + def __init__( + self, + width, + height, + hidden_size, + num_layers, + time_embed_dim, + compressed_num_frames, + qk_ln=True, + hidden_size_head=None, + elementwise_affine=True, + ): + super().__init__() + self.num_layers = num_layers + self.width = width + self.height = height + self.compressed_num_frames = compressed_num_frames + + self.adaLN_modulations = nn.ModuleList( + [nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)] + ) + + self.qk_ln = qk_ln + if qk_ln: + self.query_layernorm_list = nn.ModuleList( + [ + LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) + for _ in range(num_layers) + ] + ) + self.key_layernorm_list = nn.ModuleList( + [ + LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) + for _ in range(num_layers) + ] + ) + + def layer_forward( + self, + hidden_states, + mask, + *args, + **kwargs, + ): + text_length = kwargs["text_length"] + # hidden_states (b,(n_t+t*n_i),d) + text_hidden_states = hidden_states[:, :text_length] # (b,n,d) + img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d) + + layer = self.transformer.layers[kwargs["layer_id"]] + adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]] + + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + text_shift_msa, + text_scale_msa, + text_gate_msa, + text_shift_mlp, + text_scale_mlp, + text_gate_mlp, + ) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1) + gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = ( + gate_msa.unsqueeze(1), + gate_mlp.unsqueeze(1), + text_gate_msa.unsqueeze(1), + text_gate_mlp.unsqueeze(1), + ) + + # self full attention (b,(t n),d) + img_attention_input = layer.input_layernorm(img_hidden_states) + text_attention_input = layer.input_layernorm(text_hidden_states) + img_attention_input = modulate(img_attention_input, shift_msa, scale_msa) + text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa) + + attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d) + attention_output = layer.attention(attention_input, mask, **kwargs) + text_attention_output = attention_output[:, :text_length] # (b,n,d) + img_attention_output = attention_output[:, text_length:] # (b,(t n),d) + if self.transformer.layernorm_order == "sandwich": + text_attention_output = layer.third_layernorm(text_attention_output) + img_attention_output = layer.third_layernorm(img_attention_output) + img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d) + text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d) + + # mlp (b,(t n),d) + img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d) + text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d) + img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp) + text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp) + mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d + mlp_output = layer.mlp(mlp_input, **kwargs) + img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d) + text_mlp_output = mlp_output[:, :text_length] # language (b,n,d) + if self.transformer.layernorm_order == "sandwich": + text_mlp_output = layer.fourth_layernorm(text_mlp_output) + img_mlp_output = layer.fourth_layernorm(img_mlp_output) + + img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d) + text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d) + + hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d) + return hidden_states + + def reinit(self, parent_model=None): + for layer in self.adaLN_modulations: + nn.init.constant_(layer[-1].weight, 0) + nn.init.constant_(layer[-1].bias, 0) + + @non_conflict + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + old_impl=attention_fn_fa_score, + **kwargs, + ): + if self.qk_ln: + query_layernorm = self.query_layernorm_list[kwargs["layer_id"]] + key_layernorm = self.key_layernorm_list[kwargs["layer_id"]] + query_layer = query_layernorm(query_layer) + key_layer = key_layernorm(key_layer) + + return old_impl( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + + +str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + +class DiffusionTransformer(BaseModel): + def __init__( + self, + transformer_args, + num_frames, + time_compressed_rate, + latent_width, + latent_height, + patch_size, + in_channels, + out_channels, + hidden_size, + num_layers, + num_attention_heads, + elementwise_affine, + time_embed_dim=None, + num_classes=None, + modules={}, + input_time="adaln", + adm_in_channels=None, + parallel_output=True, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, + use_SwiGLU=False, + use_RMSNorm=False, + zero_init_y_embed=False, + **kwargs, + ): + self.latent_width = latent_width + self.latent_height = latent_height + self.patch_size = patch_size + self.num_frames = num_frames + self.time_compressed_rate = time_compressed_rate + self.spatial_length = latent_width * latent_height // patch_size**2 + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_size = hidden_size + self.model_channels = hidden_size + self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size + self.num_classes = num_classes + self.adm_in_channels = adm_in_channels + self.input_time = input_time + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.is_decoder = transformer_args.is_decoder + self.elementwise_affine = elementwise_affine + self.height_interpolation = height_interpolation + self.width_interpolation = width_interpolation + self.time_interpolation = time_interpolation + self.inner_hidden_size = hidden_size * 4 + self.zero_init_y_embed = zero_init_y_embed + try: + self.dtype = str_to_dtype[kwargs.pop("dtype")] + except: + self.dtype = torch.float32 + + if use_SwiGLU: + kwargs["activation_func"] = F.silu + elif "activation_func" not in kwargs: + approx_gelu = nn.GELU(approximate="tanh") + kwargs["activation_func"] = approx_gelu + + if use_RMSNorm: + kwargs["layernorm"] = RMSNorm + else: + kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6) + + transformer_args.num_layers = num_layers + transformer_args.hidden_size = hidden_size + transformer_args.num_attention_heads = num_attention_heads + transformer_args.parallel_output = parallel_output + super().__init__(args=transformer_args, transformer=None, **kwargs) + + module_configs = modules + self._build_modules(module_configs) + + if use_SwiGLU: + self.add_mixin( + "swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True + ) + + def _build_modules(self, module_configs): + model_channels = self.hidden_size + # time_embed_dim = model_channels * 4 + time_embed_dim = self.time_embed_dim + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + elif self.num_classes == "sequential": + assert self.adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(self.adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + if self.zero_init_y_embed: + nn.init.constant_(self.label_emb[0][2].weight, 0) + nn.init.constant_(self.label_emb[0][2].bias, 0) + else: + raise ValueError() + + pos_embed_config = module_configs["pos_embed_config"] + self.add_mixin( + "pos_embed", + instantiate_from_config( + pos_embed_config, + height=self.latent_height // self.patch_size, + width=self.latent_width // self.patch_size, + compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, + hidden_size=self.hidden_size, + ), + reinit=True, + ) + + patch_embed_config = module_configs["patch_embed_config"] + self.add_mixin( + "patch_embed", + instantiate_from_config( + patch_embed_config, + patch_size=self.patch_size, + hidden_size=self.hidden_size, + in_channels=self.in_channels, + ), + reinit=True, + ) + if self.input_time == "adaln": + adaln_layer_config = module_configs["adaln_layer_config"] + self.add_mixin( + "adaln_layer", + instantiate_from_config( + adaln_layer_config, + height=self.latent_height // self.patch_size, + width=self.latent_width // self.patch_size, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1, + hidden_size_head=self.hidden_size // self.num_attention_heads, + time_embed_dim=self.time_embed_dim, + elementwise_affine=self.elementwise_affine, + ), + ) + else: + raise NotImplementedError + + final_layer_config = module_configs["final_layer_config"] + self.add_mixin( + "final_layer", + instantiate_from_config( + final_layer_config, + hidden_size=self.hidden_size, + patch_size=self.patch_size, + out_channels=self.out_channels, + time_embed_dim=self.time_embed_dim, + latent_width=self.latent_width, + latent_height=self.latent_height, + elementwise_affine=self.elementwise_affine, + ), + reinit=True, + ) + + if "lora_config" in module_configs: + lora_config = module_configs["lora_config"] + self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True) + + return + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + b, t, d, h, w = x.shape + if x.dtype != self.dtype: + x = x.to(self.dtype) + + # This is not use in inference + if "concat_images" in kwargs and kwargs["concat_images"] is not None: + if kwargs["concat_images"].shape[0] != x.shape[0]: + concat_images = kwargs["concat_images"].repeat(2, 1, 1, 1, 1) + else: + concat_images = kwargs["concat_images"] + x = torch.cat([x, concat_images], dim=2) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + # assert y.shape[0] == x.shape[0] + assert x.shape[0] % y.shape[0] == 0 + y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) + emb = emb + self.label_emb(y) + + kwargs["seq_length"] = t * h * w // (self.patch_size**2) + kwargs["images"] = x + kwargs["emb"] = emb + kwargs["encoder_outputs"] = context + kwargs["text_length"] = context.shape[1] + + kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype) + output = super().forward(**kwargs)[0] + return output diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh new file mode 100755 index 0000000000..a1ad48c40e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh @@ -0,0 +1,7 @@ +#! /bin/bash +export TASK_QUEUE_ENABLE=2 +export HCCL_OP_EXPANSION_MODE="AIV" +torchrun --nnodes=1 --nproc_per_node 1 --master_port 29503 \ + sample_video.py \ + --base configs/cogvideox_5b.yaml configs/inference.yaml \ + --seed 42 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py new file mode 100644 index 0000000000..2d845211df --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py @@ -0,0 +1,286 @@ +import os +import math +import argparse +from arguments import get_args +from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_state, get_sequence_parallel_rank +from sat.arguments import set_random_seed + +from typing import List, Union +from tqdm import tqdm +from omegaconf import ListConfig +import imageio + +import torch +import torch_npu +import numpy as np +from einops import rearrange +import torchvision.transforms as TT + + +from sat.model.base_model import get_model +from sat.training.model_io import load_checkpoint + +from diffusion_video import SATVideoDiffusionEngine + +from torchvision.transforms.functional import center_crop, resize +from torchvision.transforms import InterpolationMode +from PIL import Image + + +def read_from_cli(): + cnt = 0 + try: + while True: + x = input("Please input English text (Ctrl-D quit): ") + yield x.strip(), cnt + cnt += 1 + except EOFError as e: + pass + + +def read_from_file(p): + with open(p, "r") as fin: + cnt = -1 + for l in fin: + cnt += 1 + yield l.strip(), cnt + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + +def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="npu"): + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() + batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + else: + batch[key] = value_dict[key] + + if T is not None: + batch["num_video_frames"] = T + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None): + os.makedirs(save_path, exist_ok=True) + + for i, vid in enumerate(video_batch): + gif_frames = [] + for frame in vid: + frame = rearrange(frame, "c h w -> h w c") + frame = (255.0 * frame).cpu().numpy().astype(np.uint8) + gif_frames.append(frame) + now_save_path = os.path.join(save_path, f"{i:06d}.mp4") + with imageio.get_writer(now_save_path, fps=fps) as writer: + for frame in gif_frames: + writer.append_data(frame) + + +def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) + + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) + + delta_h = h - image_size[0] + delta_w = w - image_size[1] + + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr + + +def sampling_main(args, model_cls): + torch.npu.config.allow_internal_format = False + if args.world_size > 1: + sp_degree = args.world_size // 2 + parallel_config = ParallelConfig(sp_degree=sp_degree, use_cfg_parallel=True, world_size=args.world_size) + init_parallel_env(parallel_config) + else: + parallel_config = ParallelConfig(sp_degree=1, use_cfg_parallel=False, world_size=args.world_size) + init_parallel_env(parallel_config) + + rank = get_sequence_parallel_rank() if get_sequence_parallel_state() else 0 + args.seed = args.seed + rank + set_random_seed(args.seed) + + if isinstance(model_cls, type): + model = get_model(args, model_cls) + else: + model = model_cls + + load_checkpoint(model, args) + model.eval() + + if args.input_type == "cli": + data_iter = read_from_cli() + elif args.input_type == "txt": + data_iter = read_from_file(args.input_file) + else: + raise NotImplementedError + + image_size = [args.sampling_image_height, args.sampling_image_width] + + if args.image2video: + chained_trainsforms = [] + chained_trainsforms.append(TT.ToTensor()) + transform = TT.Compose(chained_trainsforms) + + sample_func = model.sample + T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, args.sampling_fps + num_samples = [1] + force_uc_zero_embeddings = ["txt"] + + with torch.no_grad(): + for text, cnt in tqdm(data_iter): + if args.image2video: + text, image_path = text.split("@@") + assert os.path.exists(image_path), image_path + image = Image.open(image_path).convert("RGB") + image = transform(image).unsqueeze(0).to("npu") + image = resize_for_rectangle_crop(image, image_size, reshape_mode="center").unsqueeze(0) + image = image * 2.0 - 1.0 + image = image.unsqueeze(2).to(torch.bfloat16) + image = model.encode_first_stage(image, None) + image = image.permute(0, 2, 1, 3, 4).contiguous() + pad_shape = (image.shape[0], T - 1, C, H // F, W // F) + image = torch.concat([image, torch.zeros(pad_shape).to(image.device).to(image.dtype)], dim=1) + else: + image = None + + value_dict = { + "prompt": text, + "negative_prompt": "", + "num_frames": torch.tensor(T).unsqueeze(0), + } + + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("npu"), (c, uc)) + + if args.image2video and image is not None: + c["concat"] = image + uc["concat"] = image + + first_stage_model = model.first_stage_model + for index in range(args.batch_size): + count = 5 + experimental_config = torch_npu.profiler._ExperimentalConfig( + export_type=torch_npu.profiler.ExportType.Text, + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], #默认CPU、NPU均打开 + with_stack=True, #采集torch op的函数调用栈的开关,默认关闭 + record_shapes=True, #采集torch op的input shape和input type的开关,默认关闭 + profile_memory=True, #采集memory相关数据的开关,默认关闭 + schedule=torch_npu.profiler.schedule(wait=2, warmup=2, active=1, repeat=1), + experimental_config=experimental_config, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prof_unet") #导出tensorboard可呈现的数据形式 + ) as prof: + for i in range(count): + samples_z = sample_func( + c, + uc=uc, + batch_size=1, + shape=(T, C, H // F, W // F), + ) + prof.step() + samples_z = sample_func( + c, + uc=uc, + batch_size=1, + shape=(T, C, H // F, W // F), + ) + samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() + latent = 1.0 / model.scale_factor * samples_z + + # Decode latent serial to save GPU memory + recons = [] + loop_num = (T - 1) // 2 + for i in range(loop_num): + if i == 0: + start_frame, end_frame = 0, 3 + else: + start_frame, end_frame = i * 2 + 1, i * 2 + 3 + if i == loop_num - 1: + clear_fake_cp_cache = True + else: + clear_fake_cp_cache = False + with torch.no_grad(): + recon = first_stage_model.decode( + latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache + ) + + recons.append(recon) + + recon = torch.cat(recons, dim=2).to(torch.float32) + samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() + + save_path = os.path.join( + args.output_dir, str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:120], str(index) + ) + if rank == 0: + save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) + + finalize_parallel_env() + + +if __name__ == "__main__": + py_parser = argparse.ArgumentParser(add_help=False) + known, args_list = py_parser.parse_known_args() + + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + args.model_config.first_stage_config.params.cp_size = 1 + args.model_config.network_config.params.transformer_args.model_parallel_size = 1 + args.model_config.network_config.params.transformer_args.checkpoint_activations = False + + sampling_main(args, model_cls=SATVideoDiffusionEngine) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/__init__.py new file mode 100644 index 0000000000..5f8ebd9225 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/__init__.py @@ -0,0 +1,3 @@ +from .util import get_configs_path, instantiate_from_config + +__version__ = "0.1.0" diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/__init__.py new file mode 100644 index 0000000000..0db1d7716a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/__init__.py @@ -0,0 +1,6 @@ +from .encoders.modules import GeneralConditioner + +UNCONDITIONAL_CONFIG = { + "target": "sgm.modules.GeneralConditioner", + "params": {"emb_models": []}, +} diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py new file mode 100644 index 0000000000..7b7d7dfb38 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py @@ -0,0 +1,397 @@ +import math +from inspect import isfunction +from typing import Any, Optional + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from packaging import version +from torch import nn + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + backend=None, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.backend = backend + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + h = self.heads + + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + n_cp = x.shape[0] // n_times_crossframe_attn_in_self + k = repeat(k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + v = repeat(v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + out = torch_npu.npu_fusion_attention(q, k, v, atten_mask=mask, + input_layout="BNSD", + scale=1 / math.sqrt(q.shape[-1]), + head_num=h)[0] + + del q, k, v + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attn_mode="softmax", + sdp_backend=None, + ): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + backend=sdp_backend, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + backend=sdp_backend, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): + kwargs = {"x": x} + + if context is not None: + kwargs.update({"context": context}) + + if additional_tokens is not None: + kwargs.update({"additional_tokens": additional_tokens}) + + if n_times_crossframe_attn_in_self: + kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}) + + # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, + ) + + x + ) + x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerSingleLayerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + attn_mode="softmax", + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context) + x + x = self.ff(self.norm2(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + attn_type="softmax", + use_checkpoint=True, + # sdp_backend=SDPBackend.FLASH_ATTENTION + sdp_backend=None, + ): + super().__init__() + print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") + from omegaconf import ListConfig + + if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)): + context_dim = [context_dim] + if exists(context_dim) and isinstance(context_dim, list): + if depth != len(context_dim): + print( + f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " + f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." + ) + # depth does not match context dims. + assert all( + map(lambda x: x == context_dim[0], context_dim) + ), "need homogenous context_dim to match depth automatically" + context_dim = depth * [context_dim[0]] + elif context_dim is None: + context_dim = [None] * depth + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + attn_mode=attn_type, + checkpoint=use_checkpoint, + sdp_backend=sdp_backend, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + if i > 0 and len(context) == 1: + i = 0 # use same context for each block + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py new file mode 100644 index 0000000000..890bcfa0c9 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py @@ -0,0 +1,243 @@ +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +from sgm.modules.diffusionmodules.model import ( + AttnBlock, + Decoder, + ResnetBlock, +) +from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding +from sgm.modules.video_attention import VideoTransformerBlock +from sgm.util import partialclass + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + if timesteps is None: + timesteps = self.timesteps + + b, c, h, w = x.shape + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps, skip_video=False): + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class VideoBlock(AttnBlock): + def __init__(self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = VideoTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + attn_mode="softmax", + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + torch.nn.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + torch.nn.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps, skip_video=False): + if skip_video: + return super().forward(x) + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + if attn_type == "vanilla": + assert attn_kwargs is None + return partialclass(VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy) + else: + return NotImplementedError() + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return self.conv_out.time_mix_conv.weight if not skip_time_mix else self.conv_out.weight + + def _make_attn(self) -> Callable: + if self.time_mode not in ["conv-only", "only-last-conv"]: + return partialclass( + make_time_attn, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_attn() + + def _make_conv(self) -> Callable: + if self.time_mode != "attn-only": + return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + else: + return Conv2DWrapper + + def _make_resblock(self) -> Callable: + if self.time_mode not in ["attn-only", "only-last-conv"]: + return partialclass( + VideoResBlock, + video_kernel_size=self.video_kernel_size, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_resblock() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000..fccebf954f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/__init__.py @@ -0,0 +1,6 @@ +from .denoiser import Denoiser +from .discretizer import Discretization +from .model import Decoder, Encoder, Model +from .openaimodel import UNetModel +from .sampling import BaseDiffusionSampler +from .wrappers import OpenAIWrapper diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py new file mode 100644 index 0000000000..3cc01e36f8 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py @@ -0,0 +1,72 @@ +from typing import Dict, Union + +import torch +import torch.nn as nn + +from ...util import append_dims, instantiate_from_config + + +class Denoiser(nn.Module): + def __init__(self, weighting_config, scaling_config): + super().__init__() + + self.weighting = instantiate_from_config(weighting_config) + self.scaling = instantiate_from_config(scaling_config) + + def possibly_quantize_sigma(self, sigma): + return sigma + + def possibly_quantize_c_noise(self, c_noise): + return c_noise + + def w(self, sigma): + return self.weighting(sigma) + + def forward( + self, + network: nn.Module, + input: torch.Tensor, + sigma: torch.Tensor, + cond: Dict, + **additional_model_inputs, + ) -> torch.Tensor: + sigma = self.possibly_quantize_sigma(sigma) + sigma_shape = sigma.shape + sigma = append_dims(sigma, input.ndim) + c_skip, c_out, c_in, c_noise = self.scaling(sigma, **additional_model_inputs) + c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) + return network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip + + +class DiscreteDenoiser(Denoiser): + def __init__( + self, + weighting_config, + scaling_config, + num_idx, + discretization_config, + do_append_zero=False, + quantize_c_noise=True, + flip=True, + ): + super().__init__(weighting_config, scaling_config) + sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + self.sigmas = sigmas + # self.register_buffer("sigmas", sigmas) + self.quantize_c_noise = quantize_c_noise + + def sigma_to_idx(self, sigma): + dists = sigma - self.sigmas.to(sigma.device)[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def idx_to_sigma(self, idx): + return self.sigmas.to(idx.device)[idx] + + def possibly_quantize_sigma(self, sigma): + return self.idx_to_sigma(self.sigma_to_idx(sigma)) + + def possibly_quantize_c_noise(self, c_noise): + if self.quantize_c_noise: + return self.sigma_to_idx(c_noise) + else: + return c_noise diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py new file mode 100644 index 0000000000..2cb9643014 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from typing import Any, Tuple + +import torch + + +class DenoiserScaling(ABC): + @abstractmethod + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + pass + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class EpsScaling: + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = torch.ones_like(sigma, device=sigma.device) + c_out = -sigma + c_in = 1 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScaling: + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScalingWithEDMcNoise(DenoiserScaling): + def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class VideoScaling: # similar to VScaling + def __call__( + self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = alphas_cumprod_sqrt + c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5) + c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device) + c_noise = additional_model_inputs["idx"].clone() + return c_skip, c_out, c_in, c_noise diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py new file mode 100644 index 0000000000..b8b03ca58f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py @@ -0,0 +1,24 @@ +import torch + + +class UnitWeighting: + def __call__(self, sigma): + return torch.ones_like(sigma, device=sigma.device) + + +class EDMWeighting: + def __init__(self, sigma_data=0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma): + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + +class VWeighting(EDMWeighting): + def __init__(self): + super().__init__(sigma_data=1.0) + + +class EpsWeighting: + def __call__(self, sigma): + return sigma**-2.0 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py new file mode 100644 index 0000000000..a86b7d8dfb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py @@ -0,0 +1,126 @@ +from abc import abstractmethod +from functools import partial + +import numpy as np +import torch + +from ...modules.diffusionmodules.util import make_beta_schedule +from ...util import append_zero + + +def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray: + return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] + + +class Discretization: + def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False): + if return_idx: + sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx) + else: + sigmas = self.get_sigmas(n, device=device, return_idx=return_idx) + sigmas = append_zero(sigmas) if do_append_zero else sigmas + if return_idx: + return sigmas if not flip else torch.flip(sigmas, (0,)), idx + else: + return sigmas if not flip else torch.flip(sigmas, (0,)) + + @abstractmethod + def get_sigmas(self, n, device): + pass + + +class EDMDiscretization(Discretization): + def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def get_sigmas(self, n, device="cpu"): + ramp = torch.linspace(0, 1, n, device=device) + min_inv_rho = self.sigma_min ** (1 / self.rho) + max_inv_rho = self.sigma_max ** (1 / self.rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho + return sigmas + + +class LegacyDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + ): + super().__init__() + self.num_timesteps = num_timesteps + betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + def get_sigmas(self, n, device="cpu"): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029 + + +class ZeroSNRDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale) + keep_start=False, + post_shift=False, + ): + super().__init__() + if keep_start and not post_shift: + linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start) + self.num_timesteps = num_timesteps + betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + # SNR shift + if not post_shift: + self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod) + + self.post_shift = post_shift + self.shift_scale = shift_scale + + def get_sigmas(self, n, device="cpu", return_idx=False): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + alphas_cumprod = to_torch(alphas_cumprod) + alphas_cumprod_sqrt = alphas_cumprod.sqrt() + alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone() + alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone() + + alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T + alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T) + + if self.post_shift: + alphas_cumprod_sqrt = ( + alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2) + ) ** 0.5 + + if return_idx: + return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps + else: + return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py new file mode 100644 index 0000000000..7ce657c392 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py @@ -0,0 +1,87 @@ +import logging +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Union +from functools import partial +import math + +import torch +from einops import rearrange, repeat + +from ...util import append_dims, default, instantiate_from_config + + +class Guider(ABC): + @abstractmethod + def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: + pass + + def prepare_inputs(self, x: torch.Tensor, s: float, c: Dict, uc: Dict) -> Tuple[torch.Tensor, float, Dict]: + pass + + +class VanillaCFG: + """ + implements parallelized CFG + """ + + def __init__(self, scale, dyn_thresh_config=None): + self.scale = scale + scale_schedule = lambda scale, sigma: scale # independent of step + self.scale_schedule = partial(scale_schedule, scale) + self.dyn_thresh = instantiate_from_config( + default( + dyn_thresh_config, + {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, + ) + ) + + def __call__(self, x, sigma, scale=None): + x_u, x_c = x.chunk(2) + scale_value = default(scale, self.scale_schedule(sigma)) + x_pred = self.dyn_thresh(x_u, x_c, scale_value) + return x_pred + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class DynamicCFG(VanillaCFG): + def __init__(self, scale, exp, num_steps, dyn_thresh_config=None): + super().__init__(scale, dyn_thresh_config) + scale_schedule = ( + lambda scale, sigma, step_index: 1 + scale * (1 - math.cos(math.pi * (step_index / num_steps) ** exp)) / 2 + ) + self.scale_schedule = partial(scale_schedule, scale) + self.dyn_thresh = instantiate_from_config( + default( + dyn_thresh_config, + {"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"}, + ) + ) + + def __call__(self, x, sigma, step_index, scale=None): + x_u, x_c = x.chunk(2) + scale_value = self.scale_schedule(sigma, step_index.item()) + x_pred = self.dyn_thresh(x_u, x_c, scale_value) + return x_pred + + +class IdentityGuider: + def __call__(self, x, sigma): + return x + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + c_out[k] = c[k] + + return x, s, c_out diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py new file mode 100644 index 0000000000..7ccd72a19f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py @@ -0,0 +1,362 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import nn + + +class LoRALinearLayer(nn.Module): + def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + self.out_features = out_features + self.in_features = in_features + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRAConv2dLayer(nn.Module): + def __init__( + self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + ): + super().__init__() + + self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + # according to the official kohya_ss trainer kernel_size are always fixed for the up layer + # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 + self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) + + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states): + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + + +class LoRACompatibleConv(nn.Conv2d): + """ + A convolutional layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, scale: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + self.scale = scale + + def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale=1.0): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) + fusion = fusion.reshape((w_orig.shape)) + fused_weight = w_orig + (lora_scale * fusion) + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.data.dtype, fused_weight.data.device + + self.w_up = self.w_up.to(device=device).float() + self.w_down = self.w_down.to(device).float() + + fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) + fusion = fusion.reshape((fused_weight.shape)) + unfused_weight = fused_weight.float() - (self._lora_scale * fusion) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states, scale: float = None): + if scale is None: + scale = self.scale + if self.lora_layer is None: + # make sure to the functional Conv2D function as otherwise torch.compile's graph will break + # see: https://github.com/huggingface/diffusers/pull/4315 + return F.conv2d( + hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + else: + return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + + +class LoRACompatibleLinear(nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, scale: float = 1.0, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + self.scale = scale + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale=1.0): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (hasattr(self, "w_up") and hasattr(self, "w_down")): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states, scale: float = None): + if scale is None: + scale = self.scale + if self.lora_layer is None: + out = super().forward(hidden_states) + return out + else: + out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + return out + + +def _find_children( + model, + search_class: List[Type[nn.Module]] = [nn.Linear], +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if any([isinstance(module, _class) for _class in search_class]): + yield parent, name, module + + +def _find_modules_v2( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoRACompatibleLinear, + LoRACompatibleConv, + LoRALinearLayer, + LoRAConv2dLayer, + ], +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = (module for module in model.modules() if module.__class__.__name__ in ancestor_class) + else: + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + flag = False + while path: + try: + parent = parent.get_submodule(path.pop(0)) + except: + flag = True + break + if flag: + continue + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any([isinstance(parent, _class) for _class in exclude_children_of]): + continue + # Otherwise, yield it + yield parent, name, module + + +_find_modules = _find_modules_v2 + + +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = None, + rank: int = 4, + scale: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + lora_layer = LoRALinearLayer( + in_features=_child_module.in_features, + out_features=_child_module.out_features, + rank=rank, + ) + _tmp = ( + LoRACompatibleLinear( + _child_module.in_features, + _child_module.out_features, + lora_layer=lora_layer, + scale=scale, + ) + .to(weight.dtype) + .to(weight.device) + ) + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + lora_layer = LoRAConv2dLayer( + in_features=_child_module.in_channels, + out_features=_child_module.out_channels, + rank=rank, + kernel_size=_child_module.kernel_size, + stride=_child_module.stride, + padding=_child_module.padding, + ) + _tmp = ( + LoRACompatibleConv( + _child_module.in_channels, + _child_module.out_channels, + kernel_size=_child_module.kernel_size, + stride=_child_module.stride, + padding=_child_module.padding, + lora_layer=lora_layer, + scale=scale, + ) + .to(weight.dtype) + .to(weight.device) + ) + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + else: + continue + + _module._modules[name] = _tmp + # print('injecting lora layer to', _module, name) + + return + + +def update_lora_scale( + model: nn.Module, + target_module: Set[str] = None, + scale: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_module, search_class=[LoRACompatibleLinear, LoRACompatibleConv] + ): + _child_module.scale = scale + + return diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000..2aa75acb1b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py @@ -0,0 +1,653 @@ +# pytorch_diffusion + derived encoder decoder +import math +from typing import Any, Callable, Optional + +import numpy as np +import torch +import torch_npu +import torch.nn as nn +from einops import rearrange +from packaging import version + +from ..attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + h_ = torch_npu.npu_fusion_attention(q, k, v, atten_mask=None, + input_layout="BNSD", + scale=1 / math.sqrt(c), + head_num=1)[0] + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.attention_op: Optional[Any] = None + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) + return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + make_attn_cls = self._make_attn() + make_resblock_cls = self._make_resblock() + make_conv_cls = self._make_conv() + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) + self.mid.block_2 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + make_resblock_cls( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn_cls(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = make_conv_cls(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def _make_attn(self) -> Callable: + return make_attn + + def _make_resblock(self) -> Callable: + return ResnetBlock + + def _make_conv(self) -> Callable: + return torch.nn.Conv2d + + def get_last_layer(self, **kwargs): + return self.conv_out.weight + + def forward(self, z, **kwargs): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, **kwargs) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, **kwargs) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h, **kwargs) + if self.tanh_out: + h = torch.tanh(h) + return h diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000..f61ebe5812 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1248 @@ +import os +import math +from abc import abstractmethod +from functools import partial +from typing import Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...modules.attention import SpatialTransformer +from ...modules.diffusionmodules.util import ( + avg_pool_nd, + checkpoint, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) +from ...modules.diffusionmodules.lora import inject_trainable_lora_extended, update_lora_scale +from ...modules.video_attention import SpatialVideoTransformer +from ...util import default, exists + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + context: Optional[th.Tensor] = None, + image_only_indicator: Optional[th.Tensor] = None, + time_context: Optional[int] = None, + num_video_frames: Optional[int] = None, + ): + from ...modules.diffusionmodules.video_model import VideoResBlock + + for layer in self: + module = layer + + if isinstance(module, TimestepBlock) and not isinstance(module, VideoResBlock): + x = layer(x, emb) + elif isinstance(module, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(module, SpatialVideoTransformer): + x = layer( + x, + context, + time_context, + num_video_frames, + image_only_indicator, + ) + elif isinstance(module, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.third_up = third_up + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + t_factor = 1 if not self.third_up else 2 + x = F.interpolate( + x, + (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode="nearest", + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) + if use_conv: + print(f"Building a Downsample layer with {dims} dims.") + print( + f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " + f"kernel-size: 3, stride: {stride}, padding: {padding}" + ) + if dims == 3: + print(f" --> Downsampling third axis (time): {third_down}") + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, Iterable): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.skip_t_emb = skip_t_emb + self.emb_out_channels = 2 * self.out_channels if use_scale_shift_norm else self.out_channels + if self.skip_t_emb: + print(f"Skipping timestep embedding in {self.__class__.__name__}") + assert not self.use_scale_shift_norm + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + self.emb_out_channels, + ), + ) + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd( + dims, + self.out_channels, + self.out_channels, + kernel_size, + padding=padding, + ) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + if self.skip_t_emb: + emb_out = th.zeros_like(h) + else: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, **kwargs): + # TODO add crossframe attention and use mixed checkpoint + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +str_to_dtype = {"fp32": th.float32, "fp16": th.float16, "bf16": th.bfloat16} + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + spatial_transformer_attn_type="softmax", + adm_in_channels=None, + use_fairscale_checkpoint=False, + offload_to_cpu=False, + transformer_depth_middle=None, + dtype="fp32", + lora_init=False, + lora_rank=4, + lora_scale=1.0, + lora_weight_path=None, + ): + super().__init__() + from omegaconf.listconfig import ListConfig + + self.dtype = str_to_dtype[dtype] + + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert num_heads != -1, "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + elif isinstance(transformer_depth, ListConfig): + transformer_depth = list(transformer_depth) + transformer_depth_middle = default(transformer_depth_middle, transformer_depth[-1]) + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + # self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + if use_fp16: + print("WARNING: use_fp16 was dropped and has no effect anymore.") + # self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + assert use_fairscale_checkpoint != use_checkpoint or not (use_checkpoint or use_fairscale_checkpoint) + + self.use_fairscale_checkpoint = False + checkpoint_wrapper_fn = ( + partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu) + if self.use_fairscale_checkpoint + else lambda x: x + ) + + time_embed_dim = model_channels * 4 + self.time_embed = checkpoint_wrapper_fn( + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = checkpoint_wrapper_fn( + nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + ) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ), + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + checkpoint_wrapper_fn( + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + checkpoint_wrapper_fn( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if not use_spatial_transformer + else checkpoint_wrapper_fn( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + checkpoint_wrapper_fn( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + ) + if self.predict_codebook_ids: + self.id_predictor = checkpoint_wrapper_fn( + nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + ) + + if lora_init: + self._init_lora(lora_rank, lora_scale, lora_weight_path) + + def _init_lora(self, rank, scale, ckpt_dir=None): + inject_trainable_lora_extended(self, target_replace_module=None, rank=rank, scale=scale) + + if ckpt_dir is not None: + with open(os.path.join(ckpt_dir, "latest")) as latest_file: + latest = latest_file.read().strip() + ckpt_path = os.path.join(ckpt_dir, latest, "mp_rank_00_model_states.pt") + print(f"loading lora from {ckpt_path}") + sd = th.load(ckpt_path)["module"] + sd = { + key[len("model.diffusion_model") :]: sd[key] for key in sd if key.startswith("model.diffusion_model") + } + self.load_state_dict(sd, strict=False) + + def _update_scale(self, scale): + update_lora_scale(self, scale) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # h = x.type(self.dtype) + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + assert False, "not supported anymore. what the f*** are you doing?" + else: + return self.out(h) + + +class NoTimeUNetModel(UNetModel): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + timesteps = th.zeros_like(timesteps) + return super().forward(x, timesteps, context, y, **kwargs) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + # h = x.type(self.dtype) + h = x + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + + +if __name__ == "__main__": + + class Dummy(nn.Module): + def __init__(self, in_channels=3, model_channels=64): + super().__init__() + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(2, in_channels, model_channels, 3, padding=1))] + ) + + model = UNetModel( + use_checkpoint=True, + image_size=64, + in_channels=4, + out_channels=4, + model_channels=128, + attention_resolutions=[4, 2], + num_res_blocks=2, + channel_mult=[1, 2, 4], + num_head_channels=64, + use_spatial_transformer=False, + use_linear_in_transformer=True, + transformer_depth=1, + legacy=False, + ).npu() + x = th.randn(11, 4, 64, 64).npu() + t = th.randint(low=0, high=10, size=(11,), device="npu") + o = model(x, t) + print("done.") diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py new file mode 100644 index 0000000000..55e889bfd3 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py @@ -0,0 +1,763 @@ +""" +Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py +""" + +from typing import Dict, Union + +import torch +from omegaconf import ListConfig, OmegaConf +from tqdm import tqdm + +from ...modules.diffusionmodules.sampling_utils import ( + get_ancestral_step, + linear_multistep_coeff, + to_d, + to_neg_log_sigma, + to_sigma, +) +from ...util import append_dims, default, instantiate_from_config +from ...util import SeededNoise + +from .guiders import DynamicCFG + +DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + + +class BaseDiffusionSampler: + def __init__( + self, + discretization_config: Union[Dict, ListConfig, OmegaConf], + num_steps: Union[int, None] = None, + guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, + verbose: bool = False, + device: str = "npu", + ): + self.num_steps = num_steps + self.discretization = instantiate_from_config(discretization_config) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) + self.verbose = verbose + self.device = device + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device) + uc = default(uc, cond) + + x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) + num_sigmas = len(sigmas) + + s_in = x.new_ones([x.shape[0]]).float() + + return x, s_in, sigmas, num_sigmas, cond, uc + + def denoise(self, x, denoiser, sigma, cond, uc): + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) + denoised = self.guider(denoised, sigma) + return denoised + + def get_sigma_gen(self, num_sigmas): + sigma_generator = range(num_sigmas - 1) + if self.verbose: + print("#" * 30, " Sampling setting ", "#" * 30) + print(f"Sampler: {self.__class__.__name__}") + print(f"Discretization: {self.discretization.__class__.__name__}") + print(f"Guider: {self.guider.__class__.__name__}") + sigma_generator = tqdm( + sigma_generator, + total=num_sigmas, + desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", + ) + return sigma_generator + + +class SingleStepDiffusionSampler(BaseDiffusionSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): + raise NotImplementedError + + def euler_step(self, x, d, dt): + return x + dt * d + + +class EDMSampler(SingleStepDiffusionSampler): + def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.s_churn = s_churn + self.s_tmin = s_tmin + self.s_tmax = s_tmax + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + sigma_hat = sigma * (gamma + 1.0) + if gamma > 0: + eps = torch.randn_like(x) * self.s_noise + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 + + denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + d = to_d(x, sigma_hat, denoised) + dt = append_dims(next_sigma - sigma_hat, x.ndim) + + euler_step = self.euler_step(x, d, dt) + x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class DDIMSampler(SingleStepDiffusionSampler): + def __init__(self, s_noise=0.1, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + d = to_d(x, sigma, denoised) + dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim) + + euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) + + x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + self.s_noise, + ) + + return x + + +class AncestralSampler(SingleStepDiffusionSampler): + def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta = eta + self.s_noise = s_noise + self.noise_sampler = lambda x: torch.randn_like(x) + + def ancestral_euler_step(self, x, denoised, sigma, sigma_down): + d = to_d(x, sigma, denoised) + dt = append_dims(sigma_down - sigma, x.ndim) + + return self.euler_step(x, d, dt) + + def ancestral_step(self, x, sigma, next_sigma, sigma_up): + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, + x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), + x, + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) + + return x + + +class LinearMultistepSampler(BaseDiffusionSampler): + def __init__( + self, + order=4, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.order = order + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + ds = [] + sigmas_cpu = sigmas.detach().cpu().numpy() + for i in self.get_sigma_gen(num_sigmas): + sigma = s_in * sigmas[i] + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs) + denoised = self.guider(denoised, sigma) + d = to_d(x, sigma, denoised) + ds.append(d) + if len(ds) > self.order: + ds.pop(0) + cur_order = min(i + 1, self.order) + coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + + return x + + +class EulerEDMSampler(EDMSampler): + def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + return euler_step + + +class HeunEDMSampler(EDMSampler): + def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc): + if torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 + return euler_step + else: + denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) + d_new = to_d(euler_step, next_sigma, denoised) + d_prime = (d + d_new) / 2.0 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step) + return x + + +class EulerAncestralSampler(AncestralSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + + return x + + +class DPMPP2SAncestralSampler(AncestralSampler): + def get_variables(self, sigma, sigma_down): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] + h = t_next - t + s = t + 0.5 * h + return h, s, t, t_next + + def get_mult(self, h, s, t, t_next): + mult1 = to_sigma(s) / to_sigma(t) + mult2 = (-0.5 * h).expm1() + mult3 = to_sigma(t_next) / to_sigma(t) + mult4 = (-h).expm1() + + return mult1, mult2, mult3, mult4 + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + + if torch.sum(sigma_down) < 1e-14: + # Save a network evaluation if all noise levels are 0 + x = x_euler + else: + h, s, t, t_next = self.get_variables(sigma, sigma_down) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)] + + x2 = mult[0] * x - mult[1] * denoised + denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) + x_dpmpp2s = mult[2] * x - mult[3] * denoised2 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) + + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + return x + + +class DPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) + mult2 = (-h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + # apply correction if noise level is not 0 and not first step + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x + + +class SDEDPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) * (-h).exp() + mult2 = (-2 * h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)] + mult_noise = append_dims(next_sigma * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + + x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) + + # apply correction if noise level is not 0 and not first step + x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x + + +class SdeditEDMSampler(EulerEDMSampler): + def __init__(self, edit_ratio=0.5, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.edit_ratio = edit_ratio + + def __call__(self, denoiser, image, randn, cond, uc=None, num_steps=None, edit_ratio=None): + randn_unit = randn.clone() + randn, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(randn, cond, uc, num_steps) + + if num_steps is None: + num_steps = self.num_steps + if edit_ratio is None: + edit_ratio = self.edit_ratio + x = None + + for i in self.get_sigma_gen(num_sigmas): + if i / num_steps < edit_ratio: + continue + if x is None: + x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape)) + + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class VideoDDIMSampler(BaseDiffusionSampler): + def __init__(self, fixed_frames=0, sdedit=False, **kwargs): + super().__init__(**kwargs) + self.fixed_frames = fixed_frames + self.sdedit = sdedit + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + alpha_cumprod_sqrt, timesteps = self.discretization( + self.num_steps if num_steps is None else num_steps, + device=self.device, + return_idx=True, + do_append_zero=False, + ) + alpha_cumprod_sqrt = torch.cat([alpha_cumprod_sqrt, alpha_cumprod_sqrt.new_ones([1])]) + timesteps = torch.cat([torch.tensor(list(timesteps)).new_zeros([1]) - 1, torch.tensor(list(timesteps))]) + + uc = default(uc, cond) + + num_sigmas = len(alpha_cumprod_sqrt) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps + + def denoise(self, x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep=None, idx=None, scale=None, scale_emb=None): + additional_model_inputs = {} + + if isinstance(scale, torch.Tensor) == False and scale == 1: + additional_model_inputs["idx"] = x.new_ones([x.shape[0]]) * timestep + if scale_emb is not None: + additional_model_inputs["scale_emb"] = scale_emb + denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) + else: + additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) + denoised = denoiser( + *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs + ).to(torch.float32) + if isinstance(self.guider, DynamicCFG): + denoised = self.guider( + denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale + ) + else: + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale) + return denoised + + def sampler_step( + self, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ): + denoised = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb + ).to(torch.float32) + + a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t + + x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) + + return x + + +class VPSDEDPMPP2MSampler(VideoDDIMSampler): + def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): + alpha_cumprod = alpha_cumprod_sqrt**2 + lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt**2 + lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + h = lamb_next - lamb + + if previous_alpha_cumprod_sqrt is not None: + previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 + lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): + mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp() + mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt + + if previous_alpha_cumprod_sqrt is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ): + denoised = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb + ).to(torch.float32) + if idx == 1: + return denoised, denoised + + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + ] + mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + + x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) + if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) + + x = x_advanced + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, scale_emb=None): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + if self.fixed_frames > 0: + prefix_frames = x[:, : self.fixed_frames] + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + if self.fixed_frames > 0: + if self.sdedit: + rd = torch.randn_like(prefix_frames) + noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( + s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) + ) + x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1) + else: + x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) + + if self.fixed_frames > 0: + x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + + return x + + +class VPODEDPMPP2MSampler(VideoDDIMSampler): + def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): + alpha_cumprod = alpha_cumprod_sqrt**2 + lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() + next_alpha_cumprod = next_alpha_cumprod_sqrt**2 + lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() + h = lamb_next - lamb + + if previous_alpha_cumprod_sqrt is not None: + previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 + lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): + mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + mult2 = (-h).expm1() * next_alpha_cumprod_sqrt + + if previous_alpha_cumprod_sqrt is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + ): + denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) + if idx == 1: + return denoised, denoised + + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + ] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + x = x_advanced + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, scale=None, **kwargs): + x, s_in, alpha_cumprod_sqrt, num_sigmas, cond, uc, timesteps = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + ) + + return x diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py new file mode 100644 index 0000000000..9bb1fa8296 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py @@ -0,0 +1,155 @@ +import torch +from scipy import integrate + +from ...util import append_dims +from einops import rearrange + + +class NoDynamicThresholding: + def __call__(self, uncond, cond, scale): + scale = append_dims(scale, cond.ndim) if isinstance(scale, torch.Tensor) else scale + return uncond + scale * (cond - uncond) + + +class StaticThresholding: + def __call__(self, uncond, cond, scale): + result = uncond + scale * (cond - uncond) + result = torch.clamp(result, min=-1.0, max=1.0) + return result + + +def dynamic_threshold(x, p=0.95): + N, T, C, H, W = x.shape + x = rearrange(x, "n t c h w -> n c (t h w)") + l, r = x.quantile(q=torch.tensor([1 - p, p], device=x.device), dim=-1, keepdim=True) + s = torch.maximum(-l, r) + threshold_mask = (s > 1).expand(-1, -1, H * W * T) + if threshold_mask.any(): + x = torch.where(threshold_mask, x.clamp(min=-1 * s, max=s), x) + x = rearrange(x, "n c (t h w) -> n t c h w", t=T, h=H, w=W) + return x + + +def dynamic_thresholding2(x0): + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) + x0 = torch.clamp(x0, -s, s) # / s + return x0.to(origin_dtype) + + +def latent_dynamic_thresholding(x0): + p = 0.9995 + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0), p, dim=2) + s = append_dims(s, x0.dim()) + x0 = torch.clamp(x0, -s, s) / s + return x0.to(origin_dtype) + + +def dynamic_thresholding3(x0): + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + origin_dtype = x0.dtype + x0 = x0.to(torch.float32) + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = append_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), x0.dim()) + x0 = torch.clamp(x0, -s, s) # / s + return x0.to(origin_dtype) + + +class DynamicThresholding: + def __call__(self, uncond, cond, scale): + mean = uncond.mean() + std = uncond.std() + result = uncond + scale * (cond - uncond) + result_mean, result_std = result.mean(), result.std() + result = (result - result_mean) / result_std * std + # result = dynamic_thresholding3(result) + return result + + +class DynamicThresholdingV1: + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def __call__(self, uncond, cond, scale): + result = uncond + scale * (cond - uncond) + unscaled_result = result / self.scale_factor + B, T, C, H, W = unscaled_result.shape + flattened = rearrange(unscaled_result, "b t c h w -> b c (t h w)") + means = flattened.mean(dim=2).unsqueeze(2) + recentered = flattened - means + magnitudes = recentered.abs().max() + normalized = recentered / magnitudes + thresholded = latent_dynamic_thresholding(normalized) + denormalized = thresholded * magnitudes + uncentered = denormalized + means + unflattened = rearrange(uncentered, "b c (t h w) -> b t c h w", t=T, h=H, w=W) + scaled_result = unflattened * self.scale_factor + return scaled_result + + +class DynamicThresholdingV2: + def __call__(self, uncond, cond, scale): + B, T, C, H, W = uncond.shape + diff = cond - uncond + mim_target = uncond + diff * 4.0 + cfg_target = uncond + diff * 8.0 + + mim_flattened = rearrange(mim_target, "b t c h w -> b c (t h w)") + cfg_flattened = rearrange(cfg_target, "b t c h w -> b c (t h w)") + mim_means = mim_flattened.mean(dim=2).unsqueeze(2) + cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2) + mim_centered = mim_flattened - mim_means + cfg_centered = cfg_flattened - cfg_means + + mim_scaleref = mim_centered.std(dim=2).unsqueeze(2) + cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2) + + cfg_renormalized = cfg_centered / cfg_scaleref * mim_scaleref + + result = cfg_renormalized + cfg_means + unflattened = rearrange(result, "b c (t h w) -> b t c h w", t=T, h=H, w=W) + + return unflattened + + +def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): + if order - 1 > i: + raise ValueError(f"Order {order} too high for step {i}") + + def fn(tau): + prod = 1.0 + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + + return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): + if not eta: + return sigma_to, 0.0 + sigma_up = torch.minimum( + sigma_to, + eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + return sigma_down, sigma_up + + +def to_d(x, sigma, denoised): + return (x - denoised) / append_dims(sigma, x.ndim) + + +def to_neg_log_sigma(sigma): + return sigma.log().neg() + + +def to_sigma(neg_log_sigma): + return neg_log_sigma.neg().exp() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py new file mode 100644 index 0000000000..770de4254e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py @@ -0,0 +1,80 @@ +import torch +import torch.distributed + +from sat import mpu + +from ...util import default, instantiate_from_config + + +class EDMSampling: + def __init__(self, p_mean=-1.2, p_std=1.2): + self.p_mean = p_mean + self.p_std = p_std + + def __call__(self, n_samples, rand=None): + log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) + return log_sigma.exp() + + +class DiscreteSampling: + def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): + self.num_idx = num_idx + self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) + world_size = mpu.get_data_parallel_world_size() + self.uniform_sampling = uniform_sampling + if self.uniform_sampling: + i = 1 + while True: + if world_size % i != 0 or num_idx % (world_size // i) != 0: + i += 1 + else: + self.group_num = world_size // i + break + + assert self.group_num > 0 + assert world_size % self.group_num == 0 + self.group_width = world_size // self.group_num # the number of rank in one group + self.sigma_interval = self.num_idx // self.group_num + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None, return_idx=False): + if self.uniform_sampling: + rank = mpu.get_data_parallel_rank() + group_index = rank // self.group_width + idx = default( + rand, + torch.randint( + group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval, (n_samples,) + ), + ) + else: + idx = default( + rand, + torch.randint(0, self.num_idx, (n_samples,)), + ) + if return_idx: + return self.idx_to_sigma(idx), idx + else: + return self.idx_to_sigma(idx) + + +class PartialDiscreteSampling: + def __init__(self, discretization_config, total_num_idx, partial_num_idx, do_append_zero=False, flip=True): + self.total_num_idx = total_num_idx + self.partial_num_idx = partial_num_idx + self.sigmas = instantiate_from_config(discretization_config)( + total_num_idx, do_append_zero=do_append_zero, flip=flip + ) + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None): + idx = default( + rand, + # torch.randint(self.total_num_idx-self.partial_num_idx, self.total_num_idx, (n_samples,)), + torch.randint(0, self.partial_num_idx, (n_samples,)), + ) + return self.idx_to_sigma(idx) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000..dcd20e8e7f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py @@ -0,0 +1,328 @@ +""" +adopted from +https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +and +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +and +https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py + +thanks! +""" + +import math +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +def make_beta_schedule( + schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, +): + if schedule == "linear": + betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 + return betas.numpy() + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def mixed_checkpoint(func, inputs: dict, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function + borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that + it also works with non-tensor inputs + :param func: the function to evaluate. + :param inputs: the argument dictionary to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] + tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)] + non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)] + non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)] + args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) + return MixedCheckpointFunction.apply( + func, + len(tensor_inputs), + len(non_tensor_inputs), + tensor_keys, + non_tensor_keys, + *args, + ) + else: + return func(**inputs) + + +class MixedCheckpointFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + run_function, + length_tensors, + length_non_tensors, + tensor_keys, + non_tensor_keys, + *args, + ): + ctx.end_tensors = length_tensors + ctx.end_non_tensors = length_tensors + length_non_tensors + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors + + ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))} + ctx.input_non_tensors = { + key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])) + } + ctx.run_function = run_function + ctx.input_params = list(args[ctx.end_non_tensors :]) + + with torch.no_grad(): + output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} + ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors} + + with torch.enable_grad(), torch.npu.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors} + # shallow_copies.update(additional_args) + output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) + input_grads = torch.autograd.grad( + output_tensors, + list(ctx.input_tensors.values()) + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return ( + (None, None, None, None, None) + + input_grads[: ctx.end_tensors] + + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + + input_grads[ctx.end_tensors :] + ) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), torch.npu.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, dtype=torch.float32): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding.to(dtype) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images": + self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + if self.merge_strategy == "fixed": + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + else: + raise NotImplementedError + return alpha + + def forward( + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + x = alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + return x diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/wrappers.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/wrappers.py new file mode 100644 index 0000000000..d0b78ffd50 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/wrappers.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn +from packaging import version + +OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" + + +class IdentityWrapper(nn.Module): + def __init__(self, diffusion_model, compile_model: bool = False, dtype: torch.dtype = torch.float32): + super().__init__() + compile = ( + torch.compile + if (version.parse(torch.__version__) >= version.parse("2.0.0")) and compile_model + else lambda x: x + ) + self.diffusion_model = compile(diffusion_model) + self.dtype = dtype + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + + +class OpenAIWrapper(IdentityWrapper): + def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor: + for key in c: + c[key] = c[key].to(self.dtype) + + if x.dim() == 4: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) + elif x.dim() == 5: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2) + else: + raise ValueError("Input tensor must be 4D or 5D") + + return self.diffusion_model( + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, + ) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py new file mode 100644 index 0000000000..a85159d20f --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py @@ -0,0 +1,275 @@ +import math +from contextlib import nullcontext +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import kornia +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat +from omegaconf import ListConfig +from torch.utils.checkpoint import checkpoint +from transformers import ( + T5EncoderModel, + T5Tokenizer, +) + +from ...util import ( + append_dims, + autocast, + count_params, + default, + disabled_train, + expand_dims_like, + instantiate_from_config, +) + + +class AbstractEmbModel(nn.Module): + def __init__(self): + super().__init__() + self._is_trainable = None + self._ucg_rate = None + self._input_key = None + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def ucg_rate(self) -> Union[float, torch.Tensor]: + return self._ucg_rate + + @property + def input_key(self) -> str: + return self._input_key + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @ucg_rate.setter + def ucg_rate(self, value: Union[float, torch.Tensor]): + self._ucg_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @ucg_rate.deleter + def ucg_rate(self): + del self._ucg_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + +class GeneralConditioner(nn.Module): + OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} + KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} + + def __init__(self, emb_models: Union[List, ListConfig], cor_embs=[], cor_p=[]): + super().__init__() + embedders = [] + for n, embconfig in enumerate(emb_models): + embedder = instantiate_from_config(embconfig) + assert isinstance( + embedder, AbstractEmbModel + ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" + embedder.is_trainable = False + embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) + embedder.train = disabled_train + for param in embedder.parameters(): + param.requires_grad = False + embedder.eval() + + if "input_key" in embconfig: + embedder.input_key = embconfig["input_key"] + elif "input_keys" in embconfig: + embedder.input_keys = embconfig["input_keys"] + else: + raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") + + embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) + if embedder.legacy_ucg_val is not None: + embedder.ucg_prng = np.random.RandomState() + + embedders.append(embedder) + self.embedders = nn.ModuleList(embedders) + + if len(cor_embs) > 0: + assert len(cor_p) == 2 ** len(cor_embs) + self.cor_embs = cor_embs + self.cor_p = cor_p + + def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: + assert embedder.legacy_ucg_val is not None + p = embedder.ucg_rate + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if embedder.ucg_prng.choice(2, p=[1 - p, p]): + batch[embedder.input_key][i] = val + return batch + + def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict, cond_or_not) -> Dict: + assert embedder.legacy_ucg_val is not None + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if cond_or_not[i]: + batch[embedder.input_key][i] = val + return batch + + def get_single_embedding( + self, + embedder, + batch, + output, + cond_or_not: Optional[np.ndarray] = None, + force_zero_embeddings: Optional[List] = None, + ): + with torch.no_grad(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + if embedder.legacy_ucg_val is not None: + if cond_or_not is None: + batch = self.possibly_get_ucg_val(embedder, batch) + else: + batch = self.surely_get_ucg_val(embedder, batch, cond_or_not) + emb_out = embedder(batch[embedder.input_key]) + elif hasattr(embedder, "input_keys"): + emb_out = embedder(*[batch[k] for k in embedder.input_keys]) + assert isinstance( + emb_out, (torch.Tensor, list, tuple) + ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" + if not isinstance(emb_out, (list, tuple)): + emb_out = [emb_out] + for emb in emb_out: + out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: + if cond_or_not is None: + emb = ( + expand_dims_like( + torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), + emb, + ) + * emb + ) + else: + emb = ( + expand_dims_like( + torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device), + emb, + ) + * emb + ) + if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings: + emb = torch.zeros_like(emb) + if out_key in output: + output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key]) + else: + output[out_key] = emb + return output + + def forward(self, batch: Dict, force_zero_embeddings: Optional[List] = None) -> Dict: + output = dict() + if force_zero_embeddings is None: + force_zero_embeddings = [] + + if len(self.cor_embs) > 0: + batch_size = len(batch[list(batch.keys())[0]]) + rand_idx = np.random.choice(len(self.cor_p), size=(batch_size,), p=self.cor_p) + for emb_idx in self.cor_embs: + cond_or_not = rand_idx % 2 + rand_idx //= 2 + output = self.get_single_embedding( + self.embedders[emb_idx], + batch, + output=output, + cond_or_not=cond_or_not, + force_zero_embeddings=force_zero_embeddings, + ) + + for i, embedder in enumerate(self.embedders): + if i in self.cor_embs: + continue + output = self.get_single_embedding( + embedder, batch, output=output, force_zero_embeddings=force_zero_embeddings + ) + return output + + def get_unconditional_conditioning(self, batch_c, batch_uc=None, force_uc_zero_embeddings=None): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + ucg_rates = list() + for embedder in self.embedders: + ucg_rates.append(embedder.ucg_rate) + embedder.ucg_rate = 0.0 + cor_embs = self.cor_embs + cor_p = self.cor_p + self.cor_embs = [] + self.cor_p = [] + + c = self(batch_c) + uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) + + for embedder, rate in zip(self.embedders, ucg_rates): + embedder.ucg_rate = rate + self.cor_embs = cor_embs + self.cor_p = cor_p + + return c, uc + + +class FrozenT5Embedder(AbstractEmbModel): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, + model_dir="google/t5-v1_1-xxl", + device="npu", + max_length=77, + freeze=True, + cache_dir=None, + ): + super().__init__() + if model_dir != "google/t5-v1_1-xxl": + self.tokenizer = T5Tokenizer.from_pretrained(model_dir) + self.transformer = T5EncoderModel.from_pretrained(model_dir) + else: + self.tokenizer = T5Tokenizer.from_pretrained(model_dir, cache_dir=cache_dir) + self.transformer = T5EncoderModel.from_pretrained(model_dir, cache_dir=cache_dir) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + # @autocast + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("npu", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py new file mode 100644 index 0000000000..bb6c26b378 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py @@ -0,0 +1,285 @@ +import torch + +from ..modules.attention import * +from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding + + +class TimeMixSequential(nn.Sequential): + def forward(self, x, context=None, timesteps=None): + for layer in self: + x = layer(x, context, timesteps) + + return x + + +class VideoTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + timesteps=None, + ff_in=False, + inner_dim=None, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + switch_temporal_ca_to_sa=False, + ): + super().__init__() + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + assert int(n_heads * d_head) == inner_dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff) + + self.timesteps = timesteps + self.disable_self_attn = disable_self_attn + if self.disable_self_attn: + self.attn1 = CrossAttention( + query_dim=inner_dim, + heads=n_heads, + dim_head=d_head, + context_dim=context_dim, + dropout=dropout, + ) # is a cross-attention + else: + self.attn1 = CrossAttention( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + self.norm2 = nn.LayerNorm(inner_dim) + if switch_temporal_ca_to_sa: + self.attn2 = CrossAttention( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + else: + self.attn2 = CrossAttention( + query_dim=inner_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + + self.norm1 = nn.LayerNorm(inner_dim) + self.norm3 = nn.LayerNorm(inner_dim) + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa + + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward(self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None) -> torch.Tensor: + if self.checkpoint: + return checkpoint(self._forward, x, context, timesteps) + else: + return self._forward(x, context, timesteps=timesteps) + + def _forward(self, x, context=None, timesteps=None): + assert self.timesteps or timesteps + assert not (self.timesteps and timesteps) or self.timesteps == timesteps + timesteps = self.timesteps or timesteps + B, S, C = x.shape + x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) + + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + + if self.disable_self_attn: + x = self.attn1(self.norm1(x), context=context) + x + else: + x = self.attn1(self.norm1(x)) + x + + if self.attn2 is not None: + if self.switch_temporal_ca_to_sa: + x = self.attn2(self.norm2(x)) + x + else: + x = self.attn2(self.norm2(x), context=context) + x + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + + x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps) + return x + + def get_last_layer(self): + return self.ff.net[-1].weight + + +str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + dtype="fp32", + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + attn_type=attn_mode, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + VideoTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + attn_mode=attn_mode, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + + self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy) + self.dtype = str_to_dtype[dtype] + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert context.ndim == 3, f"n dims of spatial context should be 3 but are {context.ndim}" + + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding( + num_frames, + self.in_channels, + repeat_only=False, + max_period=self.max_time_embed_period, + dtype=self.dtype, + ) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)): + x = block( + x, + context=spatial_context, + ) + + x_mix = x + x_mix = x_mix + emb + + x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) + x = self.time_mixer( + x_spatial=x, + x_temporal=x_mix, + image_only_indicator=image_only_indicator, + ) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py new file mode 100644 index 0000000000..b549bcd621 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py @@ -0,0 +1,318 @@ +import functools +import importlib +import os +from functools import partial +from inspect import isfunction + +import fsspec +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file as load_safetensors + + +class SafeConv3d(torch.nn.Conv3d): + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + if memory_count > 2: + # print(f"WARNING: Conv3d with {memory_count:.2f}GB") + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super(SafeConv3d, self).forward(input) + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def get_string_from_tuple(s): + try: + # Check if the string starts and ends with parentheses + if s[0] == "(" and s[-1] == ")": + # Convert the string to a tuple + t = eval(s) + # Check if the type of t is tuple + if type(t) == tuple: + return t[0] + else: + pass + except: + pass + return s + + +def is_power_of_two(n): + """ + chat.openai.com/chat + Return True if n is a power of 2, otherwise return False. + + The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. + The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. + If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. + Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. + + """ + if n <= 0: + return False + return (n & (n - 1)) == 0 + + +def autocast(f, enabled=True): + def do_autocast(*args, **kwargs): + with torch.npu.amp.autocast( + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def load_partial_from_config(config): + return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + if isinstance(xc[bi], list): + text_seq = xc[bi][0] + else: + text_seq = xc[bi] + lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def isheatmap(x): + if not isinstance(x, torch.Tensor): + return False + + return x.ndim == 2 + + +def isneighbors(x): + if not isinstance(x, torch.Tensor): + return False + return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) + + +def exists(x): + return x is not None + + +def expand_dims_like(x, y): + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config, **extra_kwargs): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict()), **extra_kwargs) + + +def get_obj_from_str(string, reload=False, invalidate_cache=True): + module, cls = string.rsplit(".", 1) + if invalidate_cache: + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def load_model_from_config(config, ckpt, verbose=True, freeze=True): + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + model = instantiate_from_config(config.model) + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + + model.eval() + return model + + +def get_configs_path() -> str: + """ + Get the `configs` directory. + For a working copy, this is the one in the root of the repository, + but for an installed copy, it's in the `sgm` package (see pyproject.toml). + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "configs"), + os.path.join(this_dir, "..", "configs"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM configs in {candidates}") + + +def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): + """ + Will return the result of a recursive get attribute call. + E.g.: + a.b.c + = getattr(getattr(a, "b"), "c") + = get_nested_attribute(a, "b.c") + If any part of the attribute call is an integer x with current obj a, will + try to call a[x] instead of a.x first. + """ + attributes = attribute_path.split(".") + if depth is not None and depth > 0: + attributes = attributes[:depth] + assert len(attributes) > 0, "At least one attribute should be selected" + current_attribute = obj + current_key = None + for level, attribute in enumerate(attributes): + current_key = ".".join(attributes[: level + 1]) + try: + id_ = int(attribute) + current_attribute = current_attribute[id_] + except ValueError: + current_attribute = getattr(current_attribute, attribute) + + return (current_attribute, current_key) if return_key else current_attribute + + +from math import sqrt + + +class SeededNoise: + def __init__(self, seeds, weights): + self.seeds = seeds + self.weights = weights + weight_square_sum = 0 + for weight in weights: + weight_square_sum += weight**2 + self.weight_square_sum_sqrt = sqrt(weight_square_sum) + self.cnt = 0 + + def __call__(self, x): + self.cnt += 1 + randn_combined = torch.zeros_like(x) + for seed, weight in zip(self.seeds, self.weights): + randn = np.random.RandomState(seed + self.cnt).randn(*x.shape) + randn = torch.from_numpy(randn, dtype=x.dtype, device=x.device) + randn_combined += randn * weight + randn_combined /= self.weight_square_sum_sqrt + return randn_combined diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/cache_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/cache_mgr.py new file mode 100644 index 0000000000..0cc69f00c3 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/cache_mgr.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, asdict +from typing import Union + +import torch +import torch.nn as nn + + +class CacheConfig(): + def __init__(self, method=None): + self.method = method + + +class CacheAgentConfig(CacheConfig): + """ + The DitCache Config. + """ + def __init__(self, policy_dir: str): + """ + Args: + policy_dir: The file containing the policy. + """ + super().__init__(method="CacheAgent") + self.policy_dir = policy_dir + + +class DitCacheConfig(CacheConfig): + """ + The DitCache Config. + """ + def __init__(self, step_start: int, step_interval: int, block_start: int, num_blocks: int): + """ + Args: + step_start: The starting step for caching. + step_interval: The interval at which caching should occur. + block_start: The starting block index for caching. + num_blocks: The number of blocks to cache. + """ + super().__init__(method="DitCache") + self.step_start = step_start + self.step_interval = step_interval + self.block_start = block_start + self.num_blocks = num_blocks + + +class CacheManager: + """ + The CacheManager class is interface to manage the cache algorithm. + """ + def __init__( + self, + config:CacheConfig + ): + """ + Args: + config: The configuration for the cache algorithm. + """ + if isinstance(config, CacheConfig): + self.method = config.method + self.cache_cls = Cache_cls[self.method](**vars(config)) + + def __call__(self, block, time_step, block_idx, hidden_states, *args, **kwargs + ): + """ + Args: + block: The block in the DiT module. + time_step: The current time step. + block_idx: The index of the block. + hidden_states: The hidden states. + *args: Additional arguments. + **kwargs: Additional keyword arguments. + """ + if not self._use_cache(time_step, block_idx): + old_hidden_states = hidden_states + if isinstance(block, list): + for blk in block: + hidden_states = blk(hidden_states, *args, **kwargs) + else: + hidden_states = block(hidden_states, *args, **kwargs) + self._update_cache(hidden_states, old_hidden_states, time_step, block_idx) + else: + hidden_states += self._get_cache(time_step, block_idx) + return hidden_states + + def _use_cache(self, time_step, block_idx): + return self.cache_cls.use_cache(time_step, block_idx) + + def _get_cache(self, time_step, block_idx): + return self.cache_cls.get_cache(time_step, block_idx) + + def _update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): + self.cache_cls.update_cache(hidden_states, old_hidden_states, time_step, block_idx) + + +class CacheAgent(): + def __init__(self, policy_dir, **kwargs): + self.policy_dir = policy_dir + self.cache = [None for _ in range(32)] + + self.policy = torch.load(policy_dir) + + def use_cache(self, time_step, block_idx): + if time_step == 0: + return False + else: + return self.policy[time_step - 1, block_idx] + + def get_cache(self, time_step, block_idx): + return self.cache[block_idx] + + def update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): + delta = hidden_states - old_hidden_states + self.cache[block_idx] = delta + + +class DitCache: + def __init__(self, step_start, step_interval, block_start, num_blocks, **kwargs): + self.step_start = step_start + self.step_interval = step_interval + self.block_start = block_start + self.num_blocks = num_blocks + self.block_end = block_start + num_blocks - 1 + self.cache = None + self.time_cache = {} + + def use_cache(self, time_step, block_idx): + if time_step < self.step_start: + return False + else: + diftime = time_step - self.step_start + if diftime not in self.time_cache: + self.time_cache[diftime] = diftime % self.step_interval == 0 + if self.time_cache[diftime]: + return False + elif block_idx < self.block_start or block_idx > self.block_end: + return False + else: + return True + + def get_cache(self, time_step, block_idx): + if block_idx == self.block_start: + return self.cache + else: + return 0 + + def update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): + diftime = time_step - self.step_start + # when (time_step - self.step_start) % self.step_interval == 0: + if time_step >= self.step_start and self.time_cache[diftime]: + if block_idx == self.block_start: + self.cache = old_hidden_states + elif block_idx == self.block_end: + self.cache = hidden_states - self.cache + + +Cache_cls = { + "CacheAgent" : CacheAgent, + "DitCache" : DitCache +} diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py new file mode 100644 index 0000000000..33f79ad1e5 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist + +from .parallel_mgr import get_sequence_parallel_world_size, get_sp_group + + +def _all_to_all_func(input_, world_size, process_group, scatter_dim=2, gather_dim=1): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=process_group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def split_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + if world_size == 1: + return input_ + + if pad > 0: + pad_size = list(input_.shape) + pad_size[dim] = pad + input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim) + + dim_size = input_.size(dim) + if dim_size % world_size != 0: + raise ValueError( + f"The th{dim} dimensions of input_:{input_.size()} is not divisible by world_size:{world_size}.") + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + output = tensor_list[rank].contiguous() + return output + + +def gather_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # all gather + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + # concat + output = torch.cat(tensor_list, dim=dim) + + if pad > 0: + output = output.narrow(dim, 0, output.size(dim) - pad) + + return output + + +# ====== +# Pad +# ====== + +SPTIAL_PAD = 0 +TEMPORAL_PAD = 0 + + +def set_spatial_pad(dim_size: int): + sp_size = get_sequence_parallel_world_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global SPTIAL_PAD + SPTIAL_PAD = pad + + +def get_spatial_pad() -> int: + return SPTIAL_PAD + + +def set_temporal_pad(dim_size: int): + sp_size = get_sequence_parallel_world_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global TEMPORAL_PAD + TEMPORAL_PAD = pad + + +def get_temporal_pad() -> int: + return TEMPORAL_PAD + + +def all_to_all_with_pad( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + **kwargs +): + scatter_dim = kwargs.get("scatter_dim", 2) + gather_dim = kwargs.get("gather_dim", 1) + scatter_pad = kwargs.get("scatter_pad", 0) + gather_pad = kwargs.get("gather_pad", 0) + + if scatter_pad > 0: + pad_shape = list(input_.shape) + pad_shape[scatter_dim] = scatter_pad + pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) + input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) + + world_size = dist.get_world_size(process_group) + if input_.shape[scatter_dim] % world_size != 0: + raise ValueError( + f"The scatter_dim:{scatter_dim} of input_:{input_.shape} is not divisible by world_size:{world_size}.") + + input_ = _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) + + if gather_pad > 0: + input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) + + return input_ + + +def all_to_all( + tensor: torch.Tensor, + world_size: int, + scatter_dim: int, + gather_dim: int, + process_group: dist.ProcessGroup = None, +): + if process_group is None: + process_group = dist.group.WORLD + return _all_to_all_func(tensor, world_size, process_group, scatter_dim, gather_dim) + + +def all_to_all_sbh( + input_: torch.Tensor, + scatter_dim: int = 1, + gather_dim: int = 0, + async_op: bool = False, +): + return single_all_to_all(input_, scatter_dim, gather_dim, async_op) + + +def single_all_to_all( + input_: torch.Tensor, + scatter_dim: int, + gather_dim: int, + async_op: bool = False, +): + sp_size = get_sequence_parallel_world_size() + inp_shape = list(input_.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size + if scatter_dim < 1: + input_t = input_.reshape( + [sp_size, inp_shape[scatter_dim]] + \ + inp_shape[scatter_dim + 1:] + ) + else: + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + input_t = input_.reshape( + [-1, sp_size, inp_shape[scatter_dim]] + \ + inp_shape[scatter_dim + 1:] + ).transpose(0, 1).contiguous() + + output = torch.empty_like(input_t) + + req = dist.all_to_all_single(output, input_t, group=get_sp_group().device_group, async_op=async_op) + # if scattering the seq-dim, transpose the heads back to the original dimension + if scatter_dim < 1: + output = output.transpose(0, 1).contiguous() + + return req, output.reshape( + inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:]) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py new file mode 100644 index 0000000000..0adaccdd1d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py @@ -0,0 +1,631 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch_npu +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, ( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "npu:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and npu graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if torch.npu.is_available(): + self.device = torch.device(f"npu:{local_rank}") + else: + self.device = torch.device("cpu") + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank." + ) + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert ( + src != self.rank + ), "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + self.ulysses_world_size = self.ring_world_size = 1 + self.ulysses_rank = self.ring_rank = 0 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py new file mode 100644 index 0000000000..f86c2ae45a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import List, Optional +from dataclasses import dataclass +import torch.distributed as dist +import torch_npu +from .utils import RankGenerator, generate_masked_orthogonal_rank_groups +from .group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + +_WORLD: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None +_VAE_PARALLEL_GROUP = None +_VAE_PARALLEL_SIZE = None + + +def is_vae_parallel_initialized(): + global _VAE_PARALLEL_GROUP + if _VAE_PARALLEL_GROUP is None: + return False + else: + return True + +def initialize_vae_parallel(): + global _VAE_PARALLEL_GROUP + global _VAE_PARALLEL_SIZE + + assert _VAE_PARALLEL_GROUP is None, "vae parallel group is already initialized" + _VAE_PARALLEL_SIZE = dist.get_world_size() + + ranks = range(_VAE_PARALLEL_SIZE) + group = dist.new_group(ranks) + _VAE_PARALLEL_GROUP = group + +def get_vae_parallel_group(): + assert _VAE_PARALLEL_GROUP is not None, "vae parallel group is not initialized" + + return _VAE_PARALLEL_GROUP + +def get_vae_parallel_world_size(): + if not is_vae_parallel_initialized(): + return 1 + return _VAE_PARALLEL_SIZE + +def get_vae_parallel_rank(): + if not is_vae_parallel_initialized(): + return 0 + rank = dist.get_rank() + vae_rank = rank % _VAE_PARALLEL_SIZE + return vae_rank + +def get_vae_parallel_group_rank(): + if not is_vae_parallel_initialized(): + return 0 + rank = dist.get_rank() + vae_group_rank = rank // _VAE_PARALLEL_SIZE + return vae_group_rank + + +@dataclass +class ParallelConfig: + sp_degree: int = 1 + use_cfg_parallel: bool = False + world_size: int = 1 + + def __post_init__(self): + if self.use_cfg_parallel: + self.cfg_degree = 2 + else: + self.cfg_degree = 1 + assert self.sp_degree * self.cfg_degree <= self.world_size, ( + "sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + assert ( + self.world_size % (self.sp_degree * self.cfg_degree) == 0 + ), "world_size must be divisible by sp_degree * cfg_degree" + + +# * QUERY +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + assert _SP is not None, "pipeline model parallel group is not initialized" + return _SP + +def get_sequence_parallel_state(): + """Return state for the sequence parallel group.""" + return _SP is not None + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 1 + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 0 + return get_sp_group().rank_in_group + +# CFG +def get_cfg_group() -> GroupCoordinator: + assert ( + _CFG is not None + ), "classifier_free_guidance parallel group is not initialized" + return _CFG + +def get_cfg_state(): + """Return state for the sequence parallel group.""" + return _CFG is not None + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 1 + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 0 + return get_cfg_group().rank_in_group + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not dist.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + dist.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = int(os.getenv('RANK', 0)) + torch_npu.npu.set_device(local_rank) + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(dist.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert ( + _WORLD.world_size == dist.get_world_size() + ), "world group already initialized with a different world size" + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _CFG is not None + and _SP is not None + ) + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + assert parallel_mode in [ + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. + backend: distributed backend of pytorch collective comm. + """ + # Get world size and rank. Ensure some consistencies. + assert dist.is_initialized() + world_size: int = dist.get_world_size() + backend = backend or dist.get_backend(get_world_group().device_group) + + if ( + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree + ): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"sequence_parallel_degree ({sequence_parallel_degree}) x" + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x" + ) + + rank_generator: RankGenerator = RankGenerator( + sequence_parallel_degree, + classifier_free_guidance_degree, + "sp-cfg", + ) + + global _CFG + assert _CFG is None, "classifier_free_guidance group is already initialized" + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + + global _SP + assert _SP is None, "sequence parallel group is already initialized" + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ) + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if dist.is_initialized(): + dist.destroy_process_group() + +def init_parallel_env(parallel_config: ParallelConfig): + if not model_parallel_is_initialized(): + logger.warning("Model parallel is not initialized, initializing...") + if not dist.is_initialized(): + init_distributed_environment() + initialize_model_parallel( + classifier_free_guidance_degree=parallel_config.cfg_degree, + sequence_parallel_degree=parallel_config.sp_degree, + ) + initialize_vae_parallel() + +def finalize_parallel_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py new file mode 100644 index 0000000000..395fe93120 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from functools import reduce + +import torch + + +class AdaStep: + """ + The Adastep class is designed to optimize the sampling process in a diffusion model, + """ + def __init__(self, skip_thr=0.015, max_skip_steps=4, decay_ratio=0.99, + device="npu", forward_value=None, step_value=None): + """ + Args: + skip_thr (float): The threshold for determining whether to skip a step based on the change in latent variables. + Recommended values are between 0.01 and 0.015. Default is 0.015. + max_skip_steps (int): The maximum number of consecutive steps that can be skipped. + Recommended values are between 3 and 4. Default is 4. + decay_ratio (float): The decay ratio for the skip threshold, which is used to dynamically adjust + the threshold over time. Recommended values are between 0.95 and 0.99. Default is 0.99. + device (str): The device on which the computations will be performed. Default is "npu". + """ + + # recommand 0.01(skip around 35 steps) ~ 0.015(skip around 50 steps) + self.skip_thr = skip_thr + # recommand 3 ~ 4 + self.max_skip_steps = max_skip_steps + # recommand 0.95 ~ 0.99 + self.decay_ratio = decay_ratio + self.device = device + self.if_skip = self.max_skip_steps > 0 + self.reset_status() + + self.forwardretype = forward_value + self.stepretype = step_value + + def __call__(self, transformer, *model_args, **model_kwargs): + """ + Args: + transformer (Module): The Module that works as the DiT and returns the noise prediction. + model_args (tuple): The arguments to be passed to the transformer Module forword. + model_kwargs (dict): The keyword arguments to be passed to the transformer Module forword. + Returns: + The noise prediction from the transformer function. + """ + if self.if_skip and torch.all(self.skip_vote): + return self._return_output(self.skip_noise_pred, self.forwardretype) + + noise_pred = transformer(*model_args, **model_kwargs) + if not self.forwardretype: + if isinstance(noise_pred, tuple): + self.forwardretype = tuple + elif isinstance(noise_pred, torch.Tensor): + self.forwardretype = torch.Tensor + else: + raise(ValueError, "transformer need return a tuple which the first element is the result" + "or a Tensor. Otherwith please input the `forward_value`") + self.skip_noise_pred = self._get_input(noise_pred, self.forwardretype) + return noise_pred + + @staticmethod + def _get_input(input_value, inp_type): + if isinstance(inp_type, type): + if inp_type is tuple: + return input_value[0] + else: + return input_value + else: + return input_value[inp_type] + + @staticmethod + def _return_output(output, outptype): + if isinstance(outptype, type): + if outptype is tuple: + return (output,) + else: + return output + elif isinstance(outptype, str): + return {outptype : output} + else: + return (0,) * outptype + (output,) + + def set_param(self, skip_thr=None, max_skip_steps=None, decay_ratio=None, device=None): + """ + To set the parameters of the AdaStep class. + """ + self.skip_thr = skip_thr or self.skip_thr + self.max_skip_steps = max_skip_steps or self.max_skip_steps + self.decay_ratio = decay_ratio or self.decay_ratio + if device : + self.device = device + self.skip_vote.to(self.device) + self.if_skip = self.max_skip_steps > 0 + + def reset_status(self): + """ + Reset the status of the AdaStep class. + """ + self.skip_mask = [False] + self.skip_latents_diff = [] + self.skip_noise_pred = None + self.skip_prev_latents = 0 + self.skip_vote = torch.tensor([False], dtype=torch.bool, device=self.device) + + def update_strategy(self, latents, sequence_parallel=False, sp_group=None): + """ + Update the strategy for skipping steps based on the change in latents. + """ + if not self.stepretype: + if isinstance(latents, tuple): + self.stepretype = tuple + elif isinstance(latents, torch.Tensor): + self.stepretype = torch.Tensor + else: + raise(ValueError, "step need return a tuple which the first element is the result" + "or a Tensor. Otherwith please input the `step_value`") + if self.if_skip: + latents = self._get_input(latents, self.stepretype) + diff = latents - self.skip_prev_latents + self.skip_latents_diff.append(diff.abs().mean()) + if len(self.skip_latents_diff) >= 3: + self.skip_mask.append(self._estimate()) + + self.skip_prev_latents = latents + + mask_value = self.skip_mask[-1] + mask_value = torch.tensor([mask_value], dtype=torch.bool, device=self.device) + if sequence_parallel: + skip_vote = torch.zeros(torch.distributed.get_world_size(sp_group), + dtype=torch.bool, device=self.device) + torch.distributed.all_gather_into_tensor(skip_vote, mask_value, group=sp_group) + else: + skip_vote = mask_value + self.skip_vote = skip_vote + + def _estimate(self): + # `self.skip_latents_diff[-1]` refers to the most recent difference in latent variables. + cur_diff = self.skip_latents_diff[-1] + # `self.skip_latents_diff[-2]` refers to the second most recent difference in latent variables. + prev_diff = self.skip_latents_diff[-2] + # `self.skip_latents_diff[-3]` refers to the third most recent difference in latent variables. + prev_prev_diff = self.skip_latents_diff[-3] + + self.skip_thr = self.skip_thr * self.decay_ratio + + if len(self.skip_mask) >= self.max_skip_steps and \ + all(self.skip_mask[-self.max_skip_steps:]): + return False + + if abs((cur_diff + prev_prev_diff) / 2 - prev_diff) <= prev_diff * self.skip_thr: + return True + return False + + +class SamplingOptm: + def __init__(self, pipe, dit_forward="transformer.forward", scheduler_step="scheduler.step", + forward_value=None, step_value=None, parallel=False, group=None, config=None): + self.parallel = parallel + self.group = group + self.skip = False + if config and config["method"] == "AdaStep": + self.skip = True + ditforward_lst = dit_forward.split(".") + schedulerstep_lst = scheduler_step.split(".") + self.pipe = pipe + + self.ori_forward = reduce(getattr, ditforward_lst, self.pipe) # getattr(self.pipe, )dit_forward.split(".") + self.forward = ditforward_lst.pop() + self.ori_dit = reduce(getattr, ditforward_lst, self.pipe) + + self.ori_step = reduce(getattr, schedulerstep_lst, self.pipe) + self.step = schedulerstep_lst.pop() + self.ori_scheduler = reduce(getattr, schedulerstep_lst, self.pipe) + + shik_thr = config.get("skip_thr", 0.015) + max_skip_steps = config.get("max_skip_steps", 4) + decay_ratio = config.get("decay_ratio", 0.99) + self.skip_strategy = AdaStep(skip_thr=shik_thr, max_skip_steps=max_skip_steps, decay_ratio=decay_ratio) + + def __enter__(self): + if self.skip: + self._sub_forward() + self._sub_step() + + def __exit__(self, t, v, trace): + if self.skip: + self._revert_forward() + self._revert_step() + + def _sub_forward(self): + @functools.wraps(self.ori_forward) + def _optm_forward(*args, **kwargs): + noise_pred = self.skip_strategy(self.ori_forward, *args, **kwargs) + return noise_pred + setattr(self.ori_dit, self.forward, _optm_forward) + + def _sub_step(self): + @functools.wraps(self.ori_step) + def _optm_step(*args, **kwargs): + latents = self.ori_step(*args, **kwargs) + self.skip_strategy.update_strategy(latents, self.parallel, self.group) + return latents + setattr(self.ori_scheduler, self.step, _optm_step) + + def _revert_forward(self): + setattr(self.ori_dit, self.forward, self.ori_forward) + + def _revert_step(self): + setattr(self.ori_scheduler, self.step, self.ori_step) + diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py new file mode 100644 index 0000000000..ca1cae0436 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py @@ -0,0 +1,147 @@ +from typing import List + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.sp = sp + self.cfg = cfg + self.rank_offset = rank_offset + self.world_size = sp * cfg + + self.name_to_size = { + "sp": self.sp, + "cfg": self.cfg, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py new file mode 100644 index 0000000000..b9a69ff860 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py @@ -0,0 +1,435 @@ +import logging +import math +import re +import random +from abc import abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.distributed as dist +import torch.nn as nn +from einops import rearrange +from packaging import version + +from vae_modules.ema import LitEma + +from utils.parallel_mgr import ( + is_vae_parallel_initialized, + initialize_vae_parallel, + get_vae_parallel_group, + get_vae_parallel_group_rank, +) + + +from sgm.util import ( + instantiate_from_config, + get_obj_from_str, + default, +) +from vae_modules.cp_enc_dec import _conv_split, _conv_gather + +logpy = logging.getLogger(__name__) + + +class AbstractAutoencoder(pl.LightningModule): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None + if monitor is not None: + self.monitor = monitor + + if self.use_ema: + self.model_ema = LitEma(self, decay=ema_decay) + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if version.parse(torch.__version__) >= version.parse("2.0.0"): + self.automatic_optimization = False + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + self.init_from_ckpt(ckpt) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) + print("Missing keys: ", missing_keys) + print("Unexpected keys: ", unexpected_keys) + print(f"Restored from {path}") + + @abstractmethod + def get_input(self, batch) -> Any: + raise NotImplementedError() + + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + logpy.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + logpy.info(f"{context}: Restored training weights") + + @abstractmethod + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") + + @abstractmethod + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict())) + + def configure_optimizers(self) -> Any: + raise NotImplementedError() + + +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ + + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + regularizer_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + trainable_ae_params: Optional[List[List[str]]] = None, + ae_optimizer_args: Optional[List[dict]] = None, + trainable_disc_params: Optional[List[List[str]]] = None, + disc_optimizer_args: Optional[List[dict]] = None, + disc_start_iter: int = 0, + diff_boost_factor: float = 3.0, + ckpt_engine: Union[None, str, dict] = None, + ckpt_path: Optional[str] = None, + additional_decode_keys: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.automatic_optimization = False # pytorch lightning + self.encoder = instantiate_from_config(encoder_config) + self.decoder = instantiate_from_config(decoder_config) + self.regularization = instantiate_from_config(regularizer_config) + self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"}) + self.diff_boost_factor = diff_boost_factor + self.disc_start_iter = disc_start_iter + + if ckpt_path is not None: + assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" + logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + self.additional_decode_keys = set(default(additional_decode_keys, [])) + + def get_input(self, batch: Dict) -> torch.Tensor: + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in channels-first + # format (e.g., bchw instead if bhwc) + return batch[self.input_key] + + def get_autoencoder_params(self) -> list: + params = [] + params = params + list(self.encoder.parameters()) + params = params + list(self.decoder.parameters()) + return params + + def get_discriminator_params(self) -> list: + params = [] + return params + + def get_last_layer(self): + return self.decoder.get_last_layer() + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) + return x + + def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log + + def get_param_groups( + self, parameter_names: List[List[str]], optimizer_args: List[dict] + ) -> Tuple[List[Dict[str, Any]], int]: + groups = [] + num_params = 0 + for names, args in zip(parameter_names, optimizer_args): + params = [] + for pattern_ in names: + pattern_params = [] + pattern = re.compile(pattern_) + for p_name, param in self.named_parameters(): + if re.match(pattern, p_name): + pattern_params.append(param) + num_params += param.numel() + if len(pattern_params) == 0: + logpy.warn(f"Did not find parameters for pattern {pattern_}") + params.extend(pattern_params) + groups.append({"params": params, **args}) + return groups, num_params + + +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + ckpt_path = kwargs.pop("ckpt_path", None) + ckpt_engine = kwargs.pop("ckpt_engine", None) + super().__init__( + encoder_config={ + "target": "sgm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "sgm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params + + def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) + + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class IdentityFirstStage(AbstractAutoencoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_input(self, x: Any) -> Any: + return x + + def encode(self, x: Any, *args, **kwargs) -> Any: + return x + + def decode(self, x: Any, *args, **kwargs) -> Any: + return x + + +class VideoAutoencodingEngine(AutoencodingEngine): + def __init__( + self, + ckpt_path: Union[None, str] = None, + ignore_keys: Union[Tuple, list] = (), + image_video_weights=[1, 1], + only_train_decoder=False, + context_parallel_size=0, + **kwargs, + ): + super().__init__(**kwargs) + self.context_parallel_size = context_parallel_size + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict: + return self.log_images(batch, additional_log_kwargs, **kwargs) + + def get_input(self, batch: dict) -> torch.Tensor: + if self.context_parallel_size > 0: + if not is_vae_parallel_initialized(): + initialize_vae_parallel(self.context_parallel_size) + + batch = batch[self.input_key] + + global_src_rank = get_vae_parallel_group_rank() * self.context_parallel_size + torch.distributed.broadcast(batch, src=global_src_rank, group=get_vae_parallel_group()) + + batch = _conv_split(batch, dim=2, kernel_size=1) + return batch + + return batch[self.input_key] + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + self.init_from_ckpt(ckpt) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False) + print("Missing keys: ", missing_keys) + print("Unexpected keys: ", unexpected_keys) + print(f"Restored from {path}") + + +class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): + def __init__( + self, + cp_size=0, + *args, + **kwargs, + ): + self.cp_size = cp_size + return super().__init__(*args, **kwargs) + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + input_cp: bool = False, + output_cp: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.cp_size > 0 and not input_cp: + if not is_vae_parallel_initialized(): + initialize_vae_parallel(self.cp_size) + + global_src_rank = get_vae_parallel_group_rank() * self.cp_size + torch.distributed.broadcast(x, src=global_src_rank, group=get_vae_parallel_group()) + + x = _conv_split(x, dim=2, kernel_size=1) + + if return_reg_log: + z, reg_log = super().encode(x, return_reg_log, unregularized) + else: + z = super().encode(x, return_reg_log, unregularized) + + if self.cp_size > 0 and not output_cp: + z = _conv_gather(z, dim=2, kernel_size=1) + + if return_reg_log: + return z, reg_log + return z + + def decode( + self, + z: torch.Tensor, + input_cp: bool = False, + output_cp: bool = False, + split_kernel_size: int = 1, + **kwargs, + ): + if self.cp_size > 0 and not input_cp: + if not is_vae_parallel_initialized(): + initialize_vae_parallel(self.cp_size) + + global_src_rank = get_vae_parallel_group_rank() * self.cp_size + torch.distributed.broadcast(z, src=global_src_rank, group=get_vae_parallel_group()) + + z = _conv_split(z, dim=2, kernel_size=split_kernel_size) + + x = super().decode(z, **kwargs) + + if self.cp_size > 0 and not output_cp: + x = _conv_gather(x, dim=2, kernel_size=split_kernel_size) + + return x + + def forward( + self, + x: torch.Tensor, + input_cp: bool = False, + latent_cp: bool = False, + output_cp: bool = False, + **additional_decode_kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp) + dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs) + return z, dec, reg_log diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py new file mode 100644 index 0000000000..32b926a1c7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py @@ -0,0 +1,983 @@ +import math +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from beartype import beartype +from beartype.typing import Union, Tuple, Optional, List +from einops import rearrange + +from utils.parallel_mgr import ( + get_vae_parallel_group, + get_vae_parallel_rank, + get_vae_parallel_world_size, + get_vae_parallel_group_rank, +) + +from vae_modules.utils import SafeConv3d as Conv3d + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) else ((t,) * length) + + +def divisible_by(num, den): + return (num % den) == 0 + + +def is_odd(n): + return not divisible_by(n, 2) + + +def exists(v): + return v is not None + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def leaky_relu(p=0.1): + return nn.LeakyReLU(p) + + +def _split(input_, dim): + cp_world_size = get_vae_parallel_world_size() + + if cp_world_size == 1: + return input_ + + cp_rank = get_vae_parallel_rank() + + # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) + + inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() + input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() + dim_size = input_.size()[dim] // cp_world_size + + input_list = torch.split(input_, dim_size, dim=dim) + output = input_list[cp_rank] + + if cp_rank == 0: + output = torch.cat([inpu_first_frame_, output], dim=dim) + output = output.contiguous() + + # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) + + return output + + +def _gather(input_, dim): + cp_world_size = get_vae_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + group = get_vae_parallel_group() + cp_rank = get_vae_parallel_rank() + + # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) + + input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() + if cp_rank == 0: + input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() + + tensor_list = [torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] + + if cp_rank == 0: + input_ = torch.cat([input_first_frame_, input_], dim=dim) + + tensor_list[cp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + output = torch.cat(tensor_list, dim=dim).contiguous() + + # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape) + + return output + + +def _conv_split(input_, dim, kernel_size): + cp_world_size = get_vae_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) + + cp_rank = get_vae_parallel_rank() + + dim_size = (input_.size()[dim] - kernel_size) // cp_world_size + + if cp_rank == 0: + output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) + else: + # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) + output = input_.transpose(dim, 0)[ + cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size + ].transpose(dim, 0) + output = output.contiguous() + + # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) + + return output + + +def _conv_gather(input_, dim, kernel_size): + cp_world_size = get_vae_parallel_world_size() + + # Bypass the function if context parallel is 1 + if cp_world_size == 1: + return input_ + + group = get_vae_parallel_group() + cp_rank = get_vae_parallel_rank() + + # print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape) + + input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(0, dim).contiguous() + if cp_rank == 0: + input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() + else: + input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous() + + tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ + torch.empty_like(input_) for _ in range(cp_world_size - 1) + ] + if cp_rank == 0: + input_ = torch.cat([input_first_kernel_, input_], dim=dim) + + tensor_list[cp_rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + # print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape) + + return output + + +def _pass_from_previous_rank(input_, dim, kernel_size): + # Bypass the function if kernel size is 1 + if kernel_size == 1: + return input_ + + group = get_vae_parallel_group() + cp_rank = get_vae_parallel_rank() + cp_group_rank = get_vae_parallel_group_rank() + cp_world_size = get_vae_parallel_world_size() + + # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + global_rank = torch.distributed.get_rank() + global_world_size = torch.distributed.get_world_size() + + input_ = input_.transpose(0, dim) + + # pass from last rank + send_rank = global_rank + 1 + recv_rank = global_rank - 1 + if send_rank % cp_world_size == 0: + send_rank -= cp_world_size + if recv_rank % cp_world_size == cp_world_size - 1: + recv_rank += cp_world_size + + if cp_rank < cp_world_size - 1: + req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) + if cp_rank > 0: + recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() + req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + + if cp_rank == 0: + input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) + else: + req_recv.wait() + input_ = torch.cat([recv_buffer, input_], dim=0) + + input_ = input_.transpose(0, dim).contiguous() + + # print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + return input_ + + +def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=None): + # Bypass the function if kernel size is 1 + if kernel_size == 1: + return input_ + + group = get_vae_parallel_group() + cp_rank = get_vae_parallel_rank() + cp_group_rank = get_vae_parallel_group_rank() + cp_world_size = get_vae_parallel_world_size() + + # print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape) + + global_rank = torch.distributed.get_rank() + global_world_size = torch.distributed.get_world_size() + + input_ = input_.transpose(0, dim) + + # pass from last rank + send_rank = global_rank + 1 + recv_rank = global_rank - 1 + if send_rank % cp_world_size == 0: + send_rank -= cp_world_size + if recv_rank % cp_world_size == cp_world_size - 1: + recv_rank += cp_world_size + + # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() + # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group) + # req_recv.wait() + recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() + if cp_rank < cp_world_size - 1: + req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) + if cp_rank > 0: + req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) + # req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) + + if cp_rank == 0: + if cache_padding is not None: + input_ = torch.cat([cache_padding.transpose(0, dim).to(input_.device), input_], dim=0) + else: + input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0) + else: + req_recv.wait() + input_ = torch.cat([recv_buffer, input_], dim=0) + + input_ = input_.transpose(0, dim).contiguous() + return input_ + + +def _drop_from_previous_rank(input_, dim, kernel_size): + input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim) + return input_ + + +class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _conv_split(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _conv_gather(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _ConvolutionPassFromPreviousRank(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _pass_from_previous_rank(input_, dim, kernel_size) + + @staticmethod + def backward(ctx, grad_output): + return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None + + +class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, dim, kernel_size, cache_padding): + ctx.dim = dim + ctx.kernel_size = kernel_size + return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding) + + @staticmethod + def backward(ctx, grad_output): + return _drop_from_previous_rank(grad_output, ctx.dim, ctx.kernel_size), None, None, None + + +def conv_scatter_to_context_parallel_region(input_, dim, kernel_size): + return _ConvolutionScatterToContextParallelRegion.apply(input_, dim, kernel_size) + + +def conv_gather_from_context_parallel_region(input_, dim, kernel_size): + return _ConvolutionGatherFromContextParallelRegion.apply(input_, dim, kernel_size) + + +def conv_pass_from_last_rank(input_, dim, kernel_size): + return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size) + + +def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding): + return _FakeCPConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size, cache_padding) + + +class ContextParallelCausalConv3d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], stride=1, **kwargs): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + time_pad = time_kernel_size - 1 + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_kernel_size = time_kernel_size + self.temporal_dim = 2 + + stride = (stride, stride, stride) + dilation = (1, 1, 1) + self.conv = Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + self.cache_padding = None + + def forward(self, input_, clear_cache=True): + # if input_.shape[2] == 1: # handle image + # # first frame padding + # input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2) + # else: + # input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size) + + # padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + # input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0) + + # output_parallel = self.conv(input_parallel) + # output = output_parallel + # return output + + input_parallel = fake_cp_pass_from_previous_rank( + input_, self.temporal_dim, self.time_kernel_size, self.cache_padding + ) + + del self.cache_padding + self.cache_padding = None + if not clear_cache: + cp_rank, cp_world_size = get_vae_parallel_rank(), get_vae_parallel_world_size() + global_rank = torch.distributed.get_rank() + if cp_world_size == 1: + self.cache_padding = ( + input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + ) + else: + if cp_rank == cp_world_size - 1: + torch.distributed.isend( + input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(), + global_rank + 1 - cp_world_size, + group=get_vae_parallel_group(), + ) + if cp_rank == 0: + recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous() + torch.distributed.recv( + recv_buffer, global_rank - 1 + cp_world_size, group=get_vae_parallel_group() + ) + self.cache_padding = recv_buffer.contiguous().detach().clone().cpu() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0) + + output_parallel = self.conv(input_parallel) + output = output_parallel + return output + + +class ContextParallelGroupNorm(torch.nn.GroupNorm): + def forward(self, input_): + gather_flag = input_.shape[2] > 1 + if gather_flag: + input_ = conv_gather_from_context_parallel_region(input_, dim=2, kernel_size=1) + output = super().forward(input_) + if gather_flag: + output = conv_scatter_to_context_parallel_region(output, dim=2, kernel_size=1) + return output + + +def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D + if gather: + return ContextParallelGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + else: + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialNorm3D(nn.Module): + def __init__( + self, + f_channels, + zq_channels, + freeze_norm_layer=False, + add_conv=False, + pad_mode="constant", + gather=False, + **norm_layer_params, + ): + super().__init__() + if gather: + self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels, **norm_layer_params) + else: + self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, **norm_layer_params) + # self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params) + if freeze_norm_layer: + for p in self.norm_layer.parameters: + p.requires_grad = False + + self.add_conv = add_conv + if add_conv: + self.conv = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=zq_channels, + kernel_size=3, + ) + + self.conv_y = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=f_channels, + kernel_size=1, + ) + self.conv_b = ContextParallelCausalConv3d( + chan_in=zq_channels, + chan_out=f_channels, + kernel_size=1, + ) + + def forward(self, f, zq, clear_fake_cp_cache=True): + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] + zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") + zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") + zq = torch.cat([zq_first, zq_rest], dim=2) + else: + zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") + + if self.add_conv: + zq = self.conv(zq, clear_cache=clear_fake_cp_cache) + + # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1) + norm_f = self.norm_layer(f) + # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1) + + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +def Normalize3D( + in_channels, + zq_ch, + add_conv, + gather=False, +): + return SpatialNorm3D( + in_channels, + zq_ch, + gather=gather, + freeze_norm_layer=False, + add_conv=add_conv, + num_groups=32, + eps=1e-6, + affine=True, + ) + + +class Upsample3D(nn.Module): + def __init__( + self, + in_channels, + with_conv, + compress_time=False, + ): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time and x.shape[2] > 1: + if x.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") + x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + else: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + + else: + # only interpolate 2D + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.with_conv: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + +class DownSample3D(nn.Module): + def __init__(self, in_channels, with_conv, compress_time=False, out_channels=None): + super().__init__() + self.with_conv = with_conv + if out_channels is None: + out_channels = in_channels + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time and x.shape[2] > 1: + h, w = x.shape[-2:] + x = rearrange(x, "b c t h w -> (b h w) c t") + + if x.shape[-1] % 2 == 1: + # split first frame + x_first, x_rest = x[..., 0], x[..., 1:] + + if x_rest.shape[-1] > 0: + x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + else: + x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) + + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.conv(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + else: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + +class ContextParallelResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + zq_ch=None, + add_conv=False, + gather_norm=False, + normalization=Normalize, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = normalization( + in_channels, + zq_ch=zq_ch, + add_conv=add_conv, + gather=gather_norm, + ) + + self.conv1 = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=out_channels, + kernel_size=3, + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = normalization( + out_channels, + zq_ch=zq_ch, + add_conv=add_conv, + gather=gather_norm, + ) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = ContextParallelCausalConv3d( + chan_in=out_channels, + chan_out=out_channels, + kernel_size=3, + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=out_channels, + kernel_size=3, + ) + else: + self.nin_shortcut = Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): + h = x + + # if isinstance(self.norm1, torch.nn.GroupNorm): + # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) + if zq is not None: + h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + else: + h = self.norm1(h) + # if isinstance(self.norm1, torch.nn.GroupNorm): + # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) + + h = nonlinearity(h) + h = self.conv1(h, clear_cache=clear_fake_cp_cache) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] + + # if isinstance(self.norm2, torch.nn.GroupNorm): + # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) + if zq is not None: + h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + else: + h = self.norm2(h) + # if isinstance(self.norm2, torch.nn.GroupNorm): + # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) + + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h, clear_cache=clear_fake_cp_cache) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x, clear_cache=clear_fake_cp_cache) + else: + x = self.nin_shortcut(x) + + return x + h + + +class ContextParallelEncoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + self.conv_in = ContextParallelCausalConv3d( + chan_in=in_channels, + chan_out=self.ch, + kernel_size=3, + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + temb_channels=self.temb_ch, + gather_norm=gather_norm, + ) + ) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level < self.temporal_compress_level: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=True) + else: + down.downsample = DownSample3D(block_in, resamp_with_conv, compress_time=False) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + gather_norm=gather_norm, + ) + + self.mid.block_2 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + gather_norm=gather_norm, + ) + + # end + self.norm_out = Normalize(block_in, gather=gather_norm) + + self.conv_out = ContextParallelCausalConv3d( + chan_in=block_in, + chan_out=2 * z_channels if double_z else z_channels, + kernel_size=3, + ) + + def forward(self, x, **kwargs): + # timestep embedding + temb = None + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h, temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.block_2(h, temb) + + # end + # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) + h = self.norm_out(h) + # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) + + h = nonlinearity(h) + h = self.conv_out(h) + + return h + + +class ContextParallelDecoder3D(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # log2 of temporal_compress_times + self.temporal_compress_level = int(np.log2(temporal_compress_times)) + + if zq_ch is None: + zq_ch = z_channels + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + self.conv_in = ContextParallelCausalConv3d( + chan_in=z_channels, + chan_out=block_in, + kernel_size=3, + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + + self.mid.block_2 = ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ContextParallelResnetBlock3D( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + zq_ch=zq_ch, + add_conv=add_conv, + normalization=Normalize3D, + gather_norm=gather_norm, + ) + ) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level < self.num_resolutions - self.temporal_compress_level: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=False) + else: + up.upsample = Upsample3D(block_in, with_conv=resamp_with_conv, compress_time=True) + self.up.insert(0, up) + + self.norm_out = Normalize3D(block_in, zq_ch, add_conv=add_conv, gather=gather_norm) + + self.conv_out = ContextParallelCausalConv3d( + chan_in=block_in, + chan_out=out_ch, + kernel_size=3, + ) + + def forward(self, z, clear_fake_cp_cache=True, **kwargs): + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + t = z.shape[2] + # z to block_in + + zq = z + h = self.conv_in(z, clear_cache=clear_fake_cp_cache) + + # middle + h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, zq) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = nonlinearity(h) + h = self.conv_out(h, clear_cache=clear_fake_cp_cache) + + return h + + def get_last_layer(self): + return self.conv_out.conv.weight diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/ema.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/ema.py new file mode 100644 index 0000000000..9f1f7606c2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/regularizers.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/regularizers.py new file mode 100644 index 0000000000..205bd4a941 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/regularizers.py @@ -0,0 +1,108 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + # x = self.mean + self.std * torch.randn(self.mean.shape).to( + # device=self.parameters.device + # ) + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +class AbstractRegularizer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + raise NotImplementedError() + + @abstractmethod + def get_trainable_parameters(self) -> Any: + raise NotImplementedError() + + +class IdentityRegularizer(AbstractRegularizer): + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, dict() + + def get_trainable_parameters(self) -> Any: + yield from () + + +def measure_perplexity(predicted_indices: torch.Tensor, num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]: + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + + +class DiagonalGaussianRegularizer(AbstractRegularizer): + def __init__(self, sample: bool = True): + super().__init__() + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py new file mode 100644 index 0000000000..8cab5c195b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py @@ -0,0 +1,23 @@ +import torch + + +class SafeConv3d(torch.nn.Conv3d): + def forward(self, input): + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [ + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + for i in range(1, len(input_chunks)) + ] + + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super(SafeConv3d, self).forward(input_chunk)) + output = torch.cat(output_chunks, dim=2) + return output + else: + return super(SafeConv3d, self).forward(input) \ No newline at end of file -- Gitee From d19b412b857e9dff80c251fec9dcb1f74b299472 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 23 Jan 2025 09:20:32 +0800 Subject: [PATCH 2/9] cleancode --- .../foundation/cogvideox-sat/arguments.py | 15 +- .../cogvideox-sat/diffusion_video.py | 33 +-- .../cogvideox-sat/dit_video_concat.py | 241 +++++++++--------- .../foundation/cogvideox-sat/sample_video.py | 28 +- .../cogvideox-sat/sgm/modules/attention.py | 117 +++++---- .../sgm/modules/autoencoding/temporal_ae.py | 44 ++-- .../sgm/modules/diffusionmodules/denoiser.py | 30 +-- .../diffusionmodules/denoiser_scaling.py | 26 +- .../diffusionmodules/denoiser_weighting.py | 4 +- .../modules/diffusionmodules/discretizer.py | 27 +- .../sgm/modules/diffusionmodules/lora.py | 44 ++-- .../sgm/modules/diffusionmodules/model.py | 161 ++++-------- .../modules/diffusionmodules/openaimodel.py | 200 +++++++-------- .../sgm/modules/diffusionmodules/sampling.py | 169 ++++++------ .../diffusionmodules/sampling_utils.py | 4 +- .../diffusionmodules/sigma_sampling.py | 21 +- .../sgm/modules/diffusionmodules/util.py | 63 ++--- .../sgm/modules/encoders/modules.py | 57 ++--- .../sgm/modules/video_attention.py | 88 +++---- .../foundation/cogvideox-sat/sgm/util.py | 14 +- .../cogvideox-sat/utils/cache_mgr.py | 174 ------------- .../cogvideox-sat/utils/group_coordinator.py | 54 ++-- .../cogvideox-sat/utils/parallel_mgr.py | 81 +++--- .../cogvideox-sat/utils/sampling_optm.py | 50 ++-- .../foundation/cogvideox-sat/utils/utils.py | 14 +- .../cogvideox-sat/vae_modules/autoencoder.py | 121 +++++---- .../cogvideox-sat/vae_modules/cp_enc_dec.py | 149 ++++++----- .../cogvideox-sat/vae_modules/utils.py | 6 +- 28 files changed, 890 insertions(+), 1145 deletions(-) delete mode 100644 MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/cache_mgr.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py index 1844991a54..074eceb5b9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/arguments.py @@ -2,8 +2,6 @@ import argparse import os import torch import torch_npu -import json -import warnings import omegaconf from omegaconf import OmegaConf from sat.helpers import print_rank0 @@ -52,19 +50,19 @@ def add_inference_args(parser): help='batch size on a single GPU. batch-size * world_size = total batch_size.') group.add_argument('--mode', type=str, default='pretrain', - choices=['pretrain', # from_scratch / load ckpt for continue pretraining. - 'finetune', # finetuning, auto-warmup 100 iters, new exp name. - 'inference' # don't train. + choices=['pretrain', # from_scratch / load ckpt for continue pretraining. + 'finetune', # finetuning, auto-warmup 100 iters, new exp name. + 'inference' # don't train. ], help='what type of task to use, will influence auto-warmup, exp name, iteration') group.add_argument('--load', type=str, default=None, help='Path to a directory containing a model checkpoint.') group.add_argument('--seed', type=int, default=1234, help='random seed') - group.add_argument('--zero-stage', type=int, default=0, choices=[0, 1, 2, 3], - help='deepspeed ZeRO stage. 0 means no ZeRO.') + group.add_argument('--zero-stage', type=int, default=0, choices=[0, 1, 2, 3], + help='deepspeed ZeRO stage. 0 means no ZeRO.') group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode') group.add_argument('--bf16', action='store_true', help='Run model in bf16 mode') - + return parser @@ -100,6 +98,7 @@ def get_args(args_list=None, parser=None): return args + def process_config_to_args(args): """Fetch args from only --base""" diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py index 989feb8d0e..e9efbdb0db 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/diffusion_video.py @@ -1,11 +1,6 @@ -import random - import math -from typing import Any, Dict, List, Tuple, Union -from omegaconf import ListConfig -import torch.nn.functional as F +from typing import Dict, List, Tuple, Union -from sat.helpers import print_rank0 import torch from torch import nn @@ -17,9 +12,7 @@ from sgm.util import ( disabled_train, get_obj_from_str, instantiate_from_config, - log_txt_as_img, ) -import gc class SATVideoDiffusionEngine(nn.Module): @@ -83,10 +76,10 @@ class SATVideoDiffusionEngine(nn.Module): with torch.autocast("npu", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): if isinstance(self.first_stage_model.decoder, VideoDecoder): - kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} + kwargs = {"timesteps": len(z[n * n_samples: (n + 1) * n_samples])} else: kwargs = {} - out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs) + out = self.first_stage_model.decode(z[n * n_samples: (n + 1) * n_samples], **kwargs) all_out.append(out) out = torch.cat(all_out, dim=0) return out @@ -104,7 +97,7 @@ class SATVideoDiffusionEngine(nn.Module): all_out = [] with torch.autocast("npu", enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): - out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples]) + out = self.first_stage_model.encode(x[n * n_samples: (n + 1) * n_samples]) all_out.append(out) z = torch.cat(all_out, dim=0) z = self.scale_factor * z @@ -112,21 +105,21 @@ class SATVideoDiffusionEngine(nn.Module): @torch.no_grad() def sample( - self, - cond: Dict, - uc: Union[Dict, None] = None, - batch_size: int = 16, - shape: Union[None, Tuple, List] = None, - prefix=None, - concat_images=None, - **kwargs, + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + prefix=None, + concat_images=None, + **kwargs, ): randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) if hasattr(self, "seeded_noise"): randn = self.seeded_noise(randn) if prefix is not None: - randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1) + randn = torch.cat([prefix, randn[:, prefix.shape[1]:]], dim=1) scale = None scale_emb = None diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py index 6c2ac4e68c..a0678d1648 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py @@ -7,11 +7,11 @@ import torch import torch_npu from torch import nn import torch.nn.functional as F + torch.ops.load_library("./sgm/modules/libPTAExtensionOPS.so") from sat.model.base_model import BaseModel, non_conflict from sat.model.mixins import BaseMixin -from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default from sat.mpu.layers import ColumnParallelLinear from sgm.util import instantiate_from_config @@ -25,12 +25,12 @@ from sat.ops.layernorm import LayerNorm, RMSNorm class ImagePatchEmbeddingMixin(BaseMixin): def __init__( - self, - in_channels, - hidden_size, - patch_size, - bias=True, - text_hidden_size=None, + self, + in_channels, + hidden_size, + patch_size, + bias=True, + text_hidden_size=None, ): super().__init__() self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias) @@ -62,27 +62,29 @@ class ImagePatchEmbeddingMixin(BaseMixin): nn.init.constant_(self.proj.bias, 0) del self.transformer.word_embeddings + def attention_fn_fa_score(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): # expand head dim to query dim, if necessary # only useful for multi-query attention - num_query_heads = query_layer.shape[1] # [b, np, s, hn] + num_query_heads = query_layer.shape[1] # [b, np, s, hn] attn_output = torch_npu.npu_fusion_attention(query_layer, key_layer, value_layer, - input_layout="BNSD", - scale=1 / math.sqrt(query_layer.shape[-1]), - head_num=num_query_heads - )[0] + input_layout="BNSD", + scale=1 / math.sqrt(query_layer.shape[-1]), + head_num=num_query_heads + )[0] return attn_output + def get_3d_sincos_pos_embed( - embed_dim, - grid_height, - grid_width, - t_size, - cls_token=False, - height_interpolation=1.0, - width_interpolation=1.0, - time_interpolation=1.0, + embed_dim, + grid_height, + grid_width, + t_size, + cls_token=False, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, ): """ grid_size: int of the grid height and width @@ -157,7 +159,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) + omega = 1.0 / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product @@ -185,20 +187,20 @@ class Basic2DPositionEmbeddingMixin(BaseMixin): def reinit(self, parent_model=None): del self.transformer.position_embeddings pos_embed = get_2d_sincos_pos_embed(self.pos_embedding.shape[-1], self.height, self.width) - self.pos_embedding.data[:, -self.spatial_length :].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + self.pos_embedding.data[:, -self.spatial_length:].copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) class Basic3DPositionEmbeddingMixin(BaseMixin): def __init__( - self, - height, - width, - compressed_num_frames, - hidden_size, - text_length=0, - height_interpolation=1.0, - width_interpolation=1.0, - time_interpolation=1.0, + self, + height, + width, + compressed_num_frames, + hidden_size, + text_length=0, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, ): super().__init__() self.height = height @@ -233,7 +235,7 @@ class Basic3DPositionEmbeddingMixin(BaseMixin): ) pos_embed = torch.from_numpy(pos_embed).float() pos_embed = rearrange(pos_embed, "t n d -> (t n) d") - self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed) + self.pos_embedding.data[:, -self.num_patches:].copy_(pos_embed) def broadcat(tensors, dim=-1): @@ -257,16 +259,16 @@ def broadcat(tensors, dim=-1): class Rotary3DPositionEmbeddingMixin(BaseMixin): def __init__( - self, - height, - width, - compressed_num_frames, - hidden_size, - hidden_size_head, - text_length, - theta=10000, - rot_v=False, - learnable_pos_embed=False, + self, + height, + width, + compressed_num_frames, + hidden_size, + hidden_size_head, + text_length, + theta=10000, + rot_v=False, + learnable_pos_embed=False, ): super().__init__() self.rot_v = rot_v @@ -321,23 +323,22 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): @non_conflict def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask, - attention_dropout=None, - log_attention_weights=None, - scaling_attention_score=True, - old_impl=attention_fn_fa_score, - **kwargs, + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + old_impl=attention_fn_fa_score, + **kwargs, ): - attention_fn_default = HOOKS_DEFAULT["attention_fn"] - query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :]) - key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :]) + query_layer[:, :, self.text_length:] = self.rotary(query_layer[:, :, self.text_length:]) + key_layer[:, :, self.text_length:] = self.rotary(key_layer[:, :, self.text_length:]) if self.rot_v: - value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :]) + value_layer[:, :, self.text_length:] = self.rotary(value_layer[:, :, self.text_length:]) return attention_fn_fa_score( query_layer, @@ -376,14 +377,14 @@ def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs): class FinalLayerMixin(BaseMixin): def __init__( - self, - hidden_size, - time_embed_dim, - patch_size, - out_channels, - latent_width, - latent_height, - elementwise_affine, + self, + hidden_size, + time_embed_dim, + patch_size, + out_channels, + latent_width, + latent_height, + elementwise_affine, ): super().__init__() self.hidden_size = hidden_size @@ -393,12 +394,12 @@ class FinalLayerMixin(BaseMixin): self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True)) - self.spatial_length = latent_width * latent_height // patch_size**2 + self.spatial_length = latent_width * latent_height // patch_size ** 2 self.latent_width = latent_width self.latent_height = latent_height def final_forward(self, logits, **kwargs): - x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d) + x, emb = logits[:, kwargs["text_length"]:, :], kwargs["emb"] # x:(b,(t n),d) shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) @@ -447,16 +448,16 @@ class SwiGLUMixin(BaseMixin): class AdaLNMixin(BaseMixin): def __init__( - self, - width, - height, - hidden_size, - num_layers, - time_embed_dim, - compressed_num_frames, - qk_ln=True, - hidden_size_head=None, - elementwise_affine=True, + self, + width, + height, + hidden_size, + num_layers, + time_embed_dim, + compressed_num_frames, + qk_ln=True, + hidden_size_head=None, + elementwise_affine=True, ): super().__init__() self.num_layers = num_layers @@ -484,11 +485,11 @@ class AdaLNMixin(BaseMixin): ) def layer_forward( - self, - hidden_states, - mask, - *args, - **kwargs, + self, + hidden_states, + mask, + *args, + **kwargs, ): text_length = kwargs["text_length"] # hidden_states (b,(n_t+t*n_i),d) @@ -561,16 +562,16 @@ class AdaLNMixin(BaseMixin): @non_conflict def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask, - attention_dropout=None, - log_attention_weights=None, - scaling_attention_score=True, - old_impl=attention_fn_fa_score, - **kwargs, + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + old_impl=attention_fn_fa_score, + **kwargs, ): if self.qk_ln: query_layernorm = self.query_layernorm_list[kwargs["layer_id"]] @@ -595,39 +596,39 @@ str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bflo class DiffusionTransformer(BaseModel): def __init__( - self, - transformer_args, - num_frames, - time_compressed_rate, - latent_width, - latent_height, - patch_size, - in_channels, - out_channels, - hidden_size, - num_layers, - num_attention_heads, - elementwise_affine, - time_embed_dim=None, - num_classes=None, - modules={}, - input_time="adaln", - adm_in_channels=None, - parallel_output=True, - height_interpolation=1.0, - width_interpolation=1.0, - time_interpolation=1.0, - use_SwiGLU=False, - use_RMSNorm=False, - zero_init_y_embed=False, - **kwargs, + self, + transformer_args, + num_frames, + time_compressed_rate, + latent_width, + latent_height, + patch_size, + in_channels, + out_channels, + hidden_size, + num_layers, + num_attention_heads, + elementwise_affine, + time_embed_dim=None, + num_classes=None, + modules={}, + input_time="adaln", + adm_in_channels=None, + parallel_output=True, + height_interpolation=1.0, + width_interpolation=1.0, + time_interpolation=1.0, + use_SwiGLU=False, + use_RMSNorm=False, + zero_init_y_embed=False, + **kwargs, ): self.latent_width = latent_width self.latent_height = latent_height self.patch_size = patch_size self.num_frames = num_frames self.time_compressed_rate = time_compressed_rate - self.spatial_length = latent_width * latent_height // patch_size**2 + self.spatial_length = latent_width * latent_height // patch_size ** 2 self.in_channels = in_channels self.out_channels = out_channels self.hidden_size = hidden_size @@ -794,7 +795,7 @@ class DiffusionTransformer(BaseModel): x = torch.cat([x, concat_images], dim=2) assert (y is not None) == ( - self.num_classes is not None + self.num_classes is not None ), "must specify y if and only if the model is class-conditional" t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) emb = self.time_embed(t_emb) @@ -805,7 +806,7 @@ class DiffusionTransformer(BaseModel): y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0) emb = emb + self.label_emb(y) - kwargs["seq_length"] = t * h * w // (self.patch_size**2) + kwargs["seq_length"] = t * h * w // (self.patch_size ** 2) kwargs["images"] = x kwargs["emb"] = emb kwargs["encoder_outputs"] = context diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py index 2d845211df..cc9066079f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py @@ -2,7 +2,8 @@ import os import math import argparse from arguments import get_args -from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_state, get_sequence_parallel_rank +from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_state, \ + get_sequence_parallel_rank from sat.arguments import set_random_seed from typing import List, Union @@ -16,7 +17,6 @@ import numpy as np from einops import rearrange import torchvision.transforms as TT - from sat.model.base_model import get_model from sat.training.model_io import load_checkpoint @@ -125,7 +125,7 @@ def sampling_main(args, model_cls): else: parallel_config = ParallelConfig(sp_degree=1, use_cfg_parallel=False, world_size=args.world_size) init_parallel_env(parallel_config) - + rank = get_sequence_parallel_rank() if get_sequence_parallel_state() else 0 args.seed = args.seed + rank set_random_seed(args.seed) @@ -203,7 +203,7 @@ def sampling_main(args, model_cls): if args.image2video and image is not None: c["concat"] = image uc["concat"] = image - + first_stage_model = model.first_stage_model for index in range(args.batch_size): count = 5 @@ -215,15 +215,17 @@ def sampling_main(args, model_cls): data_simplification=False ) with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], #默认CPU、NPU均打开 - with_stack=True, #采集torch op的函数调用栈的开关,默认关闭 - record_shapes=True, #采集torch op的input shape和input type的开关,默认关闭 - profile_memory=True, #采集memory相关数据的开关,默认关闭 - schedule=torch_npu.profiler.schedule(wait=2, warmup=2, active=1, repeat=1), - experimental_config=experimental_config, - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prof_unet") #导出tensorboard可呈现的数据形式 + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + # 默认CPU、NPU均打开 + with_stack=True, # 采集torch op的函数调用栈的开关,默认关闭 + record_shapes=True, # 采集torch op的input shape和input type的开关,默认关闭 + profile_memory=True, # 采集memory相关数据的开关,默认关闭 + schedule=torch_npu.profiler.schedule(wait=2, warmup=2, active=1, repeat=1), + experimental_config=experimental_config, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prof_unet") + # 导出tensorboard可呈现的数据形式 ) as prof: - for i in range(count): + for i in range(count): samples_z = sample_func( c, uc=uc, @@ -268,7 +270,7 @@ def sampling_main(args, model_cls): ) if rank == 0: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) - + finalize_parallel_env() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py index 7b7d7dfb38..46ce4753c3 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/attention.py @@ -1,11 +1,10 @@ import math from inspect import isfunction -from typing import Any, Optional import torch +import torch_npu import torch.nn.functional as F from einops import rearrange, repeat -from packaging import version from torch import nn @@ -129,19 +128,19 @@ class SpatialSelfAttention(nn.Module): class CrossAttention(nn.Module): def __init__( - self, - query_dim, - context_dim=None, - heads=8, - dim_head=64, - dropout=0.0, - backend=None, + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + backend=None, ): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head**-0.5 + self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) @@ -152,12 +151,12 @@ class CrossAttention(nn.Module): self.backend = backend def forward( - self, - x, - context=None, - mask=None, - additional_tokens=None, - n_times_crossframe_attn_in_self=0, + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, ): h = self.heads @@ -181,9 +180,9 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) out = torch_npu.npu_fusion_attention(q, k, v, atten_mask=mask, - input_layout="BNSD", - scale=1 / math.sqrt(q.shape[-1]), - head_num=h)[0] + input_layout="BNSD", + scale=1 / math.sqrt(q.shape[-1]), + head_num=h)[0] del q, k, v out = rearrange(out, "b h n d -> b n (h d)", h=h) @@ -196,17 +195,17 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - disable_self_attn=False, - attn_mode="softmax", - sdp_backend=None, + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attn_mode="softmax", + sdp_backend=None, ): super().__init__() self.disable_self_attn = disable_self_attn @@ -251,13 +250,13 @@ class BasicTransformerBlock(nn.Module): def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): x = ( - self.attn1( - self.norm1(x), - context=context if self.disable_self_attn else None, - additional_tokens=additional_tokens, - n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, - ) - + x + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, + ) + + x ) x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x x = self.ff(self.norm3(x)) + x @@ -266,15 +265,15 @@ class BasicTransformerBlock(nn.Module): class BasicTransformerSingleLayerBlock(nn.Module): def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - attn_mode="softmax", + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + attn_mode="softmax", ): super().__init__() self.attn1 = CrossAttention( @@ -309,19 +308,19 @@ class SpatialTransformer(nn.Module): """ def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - context_dim=None, - disable_self_attn=False, - use_linear=False, - attn_type="softmax", - use_checkpoint=True, - # sdp_backend=SDPBackend.FLASH_ATTENTION - sdp_backend=None, + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + attn_type="softmax", + use_checkpoint=True, + # sdp_backend=SDPBackend.FLASH_ATTENTION + sdp_backend=None, ): super().__init__() print(f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads") diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py index 890bcfa0c9..d2ae0ebfc2 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/autoencoding/temporal_ae.py @@ -15,14 +15,14 @@ from sgm.util import partialclass class VideoResBlock(ResnetBlock): def __init__( - self, - out_channels, - *args, - dropout=0.0, - video_kernel_size=3, - alpha=0.0, - merge_strategy="learned", - **kwargs, + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, ): super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) if video_kernel_size is None: @@ -159,7 +159,7 @@ class VideoBlock(AttnBlock): return x_in + x def get_alpha( - self, + self, ): if self.merge_strategy == "fixed": return self.mix_factor @@ -170,11 +170,11 @@ class VideoBlock(AttnBlock): def make_time_attn( - in_channels, - attn_type="vanilla", - attn_kwargs=None, - alpha: float = 0, - merge_strategy: str = "learned", + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", ): if attn_type == "vanilla": assert attn_kwargs is None @@ -192,20 +192,20 @@ class VideoDecoder(Decoder): available_time_modes = ["all", "conv-only", "attn-only"] def __init__( - self, - *args, - video_kernel_size: Union[int, list] = 3, - alpha: float = 0.0, - merge_strategy: str = "learned", - time_mode: str = "conv-only", - **kwargs, + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, ): self.video_kernel_size = video_kernel_size self.alpha = alpha self.merge_strategy = merge_strategy self.time_mode = time_mode assert ( - self.time_mode in self.available_time_modes + self.time_mode in self.available_time_modes ), f"time_mode parameter has to be in {self.available_time_modes}" super().__init__(*args, **kwargs) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py index 3cc01e36f8..3b192ed410 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser.py @@ -1,4 +1,4 @@ -from typing import Dict, Union +from typing import Dict import torch import torch.nn as nn @@ -23,12 +23,12 @@ class Denoiser(nn.Module): return self.weighting(sigma) def forward( - self, - network: nn.Module, - input: torch.Tensor, - sigma: torch.Tensor, - cond: Dict, - **additional_model_inputs, + self, + network: nn.Module, + input: torch.Tensor, + sigma: torch.Tensor, + cond: Dict, + **additional_model_inputs, ) -> torch.Tensor: sigma = self.possibly_quantize_sigma(sigma) sigma_shape = sigma.shape @@ -40,14 +40,14 @@ class Denoiser(nn.Module): class DiscreteDenoiser(Denoiser): def __init__( - self, - weighting_config, - scaling_config, - num_idx, - discretization_config, - do_append_zero=False, - quantize_c_noise=True, - flip=True, + self, + weighting_config, + scaling_config, + num_idx, + discretization_config, + do_append_zero=False, + quantize_c_noise=True, + flip=True, ): super().__init__(weighting_config, scaling_config) sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py index 2cb9643014..7f803db1e0 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Tuple +from typing import Tuple import torch @@ -15,9 +15,9 @@ class EDMScaling: self.sigma_data = sigma_data def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 - c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 c_noise = 0.25 * sigma.log() return c_skip, c_out, c_in, c_noise @@ -26,35 +26,35 @@ class EpsScaling: def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = torch.ones_like(sigma, device=sigma.device) c_out = -sigma - c_in = 1 / (sigma**2 + 1.0) ** 0.5 + c_in = 1 / (sigma ** 2 + 1.0) ** 0.5 c_noise = sigma.clone() return c_skip, c_out, c_in, c_noise class VScaling: def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - c_skip = 1.0 / (sigma**2 + 1.0) - c_out = -sigma / (sigma**2 + 1.0) ** 0.5 - c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_skip = 1.0 / (sigma ** 2 + 1.0) + c_out = -sigma / (sigma ** 2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma ** 2 + 1.0) ** 0.5 c_noise = sigma.clone() return c_skip, c_out, c_in, c_noise class VScalingWithEDMcNoise(DenoiserScaling): def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - c_skip = 1.0 / (sigma**2 + 1.0) - c_out = -sigma / (sigma**2 + 1.0) ** 0.5 - c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_skip = 1.0 / (sigma ** 2 + 1.0) + c_out = -sigma / (sigma ** 2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma ** 2 + 1.0) ** 0.5 c_noise = 0.25 * sigma.log() return c_skip, c_out, c_in, c_noise class VideoScaling: # similar to VScaling def __call__( - self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs + self, alphas_cumprod_sqrt: torch.Tensor, **additional_model_inputs ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = alphas_cumprod_sqrt - c_out = -((1 - alphas_cumprod_sqrt**2) ** 0.5) + c_out = -((1 - alphas_cumprod_sqrt ** 2) ** 0.5) c_in = torch.ones_like(alphas_cumprod_sqrt, device=alphas_cumprod_sqrt.device) c_noise = additional_model_inputs["idx"].clone() return c_skip, c_out, c_in, c_noise diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py index b8b03ca58f..18285bb095 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/denoiser_weighting.py @@ -11,7 +11,7 @@ class EDMWeighting: self.sigma_data = sigma_data def __call__(self, sigma): - return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + return (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 class VWeighting(EDMWeighting): @@ -21,4 +21,4 @@ class VWeighting(EDMWeighting): class EpsWeighting: def __call__(self, sigma): - return sigma**-2.0 + return sigma ** -2.0 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py index a86b7d8dfb..a579eab64d 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/discretizer.py @@ -45,10 +45,10 @@ class EDMDiscretization(Discretization): class LegacyDDPMDiscretization(Discretization): def __init__( - self, - linear_start=0.00085, - linear_end=0.0120, - num_timesteps=1000, + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, ): super().__init__() self.num_timesteps = num_timesteps @@ -73,13 +73,13 @@ class LegacyDDPMDiscretization(Discretization): class ZeroSNRDDPMDiscretization(Discretization): def __init__( - self, - linear_start=0.00085, - linear_end=0.0120, - num_timesteps=1000, - shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale) - keep_start=False, - post_shift=False, + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale) + keep_start=False, + post_shift=False, ): super().__init__() if keep_start and not post_shift: @@ -117,8 +117,9 @@ class ZeroSNRDDPMDiscretization(Discretization): if self.post_shift: alphas_cumprod_sqrt = ( - alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2) - ) ** 0.5 + alphas_cumprod_sqrt ** 2 / ( + self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt ** 2) + ) ** 0.5 if return_idx: return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py index 7ccd72a19f..ff6371587d 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/lora.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import List, Optional, Set, Type import torch import torch.nn.functional as F @@ -50,7 +50,7 @@ class LoRALinearLayer(nn.Module): class LoRAConv2dLayer(nn.Module): def __init__( - self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None ): super().__init__() @@ -215,8 +215,8 @@ class LoRACompatibleLinear(nn.Linear): def _find_children( - model, - search_class: List[Type[nn.Module]] = [nn.Linear], + model, + search_class: List[Type[nn.Module]] = [nn.Linear], ): """ Find all modules of a certain class (or union of classes). @@ -232,15 +232,15 @@ def _find_children( def _find_modules_v2( - model, - ancestor_class: Optional[Set[str]] = None, - search_class: List[Type[nn.Module]] = [nn.Linear], - exclude_children_of: Optional[List[Type[nn.Module]]] = [ - LoRACompatibleLinear, - LoRACompatibleConv, - LoRALinearLayer, - LoRAConv2dLayer, - ], + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoRACompatibleLinear, + LoRACompatibleConv, + LoRALinearLayer, + LoRAConv2dLayer, + ], ): """ Find all modules of a certain class (or union of classes) that are direct or @@ -284,13 +284,13 @@ _find_modules = _find_modules_v2 def inject_trainable_lora_extended( - model: nn.Module, - target_replace_module: Set[str] = None, - rank: int = 4, - scale: float = 1.0, + model: nn.Module, + target_replace_module: Set[str] = None, + rank: int = 4, + scale: float = 1.0, ): for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] ): if _child_module.__class__ == nn.Linear: weight = _child_module.weight @@ -350,12 +350,12 @@ def inject_trainable_lora_extended( def update_lora_scale( - model: nn.Module, - target_module: Set[str] = None, - scale: float = 1.0, + model: nn.Module, + target_module: Set[str] = None, + scale: float = 1.0, ): for _module, name, _child_module in _find_modules( - model, target_module, search_class=[LoRACompatibleLinear, LoRACompatibleConv] + model, target_module, search_class=[LoRACompatibleLinear, LoRACompatibleConv] ): _child_module.scale = scale diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py index 2aa75acb1b..0bd79d2a5e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/model.py @@ -1,13 +1,11 @@ -# pytorch_diffusion + derived encoder decoder import math -from typing import Any, Callable, Optional +from typing import Callable import numpy as np import torch import torch_npu import torch.nn as nn from einops import rearrange -from packaging import version from ..attention import LinearAttention @@ -76,13 +74,13 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, ): super().__init__() self.in_channels = in_channels @@ -165,55 +163,6 @@ class AttnBlock(nn.Module): return x + h_ -class MemoryEfficientAttnBlock(nn.Module): - """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation - """ - - # - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.attention_op: Optional[Any] = None - - def attention(self, h_: torch.Tensor) -> torch.Tensor: - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) - - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), - (q, k, v), - ) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - - out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C) - return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) - - def forward(self, x, **kwargs): - h_ = x - h_ = self.attention(h_) - h_ = self.proj_out(h_) - return x + h_ - - def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in [ "vanilla", @@ -232,20 +181,20 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): class Model(nn.Module): def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla", + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", ): super().__init__() if use_linear_attn: @@ -401,22 +350,22 @@ class Model(nn.Module): class Encoder(nn.Module): def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs, + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, ): super().__init__() if use_linear_attn: @@ -516,23 +465,23 @@ class Encoder(nn.Module): class Decoder(nn.Module): def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs, + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, ): super().__init__() if use_linear_attn: diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py index f61ebe5812..10a7a0da30 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/openaimodel.py @@ -2,7 +2,7 @@ import os import math from abc import abstractmethod from functools import partial -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional import numpy as np import torch as th @@ -40,14 +40,14 @@ class AttentionPool2d(nn.Module): """ def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -83,13 +83,13 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ def forward( - self, - x: th.Tensor, - emb: th.Tensor, - context: Optional[th.Tensor] = None, - image_only_indicator: Optional[th.Tensor] = None, - time_context: Optional[int] = None, - num_video_frames: Optional[int] = None, + self, + x: th.Tensor, + emb: th.Tensor, + context: Optional[th.Tensor] = None, + image_only_indicator: Optional[th.Tensor] = None, + time_context: Optional[int] = None, + num_video_frames: Optional[int] = None, ): from ...modules.diffusionmodules.video_model import VideoResBlock @@ -222,20 +222,20 @@ class ResBlock(TimestepBlock): """ def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - kernel_size=3, - exchange_temb_dims=False, - skip_t_emb=False, + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, ): super().__init__() self.channels = channels @@ -353,12 +353,12 @@ class AttentionBlock(nn.Module): """ def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, ): super().__init__() self.channels = channels @@ -366,7 +366,7 @@ class AttentionBlock(nn.Module): self.num_heads = num_heads else: assert ( - channels % num_head_channels == 0 + channels % num_head_channels == 0 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint @@ -413,7 +413,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial**2) * c + matmul_ops = 2 * b * (num_spatial ** 2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -524,44 +524,44 @@ class UNetModel(nn.Module): """ def __init__( - self, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - disable_self_attentions=None, - num_attention_blocks=None, - disable_middle_self_attn=False, - use_linear_in_transformer=False, - spatial_transformer_attn_type="softmax", - adm_in_channels=None, - use_fairscale_checkpoint=False, - offload_to_cpu=False, - transformer_depth_middle=None, - dtype="fp32", - lora_init=False, - lora_rank=4, - lora_scale=1.0, - lora_weight_path=None, + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + spatial_transformer_attn_type="softmax", + adm_in_channels=None, + use_fairscale_checkpoint=False, + offload_to_cpu=False, + transformer_depth_middle=None, + dtype="fp32", + lora_init=False, + lora_rank=4, + lora_scale=1.0, + lora_weight_path=None, ): super().__init__() from omegaconf.listconfig import ListConfig @@ -570,7 +570,7 @@ class UNetModel(nn.Module): if use_spatial_transformer: assert ( - context_dim is not None + context_dim is not None ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." if context_dim is not None: @@ -942,7 +942,7 @@ class UNetModel(nn.Module): print(f"loading lora from {ckpt_path}") sd = th.load(ckpt_path)["module"] sd = { - key[len("model.diffusion_model") :]: sd[key] for key in sd if key.startswith("model.diffusion_model") + key[len("model.diffusion_model"):]: sd[key] for key in sd if key.startswith("model.diffusion_model") } self.load_state_dict(sd, strict=False) @@ -975,7 +975,7 @@ class UNetModel(nn.Module): :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( - self.num_classes is not None + self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype) @@ -1014,28 +1014,28 @@ class EncoderUNetModel(nn.Module): """ def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - *args, - **kwargs, + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, ): super().__init__() @@ -1219,7 +1219,6 @@ class EncoderUNetModel(nn.Module): if __name__ == "__main__": - class Dummy(nn.Module): def __init__(self, in_channels=3, model_channels=64): super().__init__() @@ -1227,6 +1226,7 @@ if __name__ == "__main__": [TimestepEmbedSequential(conv_nd(2, in_channels, model_channels, 3, padding=1))] ) + model = UNetModel( use_checkpoint=True, image_size=64, diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py index 55e889bfd3..b385ca8431 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py @@ -1,7 +1,3 @@ -""" -Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py -""" - from typing import Dict, Union import torch @@ -16,7 +12,6 @@ from ...modules.diffusionmodules.sampling_utils import ( to_sigma, ) from ...util import append_dims, default, instantiate_from_config -from ...util import SeededNoise from .guiders import DynamicCFG @@ -25,12 +20,12 @@ DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider class BaseDiffusionSampler: def __init__( - self, - discretization_config: Union[Dict, ListConfig, OmegaConf], - num_steps: Union[int, None] = None, - guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, - verbose: bool = False, - device: str = "npu", + self, + discretization_config: Union[Dict, ListConfig, OmegaConf], + num_steps: Union[int, None] = None, + guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, + verbose: bool = False, + device: str = "npu", ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) @@ -95,7 +90,7 @@ class EDMSampler(SingleStepDiffusionSampler): sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise - x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 + x = x + eps * append_dims(sigma_hat ** 2 - sigma ** 2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) d = to_d(x, sigma_hat, denoised) @@ -110,7 +105,7 @@ class EDMSampler(SingleStepDiffusionSampler): for i in self.get_sigma_gen(num_sigmas): gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) x = self.sampler_step( s_in * sigmas[i], @@ -134,7 +129,7 @@ class DDIMSampler(SingleStepDiffusionSampler): def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, s_noise=0.0): denoised = self.denoise(x, denoiser, sigma, cond, uc) d = to_d(x, sigma, denoised) - dt = append_dims(next_sigma * (1 - s_noise**2) ** 0.5 - sigma, x.ndim) + dt = append_dims(next_sigma * (1 - s_noise ** 2) ** 0.5 - sigma, x.ndim) euler_step = x + dt * d + s_noise * append_dims(next_sigma, x.ndim) * torch.randn_like(x) @@ -198,10 +193,10 @@ class AncestralSampler(SingleStepDiffusionSampler): class LinearMultistepSampler(BaseDiffusionSampler): def __init__( - self, - order=4, - *args, - **kwargs, + self, + order=4, + *args, + **kwargs, ): super().__init__(*args, **kwargs) @@ -319,15 +314,15 @@ class DPMPP2MSampler(BaseDiffusionSampler): return mult1, mult2 def sampler_step( - self, - old_denoised, - previous_sigma, - sigma, - next_sigma, - denoiser, - x, - cond, - uc=None, + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) @@ -390,15 +385,15 @@ class SDEDPMPP2MSampler(BaseDiffusionSampler): return mult1, mult2 def sampler_step( - self, - old_denoised, - previous_sigma, - sigma, - next_sigma, - denoiser, - x, - cond, - uc=None, + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) @@ -461,7 +456,7 @@ class SdeditEDMSampler(EulerEDMSampler): x = image + randn_unit * append_dims(s_in * sigmas[i], len(randn_unit.shape)) gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) x = self.sampler_step( s_in * sigmas[i], @@ -515,30 +510,30 @@ class VideoDDIMSampler(BaseDiffusionSampler): ).to(torch.float32) if isinstance(self.guider, DynamicCFG): denoised = self.guider( - denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, step_index=self.num_steps - timestep, scale=scale + denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, step_index=self.num_steps - timestep, scale=scale ) else: - denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt**2) ** 0.5, scale=scale) + denoised = self.guider(denoised, (1 - alpha_cumprod_sqrt ** 2) ** 0.5, scale=scale) return denoised def sampler_step( - self, - alpha_cumprod_sqrt, - next_alpha_cumprod_sqrt, - denoiser, - x, - cond, - uc=None, - idx=None, - timestep=None, - scale=None, - scale_emb=None, + self, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, ): denoised = self.denoise( x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb ).to(torch.float32) - a_t = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + a_t = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 b_t = next_alpha_cumprod_sqrt - alpha_cumprod_sqrt * a_t x = append_dims(a_t, x.ndim) * x + append_dims(b_t, x.ndim) * denoised @@ -568,14 +563,14 @@ class VideoDDIMSampler(BaseDiffusionSampler): class VPSDEDPMPP2MSampler(VideoDDIMSampler): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): - alpha_cumprod = alpha_cumprod_sqrt**2 + alpha_cumprod = alpha_cumprod_sqrt ** 2 lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() - next_alpha_cumprod = next_alpha_cumprod_sqrt**2 + next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() h = lamb_next - lamb if previous_alpha_cumprod_sqrt is not None: - previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 + previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() h_last = lamb - lamb_previous r = h_last / h @@ -584,7 +579,7 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): return h, None, lamb, lamb_next def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): - mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 * (-h).exp() + mult1 = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 * (-h).exp() mult2 = (-2 * h).expm1() * next_alpha_cumprod_sqrt if previous_alpha_cumprod_sqrt is not None: @@ -595,19 +590,19 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): return mult1, mult2 def sampler_step( - self, - old_denoised, - previous_alpha_cumprod_sqrt, - alpha_cumprod_sqrt, - next_alpha_cumprod_sqrt, - denoiser, - x, - cond, - uc=None, - idx=None, - timestep=None, - scale=None, - scale_emb=None, + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, ): denoised = self.denoise( x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb @@ -622,7 +617,7 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): append_dims(mult, x.ndim) for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) ] - mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + mult_noise = append_dims((1 - next_alpha_cumprod_sqrt ** 2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: @@ -651,9 +646,9 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): noised_prefix_frames = alpha_cumprod_sqrt[i] * prefix_frames + rd * append_dims( s_in * (1 - alpha_cumprod_sqrt[i] ** 2) ** 0.5, len(prefix_frames.shape) ) - x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames :]], dim=1) + x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1) else: - x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) x, old_denoised = self.sampler_step( old_denoised, None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], @@ -670,21 +665,21 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): ) if self.fixed_frames > 0: - x = torch.cat([prefix_frames, x[:, self.fixed_frames :]], dim=1) + x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) return x class VPODEDPMPP2MSampler(VideoDDIMSampler): def get_variables(self, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt=None): - alpha_cumprod = alpha_cumprod_sqrt**2 + alpha_cumprod = alpha_cumprod_sqrt ** 2 lamb = ((alpha_cumprod / (1 - alpha_cumprod)) ** 0.5).log() - next_alpha_cumprod = next_alpha_cumprod_sqrt**2 + next_alpha_cumprod = next_alpha_cumprod_sqrt ** 2 lamb_next = ((next_alpha_cumprod / (1 - next_alpha_cumprod)) ** 0.5).log() h = lamb_next - lamb if previous_alpha_cumprod_sqrt is not None: - previous_alpha_cumprod = previous_alpha_cumprod_sqrt**2 + previous_alpha_cumprod = previous_alpha_cumprod_sqrt ** 2 lamb_previous = ((previous_alpha_cumprod / (1 - previous_alpha_cumprod)) ** 0.5).log() h_last = lamb - lamb_previous r = h_last / h @@ -693,7 +688,7 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): return h, None, lamb, lamb_next def get_mult(self, h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt): - mult1 = ((1 - next_alpha_cumprod_sqrt**2) / (1 - alpha_cumprod_sqrt**2)) ** 0.5 + mult1 = ((1 - next_alpha_cumprod_sqrt ** 2) / (1 - alpha_cumprod_sqrt ** 2)) ** 0.5 mult2 = (-h).expm1() * next_alpha_cumprod_sqrt if previous_alpha_cumprod_sqrt is not None: @@ -704,17 +699,17 @@ class VPODEDPMPP2MSampler(VideoDDIMSampler): return mult1, mult2 def sampler_step( - self, - old_denoised, - previous_alpha_cumprod_sqrt, - alpha_cumprod_sqrt, - next_alpha_cumprod_sqrt, - denoiser, - x, - cond, - uc=None, - idx=None, - timestep=None, + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, ): denoised = self.denoise(x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx).to(torch.float32) if idx == 1: diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py index 9bb1fa8296..0d1b106741 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling_utils.py @@ -137,9 +137,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.0): return sigma_to, 0.0 sigma_up = torch.minimum( sigma_to, - eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5, ) - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 return sigma_down, sigma_up diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py index 770de4254e..009de3cc03 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sigma_sampling.py @@ -1,8 +1,6 @@ import torch import torch.distributed -from sat import mpu - from ...util import default, instantiate_from_config @@ -20,29 +18,18 @@ class DiscreteSampling: def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, uniform_sampling=False): self.num_idx = num_idx self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip) - world_size = mpu.get_data_parallel_world_size() self.uniform_sampling = uniform_sampling if self.uniform_sampling: - i = 1 - while True: - if world_size % i != 0 or num_idx % (world_size // i) != 0: - i += 1 - else: - self.group_num = world_size // i - break - - assert self.group_num > 0 - assert world_size % self.group_num == 0 - self.group_width = world_size // self.group_num # the number of rank in one group - self.sigma_interval = self.num_idx // self.group_num + self.group_num = 1 + self.group_width = 1 + self.sigma_interval = self.num_idx def idx_to_sigma(self, idx): return self.sigmas[idx] def __call__(self, n_samples, rand=None, return_idx=False): if self.uniform_sampling: - rank = mpu.get_data_parallel_rank() - group_index = rank // self.group_width + group_index = 0 idx = default( rand, torch.randint( diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py index dcd20e8e7f..82a9dff8f2 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/util.py @@ -1,14 +1,3 @@ -""" -adopted from -https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -and -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -and -https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py - -thanks! -""" - import math from typing import Optional @@ -18,13 +7,13 @@ from einops import rearrange, repeat def make_beta_schedule( - schedule, - n_timestep, - linear_start=1e-4, - linear_end=2e-2, + schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, ): if schedule == "linear": - betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2 + betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 return betas.numpy() @@ -67,13 +56,13 @@ def mixed_checkpoint(func, inputs: dict, params, flag): class MixedCheckpointFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - run_function, - length_tensors, - length_non_tensors, - tensor_keys, - non_tensor_keys, - *args, + ctx, + run_function, + length_tensors, + length_non_tensors, + tensor_keys, + non_tensor_keys, + *args, ): ctx.end_tensors = length_tensors ctx.end_non_tensors = length_tensors + length_non_tensors @@ -86,10 +75,10 @@ class MixedCheckpointFunction(torch.autograd.Function): ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))} ctx.input_non_tensors = { - key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])) + key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors: ctx.end_non_tensors])) } ctx.run_function = run_function - ctx.input_params = list(args[ctx.end_non_tensors :]) + ctx.input_params = list(args[ctx.end_non_tensors:]) with torch.no_grad(): output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors) @@ -117,10 +106,10 @@ class MixedCheckpointFunction(torch.autograd.Function): del ctx.input_params del output_tensors return ( - (None, None, None, None, None) - + input_grads[: ctx.end_tensors] - + (None,) * (ctx.end_non_tensors - ctx.end_tensors) - + input_grads[ctx.end_tensors :] + (None, None, None, None, None) + + input_grads[: ctx.end_tensors] + + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + + input_grads[ctx.end_tensors:] ) @@ -282,10 +271,10 @@ class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] def __init__( - self, - alpha: float, - merge_strategy: str = "learned_with_images", - rearrange_pattern: str = "b t -> (b t) 1 1", + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", ): super().__init__() self.merge_strategy = merge_strategy @@ -318,10 +307,10 @@ class AlphaBlender(nn.Module): return alpha def forward( - self, - x_spatial: torch.Tensor, - x_temporal: torch.Tensor, - image_only_indicator: Optional[torch.Tensor] = None, + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: alpha = self.get_alpha(image_only_indicator) x = alpha.to(x_spatial.dtype) * x_spatial + (1.0 - alpha).to(x_spatial.dtype) * x_temporal diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py index a85159d20f..b4fd8ccec1 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/encoders/modules.py @@ -1,25 +1,15 @@ -import math -from contextlib import nullcontext -from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union -import kornia import numpy as np import torch import torch.nn as nn -from einops import rearrange, repeat from omegaconf import ListConfig -from torch.utils.checkpoint import checkpoint from transformers import ( T5EncoderModel, T5Tokenizer, ) from ...util import ( - append_dims, - autocast, - count_params, - default, disabled_train, expand_dims_like, instantiate_from_config, @@ -126,12 +116,12 @@ class GeneralConditioner(nn.Module): return batch def get_single_embedding( - self, - embedder, - batch, - output, - cond_or_not: Optional[np.ndarray] = None, - force_zero_embeddings: Optional[List] = None, + self, + embedder, + batch, + output, + cond_or_not: Optional[np.ndarray] = None, + force_zero_embeddings: Optional[List] = None, ): with torch.no_grad(): if hasattr(embedder, "input_key") and (embedder.input_key is not None): @@ -153,19 +143,20 @@ class GeneralConditioner(nn.Module): if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: if cond_or_not is None: emb = ( - expand_dims_like( - torch.bernoulli((1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), - emb, - ) - * emb + expand_dims_like( + torch.bernoulli( + (1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)), + emb, + ) + * emb ) else: emb = ( - expand_dims_like( - torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device), - emb, - ) - * emb + expand_dims_like( + torch.tensor(1 - cond_or_not, dtype=emb.dtype, device=emb.device), + emb, + ) + * emb ) if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings: emb = torch.zeros_like(emb) @@ -229,12 +220,12 @@ class FrozenT5Embedder(AbstractEmbModel): """Uses the T5 transformer encoder for text""" def __init__( - self, - model_dir="google/t5-v1_1-xxl", - device="npu", - max_length=77, - freeze=True, - cache_dir=None, + self, + model_dir="google/t5-v1_1-xxl", + device="npu", + max_length=77, + freeze=True, + cache_dir=None, ): super().__init__() if model_dir != "google/t5-v1_1-xxl": diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py index bb6c26b378..d742ffffda 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/video_attention.py @@ -1,5 +1,5 @@ -import torch - +from einops import rearrange, repeat +from typing import Optional from ..modules.attention import * from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding @@ -14,21 +14,21 @@ class TimeMixSequential(nn.Sequential): class VideoTransformerBlock(nn.Module): def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - timesteps=None, - ff_in=False, - inner_dim=None, - attn_mode="softmax", - disable_self_attn=False, - disable_temporal_crossattention=False, - switch_temporal_ca_to_sa=False, + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + timesteps=None, + ff_in=False, + inner_dim=None, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + switch_temporal_ca_to_sa=False, ): super().__init__() self.ff_in = ff_in or inner_dim is not None @@ -134,27 +134,27 @@ str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bflo class SpatialVideoTransformer(SpatialTransformer): def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - use_linear=False, - context_dim=None, - use_spatial_context=False, - timesteps=None, - merge_strategy: str = "fixed", - merge_factor: float = 0.5, - time_context_dim=None, - ff_in=False, - checkpoint=False, - time_depth=1, - attn_mode="softmax", - disable_self_attn=False, - disable_temporal_crossattention=False, - max_time_embed_period: int = 10000, - dtype="fp32", + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + dtype="fp32", ): super().__init__( in_channels, @@ -217,12 +217,12 @@ class SpatialVideoTransformer(SpatialTransformer): self.dtype = str_to_dtype[dtype] def forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - time_context: Optional[torch.Tensor] = None, - timesteps: Optional[int] = None, - image_only_indicator: Optional[torch.Tensor] = None, + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, ) -> torch.Tensor: _, _, h, w = x.shape x_in = x diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py index b549bcd621..30760bd3e7 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/util.py @@ -13,7 +13,7 @@ from safetensors.torch import load_file as load_safetensors class SafeConv3d(torch.nn.Conv3d): def forward(self, input): - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024 ** 3 if memory_count > 2: # print(f"WARNING: Conv3d with {memory_count:.2f}GB") kernel_size = self.kernel_size[0] @@ -21,7 +21,7 @@ class SafeConv3d(torch.nn.Conv3d): input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW if kernel_size > 1: input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:], input_chunks[i]), dim=2) for i in range(1, len(input_chunks)) ] @@ -75,9 +75,9 @@ def is_power_of_two(n): def autocast(f, enabled=True): def do_autocast(*args, **kwargs): with torch.npu.amp.autocast( - enabled=enabled, - dtype=torch.get_autocast_gpu_dtype(), - cache_enabled=torch.is_autocast_cache_enabled(), + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), ): return f(*args, **kwargs) @@ -102,7 +102,7 @@ def log_txt_as_img(wh, xc, size=10): text_seq = xc[bi][0] else: text_seq = xc[bi] - lines = "\n".join(text_seq[start : start + nc] for start in range(0, len(text_seq), nc)) + lines = "\n".join(text_seq[start: start + nc] for start in range(0, len(text_seq), nc)) try: draw.text((0, 0), lines, fill="black", font=font) @@ -303,7 +303,7 @@ class SeededNoise: self.weights = weights weight_square_sum = 0 for weight in weights: - weight_square_sum += weight**2 + weight_square_sum += weight ** 2 self.weight_square_sum_sqrt = sqrt(weight_square_sum) self.cnt = 0 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/cache_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/cache_mgr.py deleted file mode 100644 index 0cc69f00c3..0000000000 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/cache_mgr.py +++ /dev/null @@ -1,174 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, asdict -from typing import Union - -import torch -import torch.nn as nn - - -class CacheConfig(): - def __init__(self, method=None): - self.method = method - - -class CacheAgentConfig(CacheConfig): - """ - The DitCache Config. - """ - def __init__(self, policy_dir: str): - """ - Args: - policy_dir: The file containing the policy. - """ - super().__init__(method="CacheAgent") - self.policy_dir = policy_dir - - -class DitCacheConfig(CacheConfig): - """ - The DitCache Config. - """ - def __init__(self, step_start: int, step_interval: int, block_start: int, num_blocks: int): - """ - Args: - step_start: The starting step for caching. - step_interval: The interval at which caching should occur. - block_start: The starting block index for caching. - num_blocks: The number of blocks to cache. - """ - super().__init__(method="DitCache") - self.step_start = step_start - self.step_interval = step_interval - self.block_start = block_start - self.num_blocks = num_blocks - - -class CacheManager: - """ - The CacheManager class is interface to manage the cache algorithm. - """ - def __init__( - self, - config:CacheConfig - ): - """ - Args: - config: The configuration for the cache algorithm. - """ - if isinstance(config, CacheConfig): - self.method = config.method - self.cache_cls = Cache_cls[self.method](**vars(config)) - - def __call__(self, block, time_step, block_idx, hidden_states, *args, **kwargs - ): - """ - Args: - block: The block in the DiT module. - time_step: The current time step. - block_idx: The index of the block. - hidden_states: The hidden states. - *args: Additional arguments. - **kwargs: Additional keyword arguments. - """ - if not self._use_cache(time_step, block_idx): - old_hidden_states = hidden_states - if isinstance(block, list): - for blk in block: - hidden_states = blk(hidden_states, *args, **kwargs) - else: - hidden_states = block(hidden_states, *args, **kwargs) - self._update_cache(hidden_states, old_hidden_states, time_step, block_idx) - else: - hidden_states += self._get_cache(time_step, block_idx) - return hidden_states - - def _use_cache(self, time_step, block_idx): - return self.cache_cls.use_cache(time_step, block_idx) - - def _get_cache(self, time_step, block_idx): - return self.cache_cls.get_cache(time_step, block_idx) - - def _update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): - self.cache_cls.update_cache(hidden_states, old_hidden_states, time_step, block_idx) - - -class CacheAgent(): - def __init__(self, policy_dir, **kwargs): - self.policy_dir = policy_dir - self.cache = [None for _ in range(32)] - - self.policy = torch.load(policy_dir) - - def use_cache(self, time_step, block_idx): - if time_step == 0: - return False - else: - return self.policy[time_step - 1, block_idx] - - def get_cache(self, time_step, block_idx): - return self.cache[block_idx] - - def update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): - delta = hidden_states - old_hidden_states - self.cache[block_idx] = delta - - -class DitCache: - def __init__(self, step_start, step_interval, block_start, num_blocks, **kwargs): - self.step_start = step_start - self.step_interval = step_interval - self.block_start = block_start - self.num_blocks = num_blocks - self.block_end = block_start + num_blocks - 1 - self.cache = None - self.time_cache = {} - - def use_cache(self, time_step, block_idx): - if time_step < self.step_start: - return False - else: - diftime = time_step - self.step_start - if diftime not in self.time_cache: - self.time_cache[diftime] = diftime % self.step_interval == 0 - if self.time_cache[diftime]: - return False - elif block_idx < self.block_start or block_idx > self.block_end: - return False - else: - return True - - def get_cache(self, time_step, block_idx): - if block_idx == self.block_start: - return self.cache - else: - return 0 - - def update_cache(self, hidden_states, old_hidden_states, time_step, block_idx): - diftime = time_step - self.step_start - # when (time_step - self.step_start) % self.step_interval == 0: - if time_step >= self.step_start and self.time_cache[diftime]: - if block_idx == self.block_start: - self.cache = old_hidden_states - elif block_idx == self.block_end: - self.cache = hidden_states - self.cache - - -Cache_cls = { - "CacheAgent" : CacheAgent, - "DitCache" : DitCache -} diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py index 0adaccdd1d..afb14d1e22 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/group_coordinator.py @@ -20,7 +20,7 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) def _split_tensor_dict( - tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced @@ -98,10 +98,10 @@ class GroupCoordinator: device_group: ProcessGroup # group for device communication def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], ): self.rank = torch.distributed.get_rank() @@ -207,14 +207,14 @@ class GroupCoordinator: return input_ def all_gather( - self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False ) -> Union[torch.Tensor, List[torch.Tensor]]: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert ( - -input_.dim() <= dim < input_.dim() + -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. @@ -260,7 +260,7 @@ class GroupCoordinator: if world_size == 1: return input_ assert ( - -input_.dim() <= dim < input_.dim() + -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. @@ -320,7 +320,7 @@ class GroupCoordinator: return recv[0] def broadcast_object_list( - self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None ): """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. @@ -370,7 +370,7 @@ class GroupCoordinator: assert src < self.world_size, f"Invalid src rank ({src})" assert ( - src != self.rank + src != self.rank ), "Invalid source rank. Source rank is the same as the current rank." size_tensor = torch.empty(1, dtype=torch.long, device="cpu") @@ -392,7 +392,7 @@ class GroupCoordinator: ) assert ( - rank_object == rank_size + rank_object == rank_size ), "Received object sender rank does not match the size sender rank." obj = pickle.loads(object_tensor.numpy().tobytes()) @@ -400,11 +400,11 @@ class GroupCoordinator: return obj def broadcast_tensor_dict( - self, - tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, - src: int = 0, - group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None, + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. @@ -480,9 +480,9 @@ class GroupCoordinator: return tensor_dict def send_tensor_dict( - self, - tensor_dict: Dict[str, Union[torch.Tensor, Any]], - dst: Optional[int] = None, + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Send the input tensor dictionary. NOTE: `dst` is the local rank of the source rank. @@ -522,7 +522,7 @@ class GroupCoordinator: return None def recv_tensor_dict( - self, src: Optional[int] = None + self, src: Optional[int] = None ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: """Recv the input tensor dictionary. NOTE: `src` is the local rank of the source rank. @@ -586,7 +586,7 @@ class GroupCoordinator: ) def recv( - self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None ) -> torch.Tensor: """Receives a tensor from the src rank.""" """NOTE: `src` is the rank_in_group of the source rank.""" @@ -616,11 +616,11 @@ class GroupCoordinator: class SequenceParallelGroupCoordinator(GroupCoordinator): def __init__( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - **kwargs, + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, ): super().__init__( group_ranks=group_ranks, @@ -628,4 +628,4 @@ class SequenceParallelGroupCoordinator(GroupCoordinator): torch_distributed_backend=torch_distributed_backend, ) self.ulysses_world_size = self.ring_world_size = 1 - self.ulysses_rank = self.ring_rank = 0 \ No newline at end of file + self.ulysses_rank = self.ring_rank = 0 diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py index f86c2ae45a..8a13525aca 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py @@ -22,7 +22,7 @@ from typing import List, Optional from dataclasses import dataclass import torch.distributed as dist import torch_npu -from .utils import RankGenerator, generate_masked_orthogonal_rank_groups +from .utils import RankGenerator from .group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator from diffusers.utils import logging @@ -43,6 +43,7 @@ def is_vae_parallel_initialized(): else: return True + def initialize_vae_parallel(): global _VAE_PARALLEL_GROUP global _VAE_PARALLEL_SIZE @@ -54,16 +55,19 @@ def initialize_vae_parallel(): group = dist.new_group(ranks) _VAE_PARALLEL_GROUP = group + def get_vae_parallel_group(): assert _VAE_PARALLEL_GROUP is not None, "vae parallel group is not initialized" return _VAE_PARALLEL_GROUP + def get_vae_parallel_world_size(): if not is_vae_parallel_initialized(): return 1 return _VAE_PARALLEL_SIZE + def get_vae_parallel_rank(): if not is_vae_parallel_initialized(): return 0 @@ -71,6 +75,7 @@ def get_vae_parallel_rank(): vae_rank = rank % _VAE_PARALLEL_SIZE return vae_rank + def get_vae_parallel_group_rank(): if not is_vae_parallel_initialized(): return 0 @@ -95,7 +100,7 @@ class ParallelConfig: "world_size because of classifier free guidance" ) assert ( - self.world_size % (self.sp_degree * self.cfg_degree) == 0 + self.world_size % (self.sp_degree * self.cfg_degree) == 0 ), "world_size must be divisible by sp_degree * cfg_degree" @@ -104,15 +109,18 @@ def get_world_group() -> GroupCoordinator: assert _WORLD is not None, "world group is not initialized" return _WORLD + # SP def get_sp_group() -> SequenceParallelGroupCoordinator: assert _SP is not None, "pipeline model parallel group is not initialized" return _SP + def get_sequence_parallel_state(): """Return state for the sequence parallel group.""" return _SP is not None + def get_sequence_parallel_world_size(): """Return world size for the sequence parallel group.""" if not get_sequence_parallel_state(): @@ -126,17 +134,20 @@ def get_sequence_parallel_rank(): return 0 return get_sp_group().rank_in_group + # CFG def get_cfg_group() -> GroupCoordinator: assert ( - _CFG is not None + _CFG is not None ), "classifier_free_guidance parallel group is not initialized" return _CFG + def get_cfg_state(): """Return state for the sequence parallel group.""" return _CFG is not None + def get_classifier_free_guidance_world_size(): """Return world size for the classifier_free_guidance parallel group.""" if not get_cfg_state(): @@ -147,12 +158,12 @@ def get_classifier_free_guidance_world_size(): def get_classifier_free_guidance_rank(): """Return my rank for the classifier_free_guidance parallel group.""" if not get_cfg_state(): - return 0 + return 0 return get_cfg_group().rank_in_group def init_world_group( - ranks: List[int], local_rank: int, backend: str + ranks: List[int], local_rank: int, backend: str ) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], @@ -162,11 +173,11 @@ def init_world_group( def init_distributed_environment( - world_size: int = -1, - rank: int = -1, - distributed_init_method: str = "env://", - local_rank: int = -1, - backend: str = "hccl", + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", ): logger.debug( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", @@ -205,23 +216,24 @@ def init_distributed_environment( _WORLD = init_world_group(ranks, local_rank, backend) else: assert ( - _WORLD.world_size == dist.get_world_size() + _WORLD.world_size == dist.get_world_size() ), "world group already initialized with a different world size" def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" return ( - _CFG is not None - and _SP is not None + _CFG is not None + and _SP is not None ) + def init_model_parallel_group( - group_ranks: List[List[int]], - local_rank: int, - backend: str, - parallel_mode: str, - **kwargs, + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, ) -> GroupCoordinator: assert parallel_mode in [ "sequence", @@ -241,10 +253,11 @@ def init_model_parallel_group( torch_distributed_backend=backend, ) + def initialize_model_parallel( - classifier_free_guidance_degree: int = 1, - sequence_parallel_degree: int = 1, - backend: Optional[str] = None, + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + backend: Optional[str] = None, ) -> None: """ Initialize model parallel groups. @@ -260,9 +273,9 @@ def initialize_model_parallel( backend = backend or dist.get_backend(get_world_group().device_group) if ( - world_size - != classifier_free_guidance_degree - * sequence_parallel_degree + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree ): raise RuntimeError( f"world_size ({world_size}) is not equal to " @@ -270,13 +283,13 @@ def initialize_model_parallel( f"classifier_free_guidance_degree " f"({classifier_free_guidance_degree}) x" ) - + rank_generator: RankGenerator = RankGenerator( sequence_parallel_degree, classifier_free_guidance_degree, "sp-cfg", ) - + global _CFG assert _CFG is None, "classifier_free_guidance group is already initialized" _CFG = init_model_parallel_group( @@ -289,11 +302,12 @@ def initialize_model_parallel( global _SP assert _SP is None, "sequence parallel group is already initialized" _SP = init_model_parallel_group( - group_ranks=rank_generator.get_ranks("sp"), - local_rank=get_world_group().local_rank, - backend=backend, - parallel_mode="sequence", - ) + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ) + def destroy_model_parallel(): """Set the groups to none and destroy them.""" @@ -307,6 +321,7 @@ def destroy_model_parallel(): _SP.destroy() _SP = None + def destroy_distributed_environment(): global _WORLD if _WORLD: @@ -315,6 +330,7 @@ def destroy_distributed_environment(): if dist.is_initialized(): dist.destroy_process_group() + def init_parallel_env(parallel_config: ParallelConfig): if not model_parallel_is_initialized(): logger.warning("Model parallel is not initialized, initializing...") @@ -326,7 +342,8 @@ def init_parallel_env(parallel_config: ParallelConfig): ) initialize_vae_parallel() + def finalize_parallel_env(): if model_parallel_is_initialized(): destroy_model_parallel() - destroy_distributed_environment() \ No newline at end of file + destroy_distributed_environment() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py index 395fe93120..d54a4137f7 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/sampling_optm.py @@ -24,6 +24,7 @@ class AdaStep: """ The Adastep class is designed to optimize the sampling process in a diffusion model, """ + def __init__(self, skip_thr=0.015, max_skip_steps=4, decay_ratio=0.99, device="npu", forward_value=None, step_value=None): """ @@ -38,9 +39,9 @@ class AdaStep: """ # recommand 0.01(skip around 35 steps) ~ 0.015(skip around 50 steps) - self.skip_thr = skip_thr + self.skip_thr = skip_thr # recommand 3 ~ 4 - self.max_skip_steps = max_skip_steps + self.max_skip_steps = max_skip_steps # recommand 0.95 ~ 0.99 self.decay_ratio = decay_ratio self.device = device @@ -61,19 +62,19 @@ class AdaStep: """ if self.if_skip and torch.all(self.skip_vote): return self._return_output(self.skip_noise_pred, self.forwardretype) - + noise_pred = transformer(*model_args, **model_kwargs) - if not self.forwardretype: + if not self.forwardretype: if isinstance(noise_pred, tuple): self.forwardretype = tuple elif isinstance(noise_pred, torch.Tensor): self.forwardretype = torch.Tensor else: - raise(ValueError, "transformer need return a tuple which the first element is the result" - "or a Tensor. Otherwith please input the `forward_value`") + raise (ValueError, "transformer need return a tuple which the first element is the result" + "or a Tensor. Otherwith please input the `forward_value`") self.skip_noise_pred = self._get_input(noise_pred, self.forwardretype) return noise_pred - + @staticmethod def _get_input(input_value, inp_type): if isinstance(inp_type, type): @@ -83,7 +84,7 @@ class AdaStep: return input_value else: return input_value[inp_type] - + @staticmethod def _return_output(output, outptype): if isinstance(outptype, type): @@ -92,10 +93,10 @@ class AdaStep: else: return output elif isinstance(outptype, str): - return {outptype : output} + return {outptype: output} else: - return (0,) * outptype + (output,) - + return (0,) * outptype + (output,) + def set_param(self, skip_thr=None, max_skip_steps=None, decay_ratio=None, device=None): """ To set the parameters of the AdaStep class. @@ -103,7 +104,7 @@ class AdaStep: self.skip_thr = skip_thr or self.skip_thr self.max_skip_steps = max_skip_steps or self.max_skip_steps self.decay_ratio = decay_ratio or self.decay_ratio - if device : + if device: self.device = device self.skip_vote.to(self.device) self.if_skip = self.max_skip_steps > 0 @@ -128,8 +129,8 @@ class AdaStep: elif isinstance(latents, torch.Tensor): self.stepretype = torch.Tensor else: - raise(ValueError, "step need return a tuple which the first element is the result" - "or a Tensor. Otherwith please input the `step_value`") + raise (ValueError, "step need return a tuple which the first element is the result" + "or a Tensor. Otherwith please input the `step_value`") if self.if_skip: latents = self._get_input(latents, self.stepretype) diff = latents - self.skip_prev_latents @@ -137,12 +138,12 @@ class AdaStep: if len(self.skip_latents_diff) >= 3: self.skip_mask.append(self._estimate()) - self.skip_prev_latents = latents + self.skip_prev_latents = latents mask_value = self.skip_mask[-1] mask_value = torch.tensor([mask_value], dtype=torch.bool, device=self.device) if sequence_parallel: - skip_vote = torch.zeros(torch.distributed.get_world_size(sp_group), + skip_vote = torch.zeros(torch.distributed.get_world_size(sp_group), dtype=torch.bool, device=self.device) torch.distributed.all_gather_into_tensor(skip_vote, mask_value, group=sp_group) else: @@ -160,17 +161,17 @@ class AdaStep: self.skip_thr = self.skip_thr * self.decay_ratio if len(self.skip_mask) >= self.max_skip_steps and \ - all(self.skip_mask[-self.max_skip_steps:]): + all(self.skip_mask[-self.max_skip_steps:]): return False if abs((cur_diff + prev_prev_diff) / 2 - prev_diff) <= prev_diff * self.skip_thr: return True return False - + class SamplingOptm: - def __init__(self, pipe, dit_forward="transformer.forward", scheduler_step="scheduler.step", - forward_value=None, step_value=None, parallel=False, group=None, config=None): + def __init__(self, pipe, dit_forward="transformer.forward", scheduler_step="scheduler.step", + forward_value=None, step_value=None, parallel=False, group=None, config=None): self.parallel = parallel self.group = group self.skip = False @@ -180,11 +181,11 @@ class SamplingOptm: schedulerstep_lst = scheduler_step.split(".") self.pipe = pipe - self.ori_forward = reduce(getattr, ditforward_lst, self.pipe) # getattr(self.pipe, )dit_forward.split(".") + self.ori_forward = reduce(getattr, ditforward_lst, self.pipe) # getattr(self.pipe, )dit_forward.split(".") self.forward = ditforward_lst.pop() self.ori_dit = reduce(getattr, ditforward_lst, self.pipe) - self.ori_step = reduce(getattr, schedulerstep_lst, self.pipe) + self.ori_step = reduce(getattr, schedulerstep_lst, self.pipe) self.step = schedulerstep_lst.pop() self.ori_scheduler = reduce(getattr, schedulerstep_lst, self.pipe) @@ -208,6 +209,7 @@ class SamplingOptm: def _optm_forward(*args, **kwargs): noise_pred = self.skip_strategy(self.ori_forward, *args, **kwargs) return noise_pred + setattr(self.ori_dit, self.forward, _optm_forward) def _sub_step(self): @@ -216,11 +218,11 @@ class SamplingOptm: latents = self.ori_step(*args, **kwargs) self.skip_strategy.update_strategy(latents, self.parallel, self.group) return latents + setattr(self.ori_scheduler, self.step, _optm_step) - + def _revert_forward(self): setattr(self.ori_dit, self.forward, self.ori_forward) def _revert_step(self): setattr(self.ori_scheduler, self.step, self.ori_step) - diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py index ca1cae0436..f7e638880f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/utils.py @@ -2,7 +2,7 @@ from typing import List def generate_masked_orthogonal_rank_groups( - world_size: int, parallel_size: List[int], mask: List[bool] + world_size: int, parallel_size: List[int], mask: List[bool] ) -> List[List[int]]: """Generate orthogonal parallel groups based on the parallel size and mask. @@ -48,7 +48,7 @@ def generate_masked_orthogonal_rank_groups( # stride is a prefix_product result. And the value of stride[-1] # is not used. assert ( - sum([x * y for x, y in zip(idx, stride[:-1])]) == index + sum([x * y for x, y in zip(idx, stride[:-1])]) == index ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) return idx @@ -80,11 +80,11 @@ def generate_masked_orthogonal_rank_groups( class RankGenerator(object): def __init__( - self, - sp: int, - cfg: int, - order: str, - rank_offset: int = 0, + self, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, ) -> None: self.sp = sp self.cfg = cfg diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py index b9a69ff860..e68644b1ec 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py @@ -1,17 +1,13 @@ import logging import math import re -import random from abc import abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import pytorch_lightning as pl import torch import torch.distributed as dist -import torch.nn as nn -from einops import rearrange from packaging import version from vae_modules.ema import LitEma @@ -23,7 +19,6 @@ from utils.parallel_mgr import ( get_vae_parallel_group_rank, ) - from sgm.util import ( instantiate_from_config, get_obj_from_str, @@ -42,10 +37,10 @@ class AbstractAutoencoder(pl.LightningModule): """ def __init__( - self, - ema_decay: Union[None, float] = None, - monitor: Union[None, str] = None, - input_key: str = "jpg", + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", ): super().__init__() @@ -127,24 +122,24 @@ class AutoencodingEngine(AbstractAutoencoder): """ def __init__( - self, - *args, - encoder_config: Dict, - decoder_config: Dict, - loss_config: Dict, - regularizer_config: Dict, - optimizer_config: Union[Dict, None] = None, - lr_g_factor: float = 1.0, - trainable_ae_params: Optional[List[List[str]]] = None, - ae_optimizer_args: Optional[List[dict]] = None, - trainable_disc_params: Optional[List[List[str]]] = None, - disc_optimizer_args: Optional[List[dict]] = None, - disc_start_iter: int = 0, - diff_boost_factor: float = 3.0, - ckpt_engine: Union[None, str, dict] = None, - ckpt_path: Optional[str] = None, - additional_decode_keys: Optional[List[str]] = None, - **kwargs, + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + regularizer_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + trainable_ae_params: Optional[List[List[str]]] = None, + ae_optimizer_args: Optional[List[dict]] = None, + trainable_disc_params: Optional[List[List[str]]] = None, + disc_optimizer_args: Optional[List[dict]] = None, + disc_start_iter: int = 0, + diff_boost_factor: float = 3.0, + ckpt_engine: Union[None, str, dict] = None, + ckpt_path: Optional[str] = None, + additional_decode_keys: Optional[List[str]] = None, + **kwargs, ): super().__init__(*args, **kwargs) self.automatic_optimization = False # pytorch lightning @@ -181,10 +176,10 @@ class AutoencodingEngine(AbstractAutoencoder): return self.decoder.get_last_layer() def encode( - self, - x: torch.Tensor, - return_reg_log: bool = False, - unregularized: bool = False, + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: z = self.encoder(x) if unregularized: @@ -204,7 +199,7 @@ class AutoencodingEngine(AbstractAutoencoder): return z, dec, reg_log def get_param_groups( - self, parameter_names: List[List[str]], optimizer_args: List[dict] + self, parameter_names: List[List[str]], optimizer_args: List[dict] ) -> Tuple[List[Dict[str, Any]], int]: groups = [] num_params = 0 @@ -265,7 +260,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine): n_batches = int(math.ceil(N / bs)) z = list() for i_batch in range(n_batches): - z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.encoder(x[i_batch * bs: (i_batch + 1) * bs]) z_batch = self.quant_conv(z_batch) z.append(z_batch) z = torch.cat(z, 0) @@ -285,7 +280,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine): n_batches = int(math.ceil(N / bs)) dec = list() for i_batch in range(n_batches): - dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.post_quant_conv(z[i_batch * bs: (i_batch + 1) * bs]) dec_batch = self.decoder(dec_batch, **decoder_kwargs) dec.append(dec_batch) dec = torch.cat(dec, 0) @@ -309,13 +304,13 @@ class IdentityFirstStage(AbstractAutoencoder): class VideoAutoencodingEngine(AutoencodingEngine): def __init__( - self, - ckpt_path: Union[None, str] = None, - ignore_keys: Union[Tuple, list] = (), - image_video_weights=[1, 1], - only_train_decoder=False, - context_parallel_size=0, - **kwargs, + self, + ckpt_path: Union[None, str] = None, + ignore_keys: Union[Tuple, list] = (), + image_video_weights=[1, 1], + only_train_decoder=False, + context_parallel_size=0, + **kwargs, ): super().__init__(**kwargs) self.context_parallel_size = context_parallel_size @@ -361,21 +356,21 @@ class VideoAutoencodingEngine(AutoencodingEngine): class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): def __init__( - self, - cp_size=0, - *args, - **kwargs, + self, + cp_size=0, + *args, + **kwargs, ): self.cp_size = cp_size return super().__init__(*args, **kwargs) def encode( - self, - x: torch.Tensor, - return_reg_log: bool = False, - unregularized: bool = False, - input_cp: bool = False, - output_cp: bool = False, + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + input_cp: bool = False, + output_cp: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.cp_size > 0 and not input_cp: if not is_vae_parallel_initialized(): @@ -399,12 +394,12 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): return z def decode( - self, - z: torch.Tensor, - input_cp: bool = False, - output_cp: bool = False, - split_kernel_size: int = 1, - **kwargs, + self, + z: torch.Tensor, + input_cp: bool = False, + output_cp: bool = False, + split_kernel_size: int = 1, + **kwargs, ): if self.cp_size > 0 and not input_cp: if not is_vae_parallel_initialized(): @@ -423,12 +418,12 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): return x def forward( - self, - x: torch.Tensor, - input_cp: bool = False, - latent_cp: bool = False, - output_cp: bool = False, - **additional_decode_kwargs, + self, + x: torch.Tensor, + input_cp: bool = False, + latent_cp: bool = False, + output_cp: bool = False, + **additional_decode_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, dict]: z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp) dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py index 32b926a1c7..968cd50d65 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py @@ -5,7 +5,6 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -from beartype import beartype from beartype.typing import Union, Tuple, Optional, List from einops import rearrange @@ -146,8 +145,8 @@ def _conv_split(input_, dim, kernel_size): else: # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) output = input_.transpose(dim, 0)[ - cp_rank * dim_size + kernel_size : (cp_rank + 1) * dim_size + kernel_size - ].transpose(dim, 0) + cp_rank * dim_size + kernel_size: (cp_rank + 1) * dim_size + kernel_size + ].transpose(dim, 0) output = output.contiguous() # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) @@ -171,7 +170,7 @@ def _conv_gather(input_, dim, kernel_size): if cp_rank == 0: input_ = input_.transpose(0, dim)[kernel_size:].transpose(0, dim).contiguous() else: - input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0) :].transpose(0, dim).contiguous() + input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0):].transpose(0, dim).contiguous() tensor_list = [torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))] + [ torch.empty_like(input_) for _ in range(cp_world_size - 1) @@ -216,9 +215,9 @@ def _pass_from_previous_rank(input_, dim, kernel_size): recv_rank += cp_world_size if cp_rank < cp_world_size - 1: - req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) + req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) if cp_rank > 0: - recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() + recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) if cp_rank == 0: @@ -263,9 +262,9 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non # recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() # req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group) # req_recv.wait() - recv_buffer = torch.empty_like(input_[-kernel_size + 1 :]).contiguous() + recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous() if cp_rank < cp_world_size - 1: - req_send = torch.distributed.isend(input_[-kernel_size + 1 :].contiguous(), send_rank, group=group) + req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) if cp_rank > 0: req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group) # req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group) @@ -285,7 +284,7 @@ def _fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding=Non def _drop_from_previous_rank(input_, dim, kernel_size): - input_ = input_.transpose(0, dim)[kernel_size - 1 :].transpose(0, dim) + input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(0, dim) return input_ @@ -402,17 +401,17 @@ class ContextParallelCausalConv3d(nn.Module): global_rank = torch.distributed.get_rank() if cp_world_size == 1: self.cache_padding = ( - input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu() + input_parallel[:, :, -self.time_kernel_size + 1:].contiguous().detach().clone().cpu() ) else: if cp_rank == cp_world_size - 1: torch.distributed.isend( - input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous(), + input_parallel[:, :, -self.time_kernel_size + 1:].contiguous(), global_rank + 1 - cp_world_size, group=get_vae_parallel_group(), ) if cp_rank == 0: - recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1 :]).contiguous() + recv_buffer = torch.empty_like(input_parallel[:, :, -self.time_kernel_size + 1:]).contiguous() torch.distributed.recv( recv_buffer, global_rank - 1 + cp_world_size, group=get_vae_parallel_group() ) @@ -446,14 +445,14 @@ def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D class SpatialNorm3D(nn.Module): def __init__( - self, - f_channels, - zq_channels, - freeze_norm_layer=False, - add_conv=False, - pad_mode="constant", - gather=False, - **norm_layer_params, + self, + f_channels, + zq_channels, + freeze_norm_layer=False, + add_conv=False, + pad_mode="constant", + gather=False, + **norm_layer_params, ): super().__init__() if gather: @@ -507,10 +506,10 @@ class SpatialNorm3D(nn.Module): def Normalize3D( - in_channels, - zq_ch, - add_conv, - gather=False, + in_channels, + zq_ch, + add_conv, + gather=False, ): return SpatialNorm3D( in_channels, @@ -526,10 +525,10 @@ def Normalize3D( class Upsample3D(nn.Module): def __init__( - self, - in_channels, - with_conv, - compress_time=False, + self, + in_channels, + with_conv, + compress_time=False, ): super().__init__() self.with_conv = with_conv @@ -609,17 +608,17 @@ class DownSample3D(nn.Module): class ContextParallelResnetBlock3D(nn.Module): def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - zq_ch=None, - add_conv=False, - gather_norm=False, - normalization=Normalize, + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + zq_ch=None, + add_conv=False, + gather_norm=False, + normalization=Normalize, ): super().__init__() self.in_channels = in_channels @@ -711,23 +710,23 @@ class ContextParallelResnetBlock3D(nn.Module): class ContextParallelEncoder3D(nn.Module): def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - pad_mode="first", - temporal_compress_times=4, - gather_norm=False, - **ignore_kwargs, + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignore_kwargs, ): super().__init__() self.ch = ch @@ -834,25 +833,25 @@ class ContextParallelEncoder3D(nn.Module): class ContextParallelDecoder3D(nn.Module): def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - zq_ch=None, - add_conv=False, - pad_mode="first", - temporal_compress_times=4, - gather_norm=False, - **ignorekwargs, + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + zq_ch=None, + add_conv=False, + pad_mode="first", + temporal_compress_times=4, + gather_norm=False, + **ignorekwargs, ): super().__init__() self.ch = ch diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py index 8cab5c195b..26402ac9fa 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/utils.py @@ -3,14 +3,14 @@ import torch class SafeConv3d(torch.nn.Conv3d): def forward(self, input): - memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3 + memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024 ** 3 if memory_count > 2: kernel_size = self.kernel_size[0] part_num = int(memory_count / 2) + 1 input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW if kernel_size > 1: input_chunks = [input_chunks[0]] + [ - torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) + torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:], input_chunks[i]), dim=2) for i in range(1, len(input_chunks)) ] @@ -20,4 +20,4 @@ class SafeConv3d(torch.nn.Conv3d): output = torch.cat(output_chunks, dim=2) return output else: - return super(SafeConv3d, self).forward(input) \ No newline at end of file + return super(SafeConv3d, self).forward(input) -- Gitee From 9866bacfd4d26d93b568fdaf605f5007fd0a6918 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 23 Jan 2025 01:43:09 +0000 Subject: [PATCH 3/9] rm profiling --- .../foundation/cogvideox-sat/sample_video.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py index cc9066079f..6bf3ba9fdb 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py @@ -206,33 +206,33 @@ def sampling_main(args, model_cls): first_stage_model = model.first_stage_model for index in range(args.batch_size): - count = 5 - experimental_config = torch_npu.profiler._ExperimentalConfig( - export_type=torch_npu.profiler.ExportType.Text, - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False - ) - with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], - # 默认CPU、NPU均打开 - with_stack=True, # 采集torch op的函数调用栈的开关,默认关闭 - record_shapes=True, # 采集torch op的input shape和input type的开关,默认关闭 - profile_memory=True, # 采集memory相关数据的开关,默认关闭 - schedule=torch_npu.profiler.schedule(wait=2, warmup=2, active=1, repeat=1), - experimental_config=experimental_config, - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prof_unet") - # 导出tensorboard可呈现的数据形式 - ) as prof: - for i in range(count): - samples_z = sample_func( - c, - uc=uc, - batch_size=1, - shape=(T, C, H // F, W // F), - ) - prof.step() + # count = 5 + # experimental_config = torch_npu.profiler._ExperimentalConfig( + # export_type=torch_npu.profiler.ExportType.Text, + # aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + # profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + # l2_cache=False, + # data_simplification=False + # ) + # with torch_npu.profiler.profile( + # activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + # # 默认CPU、NPU均打开 + # with_stack=True, # 采集torch op的函数调用栈的开关,默认关闭 + # record_shapes=True, # 采集torch op的input shape和input type的开关,默认关闭 + # profile_memory=True, # 采集memory相关数据的开关,默认关闭 + # schedule=torch_npu.profiler.schedule(wait=2, warmup=2, active=1, repeat=1), + # experimental_config=experimental_config, + # on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./prof_unet") + # # 导出tensorboard可呈现的数据形式 + # ) as prof: + # for i in range(count): + # samples_z = sample_func( + # c, + # uc=uc, + # batch_size=1, + # shape=(T, C, H // F, W // F), + # ) + # prof.step() samples_z = sample_func( c, uc=uc, -- Gitee From dbeebc0ef431b4fd908d0130705d44be67a2d44b Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 23 Jan 2025 05:59:52 +0000 Subject: [PATCH 4/9] cfg parallel --- .../foundation/cogvideox-sat/inference.sh | 2 +- .../sgm/modules/diffusionmodules/guiders.py | 55 ++++++++++++++++--- .../sgm/modules/diffusionmodules/sampling.py | 3 +- .../foundation/cogvideox-sat/utils/comm.py | 16 ------ .../cogvideox-sat/utils/parallel_mgr.py | 19 +------ 5 files changed, 51 insertions(+), 44 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh index a1ad48c40e..c324aecf24 100755 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh @@ -1,7 +1,7 @@ #! /bin/bash export TASK_QUEUE_ENABLE=2 export HCCL_OP_EXPANSION_MODE="AIV" -torchrun --nnodes=1 --nproc_per_node 1 --master_port 29503 \ +torchrun --nnodes=1 --nproc_per_node 2 --master_port 29503 \ sample_video.py \ --base configs/cogvideox_5b.yaml configs/inference.yaml \ --seed 42 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py index 7ce657c392..71bf91b072 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/guiders.py @@ -8,6 +8,8 @@ import torch from einops import rearrange, repeat from ...util import append_dims, default, instantiate_from_config +from utils.parallel_mgr import get_classifier_free_guidance_world_size, get_classifier_free_guidance_rank, \ + get_sequence_parallel_rank, get_cfg_group class Guider(ABC): @@ -36,21 +38,49 @@ class VanillaCFG: ) def __call__(self, x, sigma, scale=None): - x_u, x_c = x.chunk(2) + if get_classifier_free_guidance_world_size() == 1: + x_u, x_c = x.chunk(2) + elif get_classifier_free_guidance_world_size() == 2: + x_u, x_c = get_cfg_group().all_gather( + x, separate_tensors=True + ) + else: + raise ValueError("Invalid classifier free guidance world size") scale_value = default(scale, self.scale_schedule(sigma)) x_pred = self.dyn_thresh(x_u, x_c, scale_value) return x_pred def prepare_inputs(self, x, s, c, uc): c_out = dict() - - for k in c: - if k in ["vector", "crossattn", "concat"]: - c_out[k] = torch.cat((uc[k], c[k]), 0) + if get_classifier_free_guidance_world_size() == 1: + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + x_out = torch.cat([x] * 2) + s_out = torch.cat([s] * 2) + else: + x_out = x + s_out = s + if get_classifier_free_guidance_rank() == 0: + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = uc[k] + else: + assert c[k] == uc[k] + c_out[k] = c[k] + elif get_classifier_free_guidance_rank() == 1: + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = c[k] + else: + assert c[k] == uc[k] + c_out[k] = c[k] else: - assert c[k] == uc[k] - c_out[k] = c[k] - return torch.cat([x] * 2), torch.cat([s] * 2), c_out + raise ValueError("Invalid classifier free guidance rank") + return x_out, s_out, c_out class DynamicCFG(VanillaCFG): @@ -68,7 +98,14 @@ class DynamicCFG(VanillaCFG): ) def __call__(self, x, sigma, step_index, scale=None): - x_u, x_c = x.chunk(2) + if get_classifier_free_guidance_world_size() == 1: + x_u, x_c = x.chunk(2) + elif get_classifier_free_guidance_world_size() == 2: + x_u, x_c = get_cfg_group().all_gather( + x, separate_tensors=True + ) + else: + raise ValueError("Invalid classifier free guidance world size") scale_value = self.scale_schedule(sigma, step_index.item()) x_pred = self.dyn_thresh(x_u, x_c, scale_value) return x_pred diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py index b385ca8431..59c7d7bc59 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py @@ -14,6 +14,7 @@ from ...modules.diffusionmodules.sampling_utils import ( from ...util import append_dims, default, instantiate_from_config from .guiders import DynamicCFG +from utils.parallel_mgr import get_classifier_free_guidance_world_size DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} @@ -504,7 +505,7 @@ class VideoDDIMSampler(BaseDiffusionSampler): additional_model_inputs["scale_emb"] = scale_emb denoised = denoiser(x, alpha_cumprod_sqrt, cond, **additional_model_inputs).to(torch.float32) else: - additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * 2) + additional_model_inputs["idx"] = torch.cat([x.new_ones([x.shape[0]]) * timestep] * (2 // get_classifier_free_guidance_world_size())) denoised = denoiser( *self.guider.prepare_inputs(x, alpha_cumprod_sqrt, cond, uc), **additional_model_inputs ).to(torch.float32) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py index 33f79ad1e5..6ce7ed3eaa 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py @@ -1,21 +1,5 @@ #!/usr/bin/env python # coding=utf-8 -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. import torch import torch.distributed as dist diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py index 8a13525aca..56a5c75afd 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py @@ -1,21 +1,5 @@ #!/usr/bin/env python # coding=utf-8 -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. import os from typing import List, Optional @@ -49,7 +33,8 @@ def initialize_vae_parallel(): global _VAE_PARALLEL_SIZE assert _VAE_PARALLEL_GROUP is None, "vae parallel group is already initialized" - _VAE_PARALLEL_SIZE = dist.get_world_size() + _VAE_PARALLEL_SIZE = 1 + # _VAE_PARALLEL_SIZE = dist.get_world_size() ranks = range(_VAE_PARALLEL_SIZE) group = dist.new_group(ranks) -- Gitee From 1cee5a56fbebf580592d69ddaa89a20785323834 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 23 Jan 2025 07:09:28 +0000 Subject: [PATCH 5/9] sampling_optimize --- .../foundation/cogvideox-sat/sample_video.py | 8 +- .../sgm/modules/diffusionmodules/sampling.py | 114 +++++++++++++++--- 2 files changed, 105 insertions(+), 17 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py index 6bf3ba9fdb..7ac6afb833 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py @@ -1,5 +1,6 @@ import os import math +import time import argparse from arguments import get_args from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_state, \ @@ -156,9 +157,10 @@ def sampling_main(args, model_cls): T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, args.sampling_fps num_samples = [1] force_uc_zero_embeddings = ["txt"] - with torch.no_grad(): for text, cnt in tqdm(data_iter): + start = time.time() + torch.npu.synchronize() if args.image2video: text, image_path = text.split("@@") assert os.path.exists(image_path), image_path @@ -270,7 +272,9 @@ def sampling_main(args, model_cls): ) if rank == 0: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) - + torch.npu.synchronize() + end = time.time() + print(f"use time: {end - start}s") finalize_parallel_env() diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py index 59c7d7bc59..db0f6acf8b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sgm/modules/diffusionmodules/sampling.py @@ -3,6 +3,8 @@ from typing import Dict, Union import torch from omegaconf import ListConfig, OmegaConf from tqdm import tqdm +from mindiesd.pipeline.sampling_optm import SamplingOptm +from mindiesd.pipeline.sampling_optm import AdaStep from ...modules.diffusionmodules.sampling_utils import ( get_ancestral_step, @@ -17,7 +19,7 @@ from .guiders import DynamicCFG from utils.parallel_mgr import get_classifier_free_guidance_world_size DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} - +SAMPLING_OPTM = True class BaseDiffusionSampler: def __init__( @@ -590,6 +592,60 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): else: return mult1, mult2 + def pred_noise( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + denoiser, + x, + cond, + uc=None, + idx=None, + timestep=None, + scale=None, + scale_emb=None, + ): + noise = self.denoise( + x, denoiser, alpha_cumprod_sqrt, cond, uc, timestep, idx, scale=scale, scale_emb=scale_emb + ).to(torch.float32) + return noise + + def schedule_step( + self, + old_denoised, + previous_alpha_cumprod_sqrt, + alpha_cumprod_sqrt, + next_alpha_cumprod_sqrt, + x, + noise=None, + idx=None, + ): + if idx == 1: + return noise, noise + + h, r, lamb, lamb_next = self.get_variables( + alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt + ) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, alpha_cumprod_sqrt, next_alpha_cumprod_sqrt, previous_alpha_cumprod_sqrt) + ] + mult_noise = append_dims((1 - next_alpha_cumprod_sqrt**2) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5, x.ndim) + denoised = noise + x_standard = mult[0] * x - mult[1] * denoised + mult_noise * torch.randn_like(x) + if old_denoised is None or torch.sum(next_alpha_cumprod_sqrt) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + mult_noise * torch.randn_like(x) + + x = x_advanced + + return x, denoised + def sampler_step( self, old_denoised, @@ -640,6 +696,7 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): if self.fixed_frames > 0: prefix_frames = x[:, : self.fixed_frames] old_denoised = None + skip_strategy = AdaStep(skip_thr=0.004, max_skip_steps=1, decay_ratio=0.99, device="npu") for i in self.get_sigma_gen(num_sigmas): if self.fixed_frames > 0: if self.sdedit: @@ -650,20 +707,47 @@ class VPSDEDPMPP2MSampler(VideoDDIMSampler): x = torch.cat([noised_prefix_frames, x[:, self.fixed_frames:]], dim=1) else: x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) - x, old_denoised = self.sampler_step( - old_denoised, - None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], - s_in * alpha_cumprod_sqrt[i], - s_in * alpha_cumprod_sqrt[i + 1], - denoiser, - x, - cond, - uc=uc, - idx=self.num_steps - i, - timestep=timesteps[-(i + 1)], - scale=scale, - scale_emb=scale_emb, - ) + if SAMPLING_OPTM: + noise = skip_strategy( + self.pred_noise, + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) + x, old_denoised = self.schedule_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + x, + noise=noise, + idx=self.num_steps - i, + ) + skip_strategy.update_strategy(x) + else: + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * alpha_cumprod_sqrt[i - 1], + s_in * alpha_cumprod_sqrt[i], + s_in * alpha_cumprod_sqrt[i + 1], + denoiser, + x, + cond, + uc=uc, + idx=self.num_steps - i, + timestep=timesteps[-(i + 1)], + scale=scale, + scale_emb=scale_emb, + ) if self.fixed_frames > 0: x = torch.cat([prefix_frames, x[:, self.fixed_frames:]], dim=1) -- Gitee From da4a20c74479ed96158d5dd0c09a6ee1a6421c39 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 24 Jan 2025 02:49:33 +0000 Subject: [PATCH 6/9] sp+dp --- .../cogvideox-sat/dit_video_concat.py | 138 ++++++++++++++++-- .../foundation/cogvideox-sat/inference.sh | 2 +- .../foundation/cogvideox-sat/utils/comm.py | 2 +- 3 files changed, 124 insertions(+), 18 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py index a0678d1648..494a3786af 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/dit_video_concat.py @@ -21,7 +21,10 @@ from sgm.modules.diffusionmodules.util import ( timestep_embedding, ) from sat.ops.layernorm import LayerNorm, RMSNorm +from utils.comm import all_to_all_bsh +from utils.parallel_mgr import get_sp_group, get_sequence_parallel_world_size, get_sequence_parallel_rank +SP_PAD = 0 class ImagePatchEmbeddingMixin(BaseMixin): def __init__( @@ -334,22 +337,102 @@ class Rotary3DPositionEmbeddingMixin(BaseMixin): old_impl=attention_fn_fa_score, **kwargs, ): - - query_layer[:, :, self.text_length:] = self.rotary(query_layer[:, :, self.text_length:]) - key_layer[:, :, self.text_length:] = self.rotary(key_layer[:, :, self.text_length:]) - if self.rot_v: - value_layer[:, :, self.text_length:] = self.rotary(value_layer[:, :, self.text_length:]) - - return attention_fn_fa_score( - query_layer, - key_layer, - value_layer, - attention_mask, - attention_dropout=attention_dropout, - log_attention_weights=log_attention_weights, - scaling_attention_score=scaling_attention_score, - **kwargs, - ) + if get_sequence_parallel_world_size() > 1: + batch_size = query_layer.shape[0] + num_heads = query_layer.shape[1] + head_dim = query_layer.shape[-1] + h_size = num_heads * head_dim + sp_size = get_sequence_parallel_world_size() + h_size_sp = h_size // sp_size + query_layer = rearrange(query_layer, "b n s d -> (b s) n d") + key_layer = rearrange(key_layer, "b n s d -> (b s) n d") + value_layer = rearrange(value_layer, "b n s d -> (b s) n d") + + _, query_layer = all_to_all_bsh(query_layer, scatter_dim=1, gather_dim=0) + _, key_layer = all_to_all_bsh(key_layer, scatter_dim=1, gather_dim=0) + _, value_layer = all_to_all_bsh(value_layer, scatter_dim=1, gather_dim=0) + + query_layer = query_layer.view(batch_size, -1, h_size_sp) + key_layer = key_layer.view(batch_size, -1, h_size_sp) + value_layer = value_layer.view(batch_size, -1, h_size_sp) + query_layer = torch.chunk(query_layer, sp_size, dim=1) + text_layer = query_layer[0][:, :self.text_length] + stack = [text_layer] + for image_slice in query_layer: + stack.append(image_slice[:, self.text_length:]) + query_layer = torch.cat(stack, dim=1) + + key_layer = torch.chunk(key_layer, sp_size, dim=1) + text_layer = key_layer[0][:, :self.text_length] + stack = [text_layer] + for image_slice in key_layer: + stack.append(image_slice[:, self.text_length:]) + key_layer = torch.cat(stack, dim=1) + + value_layer = torch.chunk(value_layer, sp_size, dim=1) + text_layer = value_layer[0][:, :self.text_length] + stack = [text_layer] + for image_slice in value_layer: + stack.append(image_slice[:, self.text_length:]) + value_layer = torch.cat(stack, dim=1) + + if SP_PAD > 0: + query_layer = query_layer[:, :(-SP_PAD)] + key_layer = key_layer[:, :(-SP_PAD)] + value_layer = value_layer[:, :(-SP_PAD)] + + query_layer = rearrange(query_layer, "b s (n d) -> b n s d", n=num_heads // sp_size, d=head_dim) + key_layer = rearrange(key_layer, "b s (n d) -> b n s d", n=num_heads // sp_size, d=head_dim) + value_layer = rearrange(value_layer, "b s (n d) -> b n s d", n=num_heads // sp_size, d=head_dim) + + query_layer[:, :, self.text_length:] = self.rotary(query_layer[:, :, self.text_length:]) + key_layer[:, :, self.text_length:] = self.rotary(key_layer[:, :, self.text_length:]) + if self.rot_v: + value_layer[:, :, self.text_length:] = self.rotary(value_layer[:, :, self.text_length:]) + attention_output = attention_fn_fa_score( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + if SP_PAD > 0: + padding_shape = (*attention_output.shape[:2], SP_PAD, attention_output.shape[3]) + padding = torch.zeros(padding_shape, dtype=attention_output.dtype, device=attention_output.device) + attention_output = torch.cat([attention_output, padding], dim=2) + text_layer = attention_output[:, :, :self.text_length] + image_layer = attention_output[:, :, self.text_length:] + image_layer_ls = torch.chunk(image_layer, sp_size, dim=2) + stack = [] + for i in range(sp_size): + stack.append(text_layer.clone()) + stack.append(image_layer_ls[i]) + attention_output = torch.cat(stack, dim=2) + attention_output = rearrange(attention_output, "b n s d -> (b s) n d") + attention_output = attention_output.view(-1, num_heads // sp_size, head_dim).contiguous() + _, attention_output = all_to_all_bsh(attention_output, scatter_dim=0, gather_dim=1) + attention_output = attention_output.view(batch_size, -1, h_size) + attention_output = rearrange(attention_output, "b s (n d) -> b n s d", n=num_heads, d=head_dim) + else: + query_layer[:, :, self.text_length:] = self.rotary(query_layer[:, :, self.text_length:]) + key_layer[:, :, self.text_length:] = self.rotary(key_layer[:, :, self.text_length:]) + if self.rot_v: + value_layer[:, :, self.text_length:] = self.rotary(value_layer[:, :, self.text_length:]) + + attention_output = attention_fn_fa_score( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + return attention_output def modulate(x, shift, scale): @@ -521,6 +604,18 @@ class AdaLNMixin(BaseMixin): ) # self full attention (b,(t n),d) + if get_sequence_parallel_world_size() > 1: + global SP_PAD + sp_world_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + img_length = img_hidden_states.shape[1] + SP_PAD = math.ceil(img_length / sp_world_size) * sp_world_size - img_length + if SP_PAD > 0: + padding_shape = (img_hidden_states.shape[0], SP_PAD, *img_hidden_states.shape[2:]) + padding = torch.zeros(padding_shape, dtype=img_hidden_states.dtype, device=img_hidden_states.device) + img_hidden_states = torch.cat([img_hidden_states, padding], dim=1) + img_hidden_states = torch.chunk(img_hidden_states, sp_world_size, dim=1)[sp_rank] + img_attention_input = layer.input_layernorm(img_hidden_states) text_attention_input = layer.input_layernorm(text_hidden_states) img_attention_input = modulate(img_attention_input, shift_msa, scale_msa) @@ -553,6 +648,17 @@ class AdaLNMixin(BaseMixin): text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d) hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d) + if get_sequence_parallel_world_size() > 1: + sp_world_size = get_sequence_parallel_world_size() + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) + text_hidden_states = hidden_states[:, :text_length] + stack = [text_hidden_states] + hidden_states_ls = torch.chunk(hidden_states, sp_world_size, dim=1) + for hidden_states_slice in hidden_states_ls: + stack.append(hidden_states_slice[:, text_length:]) + hidden_states = torch.cat(stack, dim=1) + if SP_PAD > 0: + hidden_states = hidden_states[:, :(-SP_PAD)] return hidden_states def reinit(self, parent_model=None): diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh index c324aecf24..6683119da2 100755 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/inference.sh @@ -1,7 +1,7 @@ #! /bin/bash export TASK_QUEUE_ENABLE=2 export HCCL_OP_EXPANSION_MODE="AIV" -torchrun --nnodes=1 --nproc_per_node 2 --master_port 29503 \ +torchrun --nnodes=1 --nproc_per_node 8 --master_port 29503 \ sample_video.py \ --base configs/cogvideox_5b.yaml configs/inference.yaml \ --seed 42 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py index 6ce7ed3eaa..6d3f4973d8 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/comm.py @@ -125,7 +125,7 @@ def all_to_all( return _all_to_all_func(tensor, world_size, process_group, scatter_dim, gather_dim) -def all_to_all_sbh( +def all_to_all_bsh( input_: torch.Tensor, scatter_dim: int = 1, gather_dim: int = 0, -- Gitee From 18782f38fcf362345513c9e31fab4313dfba6a95 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Sun, 26 Jan 2025 10:31:35 +0000 Subject: [PATCH 7/9] vae parallel --- .../cogvideox-sat/configs/inference.yaml | 4 +- .../foundation/cogvideox-sat/sample_video.py | 57 ++++---- .../cogvideox-sat/utils/parallel_mgr.py | 10 +- .../cogvideox-sat/vae_modules/autoencoder.py | 34 ++--- .../cogvideox-sat/vae_modules/cp_enc_dec.py | 130 ++++++++++-------- 5 files changed, 127 insertions(+), 108 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml index f688264012..c1d882fb1b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/configs/inference.yaml @@ -10,7 +10,7 @@ args: args.sampling_image_width: 720 sampling_num_frames: 13 # Must be 13, 11 or 9 sampling_fps: 8 - fp16: True # For CogVideoX-2B - # bf16: False # For CogVideoX-5B and CoGVideoX-5B-I2V + # fp16: True # For CogVideoX-2B + bf16: True # For CogVideoX-5B and CoGVideoX-5B-I2V output_dir: outputs/ force_inference: True \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py index 7ac6afb833..e38d2ccaa9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py @@ -14,6 +14,7 @@ import imageio import torch import torch_npu +import torch.distributed as dist import numpy as np from einops import rearrange import torchvision.transforms as TT @@ -242,29 +243,39 @@ def sampling_main(args, model_cls): shape=(T, C, H // F, W // F), ) samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous() - latent = 1.0 / model.scale_factor * samples_z - - # Decode latent serial to save GPU memory - recons = [] - loop_num = (T - 1) // 2 - for i in range(loop_num): - if i == 0: - start_frame, end_frame = 0, 3 - else: - start_frame, end_frame = i * 2 + 1, i * 2 + 3 - if i == loop_num - 1: - clear_fake_cp_cache = True - else: - clear_fake_cp_cache = False - with torch.no_grad(): - recon = first_stage_model.decode( - latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache - ) - - recons.append(recon) - - recon = torch.cat(recons, dim=2).to(torch.float32) - samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() + world_size = dist.get_world_size() + if world_size <= 2: + # VAE Tiling to save memory + latent = 1.0 / model.scale_factor * samples_z + # Decode latent serial to save GPU memory + recons = [] + loop_num = (T - 1) // 2 + for i in range(loop_num): + if i == 0: + start_frame, end_frame = 0, 3 + else: + start_frame, end_frame = i * 2 + 1, i * 2 + 3 + if i == loop_num - 1: + clear_fake_cp_cache = True + else: + clear_fake_cp_cache = False + with torch.no_grad(): + recon = first_stage_model.decode( + latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache + ) + recons.append(recon) + recon = torch.cat(recons, dim=2).to(torch.float32) + samples_x = recon.permute(0, 2, 1, 3, 4).contiguous() + else: + # VAE Parallel + decode_pad = math.ceil((samples_z.shape[2] - 1) / world_size) * world_size - samples_z.shape[2] + 1 + padding_shape = (*samples_z.shape[:2], decode_pad, *samples_z.shape[3:]) + padding = torch.zeros(padding_shape, dtype=samples_z.dtype, device=samples_z.device) + samples_z = torch.cat([samples_z, padding], dim=2) + samples_x = model.decode_first_stage(samples_z).to(torch.float32) + if decode_pad > 0: + samples_x = samples_x[:, :, :args.model_config.network_config.params.num_frames] + samples_x = samples_x.permute(0, 2, 1, 3, 4).contiguous() samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu() save_path = os.path.join( diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py index 56a5c75afd..41047606b6 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/utils/parallel_mgr.py @@ -28,13 +28,12 @@ def is_vae_parallel_initialized(): return True -def initialize_vae_parallel(): +def initialize_vae_parallel(parallel_size): global _VAE_PARALLEL_GROUP global _VAE_PARALLEL_SIZE assert _VAE_PARALLEL_GROUP is None, "vae parallel group is already initialized" - _VAE_PARALLEL_SIZE = 1 - # _VAE_PARALLEL_SIZE = dist.get_world_size() + _VAE_PARALLEL_SIZE = parallel_size ranks = range(_VAE_PARALLEL_SIZE) group = dist.new_group(ranks) @@ -42,8 +41,8 @@ def initialize_vae_parallel(): def get_vae_parallel_group(): - assert _VAE_PARALLEL_GROUP is not None, "vae parallel group is not initialized" - + if _VAE_PARALLEL_GROUP is None: + return 0 return _VAE_PARALLEL_GROUP @@ -325,7 +324,6 @@ def init_parallel_env(parallel_config: ParallelConfig): classifier_free_guidance_degree=parallel_config.cfg_degree, sequence_parallel_degree=parallel_config.sp_degree, ) - initialize_vae_parallel() def finalize_parallel_env(): diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py index e68644b1ec..b5b5fc8a59 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/autoencoder.py @@ -313,7 +313,7 @@ class VideoAutoencodingEngine(AutoencodingEngine): **kwargs, ): super().__init__(**kwargs) - self.context_parallel_size = context_parallel_size + self.context_parallel_size = dist.get_world_size() if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) @@ -328,7 +328,7 @@ class VideoAutoencodingEngine(AutoencodingEngine): batch = batch[self.input_key] global_src_rank = get_vae_parallel_group_rank() * self.context_parallel_size - torch.distributed.broadcast(batch, src=global_src_rank, group=get_vae_parallel_group()) + dist.broadcast(batch, src=global_src_rank, group=get_vae_parallel_group()) batch = _conv_split(batch, dim=2, kernel_size=1) return batch @@ -361,7 +361,7 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): *args, **kwargs, ): - self.cp_size = cp_size + self.cp_size = dist.get_world_size() return super().__init__(*args, **kwargs) def encode( @@ -372,23 +372,11 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): input_cp: bool = False, output_cp: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - if self.cp_size > 0 and not input_cp: - if not is_vae_parallel_initialized(): - initialize_vae_parallel(self.cp_size) - - global_src_rank = get_vae_parallel_group_rank() * self.cp_size - torch.distributed.broadcast(x, src=global_src_rank, group=get_vae_parallel_group()) - - x = _conv_split(x, dim=2, kernel_size=1) - if return_reg_log: z, reg_log = super().encode(x, return_reg_log, unregularized) else: z = super().encode(x, return_reg_log, unregularized) - if self.cp_size > 0 and not output_cp: - z = _conv_gather(z, dim=2, kernel_size=1) - if return_reg_log: return z, reg_log return z @@ -399,21 +387,23 @@ class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine): input_cp: bool = False, output_cp: bool = False, split_kernel_size: int = 1, + use_cp: bool = True, **kwargs, ): - if self.cp_size > 0 and not input_cp: + if self.cp_size <= 2: + use_cp = False + if self.cp_size > 0 and use_cp and not input_cp: if not is_vae_parallel_initialized(): initialize_vae_parallel(self.cp_size) global_src_rank = get_vae_parallel_group_rank() * self.cp_size - torch.distributed.broadcast(z, src=global_src_rank, group=get_vae_parallel_group()) - - z = _conv_split(z, dim=2, kernel_size=split_kernel_size) + dist.broadcast(z, src=global_src_rank, group=get_vae_parallel_group()) + z = _conv_split(z, dim=2, kernel_size=1) - x = super().decode(z, **kwargs) + x = super().decode(z, use_cp=use_cp, **kwargs) - if self.cp_size > 0 and not output_cp: - x = _conv_gather(x, dim=2, kernel_size=split_kernel_size) + if self.cp_size > 0 and use_cp and not output_cp: + x = _conv_gather(x, dim=2, kernel_size=1) return x diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py index 968cd50d65..9be0a4c62a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py @@ -76,8 +76,6 @@ def _split(input_, dim): cp_rank = get_vae_parallel_rank() - # print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape) - inpu_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() dim_size = input_.size()[dim] // cp_world_size @@ -89,8 +87,6 @@ def _split(input_, dim): output = torch.cat([inpu_first_frame_, output], dim=dim) output = output.contiguous() - # print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape) - return output @@ -104,8 +100,6 @@ def _gather(input_, dim): group = get_vae_parallel_group() cp_rank = get_vae_parallel_rank() - # print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape) - input_first_frame_ = input_.transpose(0, dim)[:1].transpose(0, dim).contiguous() if cp_rank == 0: input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous() @@ -122,8 +116,6 @@ def _gather(input_, dim): output = torch.cat(tensor_list, dim=dim).contiguous() - # print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape) - return output @@ -134,8 +126,6 @@ def _conv_split(input_, dim, kernel_size): if cp_world_size == 1: return input_ - # print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape) - cp_rank = get_vae_parallel_rank() dim_size = (input_.size()[dim] - kernel_size) // cp_world_size @@ -143,14 +133,11 @@ def _conv_split(input_, dim, kernel_size): if cp_rank == 0: output = input_.transpose(dim, 0)[: dim_size + kernel_size].transpose(dim, 0) else: - # output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0) output = input_.transpose(dim, 0)[ cp_rank * dim_size + kernel_size: (cp_rank + 1) * dim_size + kernel_size ].transpose(dim, 0) output = output.contiguous() - # print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape) - return output @@ -483,24 +470,33 @@ class SpatialNorm3D(nn.Module): kernel_size=1, ) - def forward(self, f, zq, clear_fake_cp_cache=True): - if f.shape[2] > 1 and f.shape[2] % 2 == 1: + def forward(self, f, zq, clear_fake_cp_cache=True, fake_cp=True): + if f.shape[2] > 1 and get_vae_parallel_rank() == 0 and fake_cp: f_first, f_rest = f[:, :, :1], f[:, :, 1:] f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:] zq_first = torch.nn.functional.interpolate(zq_first, size=f_first_size, mode="nearest") - zq_rest = torch.nn.functional.interpolate(zq_rest, size=f_rest_size, mode="nearest") + + zq_rest_splits = torch.split(zq_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits + ] + + zq_rest = torch.cat(interpolated_splits, dim=1) zq = torch.cat([zq_first, zq_rest], dim=2) else: - zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:], mode="nearest") + f_size = f.shape[-3:] + + zq_splits = torch.split(zq, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, size=f_size, mode="nearest") for split in zq_splits + ] + zq = torch.cat(interpolated_splits, dim=1) if self.add_conv: zq = self.conv(zq, clear_cache=clear_fake_cp_cache) - # f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1) norm_f = self.norm_layer(f) - # norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f @@ -536,23 +532,38 @@ class Upsample3D(nn.Module): self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.compress_time = compress_time - def forward(self, x): + def forward(self, x, fake_cp=True): if self.compress_time and x.shape[2] > 1: - if x.shape[2] % 2 == 1: + if get_vae_parallel_rank() == 0 and fake_cp: # split first frame x_first, x_rest = x[:, :, 0], x[:, :, 1:] x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest") - x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest") + + splits = torch.split(x_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x_rest = torch.cat(interpolated_splits, dim=1) x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) else: - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) else: # only interpolate 2D t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.interpolate(split, scale_factor=2.0, mode="nearest") for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) if self.with_conv: @@ -574,21 +585,30 @@ class DownSample3D(nn.Module): self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.compress_time = compress_time - def forward(self, x): + def forward(self, x, fake_cp=True): if self.compress_time and x.shape[2] > 1: h, w = x.shape[-2:] x = rearrange(x, "b c t h w -> (b h w) c t") - if x.shape[-1] % 2 == 1: + if get_vae_parallel_rank() == 0 and fake_cp: # split first frame x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: - x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + splits = torch.split(x_rest, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + ] + x_rest = torch.cat(interpolated_splits, dim=1) x = torch.cat([x_first[..., None], x_rest], dim=-1) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) else: - x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + # x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + splits = torch.split(x, 32, dim=1) + interpolated_splits = [ + torch.nn.functional.avg_pool1d(split, kernel_size=2, stride=2) for split in splits + ] + x = torch.cat(interpolated_splits, dim=1) x = rearrange(x, "(b h w) c t -> b c t h w", h=h, w=w) if self.with_conv: @@ -668,17 +688,13 @@ class ContextParallelResnetBlock3D(nn.Module): padding=0, ) - def forward(self, x, temb, zq=None, clear_fake_cp_cache=True): + def forward(self, x, temb, zq=None, clear_fake_cp_cache=True, fake_cp=True): h = x - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: - h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp) else: h = self.norm1(h) - # if isinstance(self.norm1, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.conv1(h, clear_cache=clear_fake_cp_cache) @@ -686,14 +702,10 @@ class ContextParallelResnetBlock3D(nn.Module): if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None] - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) if zq is not None: - h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=fake_cp) else: h = self.norm2(h) - # if isinstance(self.norm2, torch.nn.GroupNorm): - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) h = nonlinearity(h) h = self.dropout(h) @@ -802,29 +814,33 @@ class ContextParallelEncoder3D(nn.Module): kernel_size=3, ) - def forward(self, x, **kwargs): + def forward(self, x, use_cp=True): + global _USE_CP + if torch.distributed.get_world_size() <= 2: + use_cp = False + _USE_CP = use_cp + # timestep embedding temb = None # downsampling - h = self.conv_in(x) + hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) + h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) + hs.append(h) if i_level != self.num_resolutions - 1: - h = self.down[i_level].downsample(h) + hs.append(self.down[i_level].downsample(hs[-1])) # middle + h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.block_2(h, temb) # end - # h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1) h = self.norm_out(h) - # h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1) - h = nonlinearity(h) h = self.conv_out(h) @@ -869,11 +885,9 @@ class ContextParallelDecoder3D(nn.Module): zq_ch = z_channels # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) self.conv_in = ContextParallelCausalConv3d( chan_in=z_channels, @@ -943,7 +957,11 @@ class ContextParallelDecoder3D(nn.Module): kernel_size=3, ) - def forward(self, z, clear_fake_cp_cache=True, **kwargs): + def forward(self, z, clear_fake_cp_cache=True, use_cp=True): + global _USE_CP + if torch.distributed.get_world_size() <= 2: + use_cp = False + _USE_CP = use_cp self.last_z_shape = z.shape # timestep embedding @@ -956,23 +974,25 @@ class ContextParallelDecoder3D(nn.Module): h = self.conv_in(z, clear_cache=clear_fake_cp_cache) # middle - h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) - h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.mid.block_1(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) + h = self.mid.block_2(h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.up[i_level].block[i_block]( + h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp + ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, zq) if i_level != 0: - h = self.up[i_level].upsample(h) + h = self.up[i_level].upsample(h, fake_cp=use_cp) # end if self.give_pre_end: return h - h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache) + h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache, fake_cp=use_cp) h = nonlinearity(h) h = self.conv_out(h, clear_cache=clear_fake_cp_cache) -- Gitee From 13836b97d309fbee25b21a3aac89eb1247355dac Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 6 Feb 2025 04:21:37 +0000 Subject: [PATCH 8/9] fix vae_parallel bugs --- .../built-in/foundation/cogvideox-sat/sample_video.py | 1 + .../cogvideox-sat/vae_modules/cp_enc_dec.py | 11 +++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py index e38d2ccaa9..cb7b0a2127 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py @@ -283,6 +283,7 @@ def sampling_main(args, model_cls): ) if rank == 0: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) + torch.npu.empty_cache() torch.npu.synchronize() end = time.time() print(f"use time: {end - start}s") diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py index 9be0a4c62a..3aca7e7e8f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/vae_modules/cp_enc_dec.py @@ -376,10 +376,13 @@ class ContextParallelCausalConv3d(nn.Module): # output_parallel = self.conv(input_parallel) # output = output_parallel # return output - - input_parallel = fake_cp_pass_from_previous_rank( - input_, self.temporal_dim, self.time_kernel_size, self.cache_padding - ) + if input_.shape[2] == 1: # handle image + # first frame padding + input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2) + else: + input_parallel = fake_cp_pass_from_previous_rank( + input_, self.temporal_dim, self.time_kernel_size, self.cache_padding + ) del self.cache_padding self.cache_padding = None -- Gitee From 72e148571581d698ab02f3d9839881bf75f8c0e2 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 6 Feb 2025 13:48:52 +0000 Subject: [PATCH 9/9] fix seed bug --- .../built-in/foundation/cogvideox-sat/sample_video.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py index cb7b0a2127..79d3211ed3 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py +++ b/MindIE/MindIE-Torch/built-in/foundation/cogvideox-sat/sample_video.py @@ -129,8 +129,6 @@ def sampling_main(args, model_cls): init_parallel_env(parallel_config) rank = get_sequence_parallel_rank() if get_sequence_parallel_state() else 0 - args.seed = args.seed + rank - set_random_seed(args.seed) if isinstance(model_cls, type): model = get_model(args, model_cls) @@ -160,6 +158,7 @@ def sampling_main(args, model_cls): force_uc_zero_embeddings = ["txt"] with torch.no_grad(): for text, cnt in tqdm(data_iter): + set_random_seed(args.seed) start = time.time() torch.npu.synchronize() if args.image2video: -- Gitee