From 863ba62121f6420144ee480245d700179b5c6941 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 10 Jan 2025 07:01:12 +0000 Subject: [PATCH 01/19] initial release --- .../foundation/open_sora_planv1_2/README.md | 4 + .../open_sora_planv1_2/inference.py | 144 ++ .../opensora/models/__init__.py | 1 + .../models/causalvideovae/__init__.py | 32 + .../models/causalvideovae/model/__init__.py | 3 + .../model/causal_vae/__init__.py | 29 + .../model/causal_vae/modeling_causalvae.py | 600 +++++++++ .../model/modeling_videobase.py | 50 + .../causalvideovae/model/modules/__init__.py | 6 + .../causalvideovae/model/modules/attention.py | 230 ++++ .../causalvideovae/model/modules/block.py | 5 + .../causalvideovae/model/modules/conv.py | 117 ++ .../causalvideovae/model/modules/normalize.py | 101 ++ .../causalvideovae/model/modules/ops.py | 47 + .../model/modules/resnet_block.py | 126 ++ .../model/modules/updownsample.py | 350 +++++ .../causalvideovae/model/utils/__init__.py | 0 .../model/utils/distrib_utils.py | 42 + .../model/utils/module_utils.py | 17 + .../opensora/models/diffusion/__init__.py | 2 + .../models/diffusion/opensora/__init__.py | 0 .../diffusion/opensora/modeling_opensora.py | 571 ++++++++ .../models/diffusion/opensora/modules.py | 1175 +++++++++++++++++ .../models/diffusion/opensora/rope.py | 88 ++ .../opensora/sample/pipeline_opensora_sp.py | 793 +++++++++++ .../open_sora_planv1_2/requirements.txt | 17 + .../foundation/open_sora_planv1_2/run.sh | 18 + .../open_sora_planv1_2/utils/comm.py | 180 +++ .../open_sora_planv1_2/utils/parallel_mgr.py | 72 + 29 files changed, 4820 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/README.md create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modeling_videobase.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/block.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/conv.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/ops.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/resnet_block.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/updownsample.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/module_utils.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/__init__.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/requirements.txt create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/README.md b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/README.md new file mode 100644 index 0000000000..3c3ca6b8d4 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/README.md @@ -0,0 +1,4 @@ +/.--- +license: apache-2.0 +--- + diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py new file mode 100644 index 0000000000..21f85bab8b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py @@ -0,0 +1,144 @@ +import argparse +import logging +import os + +import imageio +import torch +import torch_npu +from diffusers.schedulers import EulerAncestralDiscreteScheduler +from transformers import T5Tokenizer, MT5EncoderModel + +from opensora.models.causalvideovae import ae_stride_config, CausalVAEModelWrapper +from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V +from opensora.sample.pipeline_opensora_sp import OpenSoraPipeline +from utils.parallel_mgr import init_parallel_env, finalize_parallel_env, get_sequence_parallel_rank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def load_t2v_checkpoint(model_path): + logger.info(f'load_t2v_checkpoint, {model_path}') + transformer_model = OpenSoraT2V.from_pretrained(model_path, cache_dir=args.cache_dir, + low_cpu_mem_usage=False, device_map=None, + torch_dtype=weight_dtype).to("npu") + transformer_model.eval() + pipeline = OpenSoraPipeline(vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer_model).to("npu") + if args.algorithm == "dit_cache": + from mindiesd.layers.cache_mgr import CacheManager, DitCacheConfig + config = DitCacheConfig(step_start=20, step_interval=2, block_start=7, num_blocks=21) + cache = CacheManager(config) + pipeline.transformer.cache = cache + return pipeline + + +def run_model_and_save_images(pipeline): + if not isinstance(args.text_prompt, list): + args.text_prompt = [args.text_prompt] + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): + text_prompt = open(args.text_prompt[0], 'r').readlines() + args.text_prompt = [i.strip() for i in text_prompt] + + positive_prompt = """ + (masterpiece), (best quality), (ultra-detailed), (unwatermarked), + {}. + emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, + sharp focus, high budget, cinemascope, moody, epic, gorgeous + """ + + negative_prompt = """ + nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, + low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. + """ + kwargs = {} + if args.algorithm == "sampling_optimize": + kwargs["sampling_optimize"] = True + + for index, prompt in enumerate(args.text_prompt): + videos = pipeline(positive_prompt.format(prompt), + negative_prompt=negative_prompt, + num_frames=args.num_frames, + height=args.height, + width=args.width, + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale, + num_images_per_prompt=1, + mask_feature=True, + max_sequence_length=args.max_sequence_length, + **kwargs + ).images + logger.info(videos.shape) + + if get_sequence_parallel_rank() <= 0: + try: + imageio.mimwrite( + os.path.join( + args.save_img_path, + f'EulerAncestralDiscrete_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4' + ), videos[0], + fps=args.fps, quality=6, codec='libx264', + output_params=['-threads', '20']) # highest quality is 10, lowest is 0 + except: + logger.error('Error when saving {}'.format(prompt)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') + parser.add_argument("--version", type=str, default=None, choices=[None, '65x512x512', '65x256x256', '17x256x256']) + parser.add_argument("--num_frames", type=int, default=93) + parser.add_argument("--height", type=int, default=720) + parser.add_argument("--width", type=int, default=1280) + parser.add_argument('--dtype', type=str, default='bf16', help='Data type used in inference') + parser.add_argument("--cache_dir", type=str, default='./cache_dir') + parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') + parser.add_argument("--ae_path", type=str, default='CausalVAEModel_4x8x8') + parser.add_argument("--text_encoder_name", type=str, default='google/mt5-xxl') + parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") + parser.add_argument("--guidance_scale", type=float, default=7.5) + parser.add_argument("--num_sampling_steps", type=int, default=50) + parser.add_argument("--fps", type=int, default=24) + parser.add_argument("--max_sequence_length", type=int, default=512) + parser.add_argument("--text_prompt", nargs='+') + parser.add_argument("--tile_overlap_factor", type=float, default=0.25) + parser.add_argument("--algorithm", type=str, default=None, choices=[None, 'dit_cache', 'sampling_optimize']) + args = parser.parse_args() + + if args.dtype not in ['bf16', 'fp16']: + logger.error("Not supported.") + weight_dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float16 + + local_rank = int(os.getenv('RANK', 0)) + world_size = int(os.getenv('WORLD_SIZE', 1)) + if world_size > 0: + init_parallel_env(enable_sequence_parallelism=True) + + vae = CausalVAEModelWrapper(args.ae_path, dtype=torch.float16).to("npu") + vae.vae.enable_tiling() + vae.vae.tile_overlap_factor = args.tile_overlap_factor + vae.vae.tile_sample_min_size = 256 + vae.vae.tile_latent_min_size = 32 + vae.vae.tile_sample_min_size_t = 29 + vae.vae.tile_latent_min_size_t = 8 + vae.vae_scale_factor = ae_stride_config[args.ae] + vae.eval() + + text_encoder = MT5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir, + low_cpu_mem_usage=True, torch_dtype=weight_dtype).to("npu") + tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir) + text_encoder.eval() + + scheduler = EulerAncestralDiscreteScheduler() + + if not os.path.exists(args.save_img_path): + os.makedirs(args.save_img_path, exist_ok=True) + + pipeline = load_t2v_checkpoint(args.model_path) + logger.info('load model') + run_model_and_save_images(pipeline) + if world_size > 0: + finalize_parallel_env() diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/__init__.py new file mode 100644 index 0000000000..3eca32c739 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/__init__.py @@ -0,0 +1 @@ +from .causalvideovae.model import CausalVAEModelWrapper \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/__init__.py new file mode 100644 index 0000000000..bc1357393e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/__init__.py @@ -0,0 +1,32 @@ +from torchvision.transforms import Lambda +from .model.causal_vae import CausalVAEModelWrapper + +ae_stride_config = { + 'CausalVAEModel_D4_2x8x8': [2, 8, 8], + 'CausalVAEModel_D8_2x8x8': [2, 8, 8], + 'CausalVAEModel_D4_4x8x8': [4, 8, 8], + 'CausalVAEModel_D8_4x8x8': [4, 8, 8], +} + + +ae_channel_config = { + 'CausalVAEModel_D4_2x8x8': 4, + 'CausalVAEModel_D8_2x8x8': 8, + 'CausalVAEModel_D4_4x8x8': 4, + 'CausalVAEModel_D8_4x8x8': 8, +} + + +ae_denorm = { + 'CausalVAEModel_D4_2x8x8': lambda x: (x + 1.) / 2., + 'CausalVAEModel_D8_2x8x8': lambda x: (x + 1.) / 2., + 'CausalVAEModel_D4_4x8x8': lambda x: (x + 1.) / 2., + 'CausalVAEModel_D8_4x8x8': lambda x: (x + 1.) / 2., +} + +ae_norm = { + 'CausalVAEModel_D4_2x8x8': Lambda(lambda x: 2. * x - 1.), + 'CausalVAEModel_D8_2x8x8': Lambda(lambda x: 2. * x - 1.), + 'CausalVAEModel_D4_4x8x8': Lambda(lambda x: 2. * x - 1.), + 'CausalVAEModel_D8_4x8x8': Lambda(lambda x: 2. * x - 1.), +} \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/__init__.py new file mode 100644 index 0000000000..84918b27c2 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/__init__.py @@ -0,0 +1,3 @@ +from .causal_vae import ( + CausalVAEModel, CausalVAEModelWrapper +) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/__init__.py new file mode 100644 index 0000000000..c1aeb800d7 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/__init__.py @@ -0,0 +1,29 @@ +from .modeling_causalvae import CausalVAEModel + +from einops import rearrange +from torch import nn + +class CausalVAEModelWrapper(nn.Module): + def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs): + super(CausalVAEModelWrapper, self).__init__() + # if os.path.exists(ckpt): + # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) + self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) + if use_ema: + self.vae.init_from_ema(model_path) + self.vae = self.vae.ema + def encode(self, x): # b c t h w + # x = self.vae.encode(x).sample() + x = self.vae.encode(x).sample().mul_(0.18215) + return x + def decode(self, x): + # x = self.vae.decode(x) + x = self.vae.decode(x / 0.18215) + x = rearrange(x, 'b c t h w -> b t c h w').contiguous() + return x + + def dtype(self): + return self.vae.dtype + # + # def device(self): + # return self.vae.device \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py new file mode 100644 index 0000000000..299cc071ae --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py @@ -0,0 +1,600 @@ +import os +from copy import deepcopy +from typing import Tuple + +import torch +import torch_npu +import torch.nn as nn +from diffusers.configuration_utils import register_to_config + +from ..modules import Normalize +from ..modules.ops import nonlinearity +from ..modeling_videobase import VideoBaseAE +from ..utils.distrib_utils import DiagonalGaussianDistribution +from ..utils.module_utils import resolve_str_to_obj, Module + + +class Encoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in: Module = "Conv2d", + conv_out: Module = "CasualConv3d", + attention: Module = "AttnBlock", + resnet_blocks: Tuple[Module] = ( + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock2D", + "ResnetBlock3D", + ), + spatial_downsample: Tuple[Module] = ( + "Downsample", + "Downsample", + "Downsample", + "", + ), + temporal_downsample: Tuple[Module] = ("", "", "TimeDownsampleRes2x", ""), + mid_resnet: Module = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + double_z: bool = True, + ) -> None: + super().__init__() + assert len(resnet_blocks) == len(hidden_size_mult), print( + hidden_size_mult, resnet_blocks + ) + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + self.conv_in = resolve_str_to_obj(conv_in)( + 3, hidden_size, kernel_size=3, stride=1, padding=1 + ) + + # ---- Downsample ---- + curr_res = resolution + in_ch_mult = (1,) + tuple(hidden_size_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = hidden_size * in_ch_mult[i_level] + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if spatial_downsample[i_level]: + down.downsample = resolve_str_to_obj(spatial_downsample[i_level])( + block_in, block_in + ) + curr_res = curr_res // 2 + if temporal_downsample[i_level]: + down.time_downsample = resolve_str_to_obj(temporal_downsample[i_level])( + block_in, block_in + ) + self.down.append(down) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if hasattr(self.down[i_level], "downsample"): + hs.append(self.down[i_level].downsample(hs[-1])) + if hasattr(self.down[i_level], "time_downsample"): + hs_down = self.down[i_level].time_downsample(hs[-1]) + hs.append(hs_down) + + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels: int, + hidden_size: int, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = (16,), + conv_in: Module = "Conv2d", + conv_out: Module = "CasualConv3d", + attention: Module = "AttnBlock", + resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + spatial_upsample: Tuple[Module] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + temporal_upsample: Tuple[Module] = ("", "", "", "TimeUpsampleRes2x"), + mid_resnet: Module = "ResnetBlock3D", + dropout: float = 0.0, + resolution: int = 256, + num_res_blocks: int = 2, + ): + super().__init__() + # ---- Config ---- + self.num_resolutions = len(hidden_size_mult) + self.resolution = resolution + self.num_res_blocks = num_res_blocks + + # ---- In ---- + block_in = hidden_size * hidden_size_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.conv_in = resolve_str_to_obj(conv_in)( + z_channels, block_in, kernel_size=3, padding=1 + ) + + # ---- Mid ---- + self.mid = nn.Module() + self.mid.block_1 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + self.mid.attn_1 = resolve_str_to_obj(attention)(block_in) + self.mid.block_2 = resolve_str_to_obj(mid_resnet)( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + ) + + # ---- Upsample ---- + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = hidden_size * hidden_size_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + resolve_str_to_obj(resnet_blocks[i_level])( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(resolve_str_to_obj(attention)(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if spatial_upsample[i_level]: + up.upsample = resolve_str_to_obj(spatial_upsample[i_level])( + block_in, block_in + ) + curr_res = curr_res * 2 + if temporal_upsample[i_level]: + up.time_upsample = resolve_str_to_obj(temporal_upsample[i_level])( + block_in, block_in + ) + self.up.insert(0, up) + + # ---- Out ---- + self.norm_out = Normalize(block_in) + self.conv_out = resolve_str_to_obj(conv_out)( + block_in, 3, kernel_size=3, padding=1 + ) + + def forward(self, z): + h = self.conv_in(z) + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if hasattr(self.up[i_level], "upsample"): + h = self.up[i_level].upsample(h) + if hasattr(self.up[i_level], "time_upsample"): + h = self.up[i_level].time_upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h_dtype = h.dtype + h = self.conv_out(h) + return h + + +class CausalVAEModel(VideoBaseAE): + + @register_to_config + def __init__( + self, + hidden_size: int = 128, + z_channels: int = 4, + hidden_size_mult: Tuple[int] = (1, 2, 4, 4), + attn_resolutions: Tuple[int] = [], + dropout: float = 0.0, + resolution: int = 256, + double_z: bool = True, + embed_dim: int = 4, + num_res_blocks: int = 2, + q_conv: str = "CausalConv3d", + encoder_conv_in: Module = "CausalConv3d", + encoder_conv_out: Module = "CausalConv3d", + encoder_attention: Module = "AttnBlock3D", + encoder_resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + encoder_spatial_downsample: Tuple[Module] = ( + "SpatialDownsample2x", + "SpatialDownsample2x", + "SpatialDownsample2x", + "", + ), + encoder_temporal_downsample: Tuple[Module] = ( + "", + "TimeDownsample2x", + "TimeDownsample2x", + "", + ), + encoder_mid_resnet: Module = "ResnetBlock3D", + decoder_conv_in: Module = "CausalConv3d", + decoder_conv_out: Module = "CausalConv3d", + decoder_attention: Module = "AttnBlock3D", + decoder_resnet_blocks: Tuple[Module] = ( + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + "ResnetBlock3D", + ), + decoder_spatial_upsample: Tuple[Module] = ( + "", + "SpatialUpsample2x", + "SpatialUpsample2x", + "SpatialUpsample2x", + ), + decoder_temporal_upsample: Tuple[Module] = ("", "", "TimeUpsample2x", "TimeUpsample2x"), + decoder_mid_resnet: Module = "ResnetBlock3D", + use_quant_layer: bool = True + ) -> None: + super().__init__() + + self.tile_sample_min_size = 256 + self.tile_sample_min_size_t = 33 + self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) + + # t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] + # self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** len(t_down_ratio))) + 1 + self.tile_latent_min_size_t = 16 + self.tile_overlap_factor = 0.125 + self.use_tiling = False + + self.use_quant_layer = use_quant_layer + + self.encoder = Encoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=encoder_conv_in, + conv_out=encoder_conv_out, + attention=encoder_attention, + resnet_blocks=encoder_resnet_blocks, + spatial_downsample=encoder_spatial_downsample, + temporal_downsample=encoder_temporal_downsample, + mid_resnet=encoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + double_z=double_z, + ) + + self.decoder = Decoder( + z_channels=z_channels, + hidden_size=hidden_size, + hidden_size_mult=hidden_size_mult, + attn_resolutions=attn_resolutions, + conv_in=decoder_conv_in, + conv_out=decoder_conv_out, + attention=decoder_attention, + resnet_blocks=decoder_resnet_blocks, + spatial_upsample=decoder_spatial_upsample, + temporal_upsample=decoder_temporal_upsample, + mid_resnet=decoder_mid_resnet, + dropout=dropout, + resolution=resolution, + num_res_blocks=num_res_blocks, + ) + if self.use_quant_layer: + quant_conv_cls = resolve_str_to_obj(q_conv) + self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1) + self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1) + + def get_encoder(self): + if self.use_quant_layer: + return [self.quant_conv, self.encoder] + return [self.encoder] + + def get_decoder(self): + if self.use_quant_layer: + return [self.post_quant_conv, self.decoder] + return [self.decoder] + + def encode(self, x): + if self.use_tiling and ( + x.shape[-1] > self.tile_sample_min_size + or x.shape[-2] > self.tile_sample_min_size + or x.shape[-3] > self.tile_sample_min_size_t + ): + return self.tiled_encode(x) + h = self.encoder(x) + if self.use_quant_layer: + h = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(h) + return posterior + + def decode(self, z): + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + or z.shape[-3] > self.tile_latent_min_size_t + ): + return self.tiled_decode(z) + if self.use_quant_layer: + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def on_train_start(self): + self.ema = deepcopy(self) if self.save_ema == True else None + + def get_last_layer(self): + if hasattr(self.decoder.conv_out, "conv"): + return self.decoder.conv_out.conv.weight + else: + return self.decoder.conv_out.weight + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t - 1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + moments = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start: end] + if idx != 0: + moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:] + else: + moment = self.tiled_encode2d(chunk_x, return_moments=True) + moments.append(moment) + moments = torch.cat(moments, dim=2) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def tiled_decode(self, x): + t = x.shape[2] + t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t - 1)] + if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0: + t_chunk_start_end = [[0, t]] + else: + t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i + 1] + 1] for i in range(len(t_chunk_idx) - 1)] + if t_chunk_start_end[-1][-1] > t: + t_chunk_start_end[-1][-1] = t + elif t_chunk_start_end[-1][-1] < t: + last_start_end = [t_chunk_idx[-1], t] + t_chunk_start_end.append(last_start_end) + dec_ = [] + for idx, (start, end) in enumerate(t_chunk_start_end): + chunk_x = x[:, :, start: end] + if idx != 0: + dec = self.tiled_decode2d(chunk_x)[:, :, 1:] + else: + dec = self.tiled_decode2d(chunk_x) + dec_.append(dec) + dec_ = torch.cat(dec_, dim=2) + return dec_ + + def tiled_encode2d(self, x, return_moments=False): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i: i + self.tile_sample_min_size, + j: j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.use_quant_layer: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + posterior = DiagonalGaussianDistribution(moments) + if return_moments: + return moments + return posterior + + def tiled_decode2d(self, z): + + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i: i + self.tile_latent_min_size, + j: j + self.tile_latent_min_size, + ] + if self.use_quant_layer: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def enable_tiling(self, use_tiling: bool = True): + self.use_tiling = use_tiling + + def disable_tiling(self): + self.enable_tiling(False) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu") + print("init from " + path) + + if "ema_state_dict" in sd and len(sd['ema_state_dict']) > 0 and os.environ.get("NOT_USE_EMA_MODEL", 0) == 0: + print("Load from ema model!") + sd = sd["ema_state_dict"] + sd = {key.replace("module.", ""): value for key, value in sd.items()} + elif "state_dict" in sd: + print("Load from normal model!") + if "gen_model" in sd["state_dict"]: + sd = sd["state_dict"]["gen_model"] + else: + sd = sd["state_dict"] + + keys = list(sd.keys()) + + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + miss, unexpected = self.load_state_dict(sd, strict=False) + assert len(miss) == 0, f"miss key: {miss}" + if len(unexpected) > 0: + for i in unexpected: + assert 'loss' in i, "unexpected key: {i}" diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modeling_videobase.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modeling_videobase.py new file mode 100644 index 0000000000..f7686a0f27 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modeling_videobase.py @@ -0,0 +1,50 @@ +import torch +import os +import json +from diffusers.configuration_utils import ConfigMixin +from diffusers.models.modeling_utils import ModelMixin +from typing import Optional, Union +import glob + + +class VideoBaseAE(ModelMixin, ConfigMixin): + config_name = "config.json" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def encode(self, x: torch.Tensor, *args, **kwargs): + pass + + def decode(self, encoding: torch.Tensor, *args, **kwargs): + pass + + @property + def num_training_steps(self) -> int: + """Total training steps inferred from datamodule and devices.""" + if self.trainer.max_steps: + return self.trainer.max_steps + + limit_batches = self.trainer.limit_train_batches + batches = len(self.train_dataloader()) + batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) + + num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) + if self.trainer.tpu_cores: + num_devices = max(num_devices, self.trainer.tpu_cores) + + effective_accum = self.trainer.accumulate_grad_batches * num_devices + return (batches // effective_accum) * self.trainer.max_epochs + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + ckpt_files = glob.glob(os.path.join(pretrained_model_name_or_path, '*.ckpt')) + if ckpt_files: + # Adapt to checkpoint + last_ckpt_file = ckpt_files[-1] + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + model = cls.from_config(config_file) + model.init_from_ckpt(last_ckpt_file) + return model + else: + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/__init__.py new file mode 100644 index 0000000000..7ef41edd27 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/__init__.py @@ -0,0 +1,6 @@ +from .block import Block +from .attention import * +from .conv import * +from .normalize import * +from .resnet_block import * +from .updownsample import * diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py new file mode 100644 index 0000000000..ac4c96e3b0 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py @@ -0,0 +1,230 @@ +import torch.nn as nn +import torch_npu +from .normalize import Normalize +from .conv import CausalConv3d +import torch +from einops import rearrange +from .block import Block +from .ops import video_to_image + +class LinearAttention(Block): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock3D(Block): + """Compatible with old versions, there are issues, use with caution.""" + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, t, h, w = q.shape + q = q.reshape(b * t, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b * t, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, t, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + +class AttnBlock3DFix(nn.Module): + """ + Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. + """ + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c) + b, c, t, h, w = q.shape + q = q.permute(0, 2, 1, 3, 4) + q = q.reshape(b * t, c, h * w) + q = q.permute(0, 2, 1) + + # k: (b c t h w) -> (b t c h w) -> (b*t c h*w) + k = k.permute(0, 2, 1, 3, 4) + k = k.reshape(b * t, c, h * w) + + # w: (b*t hw hw) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + # v: (b c t h w) -> (b t c h w) -> (bt c hw) + # w_: (bt hw hw) -> (bt hw hw) + v = v.permute(0, 2, 1, 3, 4) + v = v.reshape(b * t, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) + h_ = h_.reshape(b, t, c, h, w) + h_ = h_.permute(0, 2, 1, 3 ,4) + + h_ = self.proj_out(h_) + + return x + h_ + + +class AttnBlock(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + @video_to_image + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class TemporalAttnBlock(Block): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, t, h, w = q.shape + q = rearrange(q, "b c t h w -> (b h w) t c") + k = rearrange(k, "b c t h w -> (b h w) c t") + v = rearrange(v, "b c t h w -> (b h w) c t") + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w) + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + print(attn_type) + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "vanilla3D": + return AttnBlock3D(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/block.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/block.py new file mode 100644 index 0000000000..e93672d20b --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/block.py @@ -0,0 +1,5 @@ +import torch.nn as nn + +class Block(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/conv.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/conv.py new file mode 100644 index 0000000000..629715b535 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/conv.py @@ -0,0 +1,117 @@ +import torch.nn as nn +from typing import Union, Tuple +import torch.nn.functional as F +import torch +import torch_npu +from .block import Block +from .ops import cast_tuple +from einops import rearrange +from .ops import video_to_image +from torch.utils.checkpoint import checkpoint + +class Conv2d(nn.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]] = 3, + stride: Union[int, Tuple[int]] = 1, + padding: Union[str, int, Tuple[int]] = 0, + dilation: Union[int, Tuple[int]] = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + @video_to_image + def forward(self, x): + return super().forward(x) + + + +class CausalConv3d(nn.Module): + def __init__( + self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.time_kernel_size = self.kernel_size[0] + self.chan_in = chan_in + self.chan_out = chan_out + stride = kwargs.pop("stride", 1) + padding = kwargs.pop("padding", 0) + padding = list(cast_tuple(padding, 3)) + padding[0] = 0 + stride = cast_tuple(stride, 3) + self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding) + self.pad = nn.ReplicationPad2d((0, 0, self.time_kernel_size - 1, 0)) + self._init_weights(init_method) + + def _init_weights(self, init_method): + ks = torch.tensor(self.kernel_size) + if init_method == "avg": + assert ( + self.kernel_size[1] == 1 and self.kernel_size[2] == 1 + ), "only support temporal up/down sample" + assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" + weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) + + eyes = torch.concat( + [ + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + ], + dim=-1, + ) + weight[:, :, :, 0, 0] = eyes + + self.conv.weight = nn.Parameter( + weight, + requires_grad=True, + ) + elif init_method == "zero": + self.conv.weight = nn.Parameter( + torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)), + requires_grad=True, + ) + if self.conv.bias is not None: + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + x_dtype = x.dtype + # 1 + 16 16 as video, 1 as image + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) # b c t h w + x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 + x = self.conv.to(device=x.device, dtype=torch.float16)(x.to(torch.float16)) + x = x.to(x_dtype) + return torch_npu.npu_format_cast(x, 2) + + +class CausalConv3d_GC(CausalConv3d): + def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]], init_method="random", **kwargs): + super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs) + def forward(self, x): + # 1 + 16 16 as video, 1 as image + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) # b c t h w + x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16 + return checkpoint(self.conv, x) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py new file mode 100644 index 0000000000..7c8c05f0fa --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from .block import Block + +class GroupNorm(Block): + def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.norm = torch.nn.GroupNorm( + num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True + ) + def forward(self, x): + return self.norm(x) + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/ops.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/ops.py new file mode 100644 index 0000000000..e298cceebf --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/ops.py @@ -0,0 +1,47 @@ +import torch +import torch_npu +import torch.nn as nn +from einops import rearrange + + +def video_to_image(func): + def wrapper(self, x, *args, **kwargs): + if x.dim() == 5: + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = func(self, x, *args, **kwargs) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + + return wrapper + + +def nonlinearity(x): + return nn.functional.silu(x) + + +def cast_tuple(t, length=1): + return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length) + + +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + dims = list(range(n_dims)) + del dims[src_dim] + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/resnet_block.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/resnet_block.py new file mode 100644 index 0000000000..bd225a81af --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/resnet_block.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +from einops import rearrange, pack, unpack +from .normalize import Normalize +from .ops import nonlinearity, video_to_image +from .conv import CausalConv3d, CausalConv3d_GC +from .block import Block +from torch.utils.checkpoint import checkpoint +import torch_npu + +class ResnetBlock2D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + @video_to_image + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + x = x + h + return x + +class ResnetBlock3D(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) + else: + self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class ResnetBlock3D_GC(Block): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) + else: + self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) + + def forward(self, x): + return checkpoint(self._forward, x, use_reentrant=True) + + def _forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/updownsample.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/updownsample.py new file mode 100644 index 0000000000..9308b77a97 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/updownsample.py @@ -0,0 +1,350 @@ +from typing import Union, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .attention import TemporalAttnBlock +from .block import Block +from .conv import CausalConv3d, CausalConv3d_GC +from .normalize import Normalize +from .ops import cast_tuple, video_to_image +from .resnet_block import ResnetBlock3D + + +class Upsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.with_conv = True + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + @video_to_image + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(Block): + def __init__(self, in_channels, out_channels, undown=False): + super().__init__() + self.with_conv = True + self.undown = undown + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + if self.undown: + self.conv = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.conv = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=0) + + @video_to_image + def forward(self, x): + if self.with_conv: + if self.undown: + x = self.conv(x) + else: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class SpatialDownsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (2, 2), + **kwargs + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 2) + stride = cast_tuple(stride, 2) + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.conv = CausalConv3d( + self.chan_in, + self.chan_out, + (1,) + self.kernel_size, + stride=(1,) + stride, + padding=0 + ) + + def forward(self, x): + pad = (0, 1, 0, 1, 0, 0) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class SpatialUpsample2x_GC(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (1, 1), + unup=False, + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.unup = unup + self.conv = CausalConv3d_GC( + self.chan_in, + self.chan_out, + (1,) + self.kernel_size, + stride=(1,) + stride, + padding=1 + ) + + def forward(self, x): + if not self.unup: + t = x.shape[2] + x = rearrange(x, "b c t h w -> b (c t) h w") + x = F.interpolate(x, scale_factor=(2, 2), mode="nearest") + x = rearrange(x, "b (c t) h w -> b c t h w", t=t) + x = self.conv(x) + return x + + +class SpatialUpsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: Union[int, Tuple[int]] = (3, 3), + stride: Union[int, Tuple[int]] = (1, 1), + unup=False, + ): + super().__init__() + self.chan_in = chan_in + self.chan_out = chan_out + self.kernel_size = kernel_size + self.unup = unup + self.conv = CausalConv3d( + self.chan_in, + self.chan_out, + (1,) + self.kernel_size, + stride=(1,) + stride, + padding=1 + ) + + def forward(self, x): + if not self.unup: + t = x.shape[2] + x = rearrange(x, "b c t h w -> b (c t) h w") + x = F.interpolate(x, scale_factor=(2, 2), mode="nearest") + x = rearrange(x, "b (c t) h w -> b c t h w", t=t) + x = self.conv(x) + return x + + +class TimeDownsample2x(Block): + def __init__( + self, + chan_in, + chan_out, + kernel_size: int = 3 + ): + super().__init__() + self.kernel_size = kernel_size + self.avg_pool = nn.AvgPool2d((kernel_size, 1), stride=(2, 1)) + self.pad = nn.ReplicationPad3d((0, 0, 0, 0, self.kernel_size - 1, 0)) + + def forward(self, x): + n, c, d, h, w = x.shape + x = self.pad(x) + x = x.view(n * c, -1, h * w) + pooled = self.avg_pool(x) + output = pooled.view(n, c, -1, h, w) + return output + + +class TimeUpsample2x(Block): + def __init__( + self, + chan_in, + chan_out + ): + super().__init__() + + def forward(self, x): + if x.size(2) > 1: + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + return x + + +class TimeDownsampleRes2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 2.0, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.avg_pool = nn.AvgPool2d((kernel_size, 1), stride=(2, 1)) + self.pad = nn.ReplicationPad3d((0, 0, 0, 0, kernel_size - 1, 0)) + self.conv = nn.Conv3d( + in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1) + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + alpha = torch.sigmoid(self.mix_factor) + n, c, d, h, w = x.shape + x = self.pad(x) + pad_x = x.view(n, c, -1, h, w) + avg_x = self.avg_pool(x.view(n * c, -1, h * w)).view(n, c, -1, h, w).to(x_dtype) + conv_x = self.conv(pad_x) + return alpha * avg_x + (1 - alpha) * conv_x + + +class TimeUpsampleRes2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 2.0, + ): + super().__init__() + self.conv = CausalConv3d( + in_channels, out_channels, kernel_size, padding=1 + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + alpha = torch.sigmoid(self.mix_factor) + if x.size(2) > 1: + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + return alpha * x + (1 - alpha) * self.conv(x) + + +class TimeDownsampleResAdv2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 1.5, + ): + super().__init__() + self.kernel_size = cast_tuple(kernel_size, 3) + self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1)) + self.attn = TemporalAttnBlock(in_channels) + self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) + self.conv = nn.Conv3d( + in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1) + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.kernel_size[0] - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + alpha = torch.sigmoid(self.mix_factor) + return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x)))) + + +class TimeUpsampleResAdv2x(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + mix_factor: float = 1.5, + ): + super().__init__() + self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0) + self.attn = TemporalAttnBlock(in_channels) + self.norm = Normalize(in_channels=in_channels) + self.conv = CausalConv3d( + in_channels, out_channels, kernel_size, padding=1 + ) + self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor])) + + def forward(self, x): + if x.size(2) > 1: + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode='trilinear') + x = torch.concat([x, x_], dim=2) + alpha = torch.sigmoid(self.mix_factor) + return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x))) + + +class Spatial2xTime2x3DDownsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=2) + + def forward(self, x): + pad = (0, 1, 0, 1, 0, 0) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Spatial2x3DDownsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=0, stride=(1, 2, 2)) + + def forward(self, x): + pad = (0, 1, 0, 1, 0, 0) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Spatial2x3DUpsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') + return self.conv(x) + + +class Spatial2xTime2x3DUpsample(Block): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x): + if x.size(2) > 1: + x, x_ = x[:, :, :1], x[:, :, 1:] + x_ = F.interpolate(x_, scale_factor=(2, 2, 2), mode='trilinear') + x = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') + x = torch.concat([x, x_], dim=2) + else: + x = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear') + return self.conv(x) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py new file mode 100644 index 0000000000..760c0673fe --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py @@ -0,0 +1,42 @@ +import torch +import numpy as np + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/module_utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/module_utils.py new file mode 100644 index 0000000000..3f843e4f3d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/module_utils.py @@ -0,0 +1,17 @@ +import importlib + +Module = str +MODULES_BASE = "opensora.models.causalvideovae.model.modules." + +def resolve_str_to_obj(str_val, append=True): + if append: + str_val = MODULES_BASE + str_val + module_name, class_name = str_val.rsplit('.', 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + +def create_instance(module_class_str: str, **kwargs): + module_name, class_name = module_class_str.rsplit('.', 1) + module = importlib.import_module(module_name) + class_ = getattr(module, class_name) + return class_(**kwargs) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/__init__.py new file mode 100644 index 0000000000..b0dac65653 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/__init__.py @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py new file mode 100644 index 0000000000..36b9037b53 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py @@ -0,0 +1,571 @@ +import collections +from typing import Any, Dict, Optional + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.models.transformer_2d import Transformer2DModelOutput +from diffusers.utils import deprecate, logging +from einops import rearrange, repeat +from torch import nn +from torch.nn import functional as F + +from opensora.models.diffusion.opensora.modules import PatchEmbed2D, BasicTransformerBlock, LayerNorm +from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_size + +logger = logging.get_logger(__name__) + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +class OpenSoraT2V(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + sample_size_t: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + interpolation_scale_h: float = None, + interpolation_scale_w: float = None, + interpolation_scale_t: float = None, + use_additional_conditions: Optional[bool] = None, + attention_mode: str = 'xformers', + downsampler: str = None, + use_rope: bool = False, + use_stable_fp32: bool = False, + ): + super().__init__() + + # Validate inputs. + if patch_size is not None: + if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]: + raise NotImplementedError( + f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'." + ) + elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None: + raise ValueError( + f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." + ) + + # Set some common variables used across the board. + self.use_rope = use_rope + self.use_linear_projection = use_linear_projection + self.interpolation_scale_t = interpolation_scale_t + self.interpolation_scale_h = interpolation_scale_h + self.interpolation_scale_w = interpolation_scale_w + self.downsampler = downsampler + self.caption_channels = caption_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.gradient_checkpointing = False + self.use_additional_conditions = False + self.cache = None + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + assert in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # 2. Initialize the right blocks. + # Initialize the output blocks and other projection blocks when necessary. + self._init_patched_inputs(norm_type=norm_type) + + def _init_patched_inputs(self, norm_type): + assert self.config.sample_size_t is not None, "OpenSoraT2V over patched input must provide sample_size_t" + assert self.config.sample_size is not None, "OpenSoraT2V over patched input must provide sample_size" + + self.num_frames = self.config.sample_size_t + self.config.sample_size = to_2tuple(self.config.sample_size) + self.height = self.config.sample_size[0] + self.width = self.config.sample_size[1] + self.patch_size_t = self.config.patch_size_t + self.patch_size = self.config.patch_size + interpolation_scale_t = (( + self.config.sample_size_t - 1) // 16 + 1) if self.config.sample_size_t % 2 == 1 else self.config.sample_size_t / 16 + interpolation_scale_t = ( + self.config.interpolation_scale_t if self.config.interpolation_scale_t is not None else interpolation_scale_t + ) + interpolation_scale = ( + self.config.interpolation_scale_h if self.config.interpolation_scale_h is not None else + self.config.sample_size[0] / 30, + self.config.interpolation_scale_w if self.config.interpolation_scale_w is not None else + self.config.sample_size[1] / 40, + ) + self.pos_embed = PatchEmbed2D( + num_frames=self.config.sample_size_t, + height=self.config.sample_size[0], + width=self.config.sample_size[1], + patch_size_t=self.config.patch_size_t, + patch_size=self.config.patch_size, + in_channels=self.in_channels, + embed_dim=self.inner_dim, + interpolation_scale=interpolation_scale, + interpolation_scale_t=interpolation_scale_t, + use_abs_pos=not self.config.use_rope, + ) + interpolation_scale_thw = (interpolation_scale_t, *interpolation_scale) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.config.num_attention_heads, + self.config.attention_head_dim, + dropout=self.config.dropout, + cross_attention_dim=self.config.cross_attention_dim, + activation_fn=self.config.activation_fn, + num_embeds_ada_norm=self.config.num_embeds_ada_norm, + attention_bias=self.config.attention_bias, + only_cross_attention=self.config.only_cross_attention, + double_self_attention=self.config.double_self_attention, + upcast_attention=self.config.upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=self.config.norm_elementwise_affine, + norm_eps=self.config.norm_eps, + attention_type=self.config.attention_type, + attention_mode=self.config.attention_mode, + downsampler=self.config.downsampler, + use_rope=self.config.use_rope, + interpolation_scale_thw=interpolation_scale_thw, + ) + for _ in range(self.config.num_layers) + ] + ) + + if self.config.norm_type != "ada_norm_single": + self.norm_out = LayerNorm(self.inner_dim, eps=1e-6) + self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim) + self.proj_out_2 = nn.Linear( + self.inner_dim, + self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels + ) + elif self.config.norm_type == "ada_norm_single": + self.norm_out = LayerNorm(self.inner_dim, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim ** 0.5) + self.proj_out = nn.Linear( + self.inner_dim, + self.config.patch_size_t * self.config.patch_size * self.config.patch_size * self.out_channels + ) + + # PixArt-Alpha blocks. + self.adaln_single = None + if self.config.norm_type == "ada_norm_single": + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=self.use_additional_conditions + ) + + self.caption_projection = None + if self.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, hidden_size=self.inner_dim + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + step_id: int = 0, + use_image_num: Optional[int] = 0, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + batch_size, c, frame, h, w = hidden_states.shape + frame = frame - use_image_num # 21-4=17 + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + attention_mask_vid, attention_mask_img = None, None + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask = attention_mask.to(self.dtype) + if get_sequence_parallel_state(): + attention_mask_vid = attention_mask[:, :frame * get_sequence_parallel_size()] # b, frame, h, w + attention_mask_img = attention_mask[:, frame * get_sequence_parallel_size():] # b, use_image_num, h, w + else: + attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w + attention_mask_img = attention_mask[:, frame:] # b, use_image_num, h, w + + if attention_mask_vid.numel() > 0: + attention_mask_vid_first_frame = attention_mask_vid[:, :1].repeat(1, self.patch_size_t - 1, 1, 1) + attention_mask_vid = torch.cat([attention_mask_vid_first_frame, attention_mask_vid], dim=1) + attention_mask_vid = attention_mask_vid.unsqueeze(1) # b 1 t h w + attention_mask_vid = F.max_pool3d(attention_mask_vid, + kernel_size=(self.patch_size_t, self.patch_size, self.patch_size), + stride=(self.patch_size_t, self.patch_size, self.patch_size)) + attention_mask_vid = rearrange(attention_mask_vid, 'b 1 t h w -> (b 1) 1 (t h w)') + if attention_mask_img.numel() > 0: + attention_mask_img = F.max_pool2d(attention_mask_img, kernel_size=(self.patch_size, self.patch_size), + stride=(self.patch_size, self.patch_size)) + attention_mask_img = rearrange(attention_mask_img, 'b i h w -> (b i) 1 (h w)') + + attention_mask_vid = (1 - attention_mask_vid.bool().to( + self.dtype)) * -10000.0 if attention_mask_vid.numel() > 0 else None + attention_mask_img = (1 - attention_mask_img.bool().to( + self.dtype)) * -10000.0 if attention_mask_img.numel() > 0 else None + + if frame == 1 and use_image_num == 0 and not get_sequence_parallel_state(): + attention_mask_img = attention_mask_vid + attention_mask_vid = None + # convert encoder_attention_mask to a bias the same way we do for attention_mask + encoder_attention_mask_vid, encoder_attention_mask_img = None, None + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: + # b, 1+use_image_num, l -> a video with images + # b, 1, l -> only images + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + in_t = encoder_attention_mask.shape[1] + encoder_attention_mask_vid = encoder_attention_mask[:, :in_t - use_image_num] # b, 1, l + encoder_attention_mask_vid = rearrange(encoder_attention_mask_vid, + 'b 1 l -> (b 1) 1 l') if encoder_attention_mask_vid.numel() > 0 else None + + encoder_attention_mask_img = encoder_attention_mask[:, in_t - use_image_num:] # b, use_image_num, l + encoder_attention_mask_img = rearrange(encoder_attention_mask_img, + 'b i l -> (b i) 1 l') if encoder_attention_mask_img.numel() > 0 else None + + if frame == 1 and use_image_num == 0 and not get_sequence_parallel_state(): + encoder_attention_mask_img = encoder_attention_mask_vid + encoder_attention_mask_vid = None + + if attention_mask_vid is not None: + attention_mask_vid = attention_mask_vid.to(torch.bool) + attention_mask_vid = attention_mask_vid.repeat(1, attention_mask_vid.shape[-1], 1) + encoder_attention_mask_vid = encoder_attention_mask_vid.to(torch.bool) + encoder_attention_mask_vid = encoder_attention_mask_vid.repeat(1, attention_mask_vid.shape[-2], 1) + if attention_mask_img is not None: + attention_mask_img = attention_mask_img.to(torch.bool) + attention_mask_img = attention_mask_img.repeat(1, attention_mask_vid.shape[-1], 1) + encoder_attention_mask_img = encoder_attention_mask_img.to(torch.bool) + encoder_attention_mask_img = encoder_attention_mask_img.repeat(1, attention_mask_vid.shape[-2], 1) + + # 1. Input + frame = ((frame - 1) // self.patch_size_t + 1) if frame % 2 == 1 else frame // self.patch_size_t + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + hidden_states_vid, hidden_states_img, encoder_hidden_states_vid, encoder_hidden_states_img, \ + timestep_vid, timestep_img, embedded_timestep_vid, embedded_timestep_img = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, frame, use_image_num + ) + # 2. Blocks + if get_sequence_parallel_state(): + if hidden_states_vid is not None: + hidden_states_vid = rearrange(hidden_states_vid, 'b s h -> s b h', b=batch_size).contiguous() + encoder_hidden_states_vid = rearrange(encoder_hidden_states_vid, 'b s h -> s b h', + b=batch_size).contiguous() + timestep_vid = timestep_vid.view(batch_size, 6, -1).transpose(0, 1).contiguous() + + for i, block in enumerate(self.transformer_blocks): + if hidden_states_vid is not None: + if self.cache: + hidden_states_vid = self.cache(block, step_id, i, + hidden_states_vid, + attention_mask=attention_mask_vid, + encoder_hidden_states=encoder_hidden_states_vid, + encoder_attention_mask=encoder_attention_mask_vid, + timestep=timestep_vid, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + frame=frame, + height=height, + width=width, + ) + else: + hidden_states_vid = block( + hidden_states_vid, + attention_mask=attention_mask_vid, + encoder_hidden_states=encoder_hidden_states_vid, + encoder_attention_mask=encoder_attention_mask_vid, + timestep=timestep_vid, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + frame=frame, + height=height, + width=width, + ) + if hidden_states_img is not None: + if self.cache: + hidden_states_img = self.cache(block, step_id, i, + hidden_states_img, + attention_mask=attention_mask_img, + encoder_hidden_states=encoder_hidden_states_img, + encoder_attention_mask=encoder_attention_mask_img, + timestep=timestep_img, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + frame=1, + height=height, + width=width, + ) + else: + hidden_states_img = block( + hidden_states_img, + attention_mask=attention_mask_img, + encoder_hidden_states=encoder_hidden_states_img, + encoder_attention_mask=encoder_attention_mask_img, + timestep=timestep_img, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + frame=1, + height=height, + width=width, + ) + + if get_sequence_parallel_state(): + if hidden_states_vid is not None: + hidden_states_vid = rearrange(hidden_states_vid, 's b h -> b s h', b=batch_size).contiguous() + + # 3. Output + output_vid, output_img = None, None + if hidden_states_vid is not None: + output_vid = self._get_output_for_patched_inputs( + hidden_states=hidden_states_vid, + timestep=timestep_vid, + class_labels=class_labels, + embedded_timestep=embedded_timestep_vid, + num_frames=frame, + height=height, + width=width, + ) # b c t h w + if hidden_states_img is not None: + output_img = self._get_output_for_patched_inputs( + hidden_states=hidden_states_img, + timestep=timestep_img, + class_labels=class_labels, + embedded_timestep=embedded_timestep_img, + num_frames=1, + height=height, + width=width, + ) # b c 1 h w + if use_image_num != 0: + output_img = rearrange(output_img, '(b i) c 1 h w -> b c i h w', i=use_image_num) + + if output_vid is not None and output_img is not None: + output = torch.cat([output_vid, output_img], dim=2) + elif output_vid is not None: + output = output_vid + elif output_img is not None: + output = output_img + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs, batch_size, + frame, use_image_num): + hidden_states_vid, hidden_states_img = self.pos_embed(hidden_states.to(self.dtype), frame) + timestep_vid, timestep_img = None, None + embedded_timestep_vid, embedded_timestep_img = None, None + encoder_hidden_states_vid, encoder_hidden_states_img = None, None + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=self.dtype + ) # b 6d, b d + if hidden_states_vid is None: + timestep_img = timestep + embedded_timestep_img = embedded_timestep + else: + timestep_vid = timestep + embedded_timestep_vid = embedded_timestep + if hidden_states_img is not None: + timestep_img = repeat(timestep, 'b d -> (b i) d', i=use_image_num).contiguous() + embedded_timestep_img = repeat(embedded_timestep, 'b d -> (b i) d', i=use_image_num).contiguous() + + if self.caption_projection is not None: + encoder_hidden_states = self.caption_projection( + encoder_hidden_states) # b, 1+use_image_num, l, d or b, 1, l, d + if hidden_states_vid is None: + encoder_hidden_states_img = rearrange(encoder_hidden_states, 'b 1 l d -> (b 1) l d') + else: + encoder_hidden_states_vid = rearrange(encoder_hidden_states[:, :1], 'b 1 l d -> (b 1) l d') + if hidden_states_img is not None: + encoder_hidden_states_img = rearrange(encoder_hidden_states[:, 1:], 'b i l d -> (b i) l d') + + return hidden_states_vid, hidden_states_img, encoder_hidden_states_vid, encoder_hidden_states_img, timestep_vid, timestep_img, embedded_timestep_vid, embedded_timestep_img + + def _get_output_for_patched_inputs( + self, hidden_states, timestep, class_labels, embedded_timestep, num_frames, height=None, width=None + ): + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=self.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states, scale=(1 + scale), shift=shift) + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states, scale=(1 + scale), shift=shift) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=( + -1, num_frames, height, width, self.patch_size_t, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nthwopqc->nctohpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, self.out_channels, num_frames * self.patch_size_t, height * self.patch_size, + width * self.patch_size) + ) + + return output + + +def OpenSoraT2V_S_122(**kwargs): + return OpenSoraT2V(num_layers=28, attention_head_dim=96, num_attention_heads=16, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1536, **kwargs) + + +def OpenSoraT2V_B_122(**kwargs): + return OpenSoraT2V(num_layers=32, attention_head_dim=96, num_attention_heads=16, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1920, **kwargs) + + +def OpenSoraT2V_L_122(**kwargs): + return OpenSoraT2V(num_layers=40, attention_head_dim=128, num_attention_heads=16, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=2048, **kwargs) + + +def OpenSoraT2V_ROPE_L_122(**kwargs): + return OpenSoraT2V(num_layers=32, attention_head_dim=96, num_attention_heads=24, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=2304, **kwargs) + + +OpenSora_models = { + "OpenSoraT2V-S/122": OpenSoraT2V_S_122, # 1.1B + "OpenSoraT2V-B/122": OpenSoraT2V_B_122, + "OpenSoraT2V-L/122": OpenSoraT2V_L_122, + "OpenSoraT2V-ROPE-L/122": OpenSoraT2V_ROPE_L_122, +} + +OpenSora_models_class = { + "OpenSoraT2V-S/122": OpenSoraT2V, + "OpenSoraT2V-B/122": OpenSoraT2V, + "OpenSoraT2V-L/122": OpenSoraT2V, + "OpenSoraT2V-ROPE-L/122": OpenSoraT2V, +} diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py new file mode 100644 index 0000000000..770ce41776 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -0,0 +1,1175 @@ +import math +import re +from typing import Any, Dict, Optional +from typing import Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import torch_npu +from diffusers.models.attention import FeedForward, GatedSelfAttentionDense +from diffusers.models.attention_processor import Attention as Attention_ +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange +from torch import nn + +from utils.comm import all_to_all_sbh +from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_size, get_sequence_parallel_rank +from .rope import PositionGetter3D, RoPE3D + +logger = logging.get_logger(__name__) + + +def get_3d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16, +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_t = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale[0] + grid_h = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale[1] + grid_w = np.arange(grid_size[2], dtype=np.float32) / (grid_size[2] / base_size[2]) / interpolation_scale[2] + grid = np.meshgrid(grid_w, grid_h, grid_t) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([3, 1, grid_size[2], grid_size[1], grid_size[0]]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") + + # use 1/3 of dimensions to encode grid_t/h/w + emb_t = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (T*H*W, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (T*H*W, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (T*H*W, D/3) + + emb = np.concatenate([emb_t, emb_h, emb_w], axis=1) # (T*H*W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16, +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size[0]) / interpolation_scale[0] + grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size[1]) / interpolation_scale[1] + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use 1/3 of dimensions to encode grid_t/h/w + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed( + embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16, +): + """ + grid_size: int of the grid return: pos_embed: [grid_size, embed_dim] or + [1+grid_size, embed_dim] (w/ or w/o cls_token) + """ + + grid = np.arange(grid_size, dtype=np.float32) / (grid_size / base_size) / interpolation_scale + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) # (H*W, D/2) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed2D(nn.Module): + """2D Image to Patch Embedding but with 3D position embedding""" + + def __init__( + self, + num_frames=1, + height=224, + width=224, + patch_size_t=1, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=(1, 1), + interpolation_scale_t=1, + use_abs_pos=True, + ): + super().__init__() + self.use_abs_pos = use_abs_pos + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size_t = patch_size_t + self.patch_size = patch_size + + self.height, self.width = height // patch_size, width // patch_size + self.base_size = (height // patch_size, width // patch_size) + self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) + pos_embed = get_2d_sincos_pos_embed( + embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.interpolation_scale_t = interpolation_scale_t + temp_pos_embed = get_1d_sincos_pos_embed(embed_dim, self.num_frames, base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t) + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0])) + + def forward(self, latent, num_frames): + b, _, _, _, _ = latent.shape + # b c 1 h w + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + latent = rearrange(latent, 'b c t h w -> (b t) c h w') + latent = self.proj(latent) + + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C + if self.layer_norm: + latent = self.norm(latent) + + if self.use_abs_pos: + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + if self.num_frames != num_frames: + if get_sequence_parallel_state(): + sp_size = get_sequence_parallel_size() + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim=self.temp_pos_embed.shape[-1], + grid_size=num_frames * sp_size, + base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t, + ) + rank = get_sequence_parallel_rank() % sp_size + st_frame = rank * num_frames + ed_frame = st_frame + num_frames + temp_pos_embed = temp_pos_embed[st_frame: ed_frame] + else: + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim=self.temp_pos_embed.shape[-1], + grid_size=num_frames, + base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t, + ) + temp_pos_embed = torch.from_numpy(temp_pos_embed) + temp_pos_embed = temp_pos_embed.float().unsqueeze(0).to(latent.device) + else: + temp_pos_embed = self.temp_pos_embed + + latent = (latent + pos_embed).to(latent.dtype) + + latent = rearrange(latent, '(b t) n c -> b t n c', b=b) + video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] + + if self.use_abs_pos: + # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() + temp_pos_embed = temp_pos_embed.unsqueeze(2) + video_latent = (video_latent + temp_pos_embed).to( + video_latent.dtype) if video_latent is not None and video_latent.numel() > 0 else None + image_latent = (image_latent + temp_pos_embed[:, :1]).to( + image_latent.dtype) if image_latent is not None and image_latent.numel() > 0 else None + + video_latent = rearrange(video_latent, + 'b t n c -> b (t n) c') if video_latent is not None and video_latent.numel() > 0 else None + image_latent = rearrange(image_latent, + 'b t n c -> (b t) n c') if image_latent is not None and image_latent.numel() > 0 else None + + if num_frames == 1 and image_latent is None and not get_sequence_parallel_state(): + image_latent = video_latent + video_latent = None + return video_latent, image_latent + + +class OverlapPatchEmbed3D(nn.Module): + """2D Image to Patch Embedding but with 3D position embedding""" + + def __init__( + self, + num_frames=1, + height=224, + width=224, + patch_size_t=1, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=(1, 1), + interpolation_scale_t=1, + use_abs_pos=True, + ): + super().__init__() + self.use_abs_pos = use_abs_pos + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv3d( + in_channels, embed_dim, kernel_size=(patch_size_t, patch_size, patch_size), + stride=(patch_size_t, patch_size, patch_size), bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size_t = patch_size_t + self.patch_size = patch_size + + self.height, self.width = height // patch_size, width // patch_size + self.base_size = (height // patch_size, width // patch_size) + self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) + pos_embed = get_2d_sincos_pos_embed( + embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.interpolation_scale_t = interpolation_scale_t + temp_pos_embed = get_1d_sincos_pos_embed(embed_dim, self.num_frames, base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t) + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent, num_frames): + b, _, _, _, _ = latent.shape + # b c 1 h w + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + latent = self.proj(latent) + + if self.flatten: + latent = rearrange(latent, 'b c t h w -> (b t) (h w) c ') + if self.layer_norm: + latent = self.norm(latent) + + if self.use_abs_pos: + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + if self.num_frames != num_frames: + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim=self.temp_pos_embed.shape[-1], + grid_size=num_frames, + base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t, + ) + temp_pos_embed = torch.from_numpy(temp_pos_embed) + temp_pos_embed = temp_pos_embed.float().unsqueeze(0).to(latent.device) + else: + temp_pos_embed = self.temp_pos_embed + + latent = (latent + pos_embed).to(latent.dtype) + + latent = rearrange(latent, '(b t) n c -> b t n c', b=b) + video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] + + if self.use_abs_pos: + temp_pos_embed = temp_pos_embed.unsqueeze(2) + video_latent = (video_latent + temp_pos_embed).to( + video_latent.dtype) if video_latent is not None and video_latent.numel() > 0 else None + image_latent = (image_latent + temp_pos_embed[:, :1]).to( + image_latent.dtype) if image_latent is not None and image_latent.numel() > 0 else None + + video_latent = rearrange(video_latent, + 'b t n c -> b (t n) c') if video_latent is not None and video_latent.numel() > 0 else None + image_latent = rearrange(image_latent, + 'b t n c -> (b t) n c') if image_latent is not None and image_latent.numel() > 0 else None + + if num_frames == 1 and image_latent is None: + image_latent = video_latent + video_latent = None + return video_latent, image_latent + + +class OverlapPatchEmbed2D(nn.Module): + """2D Image to Patch Embedding but with 3D position embedding""" + + def __init__( + self, + num_frames=1, + height=224, + width=224, + patch_size_t=1, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=(1, 1), + interpolation_scale_t=1, + use_abs_pos=True, + ): + super().__init__() + assert patch_size_t == 1 + self.use_abs_pos = use_abs_pos + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size_t = patch_size_t + self.patch_size = patch_size + + self.height, self.width = height // patch_size, width // patch_size + self.base_size = (height // patch_size, width // patch_size) + self.interpolation_scale = (interpolation_scale[0], interpolation_scale[1]) + pos_embed = get_2d_sincos_pos_embed( + embed_dim, (self.height, self.width), base_size=self.base_size, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + self.num_frames = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.base_size_t = (num_frames - 1) // patch_size_t + 1 if num_frames % 2 == 1 else num_frames // patch_size_t + self.interpolation_scale_t = interpolation_scale_t + temp_pos_embed = get_1d_sincos_pos_embed(embed_dim, self.num_frames, base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t) + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent, num_frames): + b, _, _, _, _ = latent.shape + # b c 1 h w + height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size + latent = rearrange(latent, 'b c t h w -> (b t) c h w') + latent = self.proj(latent) + + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C + if self.layer_norm: + latent = self.norm(latent) + + if self.use_abs_pos: + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + if self.num_frames != num_frames: + temp_pos_embed = get_1d_sincos_pos_embed( + embed_dim=self.temp_pos_embed.shape[-1], + grid_size=num_frames, + base_size=self.base_size_t, + interpolation_scale=self.interpolation_scale_t, + ) + temp_pos_embed = torch.from_numpy(temp_pos_embed) + temp_pos_embed = temp_pos_embed.float().unsqueeze(0).to(latent.device) + else: + temp_pos_embed = self.temp_pos_embed + + latent = (latent + pos_embed).to(latent.dtype) + + latent = rearrange(latent, '(b t) n c -> b t n c', b=b) + video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] + + if self.use_abs_pos: + # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() + temp_pos_embed = temp_pos_embed.unsqueeze(2) + video_latent = (video_latent + temp_pos_embed).to( + video_latent.dtype) if video_latent is not None and video_latent.numel() > 0 else None + image_latent = (image_latent + temp_pos_embed[:, :1]).to( + image_latent.dtype) if image_latent is not None and image_latent.numel() > 0 else None + + video_latent = rearrange(video_latent, + 'b t n c -> b (t n) c') if video_latent is not None and video_latent.numel() > 0 else None + image_latent = rearrange(image_latent, + 'b t n c -> (b t) n c') if image_latent is not None and image_latent.numel() > 0 else None + + if num_frames == 1 and image_latent is None: + image_latent = video_latent + video_latent = None + return video_latent, image_latent + + +class Attention(Attention_): + def __init__(self, downsampler, attention_mode, use_rope, interpolation_scale_thw, **kwags): + processor = AttnProcessor2_0(attention_mode=attention_mode, use_rope=use_rope, + interpolation_scale_thw=interpolation_scale_thw) + super().__init__(processor=processor, **kwags) + self.downsampler = None + if downsampler: # downsampler k155_s122 + downsampler_ker_size = list(re.search(r'k(\d{2,3})', downsampler).group(1)) # 122 + down_factor = list(re.search(r's(\d{2,3})', downsampler).group(1)) + downsampler_ker_size = [int(i) for i in downsampler_ker_size] + downsampler_padding = [(i - 1) // 2 for i in downsampler_ker_size] + down_factor = [int(i) for i in down_factor] + + if len(downsampler_ker_size) == 2: + self.downsampler = DownSampler2d(kwags['query_dim'], kwags['query_dim'], + kernel_size=downsampler_ker_size, stride=1, + padding=downsampler_padding, groups=kwags['query_dim'], + down_factor=down_factor, + down_shortcut=True) + elif len(downsampler_ker_size) == 3: + self.downsampler = DownSampler3d(kwags['query_dim'], kwags['query_dim'], + kernel_size=downsampler_ker_size, stride=1, + padding=downsampler_padding, groups=kwags['query_dim'], + down_factor=down_factor, + down_shortcut=True) + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if get_sequence_parallel_state(): + head_size = head_size // get_sequence_parallel_size() + if attention_mask is None: + return None + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + +class DownSampler3d(nn.Module): + def __init__(self, *args, **kwargs): + ''' Required kwargs: down_factor, downsampler''' + super().__init__() + self.down_factor = kwargs.pop('down_factor') + self.down_shortcut = kwargs.pop('down_shortcut') + self.layer = nn.Conv3d(*args, **kwargs) + + def forward(self, x, attention_mask, t, h, w): + b = x.shape[0] + x = rearrange(x, 'b (t h w) d -> b d t h w', t=t, h=h, w=w) + + x_dtype = x.dtype + x = self.layer(x) + (x if self.down_shortcut else 0) + + self.t = t // self.down_factor[0] + self.h = h // self.down_factor[1] + self.w = w // self.down_factor[2] + x = rearrange(x, 'b d (t dt) (h dh) (w dw) -> (b dt dh dw) (t h w) d', + t=t // self.down_factor[0], h=h // self.down_factor[1], w=w // self.down_factor[2], + dt=self.down_factor[0], dh=self.down_factor[1], dw=self.down_factor[2]) + + attention_mask = rearrange(attention_mask, 'b 1 (t h w) -> b 1 t h w', t=t, h=h, w=w) + attention_mask = rearrange(attention_mask, 'b 1 (t dt) (h dh) (w dw) -> (b dt dh dw) 1 (t h w)', + t=t // self.down_factor[0], h=h // self.down_factor[1], w=w // self.down_factor[2], + dt=self.down_factor[0], dh=self.down_factor[1], dw=self.down_factor[2]) + return x, attention_mask + + def reverse(self, x, t, h, w): + x = rearrange(x, '(b dt dh dw) (t h w) d -> b (t dt h dh w dw) d', + t=t, h=h, w=w, + dt=self.down_factor[0], dh=self.down_factor[1], dw=self.down_factor[2]) + return x + + +class DownSampler2d(nn.Module): + def __init__(self, *args, **kwargs): + ''' Required kwargs: down_factor, downsampler''' + super().__init__() + self.down_factor = kwargs.pop('down_factor') + self.down_shortcut = kwargs.pop('down_shortcut') + self.layer = nn.Conv2d(*args, **kwargs) + + def forward(self, x, attention_mask, t, h, w): + b = x.shape[0] + x = rearrange(x, 'b (t h w) d -> (b t) d h w', t=t, h=h, w=w) + x = self.layer(x) + (x if self.down_shortcut else 0) + + self.t = 1 + self.h = h // self.down_factor[0] + self.w = w // self.down_factor[1] + + x = rearrange(x, 'b d (h dh) (w dw) -> (b dh dw) (h w) d', + h=h // self.down_factor[0], w=w // self.down_factor[1], + dh=self.down_factor[0], dw=self.down_factor[1]) + + attention_mask = rearrange(attention_mask, 'b 1 (t h w) -> (b t) 1 h w', h=h, w=w) + attention_mask = rearrange(attention_mask, 'b 1 (h dh) (w dw) -> (b dh dw) 1 (h w)', + h=h // self.down_factor[0], w=w // self.down_factor[1], + dh=self.down_factor[0], dw=self.down_factor[1]) + return x, attention_mask + + def reverse(self, x, t, h, w): + x = rearrange(x, '(b t dh dw) (h w) d -> b (t h dh w dw) d', + t=t, h=h, w=w, + dh=self.down_factor[0], dw=self.down_factor[1]) + return x + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, attention_mode='xformers', use_rope=False, interpolation_scale_thw=(1, 1, 1)): + self.use_rope = use_rope + self.interpolation_scale_thw = interpolation_scale_thw + if self.use_rope: + self._init_rope(interpolation_scale_thw) + self.attention_mode = attention_mode + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def _init_rope(self, interpolation_scale_thw): + self.rope = RoPE3D(interpolation_scale_thw=interpolation_scale_thw) + self.position_getter = PositionGetter3D() + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + frame: int = 8, + height: int = 16, + width: int = 16, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = ("The `scale` argument is deprecated and will be ignored. Please remove it, " + "as passing it will raise an error in the future. `scale` should directly be " + "passed while calling the underlying pipeline component i.e., " + "via `cross_attention_kwargs`.") + deprecate("scale", "1.0.0", deprecation_message) + + if attn.downsampler is not None: + hidden_states, attention_mask = attn.downsampler(hidden_states, attention_mask, t=frame, h=height, w=width) + frame, height, width = attn.downsampler.t, attn.downsampler.h, attn.downsampler.w + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + if get_sequence_parallel_state(): + sequence_length, batch_size, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + else: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, 1, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + if get_sequence_parallel_state(): + query = query.view(-1, attn.heads, head_dim) # [s // sp, b, h * d] -> [s // sp * b, h, d] + key = key.view(-1, attn.heads, head_dim) + value = value.view(-1, attn.heads, head_dim) + h_size = attn.heads * head_dim + sp_size = get_sequence_parallel_size() + h_size_sp = h_size // sp_size + # apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d] + query = all_to_all_sbh(query, scatter_dim=1, gather_dim=0).view(-1, batch_size, h_size_sp) + key = all_to_all_sbh(key, scatter_dim=1, gather_dim=0).view(-1, batch_size, h_size_sp) + value = all_to_all_sbh(value, scatter_dim=1, gather_dim=0).view(-1, batch_size, h_size_sp) + if self.use_rope: + query = query.view(-1, batch_size, attn.heads // sp_size, head_dim) + key = key.view(-1, batch_size, attn.heads // sp_size, head_dim) + # require the shape of (batch_size x nheads x ntokens x dim) + pos_thw = self.position_getter(batch_size, t=frame * sp_size, h=height, w=width, + device=query.device) + query = self.rope(query, pos_thw) + key = self.rope(key, pos_thw) + query = query.view(-1, batch_size, h_size_sp) # SBH + key = key.view(-1, batch_size, h_size_sp) # SBH + value = value.view(-1, batch_size, h_size_sp) # SBH + query = rearrange(query, "s b (n d) -> b n s d", d=head_dim) # BNSD + key = rearrange(key, "s b (n d) -> b n s d", d=head_dim) # BNSD + value = rearrange(value, "s b (n d) -> b n s d", d=head_dim) # BNSD + hidden_states = torch_npu.npu_fusion_attention(query, key, value, + atten_mask=attention_mask, + input_layout="BNSD", + scale=1 / math.sqrt(head_dim), + head_num=attn.heads // sp_size)[0] + hidden_states = rearrange(hidden_states, "b n s d -> s b (n d)") + hidden_states = hidden_states.view(-1, attn.heads // sp_size, head_dim) + + # [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] + hidden_states = all_to_all_sbh(hidden_states, scatter_dim=0, gather_dim=1).view(-1, batch_size, h_size) + else: + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + if self.use_rope: + # require the shape of (batch_size x nheads x ntokens x dim) + pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device) + query = self.rope(query, pos_thw) + key = self.rope(key, pos_thw) + query = rearrange(query, "b s n d -> b n s d") # BNSD + key = rearrange(key, "b s n d -> b n s d") # BNSD + value = rearrange(value, "b s n d -> b n s d") # BNSD + hidden_states = torch_npu.npu_fusion_attention(query, key, value, + atten_mask=attention_mask, + input_layout="BNSD", + scale=1 / math.sqrt(head_dim), + head_num=attn.heads)[0] + hidden_states = rearrange(hidden_states, "b n s d -> b s (n d)") + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + if attn.downsampler is not None: + hidden_states = attn.downsampler.reverse(hidden_states, t=frame, h=height, w=width) + return hidden_states + + +class FeedForward_Conv3d(nn.Module): + def __init__(self, downsampler, dim, hidden_features, bias=True): + super(FeedForward_Conv3d, self).__init__() + + self.bias = bias + + self.project_in = nn.Linear(dim, hidden_features, bias=bias) + + self.dwconv = nn.ModuleList([ + nn.Conv3d(hidden_features, hidden_features, kernel_size=(5, 5, 5), stride=1, padding=(2, 2, 2), dilation=1, + groups=hidden_features, bias=bias), + nn.Conv3d(hidden_features, hidden_features, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1), dilation=1, + groups=hidden_features, bias=bias), + nn.Conv3d(hidden_features, hidden_features, kernel_size=(1, 1, 1), stride=1, padding=(0, 0, 0), dilation=1, + groups=hidden_features, bias=bias) + ]) + + self.project_out = nn.Linear(hidden_features, dim, bias=bias) + + def forward(self, x, t, h, w): + x_dtype = x.dtype + x = self.project_in(x) + x = rearrange(x, 'b (t h w) d -> b d t h w', t=t, h=h, w=w) + x = F.gelu(x) + out = x + for module in self.dwconv: + out = out + module(x) + out = rearrange(out, 'b d t h w -> b (t h w) d', t=t, h=h, w=w) + x = self.project_out(out) + return x + + +class FeedForward_Conv2d(nn.Module): + def __init__(self, downsampler, dim, hidden_features, bias=True): + super(FeedForward_Conv2d, self).__init__() + + self.bias = bias + + self.project_in = nn.Linear(dim, hidden_features, bias=bias) + + self.dwconv = nn.ModuleList([ + nn.Conv2d(hidden_features, hidden_features, kernel_size=(5, 5), stride=1, padding=(2, 2), dilation=1, + groups=hidden_features, bias=bias), + nn.Conv2d(hidden_features, hidden_features, kernel_size=(3, 3), stride=1, padding=(1, 1), dilation=1, + groups=hidden_features, bias=bias), + nn.Conv2d(hidden_features, hidden_features, kernel_size=(1, 1), stride=1, padding=(0, 0), dilation=1, + groups=hidden_features, bias=bias) + ]) + + self.project_out = nn.Linear(hidden_features, dim, bias=bias) + + def forward(self, x, t, h, w): + # import ipdb;ipdb.set_trace() + x = self.project_in(x) + x = rearrange(x, 'b (t h w) d -> (b t) d h w', t=t, h=h, w=w) + x = F.gelu(x) + out = x + for module in self.dwconv: + out = out + module(x) + out = rearrange(out, '(b t) d h w -> b (t h w) d', t=t, h=h, w=w) + x = self.project_out(out) + return x + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + attention_mode: str = "xformers", + downsampler: str = None, + use_rope: bool = False, + interpolation_scale_thw: Tuple[int] = (1, 1, 1), + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.downsampler = downsampler + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if norm_type == "ada_norm": + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_zero": + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + attention_mode=attention_mode, + downsampler=downsampler, + use_rope=use_rope, + interpolation_scale_thw=interpolation_scale_thw, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + attention_mode=attention_mode, + downsampler=False, + use_rope=False, + interpolation_scale_thw=interpolation_scale_thw, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + + elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + + if downsampler: + self.ff = FeedForward_Conv2d( + downsampler, + dim, + 2 * dim, + bias=ff_bias, + ) + else: + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if norm_type == "ada_norm_single": + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + frame: int = None, + height: int = None, + width: int = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + if get_sequence_parallel_state(): + batch_size = hidden_states.shape[1] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[:, None] + timestep.reshape(6, batch_size, -1) + ).chunk(6, dim=0) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, frame=frame, height=height, width=width, + **cross_attention_kwargs, + ) + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1.2 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self.downsampler: + ff_output = self.ff(norm_hidden_states, t=frame, h=height, w=width) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + def forward(self, x, shift, scale): + return torch_npu.npu_layer_norm_eval( + x, normalized_shape=[self.hidden_size], weight=scale, bias=shift, eps=self.eps + ) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py new file mode 100644 index 0000000000..e9d0c7c4eb --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py @@ -0,0 +1,88 @@ +import torch +import torch_npu + +from utils.parallel_mgr import get_sequence_parallel_state + + +class PositionGetter3D(object): + """ return positions of patches """ + + def __init__(self, ): + self.cache_positions = {} + + def __call__(self, b, t, h, w, device): + if not (b, t, h, w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + z = torch.arange(t, device=device) + pos = torch.cartesian_prod(z, y, x) + if get_sequence_parallel_state(): + pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, -1, 1).contiguous().expand(3, -1, b).clone() + else: + pos = pos.reshape(t * h * w, 3).transpose(0, 1).reshape(3, 1, -1).contiguous().expand(3, b, -1).clone() + poses = (pos[0].contiguous(), pos[1].contiguous(), pos[2].contiguous()) + max_poses = (int(poses[0].max()), int(poses[1].max()), int(poses[2].max())) + + self.cache_positions[b, t, h, w] = (poses, max_poses) + pos = self.cache_positions[b, t, h, w] + + return pos + + +class RoPE3D(torch.nn.Module): + + def __init__(self, freq=10000.0, F0=1.0, interpolation_scale_thw=(1, 1, 1)): + super().__init__() + self.base = freq + self.F0 = F0 + self.interpolation_scale_t = interpolation_scale_thw[0] + self.interpolation_scale_h = interpolation_scale_thw[1] + self.interpolation_scale_w = interpolation_scale_thw[2] + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype, interpolation_scale=1): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) / interpolation_scale + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + # for (batch_size x ntokens x nheads x dim) + cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] + + return torch_npu.npu_rotary_mul(tokens, cos, sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 3 (t, y and x position of each token) + output: + * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim) + """ + assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" + D = tokens.size(3) // 3 + poses, max_poses = positions + assert len(poses) == 3 and poses[0].ndim == 2 # Batch, Seq, 3 + cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) + cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) + cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) + # split features into three along the feature dimension, and apply rope1d on each half + t, y, x = tokens.chunk(3, dim=-1) + t = self.apply_rope1d(t, poses[0], cos_t, sin_t) + y = self.apply_rope1d(y, poses[1], cos_y, sin_y) + x = self.apply_rope1d(x, poses[2], cos_x, sin_x) + tokens = torch.cat((t, y, x), dim=-1) + return tokens diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py new file mode 100644 index 0000000000..b143cc371d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py @@ -0,0 +1,793 @@ +import html +import inspect +import math +from tqdm import tqdm +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union + +import torch +from diffusers.models import AutoencoderKL, Transformer2DModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import EulerAncestralDiscreteScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from mindiesd.pipeline.sampling_optm import SamplingOptm +from transformers import T5EncoderModel, T5Tokenizer + +from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_size, get_sequence_parallel_rank + +logger = logging.get_logger(__name__) + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OpenSoraPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Sigma. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: Transformer2DModel, + scheduler: EulerAncestralDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 120, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + + if device is None: + device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('npu') + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_frames, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if num_frames <= 0: + raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", + # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + # caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + return caption.strip() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, + latents=None): + shape = ( + batch_size, + num_channels_latents, + (math.ceil((int(num_frames) - 1) / self.vae.vae_scale_factor[0]) + 1) if int( + num_frames) % 2 == 1 else math.ceil(int(num_frames) / self.vae.vae_scale_factor[0]), + math.ceil(int(height) / self.vae.vae_scale_factor[1]), + math.ceil(int(width) / self.vae.vae_scale_factor[2]), + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 300, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # 1. Check inputs. Raise error if not correct + num_frames = num_frames or self.transformer.config.sample_size_t * self.vae.vae_scale_factor[0] + height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] + width = width or self.transformer.config.sample_size[1] * self.vae.vae_scale_factor[2] + self.check_inputs( + prompt, + num_frames, + height, + width, + negative_prompt, + callback_steps, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = getattr(self, '_execution_device', None) or getattr(self, 'device', None) or torch.device('npu') + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + world_size = get_sequence_parallel_size() + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if get_sequence_parallel_state(): + prompt_embeds = rearrange(prompt_embeds, 'b (n x) h -> b n x h', n=world_size, + x=prompt_embeds.shape[1] // world_size).contiguous() + rank = get_sequence_parallel_rank() + prompt_embeds = prompt_embeds[:, rank, :, :] + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + def dit_process(latents, prompt_embeds, prompt_attention_mask): + for i, t in enumerate(tqdm(timesteps)): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = torch.tensor([t], dtype=torch.float64, device=latent_model_input.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + if prompt_attention_mask.ndim == 2: + prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l + + # prepare attention_mask. + # b c t h w -> b t h w + attention_mask = torch.ones_like(latent_model_input)[:, 0] + if get_sequence_parallel_state(): + attention_mask = attention_mask.repeat(1, world_size, 1, 1) + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + step_id=i, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + return latents + + if kwargs.get("sampling_optimize", None): + config = {"method": "AdaStep", "skip_thr": 0.015, "max_skip_steps": 4, "decay_ratio": 0.99} + with SamplingOptm(self, + "transformer.forward", + "scheduler.step", + parallel=get_sequence_parallel_state(), + config=config): + latents = dit_process(latents, prompt_embeds, prompt_attention_mask) + else: + latents = dit_process(latents, prompt_embeds, prompt_attention_mask) + + world_size = get_sequence_parallel_size() + if get_sequence_parallel_state(): + latents_shape = list(latents.shape) + full_shape = [latents_shape[0] * world_size] + latents_shape[1:] + all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device) + torch.distributed.all_gather_into_tensor(all_latents, latents) + latents_list = list(all_latents.chunk(world_size, dim=0)) + latents = torch.cat(latents_list, dim=2) + + if not output_type == "latent": + # b t h w c + image = self.decode_latents(latents) + image = image[:, :num_frames, :height, :width] + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def decode_latents(self, latents): + video = self.vae.decode(latents.to(self.vae.vae.dtype)) + video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, + 2).contiguous() # b t h w c + return video diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/requirements.txt b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/requirements.txt new file mode 100644 index 0000000000..44097a4d0e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/requirements.txt @@ -0,0 +1,17 @@ +torch==2.1.0 +diffusers==0.28.0 +transformers==4.44.2 +open_clip_torch==2.20.0 +av==12.0.0 +tqdm==4.66.1 +tensorboard==2.11.0 +pre-commit==3.8.0 +mmengine==0.10.4 +ftfy==6.1.3 +accelerate==0.26.1 +bs4 +torchvision==0.16.0 +einops +numpy==1.24.0 +imageio +imageio-ffmpeg diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh new file mode 100644 index 0000000000..50fb03c198 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh @@ -0,0 +1,18 @@ +torchrun --nnodes=1 --nproc_per_node 8 --master_port 29503 \ + inference.py \ + --model_path /path/to/checkpoint-xxx/model_ema \ + --num_frames 93 \ + --height 720 \ + --width 1280 \ + --cache_dir "../cache_dir" \ + --text_encoder_name google/mt5-xxl \ + --text_prompt examples/prompt_list_0.txt \ + --ae CausalVAEModel_D4_4x8x8 \ + --ae_path "/path/to/causalvideovae" \ + --save_img_path "./sample_video_test" \ + --fps 24 \ + --guidance_scale 7.5 \ + --num_sampling_steps 100 \ + --tile_overlap_factor 0.125 \ + --max_sequence_length 512 \ + --dtype bf16 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py new file mode 100644 index 0000000000..bac2b5244e --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist + +from .parallel_mgr import get_sequence_parallel_size, get_sequence_parallel_group + + +def _all_to_all_func(input_, world_size, process_group, scatter_dim=2, gather_dim=1): + input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] + output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] + dist.all_to_all(output_list, input_list, group=process_group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +def split_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + if world_size == 1: + return input_ + + if pad > 0: + pad_size = list(input_.shape) + pad_size[dim] = pad + input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim) + + dim_size = input_.size(dim) + if dim_size % world_size != 0: + raise ValueError( + f"The th{dim} dimensions of input_:{input_.size()} is not divisible by world_size:{world_size}.") + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + output = tensor_list[rank].contiguous() + return output + + +def gather_sequence(input_, process_group: dist.ProcessGroup, dim: int, pad: int): + input_ = input_.contiguous() + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # all gather + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + # concat + output = torch.cat(tensor_list, dim=dim) + + if pad > 0: + output = output.narrow(dim, 0, output.size(dim) - pad) + + return output + + +# ====== +# Pad +# ====== + +SPTIAL_PAD = 0 +TEMPORAL_PAD = 0 + + +def set_spatial_pad(dim_size: int): + sp_size = get_sequence_parallel_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global SPTIAL_PAD + SPTIAL_PAD = pad + + +def get_spatial_pad() -> int: + return SPTIAL_PAD + + +def set_temporal_pad(dim_size: int): + sp_size = get_sequence_parallel_size() + pad = (sp_size - (dim_size % sp_size)) % sp_size + global TEMPORAL_PAD + TEMPORAL_PAD = pad + + +def get_temporal_pad() -> int: + return TEMPORAL_PAD + + +def all_to_all_with_pad( + input_: torch.Tensor, + process_group: dist.ProcessGroup, + **kwargs +): + scatter_dim = kwargs.get("scatter_dim", 2) + gather_dim = kwargs.get("gather_dim", 1) + scatter_pad = kwargs.get("scatter_pad", 0) + gather_pad = kwargs.get("gather_pad", 0) + + if scatter_pad > 0: + pad_shape = list(input_.shape) + pad_shape[scatter_dim] = scatter_pad + pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) + input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) + + world_size = dist.get_world_size(process_group) + if input_.shape[scatter_dim] % world_size != 0: + raise ValueError( + f"The scatter_dim:{scatter_dim} of input_:{input_.shape} is not divisible by world_size:{world_size}.") + + input_ = _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) + + if gather_pad > 0: + input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) + + return input_ + + +def all_to_all( + tensor: torch.Tensor, + world_size: int, + scatter_dim: int, + gather_dim: int, + process_group: dist.ProcessGroup = None, +): + if process_group is None: + process_group = dist.group.WORLD + return _all_to_all_func(tensor, world_size, process_group, scatter_dim, gather_dim) + + +def all_to_all_sbh( + input_: torch.Tensor, + scatter_dim: int = 1, + gather_dim: int = 0, +): + return single_all_to_all(input_, scatter_dim, gather_dim) + + +def single_all_to_all( + input_: torch.Tensor, + scatter_dim: int, + gather_dim: int, +): + sp_size = get_sequence_parallel_size() + inp_shape = list(input_.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size + if scatter_dim < 1: + input_t = input_.reshape( + [sp_size, inp_shape[scatter_dim]] + \ + inp_shape[scatter_dim + 1:] + ) + else: + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + input_t = input_.reshape( + [-1, sp_size, inp_shape[scatter_dim]] + \ + inp_shape[scatter_dim + 1:] + ).transpose(0, 1).contiguous() + + output = torch.empty_like(input_t) + + dist.all_to_all_single(output, input_t, group=get_sequence_parallel_group()) + # if scattering the seq-dim, transpose the heads back to the original dimension + if scatter_dim < 1: + output = output.transpose(0, 1).contiguous() + + return output.reshape( + inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:]) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py new file mode 100644 index 0000000000..817743fe75 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch.distributed as dist +import torch_npu + + +class ParallelManager(): + def __init__(self, world_size=1, rank=0, group=None): + self.sp_size = world_size + self.sp_group = group + self.enable_sp = world_size > 1 + self.rank = rank + + +PARALLEL_MANAGER = ParallelManager() + + +def set_parallel_manager(world_size, rank, group): + global PARALLEL_MANAGER + PARALLEL_MANAGER = ParallelManager(world_size, rank, group) + + +def get_sequence_parallel_group(): + return PARALLEL_MANAGER.sp_group + + +def get_sequence_parallel_size(): + return PARALLEL_MANAGER.sp_size + + +def get_sequence_parallel_state(): + return PARALLEL_MANAGER.enable_sp + + +def get_sequence_parallel_rank(): + return PARALLEL_MANAGER.rank + + +def init_parallel_env(enable_sequence_parallelism): + rank = int(os.getenv('RANK', 0)) + world_size = int(os.getenv('WORLD_SIZE', 1)) + torch_npu.npu.set_device(rank) + dist.init_process_group( + backend='hccl', init_method='env://', + world_size=world_size, rank=rank + ) + if enable_sequence_parallelism: + set_parallel_manager(world_size, rank, dist.group.WORLD) + + +def finalize_parallel_env(): + if get_sequence_parallel_state and get_sequence_parallel_size() > 1: + dist.destroy_process_group() -- Gitee From df1e5d24cc1f93ac33bf22032b6796454d9cf807 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Mon, 13 Jan 2025 07:14:24 +0000 Subject: [PATCH 02/19] comm_calc_overlap --- .../models/diffusion/opensora/modules.py | 56 ++++++++++++------- .../open_sora_planv1_2/utils/comm.py | 8 ++- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py index 770ce41776..0d5b01aa1c 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -671,29 +671,37 @@ class AttnProcessor2_0: if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - if get_sequence_parallel_state(): - query = query.view(-1, attn.heads, head_dim) # [s // sp, b, h * d] -> [s // sp * b, h, d] - key = key.view(-1, attn.heads, head_dim) - value = value.view(-1, attn.heads, head_dim) + query = attn.to_q(hidden_states) + inner_dim = query.shape[-1] + head_dim = inner_dim // attn.heads h_size = attn.heads * head_dim sp_size = get_sequence_parallel_size() h_size_sp = h_size // sp_size + + query = query.view(-1, attn.heads, head_dim) # [s // sp, b, h * d] -> [s // sp * b, h, d] # apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d] - query = all_to_all_sbh(query, scatter_dim=1, gather_dim=0).view(-1, batch_size, h_size_sp) - key = all_to_all_sbh(key, scatter_dim=1, gather_dim=0).view(-1, batch_size, h_size_sp) - value = all_to_all_sbh(value, scatter_dim=1, gather_dim=0).view(-1, batch_size, h_size_sp) + req_q, query = all_to_all_sbh(query, scatter_dim=1, gather_dim=0, async_op=True) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + key = key.view(-1, attn.heads, head_dim) + req_k, key = all_to_all_sbh(key, scatter_dim=1, gather_dim=0, async_op=True) + + value = attn.to_v(encoder_hidden_states) + value = value.view(-1, attn.heads, head_dim) + req_v, value = all_to_all_sbh(value, scatter_dim=1, gather_dim=0, async_op=True) + + req_q.wait() + query = query.view(-1, batch_size, h_size_sp) + req_k.wait() + key = key.view(-1, batch_size, h_size_sp) + req_v.wait() + req_v, value.view(-1, batch_size, h_size_sp) if self.use_rope: query = query.view(-1, batch_size, attn.heads // sp_size, head_dim) key = key.view(-1, batch_size, attn.heads // sp_size, head_dim) @@ -717,8 +725,18 @@ class AttnProcessor2_0: hidden_states = hidden_states.view(-1, attn.heads // sp_size, head_dim) # [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] - hidden_states = all_to_all_sbh(hidden_states, scatter_dim=0, gather_dim=1).view(-1, batch_size, h_size) + _, hidden_states = all_to_all_sbh(hidden_states, scatter_dim=0, gather_dim=1) + hidden_states = hidden_states.view(-1, batch_size, h_size) else: + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + inner_dim = query.shape[-1] + head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py index bac2b5244e..75840a7fd0 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py @@ -145,14 +145,16 @@ def all_to_all_sbh( input_: torch.Tensor, scatter_dim: int = 1, gather_dim: int = 0, + async_op: bool = False, ): - return single_all_to_all(input_, scatter_dim, gather_dim) + return single_all_to_all(input_, scatter_dim, gather_dim, async_op) def single_all_to_all( input_: torch.Tensor, scatter_dim: int, gather_dim: int, + async_op: bool = False, ): sp_size = get_sequence_parallel_size() inp_shape = list(input_.shape) @@ -171,10 +173,10 @@ def single_all_to_all( output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=get_sequence_parallel_group()) + req = dist.all_to_all_single(output, input_t, group=get_sequence_parallel_group(), async_op=async_op) # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_dim < 1: output = output.transpose(0, 1).contiguous() - return output.reshape( + return req, output.reshape( inp_shape[: gather_dim] + [inp_shape[gather_dim] * sp_size, ] + inp_shape[gather_dim + 1:]) -- Gitee From d2579a5f3cd95e20924d83fa0059d291d8c0aeaa Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Mon, 13 Jan 2025 11:10:17 +0000 Subject: [PATCH 03/19] multi-batch infer --- .../open_sora_planv1_2/inference.py | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py index 21f85bab8b..2f8802a115 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py @@ -37,12 +37,6 @@ def load_t2v_checkpoint(model_path): def run_model_and_save_images(pipeline): - if not isinstance(args.text_prompt, list): - args.text_prompt = [args.text_prompt] - if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): - text_prompt = open(args.text_prompt[0], 'r').readlines() - args.text_prompt = [i.strip() for i in text_prompt] - positive_prompt = """ (masterpiece), (best quality), (ultra-detailed), (unwatermarked), {}. @@ -54,12 +48,30 @@ def run_model_and_save_images(pipeline): nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. """ + + if not isinstance(args.text_prompt, list): + args.text_prompt = [positive_prompt.format(args.text_prompt)] + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): + text_prompt = open(args.text_prompt[0], 'r').readlines() + args.text_prompt = [positive_prompt.format(i.strip()) for i in text_prompt] + + if args.batch_size > 1: + prompt_list = [] + group = len(args.text_prompt) // args.batch_size + tail = len(args.text_prompt) % args.batch_size + for index in range(group): + prompt_list.append(args.text_prompt[index * args.batch_size: (index + 1) * args.batch_size]) + if tail > 0: + prompt_list.append(args.text_prompt[: -tail]) + else: + prompt_list = args.text_prompt + kwargs = {} if args.algorithm == "sampling_optimize": kwargs["sampling_optimize"] = True - for index, prompt in enumerate(args.text_prompt): - videos = pipeline(positive_prompt.format(prompt), + for index, prompt in enumerate(prompt_list): + videos = pipeline(prompt, negative_prompt=negative_prompt, num_frames=args.num_frames, height=args.height, @@ -75,13 +87,14 @@ def run_model_and_save_images(pipeline): if get_sequence_parallel_rank() <= 0: try: - imageio.mimwrite( - os.path.join( - args.save_img_path, - f'EulerAncestralDiscrete_{index}_gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4' - ), videos[0], - fps=args.fps, quality=6, codec='libx264', - output_params=['-threads', '20']) # highest quality is 10, lowest is 0 + for i in range(len(prompt)): + imageio.mimwrite( + os.path.join( + args.save_img_path, + f'EulerAncestralDiscrete_{index * args.batch_size + i}_final__gs{args.guidance_scale}_s{args.num_sampling_steps}.mp4' + ), videos[i], + fps=args.fps, quality=6, codec='libx264', + output_params=['-threads', '20']) # highest quality is 10, lowest is 0 except: logger.error('Error when saving {}'.format(prompt)) -- Gitee From 5f16407515c8593473d1b36b270674334db9b4b3 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 14 Jan 2025 06:23:15 +0000 Subject: [PATCH 04/19] dit_cache_search --- .../open_sora_planv1_2/dit_cache_search.py | 186 ++++++++++++++++++ .../open_sora_planv1_2/inference.py | 1 + 2 files changed, 187 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py new file mode 100644 index 0000000000..619b27eb85 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py @@ -0,0 +1,186 @@ +import argparse +import logging +import os + +import imageio +import torch +import torch_npu +from diffusers.schedulers import EulerAncestralDiscreteScheduler +from transformers import T5Tokenizer, MT5EncoderModel + +from opensora.models.causalvideovae import ae_stride_config, CausalVAEModelWrapper +from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V +from opensora.sample.pipeline_opensora_sp import OpenSoraPipeline +from utils.parallel_mgr import init_parallel_env, finalize_parallel_env, get_sequence_parallel_rank +from mindiesd.layers.cache_mgr import CacheManager, DitCacheConfig +from msmodelslim.pytorch.multimodal import DitCacheSearcherConfig, DitCacheSearcher + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def load_t2v_checkpoint(model_path): + logger.info(f'load_t2v_checkpoint, {model_path}') + transformer_model = OpenSoraT2V.from_pretrained(model_path, cache_dir=args.cache_dir, + low_cpu_mem_usage=False, device_map=None, + torch_dtype=weight_dtype).to("npu") + transformer_model.eval() + pipeline = OpenSoraPipeline(vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer_model).to("npu") + + return pipeline + + +def run_model_and_save_images(config: DitCacheSearcherConfig, pipeline): + positive_prompt = """ + (masterpiece), (best quality), (ultra-detailed), (unwatermarked), + {}. + emotional, harmonious, vignette, 4k epic detailed, shot on kodak, 35mm photo, + sharp focus, high budget, cinemascope, moody, epic, gorgeous + """ + + negative_prompt = """ + nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, + low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry. + """ + + if not isinstance(args.text_prompt, list): + args.text_prompt = [positive_prompt.format(args.text_prompt)] + if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'): + text_prompt = open(args.text_prompt[0], 'r').readlines() + args.text_prompt = [positive_prompt.format(i.strip()) for i in text_prompt] + + if args.batch_size > 1: + prompt_list = [] + group = len(args.text_prompt) // args.batch_size + tail = len(args.text_prompt) % args.batch_size + for index in range(group): + prompt_list.append(args.text_prompt[index * args.batch_size: (index + 1) * args.batch_size]) + if tail > 0: + prompt_list.append(args.text_prompt[: -tail]) + else: + prompt_list = args.text_prompt + + kwargs = {} + + # 固定随机种子 + local_rank = int(os.getenv('RANK', 0)) + if local_rank >= 0: + torch.manual_seed(1234 + local_rank) + # 通过config的成员对pipeline.transformer成员变量赋值 + dit_cache_config = DitCacheConfig( + step_start=config.cache_step_start, + step_interval=config.cache_step_interval, + block_start=config.cache_dit_block_start, + num_blocks=config.cache_num_dit_blocks + ) + cache = CacheManager(dit_cache_config) + pipeline.transformer.cache = cache + + for index, prompt in enumerate(prompt_list): + videos = pipeline(prompt, + negative_prompt=negative_prompt, + num_frames=args.num_frames, + height=args.height, + width=args.width, + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale, + num_images_per_prompt=1, + mask_feature=True, + max_sequence_length=args.max_sequence_length, + **kwargs + ).images + logger.info(videos.shape) + + if config.cache_num_dit_blocks == 0: + video_path = os.path.join(save_path, f'sample_{index:04d}_no_cache.mp4') + else: + video_path = os.path.join(save_path, + f'sample_{index:04d}_{config.cache_dit_block_start}_{config.cache_step_interval}_{config.cache_num_dit_blocks}_{config.cache_step_start}.mp4') + + if get_sequence_parallel_rank() <= 0: + try: + for i in range(len(prompt)): + imageio.mimwrite( + video_path, videos[i], + fps=args.fps, quality=6, codec='libx264', + output_params=['-threads', '20']) # highest quality is 10, lowest is 0 + except: + logger.error('Error when saving {}'.format(prompt)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default='LanguageBind/Open-Sora-Plan-v1.0.0') + parser.add_argument("--version", type=str, default=None, choices=[None, '65x512x512', '65x256x256', '17x256x256']) + parser.add_argument("--num_frames", type=int, default=93) + parser.add_argument("--height", type=int, default=720) + parser.add_argument("--width", type=int, default=1280) + parser.add_argument('--dtype', type=str, default='bf16', help='Data type used in inference') + parser.add_argument("--cache_dir", type=str, default='./cache_dir') + parser.add_argument("--ae", type=str, default='CausalVAEModel_4x8x8') + parser.add_argument("--ae_path", type=str, default='CausalVAEModel_4x8x8') + parser.add_argument("--text_encoder_name", type=str, default='google/mt5-xxl') + parser.add_argument("--save_img_path", type=str, default="./sample_videos/t2v") + parser.add_argument("--guidance_scale", type=float, default=7.5) + parser.add_argument("--num_sampling_steps", type=int, default=50) + parser.add_argument("--fps", type=int, default=24) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--max_sequence_length", type=int, default=512) + parser.add_argument("--text_prompt", nargs='+') + parser.add_argument("--tile_overlap_factor", type=float, default=0.25) + parser.add_argument("--algorithm", type=str, default=None, choices=[None, 'dit_cache', 'sampling_optimize']) + args = parser.parse_args() + + if args.dtype not in ['bf16', 'fp16']: + logger.error("Not supported.") + weight_dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float16 + + local_rank = int(os.getenv('RANK', 0)) + world_size = int(os.getenv('WORLD_SIZE', 1)) + if world_size > 0: + init_parallel_env(enable_sequence_parallelism=True) + + vae = CausalVAEModelWrapper(args.ae_path, dtype=torch.float16).to("npu") + vae.vae.enable_tiling() + vae.vae.tile_overlap_factor = args.tile_overlap_factor + vae.vae.tile_sample_min_size = 256 + vae.vae.tile_latent_min_size = 32 + vae.vae.tile_sample_min_size_t = 29 + vae.vae.tile_latent_min_size_t = 8 + vae.vae_scale_factor = ae_stride_config[args.ae] + vae.eval() + + text_encoder = MT5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir, + low_cpu_mem_usage=True, torch_dtype=weight_dtype).to("npu") + tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir) + text_encoder.eval() + + scheduler = EulerAncestralDiscreteScheduler() + + if not os.path.exists(args.save_img_path): + os.makedirs(args.save_img_path, exist_ok=True) + + pipeline = load_t2v_checkpoint(args.model_path) + logger.info('load model') + + config = DitCacheSearcherConfig( + dit_block_num=32, + prompts_num=1, + num_sampling_steps=100, + cache_ratio=1.5, + search_cache_path=args.save_img_path, + cache_step_start=0, + cache_dit_block_start=0, + cache_num_dit_blocks=0 + ) + search_handler = DitCacheSearcher(config, pipeline, run_model_and_save_images) + cache_final_list = search_handler.search() + print(f'****************cache_final_list in \ + [cache_dit_block_start, cache_step_interval, cache_num_dit_blocks, cache_step_start] order: {cache_final_list}') + + if world_size > 0: + finalize_parallel_env() diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py index 2f8802a115..d7946558e8 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py @@ -115,6 +115,7 @@ if __name__ == "__main__": parser.add_argument("--guidance_scale", type=float, default=7.5) parser.add_argument("--num_sampling_steps", type=int, default=50) parser.add_argument("--fps", type=int, default=24) + parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--max_sequence_length", type=int, default=512) parser.add_argument("--text_prompt", nargs='+') parser.add_argument("--tile_overlap_factor", type=float, default=0.25) -- Gitee From 6d3da0da8b75dd6977d24a66b42999c13852fe0d Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 14 Jan 2025 08:23:38 +0000 Subject: [PATCH 05/19] fix_bug --- .../foundation/open_sora_planv1_2/dit_cache_search.py | 6 +++--- .../built-in/foundation/open_sora_planv1_2/inference.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py index 619b27eb85..2f3c2c73de 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py @@ -96,14 +96,14 @@ def run_model_and_save_images(config: DitCacheSearcherConfig, pipeline): logger.info(videos.shape) if config.cache_num_dit_blocks == 0: - video_path = os.path.join(save_path, f'sample_{index:04d}_no_cache.mp4') + video_path = os.path.join(args.save_img_path, f'sample_{index:04d}_no_cache.mp4') else: - video_path = os.path.join(save_path, + video_path = os.path.join(args.save_img_path, f'sample_{index:04d}_{config.cache_dit_block_start}_{config.cache_step_interval}_{config.cache_num_dit_blocks}_{config.cache_step_start}.mp4') if get_sequence_parallel_rank() <= 0: try: - for i in range(len(prompt)): + for i in range(len(prompt) if args.batch_size > 1 else 1): imageio.mimwrite( video_path, videos[i], fps=args.fps, quality=6, codec='libx264', diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py index d7946558e8..f0e06d954c 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py @@ -87,7 +87,7 @@ def run_model_and_save_images(pipeline): if get_sequence_parallel_rank() <= 0: try: - for i in range(len(prompt)): + for i in range(len(prompt) if args.batch_size > 1 else 1): imageio.mimwrite( os.path.join( args.save_img_path, -- Gitee From 4112d6cfc95ca4bd0c2a3d255a3ba2ab0e2e6860 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 14 Jan 2025 10:41:57 +0000 Subject: [PATCH 06/19] dp+sp --- .../open_sora_planv1_2/inference.py | 14 +- .../diffusion/opensora/modeling_opensora.py | 6 +- .../models/diffusion/opensora/modules.py | 6 +- .../opensora/sample/pipeline_opensora_sp.py | 78 ++- .../foundation/open_sora_planv1_2/run.sh | 3 +- .../open_sora_planv1_2/utils/comm.py | 10 +- .../utils/group_coordinator.py | 631 ++++++++++++++++++ .../open_sora_planv1_2/utils/parallel_mgr.py | 263 +++++++- .../open_sora_planv1_2/utils/utils.py | 147 ++++ 9 files changed, 1094 insertions(+), 64 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/group_coordinator.py create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py index f0e06d954c..e7c9415ddd 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py @@ -11,7 +11,7 @@ from transformers import T5Tokenizer, MT5EncoderModel from opensora.models.causalvideovae import ae_stride_config, CausalVAEModelWrapper from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V from opensora.sample.pipeline_opensora_sp import OpenSoraPipeline -from utils.parallel_mgr import init_parallel_env, finalize_parallel_env, get_sequence_parallel_rank +from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_rank logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -95,8 +95,8 @@ def run_model_and_save_images(pipeline): ), videos[i], fps=args.fps, quality=6, codec='libx264', output_params=['-threads', '20']) # highest quality is 10, lowest is 0 - except: - logger.error('Error when saving {}'.format(prompt)) + except Exception as e: + logger.error(f'Error when saving {prompt}, the error is {e}') if __name__ == "__main__": @@ -120,16 +120,18 @@ if __name__ == "__main__": parser.add_argument("--text_prompt", nargs='+') parser.add_argument("--tile_overlap_factor", type=float, default=0.25) parser.add_argument("--algorithm", type=str, default=None, choices=[None, 'dit_cache', 'sampling_optimize']) + parser.add_argument("--use_cfg_parallel", action='store_true') args = parser.parse_args() if args.dtype not in ['bf16', 'fp16']: logger.error("Not supported.") weight_dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float16 - local_rank = int(os.getenv('RANK', 0)) world_size = int(os.getenv('WORLD_SIZE', 1)) - if world_size > 0: - init_parallel_env(enable_sequence_parallelism=True) + if world_size > 1: + sp_degree = world_size // 2 if args.use_cfg_parallel else world_size + parallel_config = ParallelConfig(sp_degree=sp_degree, use_cfg_parallel=args.use_cfg_parallel, world_size=world_size) + init_parallel_env(parallel_config) vae = CausalVAEModelWrapper(args.ae_path, dtype=torch.float16).to("npu") vae.vae.enable_tiling() diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py index 36b9037b53..05bc9a05fb 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py @@ -13,7 +13,7 @@ from torch import nn from torch.nn import functional as F from opensora.models.diffusion.opensora.modules import PatchEmbed2D, BasicTransformerBlock, LayerNorm -from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_size +from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_world_size logger = logging.get_logger(__name__) @@ -295,8 +295,8 @@ class OpenSoraT2V(ModelMixin, ConfigMixin): if attention_mask is not None and attention_mask.ndim == 4: attention_mask = attention_mask.to(self.dtype) if get_sequence_parallel_state(): - attention_mask_vid = attention_mask[:, :frame * get_sequence_parallel_size()] # b, frame, h, w - attention_mask_img = attention_mask[:, frame * get_sequence_parallel_size():] # b, use_image_num, h, w + attention_mask_vid = attention_mask[:, :frame * get_sequence_parallel_world_size()] # b, frame, h, w + attention_mask_img = attention_mask[:, frame * get_sequence_parallel_world_size():] # b, use_image_num, h, w # b, use_image_num, h, w else: attention_mask_vid = attention_mask[:, :frame] # b, frame, h, w attention_mask_img = attention_mask[:, frame:] # b, use_image_num, h, w diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py index 0d5b01aa1c..ee6d6045ba 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -17,7 +17,7 @@ from einops import rearrange from torch import nn from utils.comm import all_to_all_sbh -from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_size, get_sequence_parallel_rank +from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_world_size, get_sequence_parallel_rank from .rope import PositionGetter3D, RoPE3D logger = logging.get_logger(__name__) @@ -201,7 +201,7 @@ class PatchEmbed2D(nn.Module): if self.num_frames != num_frames: if get_sequence_parallel_state(): - sp_size = get_sequence_parallel_size() + sp_size = get_sequence_parallel_world_size() temp_pos_embed = get_1d_sincos_pos_embed( embed_dim=self.temp_pos_embed.shape[-1], grid_size=num_frames * sp_size, @@ -517,7 +517,7 @@ class Attention(Attention_): """ head_size = self.heads if get_sequence_parallel_state(): - head_size = head_size // get_sequence_parallel_size() + head_size = head_size // get_sequence_parallel_world_size() if attention_mask is None: return None diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py index b143cc371d..ee54ccaba2 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py @@ -7,6 +7,7 @@ import urllib.parse as ul from typing import Callable, List, Optional, Tuple, Union import torch +import torch_npu from diffusers.models import AutoencoderKL, Transformer2DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput from diffusers.schedulers import EulerAncestralDiscreteScheduler @@ -21,7 +22,15 @@ from einops import rearrange from mindiesd.pipeline.sampling_optm import SamplingOptm from transformers import T5EncoderModel, T5Tokenizer -from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_size, get_sequence_parallel_rank +from utils.parallel_mgr import ( + get_sequence_parallel_state, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + get_sp_group +) logger = logging.get_logger(__name__) @@ -477,6 +486,35 @@ class OpenSoraPipeline(DiffusionPipeline): caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() + + def _process_cfg_split_batch( + self, + negative_embeds: torch.Tensor, + embeds: torch.Tensor, + negative_embdes_mask: torch.Tensor = None, + embeds_mask: torch.Tensor = None, + ): + if get_classifier_free_guidance_world_size() == 1: + embeds = torch.cat([negative_embeds, embeds], dim=0) + elif get_classifier_free_guidance_rank() == 0: + embeds = negative_embeds + elif get_classifier_free_guidance_rank() == 1: + embeds = embeds + else: + raise ValueError("Invalid classifier free guidance rank") + + if negative_embdes_mask is None: + return embeds + + if get_classifier_free_guidance_world_size() == 1: + embeds_mask = torch.cat([negative_embdes_mask, embeds_mask], dim=0) + elif get_classifier_free_guidance_rank() == 0: + embeds_mask = negative_embdes_mask + elif get_classifier_free_guidance_rank() == 1: + embeds_mask = embeds_mask + else: + raise ValueError("Invalid classifier free guidance rank") + return embeds, embeds_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, @@ -661,19 +699,26 @@ class OpenSoraPipeline(DiffusionPipeline): max_sequence_length=max_sequence_length, ) if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + if negative_prompt_attention_mask is not None: + prompt_embeds, prompt_attention_mask = self._process_cfg_split_batch( + negative_prompt_embeds, + prompt_embeds, + negative_prompt_attention_mask, + prompt_attention_mask + ) + else: + prompt_embeds = self._process_cfg_split_batch(negative_prompt_embeds, prompt_embeds) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels - world_size = get_sequence_parallel_size() + world_size = get_sequence_parallel_world_size() latents = self.prepare_latents( batch_size * num_images_per_prompt, latent_channels, - (num_frames + world_size - 1) // world_size if get_sequence_parallel_state() else num_frames, + (num_frames + world_size - 1) // world_size if world_size > 1 else num_frames, height, width, prompt_embeds.dtype, @@ -698,7 +743,9 @@ class OpenSoraPipeline(DiffusionPipeline): def dit_process(latents, prompt_embeds, prompt_attention_mask): for i, t in enumerate(tqdm(timesteps)): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat( + [latents] * (2 // get_classifier_free_guidance_world_size()) + ) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) current_timestep = torch.tensor([t], dtype=torch.float64, device=latent_model_input.device) @@ -730,7 +777,12 @@ class OpenSoraPipeline(DiffusionPipeline): # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + if get_classifier_free_guidance_world_size() == 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + elif get_classifier_free_guidance_world_size() == 2: + noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather( + noise_pred, separate_tensors=True + ) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # learned sigma @@ -752,7 +804,7 @@ class OpenSoraPipeline(DiffusionPipeline): return latents if kwargs.get("sampling_optimize", None): - config = {"method": "AdaStep", "skip_thr": 0.015, "max_skip_steps": 4, "decay_ratio": 0.99} + config = {"method": "AdaStep", "skip_thr": 0.012, "max_skip_steps": 3, "decay_ratio": 0.99} with SamplingOptm(self, "transformer.forward", "scheduler.step", @@ -762,14 +814,8 @@ class OpenSoraPipeline(DiffusionPipeline): else: latents = dit_process(latents, prompt_embeds, prompt_attention_mask) - world_size = get_sequence_parallel_size() - if get_sequence_parallel_state(): - latents_shape = list(latents.shape) - full_shape = [latents_shape[0] * world_size] + latents_shape[1:] - all_latents = torch.zeros(full_shape, dtype=latents.dtype, device=latents.device) - torch.distributed.all_gather_into_tensor(all_latents, latents) - latents_list = list(all_latents.chunk(world_size, dim=0)) - latents = torch.cat(latents_list, dim=2) + if get_sequence_parallel_world_size() > 1: + latents = get_sp_group().all_gather(latents, dim=2) if not output_type == "latent": # b t h w c diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh index 50fb03c198..262ccc6f09 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh @@ -15,4 +15,5 @@ torchrun --nnodes=1 --nproc_per_node 8 --master_port 29503 \ --num_sampling_steps 100 \ --tile_overlap_factor 0.125 \ --max_sequence_length 512 \ - --dtype bf16 \ No newline at end of file + --dtype bf16 \ + --use_cfg_parallel \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py index 75840a7fd0..33f79ad1e5 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py @@ -20,7 +20,7 @@ import torch import torch.distributed as dist -from .parallel_mgr import get_sequence_parallel_size, get_sequence_parallel_group +from .parallel_mgr import get_sequence_parallel_world_size, get_sp_group def _all_to_all_func(input_, world_size, process_group, scatter_dim=2, gather_dim=1): @@ -79,7 +79,7 @@ TEMPORAL_PAD = 0 def set_spatial_pad(dim_size: int): - sp_size = get_sequence_parallel_size() + sp_size = get_sequence_parallel_world_size() pad = (sp_size - (dim_size % sp_size)) % sp_size global SPTIAL_PAD SPTIAL_PAD = pad @@ -90,7 +90,7 @@ def get_spatial_pad() -> int: def set_temporal_pad(dim_size: int): - sp_size = get_sequence_parallel_size() + sp_size = get_sequence_parallel_world_size() pad = (sp_size - (dim_size % sp_size)) % sp_size global TEMPORAL_PAD TEMPORAL_PAD = pad @@ -156,7 +156,7 @@ def single_all_to_all( gather_dim: int, async_op: bool = False, ): - sp_size = get_sequence_parallel_size() + sp_size = get_sequence_parallel_world_size() inp_shape = list(input_.shape) inp_shape[scatter_dim] = inp_shape[scatter_dim] // sp_size if scatter_dim < 1: @@ -173,7 +173,7 @@ def single_all_to_all( output = torch.empty_like(input_t) - req = dist.all_to_all_single(output, input_t, group=get_sequence_parallel_group(), async_op=async_op) + req = dist.all_to_all_single(output, input_t, group=get_sp_group().device_group, async_op=async_op) # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_dim < 1: output = output.transpose(0, 1).contiguous() diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/group_coordinator.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/group_coordinator.py new file mode 100644 index 0000000000..0adaccdd1d --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/group_coordinator.py @@ -0,0 +1,631 @@ +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from collections import namedtuple +from typing import Any, Dict, List, Optional, Tuple, Union +import pickle + +import torch +import torch_npu +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +from diffusers.utils import logging + +logger = logging.get_logger(__name__) + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, ( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "npu:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (prefix + key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and npu graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if torch.npu.is_available(): + self.device = torch.device(f"npu:{local_rank}") + else: + self.device = torch.device("cpu") + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.view(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank." + ) + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert ( + src != self.rank + ), "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + ) + self.ulysses_world_size = self.ring_world_size = 1 + self.ulysses_rank = self.ring_rank = 0 \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py index 817743fe75..dc93029af4 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py @@ -18,55 +18,258 @@ # LICENSE file in the root directory of this source tree. import os - +from typing import List, Optional +from dataclasses import dataclass import torch.distributed as dist import torch_npu +from .utils import RankGenerator, generate_masked_orthogonal_rank_groups +from .group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator +from diffusers.utils import logging -class ParallelManager(): - def __init__(self, world_size=1, rank=0, group=None): - self.sp_size = world_size - self.sp_group = group - self.enable_sp = world_size > 1 - self.rank = rank - +logger = logging.get_logger(__name__) -PARALLEL_MANAGER = ParallelManager() +_WORLD: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None -def set_parallel_manager(world_size, rank, group): - global PARALLEL_MANAGER - PARALLEL_MANAGER = ParallelManager(world_size, rank, group) +@dataclass +class ParallelConfig: + sp_degree: int = 1 + use_cfg_parallel: bool = False + world_size: int = 1 + def __post_init__(self): + if self.use_cfg_parallel: + self.cfg_degree = 2 + else: + self.cfg_degree = 1 + assert self.sp_degree * self.cfg_degree <= self.world_size, ( + "sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + assert ( + self.world_size % (self.sp_degree * self.cfg_degree) == 0 + ), "world_size must be divisible by sp_degree * cfg_degree" -def get_sequence_parallel_group(): - return PARALLEL_MANAGER.sp_group +# * QUERY +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD -def get_sequence_parallel_size(): - return PARALLEL_MANAGER.sp_size - +# SP +def get_sp_group() -> SequenceParallelGroupCoordinator: + assert _SP is not None, "pipeline model parallel group is not initialized" + return _SP def get_sequence_parallel_state(): - return PARALLEL_MANAGER.enable_sp + """Return state for the sequence parallel group.""" + return _SP is not None + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return get_sp_group().world_size def get_sequence_parallel_rank(): - return PARALLEL_MANAGER.rank + """Return my rank for the sequence parallel group.""" + return get_sp_group().rank_in_group + +# CFG +def get_cfg_group() -> GroupCoordinator: + assert ( + _CFG is not None + ), "classifier_free_guidance parallel group is not initialized" + return _CFG + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + return get_cfg_group().world_size -def init_parallel_env(enable_sequence_parallelism): - rank = int(os.getenv('RANK', 0)) - world_size = int(os.getenv('WORLD_SIZE', 1)) - torch_npu.npu.set_device(rank) - dist.init_process_group( - backend='hccl', init_method='env://', - world_size=world_size, rank=rank +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + return get_cfg_group().rank_in_group + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, ) - if enable_sequence_parallelism: - set_parallel_manager(world_size, rank, dist.group.WORLD) -def finalize_parallel_env(): - if get_sequence_parallel_state and get_sequence_parallel_size() > 1: +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "hccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not dist.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + # this backend is used for WORLD + dist.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = int(os.getenv('RANK', 0)) + torch_npu.npu.set_device(local_rank) + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(dist.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert ( + _WORLD.world_size == dist.get_world_size() + ), "world group already initialized with a different world size" + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _CFG is not None + and _SP is not None + ) + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + assert parallel_mode in [ + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + **kwargs, + ) + else: + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + ) + +def initialize_model_parallel( + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. + backend: distributed backend of pytorch collective comm. + """ + # Get world size and rank. Ensure some consistencies. + assert dist.is_initialized() + world_size: int = dist.get_world_size() + backend = backend or dist.get_backend(get_world_group().device_group) + + if ( + world_size + != classifier_free_guidance_degree + * sequence_parallel_degree + ): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"sequence_parallel_degree ({sequence_parallel_degree}) x" + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x" + ) + + rank_generator: RankGenerator = RankGenerator( + sequence_parallel_degree, + classifier_free_guidance_degree, + "sp-cfg", + ) + + global _CFG + assert _CFG is None, "classifier_free_guidance group is already initialized" + _CFG = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + + global _SP + assert _SP is None, "sequence parallel group is already initialized" + _SP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ) + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _CFG + if _CFG: + _CFG.destroy() + _CFG = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if dist.is_initialized(): dist.destroy_process_group() + +def init_parallel_env(parallel_config: ParallelConfig): + if not model_parallel_is_initialized(): + logger.warning("Model parallel is not initialized, initializing...") + if not dist.is_initialized(): + init_distributed_environment() + initialize_model_parallel( + classifier_free_guidance_degree=parallel_config.cfg_degree, + sequence_parallel_degree=parallel_config.sp_degree, + ) + +def finalize_parallel_env(): + if model_parallel_is_initialized(): + destroy_model_parallel() + destroy_distributed_environment() \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py new file mode 100644 index 0000000000..ca1cae0436 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py @@ -0,0 +1,147 @@ +from typing import List + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: List[int], mask: List[bool] +) -> List[List[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + sp: int, + cfg: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.sp = sp + self.cfg = cfg + self.rank_offset = rank_offset + self.world_size = sp * cfg + + self.name_to_size = { + "sp": self.sp, + "cfg": self.cfg, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + independent_ep (bool: True): + This flag controls whether we treat EP and DP independently. + EP shares ranks with DP, if we want to get ranks related to + EP, we should set the flag. For example, get_ranks('dp', True) + will get DP modulo EP group, and get_ranks('dp', False) will + get full DP group. + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks -- Gitee From fa49a0499558a44c42408f101449820fea039cc9 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 14 Jan 2025 10:48:19 +0000 Subject: [PATCH 07/19] fix bug --- .../opensora/models/diffusion/opensora/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py index ee6d6045ba..56f53ea01f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -676,7 +676,7 @@ class AttnProcessor2_0: inner_dim = query.shape[-1] head_dim = inner_dim // attn.heads h_size = attn.heads * head_dim - sp_size = get_sequence_parallel_size() + sp_size = get_sequence_parallel_world_size() h_size_sp = h_size // sp_size query = query.view(-1, attn.heads, head_dim) # [s // sp, b, h * d] -> [s // sp * b, h, d] -- Gitee From fd818a2ddf8dcd3b64f779c036ae3d3a6f17f898 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 14 Jan 2025 11:56:14 +0000 Subject: [PATCH 08/19] fix bug --- .../open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py index ee54ccaba2..af3481bdd4 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py @@ -647,6 +647,9 @@ class OpenSoraPipeline(DiffusionPipeline): If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images """ + + torch.manual_seed(get_sequence_parallel_rank() + 1234) + # 1. Check inputs. Raise error if not correct num_frames = num_frames or self.transformer.config.sample_size_t * self.vae.vae_scale_factor[0] height = height or self.transformer.config.sample_size[0] * self.vae.vae_scale_factor[1] -- Gitee From c2a002d36674584dd106b258fe1802f854a7530c Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Wed, 15 Jan 2025 07:43:48 +0000 Subject: [PATCH 09/19] add rope cache --- .../models/diffusion/opensora/modules.py | 2 +- .../opensora/models/diffusion/opensora/rope.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py index 56f53ea01f..24e3301ecc 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -701,7 +701,7 @@ class AttnProcessor2_0: req_k.wait() key = key.view(-1, batch_size, h_size_sp) req_v.wait() - req_v, value.view(-1, batch_size, h_size_sp) + value = value.view(-1, batch_size, h_size_sp) if self.use_rope: query = query.view(-1, batch_size, attn.heads // sp_size, head_dim) key = key.view(-1, batch_size, attn.heads // sp_size, head_dim) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py index e9d0c7c4eb..da862ede22 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py @@ -51,18 +51,18 @@ class RoPE3D(torch.nn.Module): self.cache[D, seq_len, device, dtype] = (cos, sin) return self.cache[D, seq_len, device, dtype] - @staticmethod - def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - def apply_rope1d(self, tokens, pos1d, cos, sin): assert pos1d.ndim == 2 # for (batch_size x ntokens x nheads x dim) - cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] - sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] - - return torch_npu.npu_rotary_mul(tokens, cos, sin) + if (pos1d, cos) not in self.cache or (pos1d, sin) not in self.cache: + _cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] + _sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] + self.cache[pos1d, cos] = _cos + self.cache[pos1d, sin] = _sin + else: + _cos = self.cache[pos1d, cos] + _sin = self.cache[pos1d, sin] + return torch_npu.npu_rotary_mul(tokens, _cos, _sin) def forward(self, tokens, positions): """ -- Gitee From d1cf66973520d0c2d734eb7e82d8e8f242da3635 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Thu, 16 Jan 2025 01:12:24 +0000 Subject: [PATCH 10/19] add torch.npu.config.allow_internal_format = False --- .../built-in/foundation/open_sora_planv1_2/inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py index e7c9415ddd..10abaaf06a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py @@ -126,6 +126,7 @@ if __name__ == "__main__": if args.dtype not in ['bf16', 'fp16']: logger.error("Not supported.") weight_dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float16 + torch.npu.config.allow_internal_format = False world_size = int(os.getenv('WORLD_SIZE', 1)) if world_size > 1: -- Gitee From c4a55c6f886083fffd9d5d7402e04ed9a335b09a Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 17 Jan 2025 05:05:09 +0000 Subject: [PATCH 11/19] optimize mask logic --- .../diffusion/opensora/modeling_opensora.py | 4 +- .../models/diffusion/opensora/modules.py | 6 +++ .../opensora/sample/pipeline_opensora_sp.py | 18 ++++--- .../foundation/open_sora_planv1_2/run.sh | 2 +- .../open_sora_planv1_2/t2v_sora.txt | 48 +++++++++++++++++++ 5 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/t2v_sora.txt diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py index 05bc9a05fb..e65e774fc7 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py @@ -344,13 +344,13 @@ class OpenSoraT2V(ModelMixin, ConfigMixin): if attention_mask_vid is not None: attention_mask_vid = attention_mask_vid.to(torch.bool) attention_mask_vid = attention_mask_vid.repeat(1, attention_mask_vid.shape[-1], 1) + if encoder_attention_mask_vid is not None: encoder_attention_mask_vid = encoder_attention_mask_vid.to(torch.bool) - encoder_attention_mask_vid = encoder_attention_mask_vid.repeat(1, attention_mask_vid.shape[-2], 1) if attention_mask_img is not None: attention_mask_img = attention_mask_img.to(torch.bool) attention_mask_img = attention_mask_img.repeat(1, attention_mask_vid.shape[-1], 1) + if encoder_attention_mask_img is not None: encoder_attention_mask_img = encoder_attention_mask_img.to(torch.bool) - encoder_attention_mask_img = encoder_attention_mask_img.repeat(1, attention_mask_vid.shape[-2], 1) # 1. Input frame = ((frame - 1) // self.patch_size_t + 1) if frame % 2 == 1 else frame // self.patch_size_t diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py index 24e3301ecc..eacfce8ba0 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -713,6 +713,10 @@ class AttnProcessor2_0: query = query.view(-1, batch_size, h_size_sp) # SBH key = key.view(-1, batch_size, h_size_sp) # SBH value = value.view(-1, batch_size, h_size_sp) # SBH + + if attention_mask is not None and attention_mask.shape[-2] == 1: + attention_mask = attention_mask.repeat(1, 1, query.shape[0], 1) + query = rearrange(query, "s b (n d) -> b n s d", d=head_dim) # BNSD key = rearrange(key, "s b (n d) -> b n s d", d=head_dim) # BNSD value = rearrange(value, "s b (n d) -> b n s d", d=head_dim) # BNSD @@ -740,6 +744,8 @@ class AttnProcessor2_0: query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) + if attention_mask is not None and attention_mask.shape[-2] == 1: + attention_mask = attention_mask.repeat(1, 1, query.shape[1], 1) if self.use_rope: # require the shape of (batch_size x nheads x ntokens x dim) pos_thw = self.position_getter(batch_size, t=frame, h=height, w=width, device=query.device) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py index af3481bdd4..a8905e3879 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py @@ -76,7 +76,7 @@ def retrieve_timesteps( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + scheduler.set_timesteps(num_inference_steps, timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: @@ -516,6 +516,9 @@ class OpenSoraPipeline(DiffusionPipeline): raise ValueError("Invalid classifier free guidance rank") return embeds, embeds_mask + def set_seed(self, seed: int): + torch.manual_seed(seed) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None): @@ -648,7 +651,7 @@ class OpenSoraPipeline(DiffusionPipeline): returned where the first element is a list with the generated images """ - torch.manual_seed(get_sequence_parallel_rank() + 1234) + self.set_seed(get_sequence_parallel_rank() + 1234) # 1. Check inputs. Raise error if not correct num_frames = num_frames or self.transformer.config.sample_size_t * self.vae.vae_scale_factor[0] @@ -762,10 +765,13 @@ class OpenSoraPipeline(DiffusionPipeline): prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l # prepare attention_mask. - # b c t h w -> b t h w - attention_mask = torch.ones_like(latent_model_input)[:, 0] - if get_sequence_parallel_state(): - attention_mask = attention_mask.repeat(1, world_size, 1, 1) + attention_mask = None + if self.transformer.config.downsampler is not None: + # b c t h w -> b t h w + attention_mask = torch.ones_like(latent_model_input)[:, 0] + if get_sequence_parallel_state(): + attention_mask = attention_mask.repeat(1, world_size, 1, 1) + # predict noise model_output noise_pred = self.transformer( latent_model_input, diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh index 262ccc6f09..fa22afdbad 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh @@ -6,7 +6,7 @@ torchrun --nnodes=1 --nproc_per_node 8 --master_port 29503 \ --width 1280 \ --cache_dir "../cache_dir" \ --text_encoder_name google/mt5-xxl \ - --text_prompt examples/prompt_list_0.txt \ + --text_prompt ./t2v_sora.txt \ --ae CausalVAEModel_D4_4x8x8 \ --ae_path "/path/to/causalvideovae" \ --save_img_path "./sample_video_test" \ diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/t2v_sora.txt b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/t2v_sora.txt new file mode 100644 index 0000000000..eeb887b186 --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/t2v_sora.txt @@ -0,0 +1,48 @@ +A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. +Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field. +A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors. +Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway. +Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle. The art style is 3D and realistic, with a focus on lighting and texture. The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image. +A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures. +This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird’s head is tilted slightly to the side, giving the impression of it looking regal and majestic. The background is blurred, drawing attention to the bird’s striking appearance. +Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee. +A young man at his 20s is sitting on a piece of cloud in the sky, reading a book. +Historical footage of California during the gold rush. +A close up view of a glass sphere that has a zen garden within it. There is a small dwarf in the sphere who is raking the zen garden and creating patterns in the sand. +Extreme close up of a 24 year old woman’s eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field, vivid colors, cinematic +A cartoon kangaroo disco dances. +A beautiful homemade video showing the people of Lagos, Nigeria in the year 2056. Shot with a mobile phone camera. +A petri dish with a bamboo forest growing within it that has tiny red pandas running around. +The camera rotates around a large stack of vintage televisions all showing different programs — 1950s sci-fi movies, horror movies, news, static, a 1970s sitcom, etc, set inside a large New York museum gallery. +3D animation of a small, round, fluffy creature with big, expressive eyes explores a vibrant, enchanted forest. The creature, a whimsical blend of a rabbit and a squirrel, has soft blue fur and a bushy, striped tail. It hops along a sparkling stream, its eyes wide with wonder. The forest is alive with magical elements: flowers that glow and change colors, trees with leaves in shades of purple and silver, and small floating lights that resemble fireflies. The creature stops to interact playfully with a group of tiny, fairy-like beings dancing around a mushroom ring. The creature looks up in awe at a large, glowing tree that seems to be the heart of the forest. +The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds. +Reflections in the window of a train traveling through the Tokyo suburbs. +A drone camera circles around a beautiful historic church built on a rocky outcropping along the Amalfi Coast, the view showcases historic and magnificent architectural details and tiered pathways and patios, waves are seen crashing against the rocks below as the view overlooks the horizon of the coastal waters and hilly landscapes of the Amalfi Coast Italy, several distant people are seen walking and enjoying vistas on patios of the dramatic ocean views, the warm glow of the afternoon sun creates a magical and romantic feeling to the scene, the view is stunning captured with beautiful photography. +A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect. +A flock of paper airplanes flutters through a dense jungle, weaving around trees as if they were migrating birds. +A cat waking up its sleeping owner demanding breakfast. The owner tries to ignore the cat, but the cat tries new tactics and finally the owner pulls out a secret stash of treats from under the pillow to hold the cat off a little longer. +Borneo wildlife on the Kinabatangan River +A Chinese Lunar New Year celebration video with Chinese Dragon. +Tour of an art gallery with many beautiful works of art in different styles. +Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes. +A stop motion animation of a flower growing out of the windowsill of a suburban house. +The story of a robot’s life in a cyberpunk setting. +An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film. +A beautiful silhouette animation shows a wolf howling at the moon, feeling lonely, until it finds its pack. +New York City submerged like Atlantis. Fish, whales, sea turtles and sharks swim through the streets of New York. +A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in. +Step-printing scene of a person running, cinematic film shot in 35mm. +Five gray wolf pups frolicking and chasing each other around a remote gravel road, surrounded by grass. The pups run and leap, chasing each other, and nipping at each other, playing. +Basketball through hoop then explodes. +Archeologists discover a generic plastic chair in the desert, excavating and dusting it with great care. +A grandmother with neatly combed grey hair stands behind a colorful birthday cake with numerous candles at a wood dining room table, expression is one of pure joy and happiness, with a happy glow in her eye. She leans forward and blows out the candles with a gentle puff, the cake has pink frosting and sprinkles and the candles cease to flicker, the grandmother wears a light blue blouse adorned with floral patterns, several happy friends and family sitting at the table can be seen celebrating, out of focus. The scene is beautifully captured, cinematic, showing a 3/4 view of the grandmother and the dining room. Warm color tones and soft lighting enhance the mood. +The camera directly faces colorful buildings in Burano Italy. An adorable dalmation looks through a window on a building on the ground floor. Many people are walking and cycling along the canal streets in front of the buildings. +An adorable happy otter confidently stands on a surfboard wearing a yellow lifejacket, riding along turquoise tropical waters near lush tropical islands, 3D digital render art style. +This close-up shot of a chameleon showcases its striking color changing capabilities. The background is blurred, drawing attention to the animal’s striking appearance. +A corgi vlogging itself in tropical Maui. +A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates a warm contrast, accentuating the cat’s orange fur. The shot is clear and sharp, with a shallow depth of field. +Aerial view of Santorini during the blue hour, showcasing the stunning architecture of white Cycladic buildings with blue domes. The caldera views are breathtaking, and the lighting creates a beautiful, serene atmosphere. +Tiltshift of a construction site filled with workers, equipment, and heavy machinery. +A giant, towering cloud in the shape of a man looms over the earth. The cloud man shoots lighting bolts down to the earth. +A Samoyed and a Golden Retriever dog are playfully romping through a futuristic neon city at night. The neon lights emitted from the nearby buildings glistens off of their fur. +The Glenfinnan Viaduct is a historic railway bridge in Scotland, UK, that crosses over the west highland line between the towns of Mallaig and Fort William. It is a stunning sight as a steam train leaves the bridge, traveling over the arch-covered viaduct. The landscape is dotted with lush greenery and rocky mountains, creating a picturesque backdrop for the train journey. The sky is blue and the sun is shining, making for a beautiful day to explore this majestic spot. -- Gitee From 273c746a02a2b666760e8529fa91b9d2ebadd08f Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 17 Jan 2025 05:11:43 +0000 Subject: [PATCH 12/19] add env --- .../built-in/foundation/open_sora_planv1_2/run.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh index fa22afdbad..3b18fe5614 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh @@ -1,3 +1,5 @@ +export TASK_QUEUE_ENABLE=2 +export HCCL_OP_EXPANSION_MODE="AIV" torchrun --nnodes=1 --nproc_per_node 8 --master_port 29503 \ inference.py \ --model_path /path/to/checkpoint-xxx/model_ema \ @@ -12,7 +14,7 @@ torchrun --nnodes=1 --nproc_per_node 8 --master_port 29503 \ --save_img_path "./sample_video_test" \ --fps 24 \ --guidance_scale 7.5 \ - --num_sampling_steps 100 \ + --num_sampling_steps 3 \ --tile_overlap_factor 0.125 \ --max_sequence_length 512 \ --dtype bf16 \ -- Gitee From 8d4de45ced85f0c475562f9e84e74cf5d4dd32df Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 17 Jan 2025 07:29:59 +0000 Subject: [PATCH 13/19] remove attention mask --- .../opensora/sample/pipeline_opensora_sp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py index a8905e3879..003ee0cea4 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py @@ -765,12 +765,12 @@ class OpenSoraPipeline(DiffusionPipeline): prompt_attention_mask = prompt_attention_mask.unsqueeze(1) # b l -> b 1 l # prepare attention_mask. + # b c t h w -> b t h w + # Mask is need in sparse attention mode. + # attention_mask = torch.ones_like(latent_model_input)[:, 0] + # if get_sequence_parallel_state(): + # attention_mask = attention_mask.repeat(1, world_size, 1, 1) attention_mask = None - if self.transformer.config.downsampler is not None: - # b c t h w -> b t h w - attention_mask = torch.ones_like(latent_model_input)[:, 0] - if get_sequence_parallel_state(): - attention_mask = attention_mask.repeat(1, world_size, 1, 1) # predict noise model_output noise_pred = self.transformer( -- Gitee From b086ebc027a4f5e1bd36fc6b0ffa9f76eea0c6b5 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 17 Jan 2025 08:55:55 +0000 Subject: [PATCH 14/19] run.sh 100 step --- .../MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh index 3b18fe5614..b3739a6a38 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/run.sh @@ -14,7 +14,7 @@ torchrun --nnodes=1 --nproc_per_node 8 --master_port 29503 \ --save_img_path "./sample_video_test" \ --fps 24 \ --guidance_scale 7.5 \ - --num_sampling_steps 3 \ + --num_sampling_steps 100 \ --tile_overlap_factor 0.125 \ --max_sequence_length 512 \ --dtype bf16 \ -- Gitee From 7be2e5fbf83af59a0577d93abd3aec6965d3f69a Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Wed, 22 Jan 2025 04:56:56 +0000 Subject: [PATCH 15/19] fix bug in world_size 1 inference --- .../open_sora_planv1_2/utils/parallel_mgr.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py index dc93029af4..3cedc7fe16 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py @@ -70,11 +70,15 @@ def get_sequence_parallel_state(): def get_sequence_parallel_world_size(): """Return world size for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 1 return get_sp_group().world_size def get_sequence_parallel_rank(): """Return my rank for the sequence parallel group.""" + if not get_sequence_parallel_state(): + return 0 return get_sp_group().rank_in_group # CFG @@ -84,14 +88,21 @@ def get_cfg_group() -> GroupCoordinator: ), "classifier_free_guidance parallel group is not initialized" return _CFG +def get_cfg_state(): + """Return state for the sequence parallel group.""" + return _CFG is not None def get_classifier_free_guidance_world_size(): """Return world size for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 1 return get_cfg_group().world_size def get_classifier_free_guidance_rank(): """Return my rank for the classifier_free_guidance parallel group.""" + if not get_cfg_state(): + return 0 return get_cfg_group().rank_in_group -- Gitee From a05976d67209361122a429947b8a31cca42b307b Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 24 Jan 2025 09:05:00 +0000 Subject: [PATCH 16/19] cleancode --- .../open_sora_planv1_2/dit_cache_search.py | 4 +- .../model/causal_vae/__init__.py | 15 ++-- .../model/causal_vae/modeling_causalvae.py | 38 ++++---- .../model/modeling_videobase.py | 7 +- .../causalvideovae/model/modules/__init__.py | 26 ++++-- .../causalvideovae/model/modules/attention.py | 14 ++- .../causalvideovae/model/modules/block.py | 1 + .../causalvideovae/model/modules/conv.py | 28 +++--- .../causalvideovae/model/modules/normalize.py | 20 +++-- .../causalvideovae/model/modules/ops.py | 6 +- .../model/modules/resnet_block.py | 6 +- .../model/utils/distrib_utils.py | 3 +- .../model/utils/module_utils.py | 2 + .../diffusion/opensora/modeling_opensora.py | 9 +- .../models/diffusion/opensora/modules.py | 7 +- .../models/diffusion/opensora/rope.py | 12 ++- .../opensora/sample/pipeline_opensora_sp.py | 16 ++-- .../open_sora_planv1_2/utils/comm.py | 19 ---- .../utils/group_coordinator.py | 48 ++-------- .../open_sora_planv1_2/utils/parallel_mgr.py | 87 +++++++++---------- .../open_sora_planv1_2/utils/utils.py | 2 +- 21 files changed, 175 insertions(+), 195 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py index 2f3c2c73de..b2541b6bba 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py @@ -108,8 +108,8 @@ def run_model_and_save_images(config: DitCacheSearcherConfig, pipeline): video_path, videos[i], fps=args.fps, quality=6, codec='libx264', output_params=['-threads', '20']) # highest quality is 10, lowest is 0 - except: - logger.error('Error when saving {}'.format(prompt)) + except Exception as e: + logger.error(f'Error when saving {prompt}, the error is {e}') if __name__ == "__main__": diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/__init__.py index c1aeb800d7..7616fb093c 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/__init__.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/__init__.py @@ -1,29 +1,24 @@ -from .modeling_causalvae import CausalVAEModel - from einops import rearrange from torch import nn +from .modeling_causalvae import CausalVAEModel + class CausalVAEModelWrapper(nn.Module): def __init__(self, model_path, subfolder=None, cache_dir=None, use_ema=False, **kwargs): super(CausalVAEModelWrapper, self).__init__() - # if os.path.exists(ckpt): - # self.vae = CausalVAEModel.load_from_checkpoint(ckpt) self.vae = CausalVAEModel.from_pretrained(model_path, subfolder=subfolder, cache_dir=cache_dir, **kwargs) if use_ema: self.vae.init_from_ema(model_path) self.vae = self.vae.ema + def encode(self, x): # b c t h w - # x = self.vae.encode(x).sample() x = self.vae.encode(x).sample().mul_(0.18215) return x + def decode(self, x): - # x = self.vae.decode(x) x = self.vae.decode(x / 0.18215) x = rearrange(x, 'b c t h w -> b t c h w').contiguous() return x def dtype(self): - return self.vae.dtype - # - # def device(self): - # return self.vae.device \ No newline at end of file + return self.vae.dtype \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py index 299cc071ae..a7fe2e64e5 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py @@ -6,6 +6,7 @@ import torch import torch_npu import torch.nn as nn from diffusers.configuration_utils import register_to_config +from diffusers.utils import logging from ..modules import Normalize from ..modules.ops import nonlinearity @@ -13,6 +14,8 @@ from ..modeling_videobase import VideoBaseAE from ..utils.distrib_utils import DiagonalGaussianDistribution from ..utils.module_utils import resolve_str_to_obj, Module +logger = logging.get_logger(__name__) + class Encoder(nn.Module): def __init__( @@ -44,9 +47,8 @@ class Encoder(nn.Module): double_z: bool = True, ) -> None: super().__init__() - assert len(resnet_blocks) == len(hidden_size_mult), print( - hidden_size_mult, resnet_blocks - ) + if not len(resnet_blocks) == len(hidden_size_mult): + logger.error(f"hidden_size_mult is not equal to resnet_blocks") # ---- Config ---- self.num_resolutions = len(hidden_size_mult) self.resolution = resolution @@ -262,7 +264,7 @@ class CausalVAEModel(VideoBaseAE): hidden_size: int = 128, z_channels: int = 4, hidden_size_mult: Tuple[int] = (1, 2, 4, 4), - attn_resolutions: Tuple[int] = [], + attn_resolutions: Tuple[int] = None, dropout: float = 0.0, resolution: int = 256, double_z: bool = True, @@ -316,14 +318,13 @@ class CausalVAEModel(VideoBaseAE): self.tile_sample_min_size_t = 33 self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1))) - # t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0] - # self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** len(t_down_ratio))) + 1 self.tile_latent_min_size_t = 16 self.tile_overlap_factor = 0.125 self.use_tiling = False self.use_quant_layer = use_quant_layer - + if attn_resolutions is None: + ignore_keys = [] self.encoder = Encoder( z_channels=z_channels, hidden_size=hidden_size, @@ -374,11 +375,10 @@ class CausalVAEModel(VideoBaseAE): return [self.decoder] def encode(self, x): - if self.use_tiling and ( - x.shape[-1] > self.tile_sample_min_size - or x.shape[-2] > self.tile_sample_min_size - or x.shape[-3] > self.tile_sample_min_size_t - ): + shape_status = (x.shape[-1] > self.tile_sample_min_size + or x.shape[-2] > self.tile_sample_min_size + or x.shape[-3] > self.tile_sample_min_size_t) + if self.use_tiling and shape_status: return self.tiled_encode(x) h = self.encoder(x) if self.use_quant_layer: @@ -398,8 +398,8 @@ class CausalVAEModel(VideoBaseAE): dec = self.decoder(z) return dec - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) + def forward(self, input_x, sample_posterior=True): + posterior = self.encode(input_x) if sample_posterior: z = posterior.sample() else: @@ -570,7 +570,9 @@ class CausalVAEModel(VideoBaseAE): def disable_tiling(self): self.enable_tiling(False) - def init_from_ckpt(self, path, ignore_keys=list()): + def init_from_ckpt(self, path, ignore_keys=None): + if ignore_keys is None: + ignore_keys = list() sd = torch.load(path, map_location="cpu") print("init from " + path) @@ -594,7 +596,5 @@ class CausalVAEModel(VideoBaseAE): del sd[k] miss, unexpected = self.load_state_dict(sd, strict=False) - assert len(miss) == 0, f"miss key: {miss}" - if len(unexpected) > 0: - for i in unexpected: - assert 'loss' in i, "unexpected key: {i}" + if not len(miss) == 0: + logger.error(f"miss key: {miss}") diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modeling_videobase.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modeling_videobase.py index f7686a0f27..650e850e75 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modeling_videobase.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modeling_videobase.py @@ -1,10 +1,11 @@ -import torch import os import json +import glob +from typing import Optional, Union + +import torch from diffusers.configuration_utils import ConfigMixin from diffusers.models.modeling_utils import ModelMixin -from typing import Optional, Union -import glob class VideoBaseAE(ModelMixin, ConfigMixin): diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/__init__.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/__init__.py index 7ef41edd27..498d042c59 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/__init__.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/__init__.py @@ -1,6 +1,22 @@ from .block import Block -from .attention import * -from .conv import * -from .normalize import * -from .resnet_block import * -from .updownsample import * +from .attention import LinearAttention, LinAttnBlock, AttnBlock3D, AttnBlock3DFix, AttnBlock, TemporalAttnBlock +from .conv import Conv2d, CausalConv3d, CausalConv3d_GC +from .normalize import GroupNorm, ActNorm, Normalize +from .resnet_block import ResnetBlock2D, ResnetBlock3D, ResnetBlock3D_GC +from .updownsample import ( + Upsample, + Downsample, + SpatialDownsample2x, + SpatialUpsample2x_GC, + SpatialUpsample2x, + TimeDownsample2x, + TimeUpsample2x, + TimeDownsampleRes2x, + TimeUpsampleRes2x, + TimeDownsampleResAdv2x, + TimeUpsampleResAdv2x, + Spatial2xTime2x3DDownsample, + Spatial2x3DDownsample, + Spatial2x3DUpsample, + Spatial2xTime2x3DUpsample +) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py index ac4c96e3b0..f386b6ab93 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py @@ -1,12 +1,16 @@ +import torch import torch.nn as nn import torch_npu +from einops import rearrange +from diffusers.utils import logging from .normalize import Normalize from .conv import CausalConv3d -import torch -from einops import rearrange from .block import Block from .ops import video_to_image +logger = logging.get_logger(__name__) + + class LinearAttention(Block): def __init__(self, dim, heads=4, dim_head=32): super().__init__() @@ -75,6 +79,7 @@ class AttnBlock3D(Block): return x + h_ + class AttnBlock3DFix(nn.Module): """ Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172. @@ -122,7 +127,7 @@ class AttnBlock3DFix(nn.Module): # h_: (b*t c hw) -> (b t c h w) -> (b c t h w) h_ = h_.reshape(b, t, c, h, w) - h_ = h_.permute(0, 2, 1, 3 ,4) + h_ = h_.permute(0, 2, 1, 3, 4) h_ = self.proj_out(h_) @@ -217,7 +222,8 @@ class TemporalAttnBlock(Block): def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown" + if not attn_type in ["vanilla", "linear", "none", "vanilla3D"]: + logger.error(f"attn_type {attn_type} unknown") print(f"making attention of type '{attn_type}' with {in_channels} in_channels") print(attn_type) if attn_type == "vanilla": diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/block.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/block.py index e93672d20b..94366413c1 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/block.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/block.py @@ -1,5 +1,6 @@ import torch.nn as nn + class Block(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/conv.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/conv.py index 629715b535..4b27f0c506 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/conv.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/conv.py @@ -1,13 +1,17 @@ -import torch.nn as nn from typing import Union, Tuple -import torch.nn.functional as F import torch import torch_npu +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from einops import rearrange +from diffusers.utils import logging from .block import Block from .ops import cast_tuple -from einops import rearrange from .ops import video_to_image -from torch.utils.checkpoint import checkpoint + +logger = logging.get_logger(__name__) + class Conv2d(nn.Conv2d): def __init__( @@ -41,7 +45,6 @@ class Conv2d(nn.Conv2d): @video_to_image def forward(self, x): return super().forward(x) - class CausalConv3d(nn.Module): @@ -65,17 +68,17 @@ class CausalConv3d(nn.Module): def _init_weights(self, init_method): ks = torch.tensor(self.kernel_size) if init_method == "avg": - assert ( - self.kernel_size[1] == 1 and self.kernel_size[2] == 1 - ), "only support temporal up/down sample" - assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out" + if not (self.kernel_size[1] == 1 and self.kernel_size[2] == 1): + logger.error("only support temporal up/down sample") + if not self.chan_in == self.chan_out: + logger.error("chan_in must be equal to chan_out") weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)) eyes = torch.concat( [ - torch.eye(self.chan_in).unsqueeze(-1) * 1/3, - torch.eye(self.chan_in).unsqueeze(-1) * 1/3, - torch.eye(self.chan_in).unsqueeze(-1) * 1/3, + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, + torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3, ], dim=-1, ) @@ -108,6 +111,7 @@ class CausalConv3d(nn.Module): class CausalConv3d_GC(CausalConv3d): def __init__(self, chan_in, chan_out, kernel_size: Union[int, Tuple[int]], init_method="random", **kwargs): super().__init__(chan_in, chan_out, kernel_size, init_method, **kwargs) + def forward(self, x): # 1 + 16 16 as video, 1 as image first_frame_pad = x[:, :, :1, :, :].repeat( diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py index 7c8c05f0fa..cdb9078a5e 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py @@ -1,25 +1,33 @@ import torch import torch.nn as nn +from diffusers.utils import logging from .block import Block +logger = logging.get_logger(__name__) + + class GroupNorm(Block): def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.norm = torch.nn.GroupNorm( num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True ) + def forward(self, x): return self.norm(x) + def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True ) + class ActNorm(nn.Module): def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): - assert affine + if not affine: + logger.error("not supported") super().__init__() self.logdet = logdet self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) @@ -28,9 +36,9 @@ class ActNorm(nn.Module): self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) - def initialize(self, input): + def initialize(self, input_x): with torch.no_grad(): - flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + flatten = input_x.permute(1, 0, 2, 3).contiguous().view(input_x.shape[1], -1) mean = ( flatten.mean(1) .unsqueeze(1) @@ -53,7 +61,7 @@ class ActNorm(nn.Module): if reverse: return self.reverse(input) if len(input.shape) == 2: - input = input[:,:,None,None] + input = input[:, :, None, None] squeeze = True else: squeeze = False @@ -71,7 +79,7 @@ class ActNorm(nn.Module): if self.logdet: log_abs = torch.log(torch.abs(self.scale)) - logdet = height*width*torch.sum(log_abs) + logdet = height * width * torch.sum(log_abs) logdet = logdet * torch.ones(input.shape[0]).to(input) return h, logdet @@ -89,7 +97,7 @@ class ActNorm(nn.Module): self.initialized.fill_(1) if len(output.shape) == 2: - output = output[:,:,None,None] + output = output[:, :, None, None] squeeze = True else: squeeze = False diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/ops.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/ops.py index e298cceebf..8ba2072c0c 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/ops.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/ops.py @@ -2,6 +2,9 @@ import torch import torch_npu import torch.nn as nn from einops import rearrange +from diffusers.utils import logging + +logger = logging.get_logger(__name__) def video_to_image(func): @@ -30,7 +33,8 @@ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): src_dim = n_dims + src_dim if dest_dim < 0: dest_dim = n_dims + dest_dim - assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + if not 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims: + logger.error("dim error") dims = list(range(n_dims)) del dims[src_dim] permutation = [] diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/resnet_block.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/resnet_block.py index bd225a81af..40d4bff8bf 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/resnet_block.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/resnet_block.py @@ -1,12 +1,13 @@ import torch +import torch_npu import torch.nn as nn +from torch.utils.checkpoint import checkpoint from einops import rearrange, pack, unpack from .normalize import Normalize from .ops import nonlinearity, video_to_image from .conv import CausalConv3d, CausalConv3d_GC from .block import Block -from torch.utils.checkpoint import checkpoint -import torch_npu + class ResnetBlock2D(Block): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, @@ -53,6 +54,7 @@ class ResnetBlock2D(Block): x = x + h return x + class ResnetBlock3D(Block): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): super().__init__() diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py index 760c0673fe..4009f70693 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py @@ -1,6 +1,7 @@ import torch import numpy as np + class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters @@ -30,7 +31,7 @@ class DiagonalGaussianDistribution(object): + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3]) - def nll(self, sample, dims=[1,2,3]): + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.]) logtwopi = np.log(2.0 * np.pi) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/module_utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/module_utils.py index 3f843e4f3d..8c92fd359d 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/module_utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/module_utils.py @@ -3,6 +3,7 @@ import importlib Module = str MODULES_BASE = "opensora.models.causalvideovae.model.modules." + def resolve_str_to_obj(str_val, append=True): if append: str_val = MODULES_BASE + str_val @@ -10,6 +11,7 @@ def resolve_str_to_obj(str_val, append=True): module = importlib.import_module(module_name) return getattr(module, class_name) + def create_instance(module_class_str: str, **kwargs): module_name, class_name = module_class_str.rsplit('.', 1) module = importlib.import_module(module_name) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py index e65e774fc7..3b5b9116fe 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modeling_opensora.py @@ -122,7 +122,8 @@ class OpenSoraT2V(ModelMixin, ConfigMixin): # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration - assert in_channels is not None and patch_size is not None + if not (in_channels is not None and patch_size is not None): + logger.error("in_channels and patch_size should not be None") if norm_type == "layer_norm" and num_embeds_ada_norm is not None: deprecation_message = ( @@ -140,8 +141,10 @@ class OpenSoraT2V(ModelMixin, ConfigMixin): self._init_patched_inputs(norm_type=norm_type) def _init_patched_inputs(self, norm_type): - assert self.config.sample_size_t is not None, "OpenSoraT2V over patched input must provide sample_size_t" - assert self.config.sample_size is not None, "OpenSoraT2V over patched input must provide sample_size" + if self.config.sample_size_t is None: + logger.error("OpenSoraT2V over patched input must provide sample_size_t") + if self.config.sample_size is None: + logger.error("OpenSoraT2V over patched input must provide sample_size") self.num_frames = self.config.sample_size_t self.config.sample_size = to_2tuple(self.config.sample_size) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py index eacfce8ba0..8ffd5e7b85 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -172,7 +172,6 @@ class PatchEmbed2D(nn.Module): temp_pos_embed = get_1d_sincos_pos_embed(embed_dim, self.num_frames, base_size=self.base_size_t, interpolation_scale=self.interpolation_scale_t) self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) - # self.temp_embed_gate = nn.Parameter(torch.tensor([0.0])) def forward(self, latent, num_frames): b, _, _, _, _ = latent.shape @@ -230,7 +229,6 @@ class PatchEmbed2D(nn.Module): video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] if self.use_abs_pos: - # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() temp_pos_embed = temp_pos_embed.unsqueeze(2) video_latent = (video_latent + temp_pos_embed).to( video_latent.dtype) if video_latent is not None and video_latent.numel() > 0 else None @@ -378,7 +376,8 @@ class OverlapPatchEmbed2D(nn.Module): use_abs_pos=True, ): super().__init__() - assert patch_size_t == 1 + if not patch_size_t == 1: + logger.error("patch_size_t is not 1") self.use_abs_pos = use_abs_pos self.flatten = flatten self.layer_norm = layer_norm @@ -452,7 +451,6 @@ class OverlapPatchEmbed2D(nn.Module): video_latent, image_latent = latent[:, :num_frames], latent[:, num_frames:] if self.use_abs_pos: - # temp_pos_embed = temp_pos_embed.unsqueeze(2) * self.temp_embed_gate.tanh() temp_pos_embed = temp_pos_embed.unsqueeze(2) video_latent = (video_latent + temp_pos_embed).to( video_latent.dtype) if video_latent is not None and video_latent.numel() > 0 else None @@ -831,7 +829,6 @@ class FeedForward_Conv2d(nn.Module): self.project_out = nn.Linear(hidden_features, dim, bias=bias) def forward(self, x, t, h, w): - # import ipdb;ipdb.set_trace() x = self.project_in(x) x = rearrange(x, 'b (t h w) d -> (b t) d h w', t=t, h=h, w=w) x = F.gelu(x) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py index da862ede22..ca2865f65a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py @@ -1,8 +1,11 @@ import torch import torch_npu +from diffusers.utils import logging from utils.parallel_mgr import get_sequence_parallel_state +logger = logging.get_logger(__name__) + class PositionGetter3D(object): """ return positions of patches """ @@ -52,7 +55,8 @@ class RoPE3D(torch.nn.Module): return self.cache[D, seq_len, device, dtype] def apply_rope1d(self, tokens, pos1d, cos, sin): - assert pos1d.ndim == 2 + if not pos1d.ndim == 2: + logger.error("only support 2d pos") # for (batch_size x ntokens x nheads x dim) if (pos1d, cos) not in self.cache or (pos1d, sin) not in self.cache: _cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] @@ -72,10 +76,12 @@ class RoPE3D(torch.nn.Module): output: * tokens after appplying RoPE3D (batch_size x nheads x ntokens x x dim) """ - assert tokens.size(3) % 3 == 0, "number of dimensions should be a multiple of three" + if not tokens.size(3) % 3 == 0: + logger.error("number of dimensions should be a multiple of three") D = tokens.size(3) // 3 poses, max_poses = positions - assert len(poses) == 3 and poses[0].ndim == 2 # Batch, Seq, 3 + if not len(poses) == 3 and poses[0].ndim == 2: + logger.error("only support Batch, Seq, 3 shape") cos_t, sin_t = self.get_cos_sin(D, max_poses[0] + 1, tokens.device, tokens.dtype, self.interpolation_scale_t) cos_y, sin_y = self.get_cos_sin(D, max_poses[1] + 1, tokens.device, tokens.dtype, self.interpolation_scale_h) cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py index 003ee0cea4..d62dc6023f 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/sample/pipeline_opensora_sp.py @@ -1,22 +1,18 @@ import html import inspect import math -from tqdm import tqdm import re import urllib.parse as ul from typing import Callable, List, Optional, Tuple, Union +from tqdm import tqdm import torch import torch_npu from diffusers.models import AutoencoderKL, Transformer2DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput from diffusers.schedulers import EulerAncestralDiscreteScheduler -from diffusers.utils import ( - BACKENDS_MAPPING, - is_bs4_available, - is_ftfy_available, - logging, -) +from diffusers.utils import BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, logging + from diffusers.utils.torch_utils import randn_tensor from einops import rearrange from mindiesd.pipeline.sampling_optm import SamplingOptm @@ -294,9 +290,8 @@ class OpenSoraPipeline(DiffusionPipeline): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): + is_callback_steps_valid = (callback_steps is None) or (not isinstance(callback_steps, int) or callback_steps <= 0) + if is_callback_steps_valid: raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." @@ -408,7 +403,6 @@ class OpenSoraPipeline(DiffusionPipeline): caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) - # caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py index 33f79ad1e5..8d2fffb3c1 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/comm.py @@ -1,22 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - import torch import torch.distributed as dist diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/group_coordinator.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/group_coordinator.py index 0adaccdd1d..bb1dbedf81 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/group_coordinator.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/group_coordinator.py @@ -33,10 +33,11 @@ def _split_tensor_dict( metadata_list: List[Tuple[str, Any]] = [] tensor_list = [] for key, value in tensor_dict.items(): - assert "%" not in key, ( - "Avoid having '%' in key " - "as it is used as a separator for nested entries." - ) + if "%" in key: + logger.error( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) if isinstance(value, torch.Tensor): # Note: we cannot use `value.device` here, # because it contains not only the device type but also the device @@ -123,9 +124,6 @@ class GroupCoordinator: self.device_group = device_group self.cpu_group = cpu_group - assert self.cpu_group is not None - assert self.device_group is not None - if torch.npu.is_available(): self.device = torch.device(f"npu:{local_rank}") else: @@ -213,9 +211,6 @@ class GroupCoordinator: # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -259,9 +254,6 @@ class GroupCoordinator: # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -284,7 +276,6 @@ class GroupCoordinator: """Broadcast the input tensor. NOTE: `src` is the local rank of the source rank. """ - assert src < self.world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: @@ -299,13 +290,11 @@ class GroupCoordinator: """Broadcast the input object. NOTE: `src` is the local rank of the source rank. """ - assert src < self.world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: return obj if self.shm_broadcaster is not None: - assert src == 0, "Shared memory broadcaster only supports src=0" return self.shm_broadcaster.broadcast_object(obj) if self.rank_in_group == src: torch.distributed.broadcast_object_list( @@ -325,7 +314,6 @@ class GroupCoordinator: """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. """ - assert src < self.world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. if self.world_size == 1: @@ -340,13 +328,6 @@ class GroupCoordinator: """Send the input object list to the destination rank.""" """NOTE: `dst` is the local rank of the destination rank.""" - assert dst < self.world_size, f"Invalid dst rank ({dst})" - - assert dst != self.rank, ( - "Invalid destination rank. Destination rank is the same " - "as the current rank." - ) - # Serialize object to tensor and get the size as well object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) @@ -367,12 +348,6 @@ class GroupCoordinator: """Receive the input object list from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" - assert src < self.world_size, f"Invalid src rank ({src})" - - assert ( - src != self.rank - ), "Invalid source rank. Source rank is the same as the current rank." - size_tensor = torch.empty(1, dtype=torch.long, device="cpu") # Receive object size @@ -391,10 +366,6 @@ class GroupCoordinator: object_tensor, src=self.ranks[src], group=self.cpu_group ) - assert ( - rank_object == rank_size - ), "Received object sender rank does not match the size sender rank." - obj = pickle.loads(object_tensor.numpy().tobytes()) return obj @@ -415,15 +386,11 @@ class GroupCoordinator: group = self.device_group metadata_group = self.cpu_group - assert src < self.world_size, f"Invalid src rank ({src})" src = self.ranks[src] rank = self.rank if rank == src: metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, dict - ), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `broadcast_object_list` has serialization & deserialization, @@ -496,12 +463,8 @@ class GroupCoordinator: if dst is None: dst = self.group_next_rank - assert dst < self.world_size, f"Invalid dst rank ({dst})" metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, dict - ), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `send_object_list` has serialization & deserialization, @@ -536,7 +499,6 @@ class GroupCoordinator: if src is None: src = self.group_prev_rank - assert src < self.world_size, f"Invalid src rank ({src})" recv_metadata_list = self.recv_object(src=src) tensor_dict: Dict[str, Any] = {} diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py index 3cedc7fe16..24278e695b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py @@ -1,32 +1,12 @@ -#!/usr/bin/env python -# coding=utf-8 -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - import os from typing import List, Optional from dataclasses import dataclass import torch.distributed as dist import torch_npu +from diffusers.utils import logging from .utils import RankGenerator, generate_masked_orthogonal_rank_groups from .group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinator -from diffusers.utils import logging - logger = logging.get_logger(__name__) _WORLD: Optional[GroupCoordinator] = None @@ -45,29 +25,34 @@ class ParallelConfig: self.cfg_degree = 2 else: self.cfg_degree = 1 - assert self.sp_degree * self.cfg_degree <= self.world_size, ( - "sp_degree * cfg_degree must be less than or equal to " - "world_size because of classifier free guidance" - ) - assert ( - self.world_size % (self.sp_degree * self.cfg_degree) == 0 - ), "world_size must be divisible by sp_degree * cfg_degree" + if not self.sp_degree * self.cfg_degree <= self.world_size: + logger.error( + "sp_degree * cfg_degree must be less than or equal to " + "world_size because of classifier free guidance" + ) + if not (self.world_size % (self.sp_degree * self.cfg_degree) == 0): + logger.error("world_size must be divisible by sp_degree * cfg_degree") # * QUERY def get_world_group() -> GroupCoordinator: - assert _WORLD is not None, "world group is not initialized" + if _WORLD is None: + logger.error("world group is not initialized") return _WORLD + # SP def get_sp_group() -> SequenceParallelGroupCoordinator: - assert _SP is not None, "pipeline model parallel group is not initialized" + if _SP is None: + logger.error("pipeline model parallel group is not initialized") return _SP + def get_sequence_parallel_state(): """Return state for the sequence parallel group.""" return _SP is not None + def get_sequence_parallel_world_size(): """Return world size for the sequence parallel group.""" if not get_sequence_parallel_state(): @@ -81,17 +66,19 @@ def get_sequence_parallel_rank(): return 0 return get_sp_group().rank_in_group + # CFG def get_cfg_group() -> GroupCoordinator: - assert ( - _CFG is not None - ), "classifier_free_guidance parallel group is not initialized" + if _CFG is None: + logger.error("classifier_free_guidance parallel group is not initialized") return _CFG + def get_cfg_state(): """Return state for the sequence parallel group.""" return _CFG is not None + def get_classifier_free_guidance_world_size(): """Return world size for the classifier_free_guidance parallel group.""" if not get_cfg_state(): @@ -132,10 +119,11 @@ def init_distributed_environment( backend, ) if not dist.is_initialized(): - assert distributed_init_method is not None, ( - "distributed_init_method must be provided when initializing " - "distributed environment" - ) + if distributed_init_method is None: + logger.error( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) # this backend is used for WORLD dist.init_process_group( backend=backend, @@ -159,9 +147,8 @@ def init_distributed_environment( ranks = list(range(dist.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) else: - assert ( - _WORLD.world_size == dist.get_world_size() - ), "world group already initialized with a different world size" + if not _WORLD.world_size == dist.get_world_size(): + logger.error("world group already initialized with a different world size") def model_parallel_is_initialized(): @@ -171,6 +158,7 @@ def model_parallel_is_initialized(): and _SP is not None ) + def init_model_parallel_group( group_ranks: List[List[int]], local_rank: int, @@ -178,10 +166,11 @@ def init_model_parallel_group( parallel_mode: str, **kwargs, ) -> GroupCoordinator: - assert parallel_mode in [ + if parallel_mode not in [ "sequence", "classifier_free_guidance", - ], f"parallel_mode {parallel_mode} is not supported" + ]: + logger.error(f"parallel_mode {parallel_mode} is not supported") if parallel_mode == "sequence": return SequenceParallelGroupCoordinator( group_ranks=group_ranks, @@ -196,6 +185,7 @@ def init_model_parallel_group( torch_distributed_backend=backend, ) + def initialize_model_parallel( classifier_free_guidance_degree: int = 1, sequence_parallel_degree: int = 1, @@ -210,7 +200,8 @@ def initialize_model_parallel( backend: distributed backend of pytorch collective comm. """ # Get world size and rank. Ensure some consistencies. - assert dist.is_initialized() + if not dist.is_initialized(): + logger.error("dist is not initialized") world_size: int = dist.get_world_size() backend = backend or dist.get_backend(get_world_group().device_group) @@ -233,7 +224,8 @@ def initialize_model_parallel( ) global _CFG - assert _CFG is None, "classifier_free_guidance group is already initialized" + if _CFG is not None: + logger.error("classifier_free_guidance group is already initialized") _CFG = init_model_parallel_group( group_ranks=rank_generator.get_ranks("cfg"), local_rank=get_world_group().local_rank, @@ -242,7 +234,8 @@ def initialize_model_parallel( ) global _SP - assert _SP is None, "sequence parallel group is already initialized" + if _SP is not None: + logger.error("sequence parallel group is already initialized") _SP = init_model_parallel_group( group_ranks=rank_generator.get_ranks("sp"), local_rank=get_world_group().local_rank, @@ -250,6 +243,7 @@ def initialize_model_parallel( parallel_mode="sequence", ) + def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _CFG @@ -262,6 +256,7 @@ def destroy_model_parallel(): _SP.destroy() _SP = None + def destroy_distributed_environment(): global _WORLD if _WORLD: @@ -270,6 +265,7 @@ def destroy_distributed_environment(): if dist.is_initialized(): dist.destroy_process_group() + def init_parallel_env(parallel_config: ParallelConfig): if not model_parallel_is_initialized(): logger.warning("Model parallel is not initialized, initializing...") @@ -280,6 +276,7 @@ def init_parallel_env(parallel_config: ParallelConfig): sequence_parallel_degree=parallel_config.sp_degree, ) + def finalize_parallel_env(): if model_parallel_is_initialized(): destroy_model_parallel() diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py index ca1cae0436..4ec2a04d74 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py @@ -142,6 +142,6 @@ class RankGenerator(object): ) if self.rank_offset > 0: for rank_group in ranks: - for i in range(len(rank_group)): + for i, _ in enumerate(rank_group): rank_group[i] += self.rank_offset return ranks -- Gitee From 9abec275f8a342b3a0b56a0f1ddb776ab20e26eb Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 24 Jan 2025 09:26:49 +0000 Subject: [PATCH 17/19] cleancode --- .../model/causal_vae/modeling_causalvae.py | 11 ++++++----- .../causalvideovae/model/modules/attention.py | 2 +- .../causalvideovae/model/modules/normalize.py | 16 ++++++++-------- .../causalvideovae/model/utils/distrib_utils.py | 4 +++- .../foundation/open_sora_planv1_2/utils/utils.py | 8 ++++++-- 5 files changed, 24 insertions(+), 17 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py index a7fe2e64e5..d588fa63ba 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/causal_vae/modeling_causalvae.py @@ -387,11 +387,12 @@ class CausalVAEModel(VideoBaseAE): return posterior def decode(self, z): - if self.use_tiling and ( - z.shape[-1] > self.tile_latent_min_size - or z.shape[-2] > self.tile_latent_min_size - or z.shape[-3] > self.tile_latent_min_size_t - ): + shape_status = ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + or z.shape[-3] > self.tile_latent_min_size_t + ) + if self.use_tiling and shape_status: return self.tiled_decode(z) if self.use_quant_layer: z = self.post_quant_conv(z) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py index f386b6ab93..ad17fcbeb4 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/attention.py @@ -222,7 +222,7 @@ class TemporalAttnBlock(Block): def make_attn(in_channels, attn_type="vanilla"): - if not attn_type in ["vanilla", "linear", "none", "vanilla3D"]: + if attn_type not in ["vanilla", "linear", "none", "vanilla3D"]: logger.error(f"attn_type {attn_type} unknown") print(f"making attention of type '{attn_type}' with {in_channels} in_channels") print(attn_type) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py index cdb9078a5e..623b722f43 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/modules/normalize.py @@ -57,22 +57,22 @@ class ActNorm(nn.Module): self.loc.data.copy_(-mean) self.scale.data.copy_(1 / (std + 1e-6)) - def forward(self, input, reverse=False): + def forward(self, input_x, reverse=False): if reverse: - return self.reverse(input) - if len(input.shape) == 2: - input = input[:, :, None, None] + return self.reverse(input_x) + if len(input_x.shape) == 2: + input_x = input_x[:, :, None, None] squeeze = True else: squeeze = False - _, _, height, width = input.shape + _, _, height, width = input_x.shape if self.training and self.initialized.item() == 0: - self.initialize(input) + self.initialize(input_x) self.initialized.fill_(1) - h = self.scale * (input + self.loc) + h = self.scale * (input_x + self.loc) if squeeze: h = h.squeeze(-1).squeeze(-1) @@ -80,7 +80,7 @@ class ActNorm(nn.Module): if self.logdet: log_abs = torch.log(torch.abs(self.scale)) logdet = height * width * torch.sum(log_abs) - logdet = logdet * torch.ones(input.shape[0]).to(input) + logdet = logdet * torch.ones(input_x.shape[0]).to(input_x) return h, logdet return h diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py index 4009f70693..18fdf7d1ea 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/causalvideovae/model/utils/distrib_utils.py @@ -31,9 +31,11 @@ class DiagonalGaussianDistribution(object): + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3]) - def nll(self, sample, dims=[1, 2, 3]): + def nll(self, sample, dims=None): if self.deterministic: return torch.Tensor([0.]) + if dims is None: + dims = [1, 2, 3] logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py index 4ec2a04d74..bffe388d48 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py @@ -1,4 +1,7 @@ from typing import List +from diffusers.utils import logging + +logger = logging.get_logger(__name__) def generate_masked_orthogonal_rank_groups( @@ -47,9 +50,10 @@ def generate_masked_orthogonal_rank_groups( idx = [(index // d) % s for s, d in zip(shape, stride)] # stride is a prefix_product result. And the value of stride[-1] # is not used. - assert ( + if not ( sum([x * y for x, y in zip(idx, stride[:-1])]) == index - ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + ): + logger.error("idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)) return idx masked_shape = [s for s, m in zip(parallel_size, mask) if m] -- Gitee From 832babdb7840a57f5a916741e2ae2c9c0bc8e45b Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Tue, 25 Feb 2025 11:23:00 +0000 Subject: [PATCH 18/19] support 24 card and 32 card parallel --- .../models/diffusion/opensora/modules.py | 54 ++++++++++++++++++- .../models/diffusion/opensora/rope.py | 23 +++++--- 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py index 8ffd5e7b85..fc08aec651 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -17,7 +17,12 @@ from einops import rearrange from torch import nn from utils.comm import all_to_all_sbh -from utils.parallel_mgr import get_sequence_parallel_state, get_sequence_parallel_world_size, get_sequence_parallel_rank +from utils.parallel_mgr import ( + get_sp_group, + get_sequence_parallel_state, + get_sequence_parallel_world_size, + get_sequence_parallel_rank +) from .rope import PositionGetter3D, RoPE3D logger = logging.get_logger(__name__) @@ -669,7 +674,52 @@ class AttnProcessor2_0: if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if get_sequence_parallel_state(): + sp_size = get_sequence_parallel_world_size() + if get_sequence_parallel_state() and sp_size > 8: + query = attn.to_q(hidden_states) + inner_dim = query.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(-1, batch_size, inner_dim) # [s // sp, b, h * d] + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + key = key.view(-1, batch_size, inner_dim) # [s // sp, b, h * d] + + value = attn.to_v(encoder_hidden_states) + value = value.view(-1, batch_size, inner_dim) # [s // sp, b, h * d] + + if self.use_rope: + query = query.view(-1, batch_size, attn.heads, head_dim) + key = key.view(-1, batch_size, attn.heads, head_dim) + # require the shape of (batch_size x nheads x ntokens x dim) + pos_thw = self.position_getter(batch_size, t=frame * sp_size, h=height, w=width, + device=query.device) + query = self.rope(query, pos_thw, rope_split=True) + key = self.rope(key, pos_thw, rope_split=True) + query = query.view(-1, batch_size, attn.heads * head_dim) + key = key.view(-1, batch_size, attn.heads * head_dim) + + key_value = torch.concat([key, value], dim=1) + key_value_full = get_sp_group().all_gather(key_value, dim=0, separate_tensors=False) + key_full, value_full = key_value_full.chunk(2, dim=1) + + if attention_mask is not None and attention_mask.shape[-2] == 1: + attention_mask = attention_mask.repeat(1, 1, query.shape[0], 1) + + query = rearrange(query, "s b (n d) -> b n s d", d=head_dim) # BNSD + key_full = rearrange(key_full, "s b (n d) -> b n s d", d=head_dim) # BNSD + value_full = rearrange(value_full, "s b (n d) -> b n s d", d=head_dim) # BNSD + hidden_states = torch_npu.npu_fused_infer_attention_score(query, key_full, value_full, + atten_mask=attention_mask, + input_layout="BNSD", + scale=1 / math.sqrt(head_dim), + num_heads=attn.heads)[0] + hidden_states = rearrange(hidden_states, "b n s d -> s b (n d)") + elif get_sequence_parallel_state() and sp_size <= 8: query = attn.to_q(hidden_states) inner_dim = query.shape[-1] head_dim = inner_dim // attn.heads diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py index ca2865f65a..601ceba290 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py @@ -2,7 +2,11 @@ import torch import torch_npu from diffusers.utils import logging -from utils.parallel_mgr import get_sequence_parallel_state +from utils.parallel_mgr import( + get_sequence_parallel_state, + get_sequence_parallel_world_size, + get_sequence_parallel_rank +) logger = logging.get_logger(__name__) @@ -54,7 +58,7 @@ class RoPE3D(torch.nn.Module): self.cache[D, seq_len, device, dtype] = (cos, sin) return self.cache[D, seq_len, device, dtype] - def apply_rope1d(self, tokens, pos1d, cos, sin): + def apply_rope1d(self, tokens, pos1d, cos, sin, rope_split=False): if not pos1d.ndim == 2: logger.error("only support 2d pos") # for (batch_size x ntokens x nheads x dim) @@ -66,9 +70,14 @@ class RoPE3D(torch.nn.Module): else: _cos = self.cache[pos1d, cos] _sin = self.cache[pos1d, sin] - return torch_npu.npu_rotary_mul(tokens, _cos, _sin) + if rope_split: + rank = get_sequence_parallel_rank() + seqlen = _cos.shape[0] // get_sequence_parallel_world_size() + return torch_npu.npu_rotary_mul(tokens, _cos[seqlen * rank: seqlen * (rank + 1)], _sin[seqlen * rank: seqlen * (rank + 1)]) + else: + return torch_npu.npu_rotary_mul(tokens, _cos, _sin) - def forward(self, tokens, positions): + def forward(self, tokens, positions, rope_split=False): """ input: * tokens: batch_size x nheads x ntokens x dim @@ -87,8 +96,8 @@ class RoPE3D(torch.nn.Module): cos_x, sin_x = self.get_cos_sin(D, max_poses[2] + 1, tokens.device, tokens.dtype, self.interpolation_scale_w) # split features into three along the feature dimension, and apply rope1d on each half t, y, x = tokens.chunk(3, dim=-1) - t = self.apply_rope1d(t, poses[0], cos_t, sin_t) - y = self.apply_rope1d(y, poses[1], cos_y, sin_y) - x = self.apply_rope1d(x, poses[2], cos_x, sin_x) + t = self.apply_rope1d(t, poses[0], cos_t, sin_t, rope_split) + y = self.apply_rope1d(y, poses[1], cos_y, sin_y, rope_split) + x = self.apply_rope1d(x, poses[2], cos_x, sin_x, rope_split) tokens = torch.cat((t, y, x), dim=-1) return tokens -- Gitee From 1cbede5bcad681d73fb1e0de154c2bd9e6ba0f35 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Wed, 26 Feb 2025 13:27:05 +0000 Subject: [PATCH 19/19] add parallel degree to avoid frame padding --- .../open_sora_planv1_2/dit_cache_search.py | 7 +- .../open_sora_planv1_2/inference.py | 9 ++- .../models/diffusion/opensora/modules.py | 68 +++++-------------- .../models/diffusion/opensora/rope.py | 8 +-- .../open_sora_planv1_2/utils/parallel_mgr.py | 63 +++++++++++++++-- .../open_sora_planv1_2/utils/utils.py | 5 +- 6 files changed, 94 insertions(+), 66 deletions(-) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py index b2541b6bba..3986cb0f7b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/dit_cache_search.py @@ -11,7 +11,7 @@ from transformers import T5Tokenizer, MT5EncoderModel from opensora.models.causalvideovae import ae_stride_config, CausalVAEModelWrapper from opensora.models.diffusion.opensora.modeling_opensora import OpenSoraT2V from opensora.sample.pipeline_opensora_sp import OpenSoraPipeline -from utils.parallel_mgr import init_parallel_env, finalize_parallel_env, get_sequence_parallel_rank +from utils.parallel_mgr import ParallelConfig, init_parallel_env, finalize_parallel_env, get_sequence_parallel_rank from mindiesd.layers.cache_mgr import CacheManager, DitCacheConfig from msmodelslim.pytorch.multimodal import DitCacheSearcherConfig, DitCacheSearcher @@ -142,7 +142,10 @@ if __name__ == "__main__": local_rank = int(os.getenv('RANK', 0)) world_size = int(os.getenv('WORLD_SIZE', 1)) if world_size > 0: - init_parallel_env(enable_sequence_parallelism=True) + if world_size > 16: + logger.error("World size should less than 16 in DIT cache search.") + parallel_config = ParallelConfig(sp_degree=world_size, use_cfg_parallel=False, world_size=world_size) + init_parallel_env(parallel_config) vae = CausalVAEModelWrapper(args.ae_path, dtype=torch.float16).to("npu") vae.vae.enable_tiling() diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py index 10abaaf06a..cf7dbf58a9 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/inference.py @@ -130,8 +130,15 @@ if __name__ == "__main__": world_size = int(os.getenv('WORLD_SIZE', 1)) if world_size > 1: + tp_degree = 1 sp_degree = world_size // 2 if args.use_cfg_parallel else world_size - parallel_config = ParallelConfig(sp_degree=sp_degree, use_cfg_parallel=args.use_cfg_parallel, world_size=world_size) + if sp_degree > 12: + tp_degree = 2 + sp_degree //= 2 + parallel_config = ParallelConfig(tp_degree=tp_degree, + sp_degree=sp_degree, + use_cfg_parallel=args.use_cfg_parallel, + world_size=world_size) init_parallel_env(parallel_config) vae = CausalVAEModelWrapper(args.ae_path, dtype=torch.float16).to("npu") diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py index fc08aec651..ebf6389d5a 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/modules.py @@ -21,7 +21,10 @@ from utils.parallel_mgr import ( get_sp_group, get_sequence_parallel_state, get_sequence_parallel_world_size, - get_sequence_parallel_rank + get_sequence_parallel_rank, + get_tp_group, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, ) from .rope import PositionGetter3D, RoPE3D @@ -675,57 +678,12 @@ class AttnProcessor2_0: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) sp_size = get_sequence_parallel_world_size() - if get_sequence_parallel_state() and sp_size > 8: - query = attn.to_q(hidden_states) - inner_dim = query.shape[-1] - head_dim = inner_dim // attn.heads - query = query.view(-1, batch_size, inner_dim) # [s // sp, b, h * d] - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - key = key.view(-1, batch_size, inner_dim) # [s // sp, b, h * d] - - value = attn.to_v(encoder_hidden_states) - value = value.view(-1, batch_size, inner_dim) # [s // sp, b, h * d] - - if self.use_rope: - query = query.view(-1, batch_size, attn.heads, head_dim) - key = key.view(-1, batch_size, attn.heads, head_dim) - # require the shape of (batch_size x nheads x ntokens x dim) - pos_thw = self.position_getter(batch_size, t=frame * sp_size, h=height, w=width, - device=query.device) - query = self.rope(query, pos_thw, rope_split=True) - key = self.rope(key, pos_thw, rope_split=True) - query = query.view(-1, batch_size, attn.heads * head_dim) - key = key.view(-1, batch_size, attn.heads * head_dim) - - key_value = torch.concat([key, value], dim=1) - key_value_full = get_sp_group().all_gather(key_value, dim=0, separate_tensors=False) - key_full, value_full = key_value_full.chunk(2, dim=1) - - if attention_mask is not None and attention_mask.shape[-2] == 1: - attention_mask = attention_mask.repeat(1, 1, query.shape[0], 1) - - query = rearrange(query, "s b (n d) -> b n s d", d=head_dim) # BNSD - key_full = rearrange(key_full, "s b (n d) -> b n s d", d=head_dim) # BNSD - value_full = rearrange(value_full, "s b (n d) -> b n s d", d=head_dim) # BNSD - hidden_states = torch_npu.npu_fused_infer_attention_score(query, key_full, value_full, - atten_mask=attention_mask, - input_layout="BNSD", - scale=1 / math.sqrt(head_dim), - num_heads=attn.heads)[0] - hidden_states = rearrange(hidden_states, "b n s d -> s b (n d)") - elif get_sequence_parallel_state() and sp_size <= 8: + tp_size = get_tensor_model_parallel_world_size() + if get_sequence_parallel_state(): query = attn.to_q(hidden_states) inner_dim = query.shape[-1] head_dim = inner_dim // attn.heads - h_size = attn.heads * head_dim - sp_size = get_sequence_parallel_world_size() - h_size_sp = h_size // sp_size + h_size_sp = inner_dim // sp_size query = query.view(-1, attn.heads, head_dim) # [s // sp, b, h * d] -> [s // sp * b, h, d] # apply all_to_all to gather sequence and split attention heads [s // sp * b, h, d] -> [s * b, h // sp, d] @@ -746,6 +704,8 @@ class AttnProcessor2_0: req_q.wait() query = query.view(-1, batch_size, h_size_sp) + if tp_size > 1 and self.use_rope: + query = query.chunk(2, dim=0)[get_tensor_model_parallel_rank()] req_k.wait() key = key.view(-1, batch_size, h_size_sp) req_v.wait() @@ -756,7 +716,10 @@ class AttnProcessor2_0: # require the shape of (batch_size x nheads x ntokens x dim) pos_thw = self.position_getter(batch_size, t=frame * sp_size, h=height, w=width, device=query.device) - query = self.rope(query, pos_thw) + if tp_size > 1: + query = self.rope(query, pos_thw, rope_split=True) + else: + query = self.rope(query, pos_thw) key = self.rope(key, pos_thw) query = query.view(-1, batch_size, h_size_sp) # SBH key = key.view(-1, batch_size, h_size_sp) # SBH @@ -775,10 +738,11 @@ class AttnProcessor2_0: head_num=attn.heads // sp_size)[0] hidden_states = rearrange(hidden_states, "b n s d -> s b (n d)") hidden_states = hidden_states.view(-1, attn.heads // sp_size, head_dim) - + if tp_size > 1 and self.use_rope: + hidden_states = get_tp_group().all_gather(hidden_states, separate_tensors=False) # [s * b, h // sp, d] -> [s // sp * b, h, d] -> [s // sp, b, h * d] _, hidden_states = all_to_all_sbh(hidden_states, scatter_dim=0, gather_dim=1) - hidden_states = hidden_states.view(-1, batch_size, h_size) + hidden_states = hidden_states.view(-1, batch_size, inner_dim) else: query = attn.to_q(hidden_states) if encoder_hidden_states is None: diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py index 601ceba290..c2e9d8ad05 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/opensora/models/diffusion/opensora/rope.py @@ -4,8 +4,8 @@ from diffusers.utils import logging from utils.parallel_mgr import( get_sequence_parallel_state, - get_sequence_parallel_world_size, - get_sequence_parallel_rank + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_rank, ) logger = logging.get_logger(__name__) @@ -71,8 +71,8 @@ class RoPE3D(torch.nn.Module): _cos = self.cache[pos1d, cos] _sin = self.cache[pos1d, sin] if rope_split: - rank = get_sequence_parallel_rank() - seqlen = _cos.shape[0] // get_sequence_parallel_world_size() + rank = get_tensor_model_parallel_rank() + seqlen = _cos.shape[0] // get_tensor_model_parallel_world_size() return torch_npu.npu_rotary_mul(tokens, _cos[seqlen * rank: seqlen * (rank + 1)], _sin[seqlen * rank: seqlen * (rank + 1)]) else: return torch_npu.npu_rotary_mul(tokens, _cos, _sin) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py index 24278e695b..3169ddb5cf 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/parallel_mgr.py @@ -10,12 +10,14 @@ from .group_coordinator import GroupCoordinator, SequenceParallelGroupCoordinato logger = logging.get_logger(__name__) _WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None _SP: Optional[SequenceParallelGroupCoordinator] = None _CFG: Optional[GroupCoordinator] = None @dataclass class ParallelConfig: + tp_degree: int = 1 sp_degree: int = 1 use_cfg_parallel: bool = False world_size: int = 1 @@ -25,12 +27,12 @@ class ParallelConfig: self.cfg_degree = 2 else: self.cfg_degree = 1 - if not self.sp_degree * self.cfg_degree <= self.world_size: + if not self.tp_degree * self.sp_degree * self.cfg_degree <= self.world_size: logger.error( "sp_degree * cfg_degree must be less than or equal to " "world_size because of classifier free guidance" ) - if not (self.world_size % (self.sp_degree * self.cfg_degree) == 0): + if not (self.world_size % (self.tp_degree * self.sp_degree * self.cfg_degree) == 0): logger.error("world_size must be divisible by sp_degree * cfg_degree") @@ -41,10 +43,36 @@ def get_world_group() -> GroupCoordinator: return _WORLD +# TP +def get_tp_group() -> GroupCoordinator: + if _TP is None: + logger.error("tensor model parallel group is not initialized") + return _TP + + +def get_tensor_parallel_state(): + """Return state for the tensor parallel group.""" + return _TP is not None + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + if not get_tensor_parallel_state(): + return 1 + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + if not get_tensor_parallel_state(): + return 1 + return get_tp_group().rank_in_group + + # SP def get_sp_group() -> SequenceParallelGroupCoordinator: if _SP is None: - logger.error("pipeline model parallel group is not initialized") + logger.error("sequence model parallel group is not initialized") return _SP @@ -156,6 +184,7 @@ def model_parallel_is_initialized(): return ( _CFG is not None and _SP is not None + and _TP is not None ) @@ -167,6 +196,7 @@ def init_model_parallel_group( **kwargs, ) -> GroupCoordinator: if parallel_mode not in [ + "tensor", "sequence", "classifier_free_guidance", ]: @@ -189,14 +219,16 @@ def init_model_parallel_group( def initialize_model_parallel( classifier_free_guidance_degree: int = 1, sequence_parallel_degree: int = 1, + tensor_parallel_degree: int = 1, backend: Optional[str] = None, ) -> None: """ Initialize model parallel groups. Arguments: - classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) - sequence_parallel_degree: number of GPUs used for sequence parallelism. + classifier_free_guidance_degree: number of NPUs used for Classifier Free Guidance (CFG) + tensor_parallel_degree: number of NPUs used for tensor parallelism. + sequence_parallel_degree: number of NPUs used for sequence parallelism. backend: distributed backend of pytorch collective comm. """ # Get world size and rank. Ensure some consistencies. @@ -208,19 +240,22 @@ def initialize_model_parallel( if ( world_size != classifier_free_guidance_degree + * tensor_parallel_degree * sequence_parallel_degree ): raise RuntimeError( f"world_size ({world_size}) is not equal to " + f"tensor_parallel_degree ({tensor_parallel_degree}) x " f"sequence_parallel_degree ({sequence_parallel_degree}) x" f"classifier_free_guidance_degree " f"({classifier_free_guidance_degree}) x" ) rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, sequence_parallel_degree, classifier_free_guidance_degree, - "sp-cfg", + "tp-sp-cfg", ) global _CFG @@ -243,6 +278,16 @@ def initialize_model_parallel( parallel_mode="sequence", ) + global _TP + if _TP is not None: + logger.error("Tensor parallel group is already initialized") + _TP = init_model_parallel_group( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + def destroy_model_parallel(): """Set the groups to none and destroy them.""" @@ -256,6 +301,11 @@ def destroy_model_parallel(): _SP.destroy() _SP = None + global _TP + if _TP: + _TP.destroy() + _TP = None + def destroy_distributed_environment(): global _WORLD @@ -274,6 +324,7 @@ def init_parallel_env(parallel_config: ParallelConfig): initialize_model_parallel( classifier_free_guidance_degree=parallel_config.cfg_degree, sequence_parallel_degree=parallel_config.sp_degree, + tensor_parallel_degree=parallel_config.tp_degree, ) diff --git a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py index bffe388d48..84638d0a5b 100644 --- a/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py +++ b/MindIE/MindIE-Torch/built-in/foundation/open_sora_planv1_2/utils/utils.py @@ -85,17 +85,20 @@ def generate_masked_orthogonal_rank_groups( class RankGenerator(object): def __init__( self, + tp: int, sp: int, cfg: int, order: str, rank_offset: int = 0, ) -> None: + self.tp = tp self.sp = sp self.cfg = cfg self.rank_offset = rank_offset - self.world_size = sp * cfg + self.world_size = tp * sp * cfg self.name_to_size = { + "tp": self.tp, "sp": self.sp, "cfg": self.cfg, } -- Gitee