From 837778def72fb220bb31d291a276c2bd325abcfb Mon Sep 17 00:00:00 2001 From: mazhixin00 Date: Wed, 26 Mar 2025 11:56:52 +0800 Subject: [PATCH 1/2] =?UTF-8?q?hunyuanvide=E9=83=A8=E5=88=86=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E4=B8=8A=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../HunyuanVideo/attentioncache_search.py | 93 +++++++ .../HunyuanVideo/ditcache_search.py | 162 +++++++++++ .../hyvideo/diffusion/__init__.py | 2 + .../hyvideo/diffusion/schedulers/__init__.py | 1 + .../scheduling_flow_match_discrete.py | 257 ++++++++++++++++++ MindIE/MultiModal/HunyuanVideo/prompts.txt | 6 + .../MultiModal/HunyuanVideo/requirements.txt | 17 ++ .../MultiModal/HunyuanVideo/sample_video.py | 146 ++++++++++ 8 files changed, 684 insertions(+) create mode 100644 MindIE/MultiModal/HunyuanVideo/attentioncache_search.py create mode 100644 MindIE/MultiModal/HunyuanVideo/ditcache_search.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/__init__.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/schedulers/__init__.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py create mode 100644 MindIE/MultiModal/HunyuanVideo/prompts.txt create mode 100644 MindIE/MultiModal/HunyuanVideo/requirements.txt create mode 100644 MindIE/MultiModal/HunyuanVideo/sample_video.py diff --git a/MindIE/MultiModal/HunyuanVideo/attentioncache_search.py b/MindIE/MultiModal/HunyuanVideo/attentioncache_search.py new file mode 100644 index 0000000000..63811faf36 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/attentioncache_search.py @@ -0,0 +1,93 @@ +import os +import time +import math +import logging +from typing import List +from pathlib import Path +import cv2 +import numpy as np +from hyvideo.config import parse_args +from hyvideo.inference import HunyuanVideoSampler +from attentioncache_search_tool import CacheSearcher +from mindiesd import CacheConfig, CacheAgent +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format = False + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("Cache Searcher") + + +def prepare_pipeline(): + args_init = parse_args() + models_root_path = Path(args_init.model_base) + if not models_root_path.exists(): + raise ValueError(f"`models_root` not exists: {models_root_path}") + + # Create save folder to save the samples + save_path = args_init.save_path if args_init.save_path_suffix == "" else f'{args_init.save_path}_{args_init.save_path_suffix}' + if not os.path.exists(args_init.save_path): + os.makedirs(save_path, exist_ok=True) + + # Load models + hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args_init) + return hunyuan_video_sampler + + +def generate_videos(prompts, hunyuan_video_sampler): + # prompts, num_sampling_steps, config, self.pipeline + # 通过config的成员对pipeline.transformer成员变量赋值 + + output_list = [] + for prompt_ in prompts: + # Start sampling + outputs = hunyuan_video_sampler.predict( + prompt=prompt_, + height=args.video_size[0], + width=args.video_size[1], + video_length=args.video_length, + seed=args.seed, + negative_prompt=args.neg_prompt, + infer_steps=args.infer_steps, + guidance_scale=args.cfg_scale, + num_videos_per_prompt=args.num_videos, + flow_shift=args.flow_shift, + batch_size=args.batch_size, + embedded_guidance_scale=args.embedded_cfg_scale + ) + samples = outputs['samples'] + + for j, _ in enumerate(samples): + output_list.append(samples[j]) + return output_list + +if __name__ == "__main__": + hunyuan_video_sampler = prepare_pipeline() + args = hunyuan_video_sampler.args + pipeline = hunyuan_video_sampler.pipeline + transformer = hunyuan_video_sampler.pipeline.transformer + + prompts = [ + "realistic style, a lone cowboy rides his horse across an open plain at beautiful sunset, soft light, warm colors", + "extreme close-up with a shallow depth of field of a puddle in a street. reflecting a busy futuristic Tokyo city with bright neon signs, night, lens flare" + ] + + config = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.double_blocks + transformer.single_blocks), + steps_count=args.infer_steps + ) + cache = CacheAgent(config) + for block in transformer.double_blocks + transformer.single_blocks: + block.cache = cache + + search = CacheSearcher(cache) + search.search_apply( + args.infer_steps, + 2.0, + generate_videos, + prompts=prompts, + hunyuan_video_sampler=hunyuan_video_sampler + ) \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/ditcache_search.py b/MindIE/MultiModal/HunyuanVideo/ditcache_search.py new file mode 100644 index 0000000000..24340001b7 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/ditcache_search.py @@ -0,0 +1,162 @@ +import os +import time +from pathlib import Path +from loguru import logger + +from hyvideo.utils.file_utils import save_videos_grid +from hyvideo.config import parse_args +from hyvideo.inference import HunyuanVideoSampler +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +import cv2 +import numpy as np +from tqdm import tqdm +from mindiesd.runtime.cache_manager import CacheManager, DitCacheConfig +from msmodelslim.pytorch.multimodal import DitCacheSearcherConfig, DitCacheSearcher +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format = False + + +def run_model_and_save_samples(config: DitCacheSearcherConfig, hunyuan_video_sampler): + # 通过config的成员对pipeline.transformer成员变量赋值 + dit_cache_config = DitCacheConfig( + step_start=config.cache_step_start, + step_interval=config.cache_step_interval, + block_start=config.cache_dit_block_start, + num_blocks=config.cache_num_dit_blocks + ) + cache = CacheManager(dit_cache_config) + if args.search_single_cache: + pipeline.transformer.cache_single = cache + elif args.search_double_cache: + pipeline.transformer.cache_dual = cache + + prompts = [ + "realistic style, a lone cowboy rides his horse across an open plain at beautiful sunset, soft light, warm colors", + "extreme close-up with a shallow depth of field of a puddle in a street. reflecting a busy futuristic Tokyo city with bright neon signs, night, lens flare", + "Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours", + "Timelapse of the northern lights dancing across the Arctic sky, stars twinkling, snow-covered landscape", + "A panning shot of a serene mountain landscape, slowly revealing snow-capped peaks, granite rocks and a crystal-clear lake reflecting the sky", + "moody shot of a central European alley film noir cinematic black and white high contrast high detail" + ] + for idx, prompt_ in enumerate(prompts): + # Start sampling + outputs = hunyuan_video_sampler.predict( + prompt=prompt_, + height=args.video_size[0], + width=args.video_size[1], + video_length=args.video_length, + seed=args.seed, + negative_prompt=args.neg_prompt, + infer_steps=args.infer_steps, + guidance_scale=args.cfg_scale, + num_videos_per_prompt=args.num_videos, + flow_shift=args.flow_shift, + batch_size=args.batch_size, + embedded_guidance_scale=args.embedded_cfg_scale + ) + samples = outputs['samples'] + + if config.cache_num_dit_blocks == 0: + video_path = os.path.join(args.save_path, f'sample_{idx:04d}_no_cache.mp4') + else: + video_path = os.path.join(args.save_path, + f'sample_{idx:04d}_{config.cache_dit_block_start}_{config.cache_step_interval}_{config.cache_num_dit_blocks}_{config.cache_step_start}.mp4') + # Save samples + if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: + for j, sample in enumerate(samples): + sample = samples[j].unsqueeze(0) + save_videos_grid(sample, video_path, fps=24) + logger.info(f'Sample save to: {video_path}') + + +def prepare_pipeline(): + args_init = parse_args() + models_root_path = Path(args_init.model_base) + if not models_root_path.exists(): + raise ValueError(f"`models_root` not exists: {models_root_path}") + + # Create save folder to save the samples + save_path = args_init.save_path if args_init.save_path_suffix == "" else f'{args_init.save_path}_{args_init.save_path_suffix}' + if not os.path.exists(args_init.save_path): + os.makedirs(save_path, exist_ok=True) + + # Load models + hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args_init) + return hunyuan_video_sampler + + +def search_single_block(): + block_num = len(pipeline.transformer.single_blocks) + print(f"single block num:{block_num}") + + config = DitCacheSearcherConfig( + dit_block_num=block_num, + prompts_num=1, + num_sampling_steps=args.infer_steps, + cache_ratio=args.cache_ratio, + search_cache_path=args.save_path, + cache_step_start=0, + cache_dit_block_start=0, + cache_num_dit_blocks=0 + ) + search_handler = DitCacheSearcher(config, hunyuan_video_sampler, run_model_and_save_samples) + cache_final_list = search_handler.search() + print(f'****************single block: cache_final_list in \ + [cache_dit_block_start, cache_step_interval, cache_num_dit_blocks, cache_step_start] order: {cache_final_list}') + + +def search_double_block(): + block_num = len(pipeline.transformer.double_blocks) + print(f"double block num:{block_num}") + + config = DitCacheSearcherConfig( + dit_block_num=block_num, + prompts_num=1, + num_sampling_steps=args.infer_steps, + cache_ratio=args.cache_ratio, + search_cache_path=args.save_path, + cache_step_start=0, + cache_dit_block_start=0, + cache_num_dit_blocks=0 + ) + search_handler = DitCacheSearcher(config, hunyuan_video_sampler, run_model_and_save_samples) + cache_final_list = search_handler.search() + print(f'****************double block: cache_final_list in \ + [cache_dit_block_start, cache_step_interval, \ + cache_num_dit_blocks, cache_step_start] order: {cache_final_list}') + + +def check_args(args): + if args.search_single_cache and args.search_double_cache: + print("Cache policies cannot be searched at the same time.") + return False + if (not args.search_single_cache) and (not args.search_double_cache): + print("Must specify search_single_cache or search_double_cache parameter.") + return False + if args.search_single_cache: + if not args.use_cache: + print(f"If you want to use the 'search_single_cache' parameter, \ + please specify the 'use_cache' parameter at the same time.") + return False + if args.search_double_cache: + if not args.use_cache_double: + print(f"If you want to use the 'search_double_cache' parameter, \ + please specify the 'use_cache_double' parameter at the same time.") + return False + return True + + +if __name__ == "__main__": + hunyuan_video_sampler = prepare_pipeline() + args = hunyuan_video_sampler.args + pipeline = hunyuan_video_sampler.pipeline + args_status = check_args(args) + + if args_status: + if args.search_single_cache: + search_single_block() + if args.search_double_cache: + search_double_block() + print("Search Done!!!!") \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/__init__.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/__init__.py new file mode 100644 index 0000000000..9e03c0b8bc --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/__init__.py @@ -0,0 +1,2 @@ +from .pipelines import HunyuanVideoPipeline +from .schedulers import FlowMatchDiscreteScheduler \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/schedulers/__init__.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/schedulers/__init__.py new file mode 100644 index 0000000000..6afd057baf --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/schedulers/__init__.py @@ -0,0 +1 @@ +from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py new file mode 100644 index 0000000000..8a5f21bc2c --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py @@ -0,0 +1,257 @@ +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class FlowMatchDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed] for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + reverse (`bool`, defaults to `True`): + Whether to reverse the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + reverse: bool = True, + solver: str = "euler", + n_tokens: Optional[int] = None, + ): + sigmas = torch.linspace(1, 0, num_train_timesteps + 1) + + if not reverse: + sigmas = sigmas.flip(0) + + self.sigmas = sigmas + # the value fed to model + self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32) + + self._step_index = None + self._begin_index = None + + self.supported_solver = ["euler"] + if solver not in self.supported_solver: + raise ValueError( + f"Solver {solver} not supported. Supported solvers: {self.supported_solver}" + ) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + n_tokens: int = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + """ + self.num_inference_steps = num_inference_steps + + sigmas = torch.linspace(1, 0, num_inference_steps + 1) + sigmas = self.sd3_time_shift(sigmas) + + if not self.config.reverse: + sigmas = 1 - sigmas + + self.sigmas = sigmas + self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to( + dtype=torch.float32, device=device + ) + + # Reset step index + self._step_index = None + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def scale_model_input( + self, sample: torch.Tensor, timestep: Optional[int] = None + ) -> torch.Tensor: + return sample + + def sd3_time_shift(self, t: torch.Tensor): + return (self.config.shift * t) / (1 + (self.config.shift - 1) * t) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + return_dict: bool = True, + ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + n_tokens (`int`, *optional*): + Number of tokens in the input sequence. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] + + if self.config.solver == "euler": + prev_sample = sample + model_output.to(torch.float32) * dt + else: + raise ValueError( + f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}" + ) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample) + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/prompts.txt b/MindIE/MultiModal/HunyuanVideo/prompts.txt new file mode 100644 index 0000000000..ac028cbaf2 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/prompts.txt @@ -0,0 +1,6 @@ +realistic style, a lone cowboy rides his horse across an open plain at beautiful sunset, soft light, warm colors. +extreme close-up with a shallow depth of field of a puddle in a street. reflecting a busy futuristic Tokyo city with bright neon signs, night, lens flare. +Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours. +Timelapse of the northern lights dancing across the Arctic sky, stars twinkling, snow-covered landscape. +A panning shot of a serene mountain landscape, slowly revealing snow-capped peaks, granite rocks and a crystal-clear lake reflecting the sky. +moody shot of a central European alley film noir cinematic black and white high contrast high detail. \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/requirements.txt b/MindIE/MultiModal/HunyuanVideo/requirements.txt new file mode 100644 index 0000000000..a34dfdf5bc --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/requirements.txt @@ -0,0 +1,17 @@ +opencv-python==4.9.0.80 +diffusers==0.31.0 +transformers==4.46.3 +tokenizers==0.20.3 +accelerate==1.1.1 +pandas==2.0.3 +numpy==1.24.4 +einops==0.7.0 +tqdm==4.66.2 +loguru==0.7.2 +imageio==2.34.0 +imageio-ffmpeg==0.5.1 +safetensors==0.4.3 +gradio==5.0.0 +yunchang==0.6.0 +deepspeed +torchvision==0.16.0 \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/sample_video.py b/MindIE/MultiModal/HunyuanVideo/sample_video.py new file mode 100644 index 0000000000..77044fd65d --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/sample_video.py @@ -0,0 +1,146 @@ +import os +import time +from pathlib import Path +from loguru import logger +from hyvideo.utils.file_utils import save_videos_grid +from hyvideo.config import parse_args +from hyvideo.inference import HunyuanVideoSampler +from mindiesd import CacheConfig, CacheAgent +import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format = False + + +def main(): + args = parse_args() + models_root_path = Path(args.model_base) + if not models_root_path.exists(): + raise ValueError(f"`models_root` not exists: {models_root_path}") + + # Create save folder to save the samples + save_path = args.save_path if args.save_path_suffix == "" else f'{args.save_path}_{args.save_path_suffix}' + if not os.path.exists(args.save_path): + os.makedirs(save_path, exist_ok=True) + + # Load models + hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args) + transformer = hunyuan_video_sampler.pipeline.transformer + + # Get the updated args + args = hunyuan_video_sampler.args + if args.prompt.endswith('txt'): + with open(args.prompt, 'r') as file: + text_prompt = file.readlines() + prompts = [line.strip() for line in text_prompt] + else: + prompts = [args.prompt] + + if args.use_cache: + # single + config_single = CacheConfig( + method="dit_block_cache", + blocks_count=len(transformer.single_blocks), + steps_count=args.infer_steps, + step_start=args.cache_start_steps, + step_interval=args.cache_interval, + step_end=args.infer_steps - 1, + block_start=args.single_block_start, + block_end=args.single_block_end + ) + cache_single = CacheAgent(config_single) + hunyuan_video_sampler.pipeline.transformer.cache_single = cache_single + if args.use_cache_double: + # double + config_double = CacheConfig( + method="dit_block_cache", + blocks_count=len(transformer.double_blocks), + steps_count=args.infer_steps, + step_start=args.cache_start_steps, + step_interval=args.cache_interval, + step_end=args.infer_steps - 1, + block_start=args.double_block_start, + block_end=args.double_block_end + ) + cache_dual = CacheAgent(config_double) + hunyuan_video_sampler.pipeline.transformer.cache_dual = cache_dual + + if args.use_attentioncache: + config_double = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.double_blocks), + steps_count=args.infer_steps, + step_start=args.start_step, + step_interval=args.attentioncache_interval, + step_end=args.end_step + ) + config_single = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.single_blocks), + steps_count=args.infer_steps, + step_start=args.start_step, + step_interval=args.attentioncache_interval, + step_end=args.end_step + ) + else: + config_double = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.double_blocks), + steps_count=args.infer_steps + ) + config_single = CacheConfig( + method="attention_cache", + blocks_count=len(transformer.single_blocks), + steps_count=args.infer_steps + ) + cache_double = CacheAgent(config_double) + cache_single = CacheAgent(config_single) + for block in transformer.double_blocks: + block.cache = cache_double + for block in transformer.single_blocks: + block.cache = cache_single + + # warmup + outputs = hunyuan_video_sampler.predict( + prompt=prompts[0], + height=args.video_size[0], + width=args.video_size[1], + video_length=args.video_length, + seed=args.seed, + negative_prompt=args.neg_prompt, + infer_steps=2, + guidance_scale=args.cfg_scale, + num_videos_per_prompt=args.num_videos, + flow_shift=args.flow_shift, + batch_size=args.batch_size, + embedded_guidance_scale=args.embedded_cfg_scale + ) + # Start sampling + for idx, prompt_ in enumerate(prompts): + outputs = hunyuan_video_sampler.predict( + prompt=prompt_, + height=args.video_size[0], + width=args.video_size[1], + video_length=args.video_length, + seed=args.seed, + negative_prompt=args.neg_prompt, + infer_steps=args.infer_steps, + guidance_scale=args.cfg_scale, + num_videos_per_prompt=args.num_videos, + flow_shift=args.flow_shift, + batch_size=args.batch_size, + embedded_guidance_scale=args.embedded_cfg_scale + ) + samples = outputs['samples'] + + # Save samples + if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: + for i, sample in enumerate(samples): + sample = samples[i].unsqueeze(0) + video_path = f"{save_path}/sample_{idx}.mp4" + save_videos_grid(sample, video_path, fps=24) + logger.info(f'Sample save to: {video_path}') + +if __name__ == "__main__": + main() \ No newline at end of file -- Gitee From eccceda15deb39b0d6ad7031b0dd1e6c78112922 Mon Sep 17 00:00:00 2001 From: mazhixin00 Date: Mon, 31 Mar 2025 11:54:52 +0800 Subject: [PATCH 2/2] =?UTF-8?q?hunyuanvideo=E9=83=A8=E5=88=86=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E4=B8=8A=E5=BA=93-v2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../HunyuanVideo/hyvideo/modules/__init__.py | 26 + .../hyvideo/modules/activation_layers.py | 23 + .../HunyuanVideo/hyvideo/modules/attention.py | 195 ++++ .../hyvideo/modules/attn_layer.py | 284 ++++++ .../hyvideo/modules/embed_layers.py | 157 ++++ .../HunyuanVideo/hyvideo/modules/fa.py | 72 ++ .../hyvideo/modules/fp8_optimization.py | 101 ++ .../hyvideo/modules/mlp_layers.py | 118 +++ .../HunyuanVideo/hyvideo/modules/models.py | 871 ++++++++++++++++++ .../hyvideo/modules/modulate_layers.py | 76 ++ .../hyvideo/modules/new_parallel.py | 219 +++++ .../hyvideo/modules/norm_layers.py | 75 ++ .../hyvideo/modules/posemb_layers.py | 310 +++++++ .../hyvideo/modules/token_refiner.py | 235 +++++ 14 files changed, 2762 insertions(+) create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/__init__.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/activation_layers.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/attention.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/attn_layer.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/embed_layers.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/fa.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/fp8_optimization.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/mlp_layers.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/models.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/modulate_layers.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/new_parallel.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/norm_layers.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/posemb_layers.py create mode 100644 MindIE/MultiModal/HunyuanVideo/hyvideo/modules/token_refiner.py diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/__init__.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/__init__.py new file mode 100644 index 0000000000..098191b69b --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/__init__.py @@ -0,0 +1,26 @@ +from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG + + +def load_model(args, in_channels, out_channels, factor_kwargs): + """load hunyuan video model + + Args: + args (dict): model args + in_channels (int): input channels number + out_channels (int): output channels number + factor_kwargs (dict): factor kwargs + + Returns: + model (nn.Module): The hunyuan video model + """ + if args.model in HUNYUAN_VIDEO_CONFIG.keys(): + model = HYVideoDiffusionTransformer( + args, + in_channels=in_channels, + out_channels=out_channels, + **HUNYUAN_VIDEO_CONFIG[args.model], + **factor_kwargs, + ) + return model + else: + raise NotImplementedError() \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/activation_layers.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/activation_layers.py new file mode 100644 index 0000000000..fd0e145ca8 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/activation_layers.py @@ -0,0 +1,23 @@ +import torch.nn as nn + + +def get_activation_layer(act_type): + """get activation layer + + Args: + act_type (str): the activation type + + Returns: + torch.nn.functional: the activation layer + """ + if act_type == "gelu": + return lambda: nn.GELU() + elif act_type == "gelu_tanh": + # Approximate `tanh` requires torch >= 1.13 + return lambda: nn.GELU(approximate="tanh") + elif act_type == "relu": + return nn.ReLU + elif act_type == "silu": + return nn.SiLU + else: + raise ValueError(f"Unknown activation type: {act_type}") \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/attention.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/attention.py new file mode 100644 index 0000000000..11a1b56f69 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/attention.py @@ -0,0 +1,195 @@ +import importlib.metadata +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_npu + +MEMORY_LAYOUT = { + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} +MAX_TOKEN=2147483647 + +def get_cu_seqlens(text_mask, img_len): + """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len + + Args: + text_mask (torch.Tensor): the mask of text + img_len (int): the length of image + + Returns: + torch.Tensor: the calculated cu_seqlens for flash attention + """ + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") + + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + + return cu_seqlens + + +def attention( + q, + k, + v, + mode="flash", + drop_rate=0, + attn_mask=None, + causal=False, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + q = pre_attn_layout(q) + k = pre_attn_layout(k) + v = pre_attn_layout(v) + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + if attn_mask is not None: + x = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal + ) + else: + scale = q.shape[-1] ** -0.5 + x1 = torch_npu.npu_prompt_flash_attention( + q[:, :, :cu_seqlens_q[1], :], + k[:, :, :cu_seqlens_kv[1], :], + v[:, :, :cu_seqlens_kv[1], :], + num_heads=q.shape[1], + input_layout="BNSD", + scale_value=scale, + pre_tokens=MAX_TOKEN, + next_tokens=MAX_TOKEN + ) + x2 = torch_npu.npu_prompt_flash_attention( + q[:, :, cu_seqlens_q[1]:, :], + k[:, :, cu_seqlens_kv[1]:, :], + v[:, :, cu_seqlens_kv[1]:, :], + num_heads=q.shape[1], + input_layout="BNSD", + scale_value=scale, + pre_tokens=MAX_TOKEN, + next_tokens=MAX_TOKEN + ) + x = torch.cat([x1, x2], dim=2) + + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert ( + attn_mask is None + ), "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril( + diagonal=0 + ) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + # TODO: Maybe force q and k to be float32 to avoid numerical overflow + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +def parallel_attention( + hybrid_seq_parallel_attn, + q, + k, + v, + img_q_len, + img_kv_len, + cu_seqlens_q, + cu_seqlens_kv, + scale +): + attn1 = hybrid_seq_parallel_attn( + None, + q[:, :img_q_len, :, :], + k[:, :img_kv_len, :, :], + v[:, :img_kv_len, :, :], + dropout_p=0.0, + causal=False, + joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]], + joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]], + joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]], + joint_strategy="rear", + scale=scale + ) + attn2 = torch_npu.npu_fusion_attention( + q[:, cu_seqlens_q[1]:], + k[:, cu_seqlens_kv[1]:], + v[:, cu_seqlens_kv[1]:], + head_num=q.shape[-2], + input_layout="BSND", + scale=scale, + pre_tockens=MAX_TOKEN, + next_tockens=MAX_TOKEN + )[0] + attn = torch.cat([attn1, attn2], dim=1) + b, s, a, d = attn.shape + attn = attn.reshape(b, s, -1) + + return attn \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/attn_layer.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/attn_layer.py new file mode 100644 index 0000000000..a4c37b00f0 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/attn_layer.py @@ -0,0 +1,284 @@ +import logging +import torch +from torch import Tensor +import torch_npu +import torch.distributed as dist +import math +import os +from yunchang import LongContextAttention +try: + from yunchang.kernels import AttnType +except ImportError: + raise ImportError("Please install yunchang 0.6.0 or later") +from typing import Any +from yunchang.comm.all_to_all import SeqAllToAll4D +from .new_parallel import all_to_all_v1, all_to_all_v2 +from .fa import attention_ATB, attention_FAScore, attention_LA +ALGO = int(os.getenv('ALGO')) +if ALGO == 1: + import torch_atb +if ALGO == 2: + from mindiesd import attention_forward + +logger = logging.getLogger(__name__) +MAX_TOKEN = 2147483647 + +class xFuserLongContextAttention(LongContextAttention): + ring_impl_type_supported_kv_cache = ["basic"] + + def __init__( + self, + scatter_idx: int = 2, + gather_idx: int = 1, + ring_impl_type: str = "basic", + use_pack_qkv: bool = False, + use_kv_cache: bool = False, + attn_type: AttnType = AttnType.FA, + divisable: bool = True + ) -> None: + """ + Arguments: + scatter_idx: int = 2, the scatter dimension index for Ulysses All2All + gather_idx: int = 1, the gather dimension index for Ulysses All2All + ring_impl_type: str = "basic", the ring implementation type, currently only support "basic" + use_pack_qkv: bool = False, whether to use pack qkv in the input + use_kv_cache: bool = False, whether to use kv cache in the attention layer, which is applied in PipeFusion. + """ + super().__init__( + scatter_idx=scatter_idx, + gather_idx=gather_idx, + ring_impl_type=ring_impl_type, + use_pack_qkv=use_pack_qkv, + attn_type = attn_type, + ) + self.use_kv_cache = use_kv_cache + if ( + use_kv_cache + and ring_impl_type not in self.ring_impl_type_supported_kv_cache + ): + raise RuntimeError( + f"ring_impl_type: {ring_impl_type} do not support SP kv cache." + ) + self.world_size = dist.get_world_size() + self.divisable = divisable + + self.algo = int(os.getenv('ALGO')) + self.self_attention = None + if self.algo == 1: + import torch_atb + self_attention_param = torch_atb.SelfAttentionParam() + if self.world_size == 8: + self_attention_param.head_num = 1 + self_attention_param.kv_head_num = 1 + elif self.world_size == 16: + self_attention_param.head_num = 3 + self_attention_param.kv_head_num = 3 + self_attention_param.qk_scale = 1.0 / math.sqrt(1.0 * 128) + self_attention_param.input_layout = torch_atb.TYPE_BNSD + self_attention_param.calc_type = torch_atb.SelfAttentionParam.CalcType.PA_ENCODER + self.self_attention = torch_atb.Operation(self_attention_param) + + def forward( + self, + attn, + query: Tensor, + key: Tensor, + value: Tensor, + *, + joint_tensor_query=None, + joint_tensor_key=None, + joint_tensor_value=None, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + joint_strategy="none", + scale=None + ) -> Tensor: + """forward + + Arguments: + attn (Attention): the attention module + query (Tensor): query input to the layer + key (Tensor): key input to the layer + value (Tensor): value input to the layer + args: other args, + joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy + joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy + joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy, + *args: the args same as flash_attn_interface + joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear" + + Returns: + * output (Tensor): context output + """ + is_joint = False + if (joint_tensor_query is not None and + joint_tensor_key is not None and + joint_tensor_value is not None): + supported_joint_strategy = ["front", "rear"] + if joint_strategy not in supported_joint_strategy: + raise ValueError( + f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}" + ) + elif joint_strategy == "rear": + query = torch.cat([query, joint_tensor_query], dim=1) + is_joint = True + else: + query = torch.cat([joint_tensor_query, query], dim=1) + is_joint = True + elif (joint_tensor_query is None and + joint_tensor_key is None and + joint_tensor_value is None): + pass + else: + raise ValueError( + f"joint_tensor_query, joint_tensor_key, and joint_tensor_value should be None or not None simultaneously." + ) + if is_joint: + ulysses_world_size = dist.get_world_size(self.ulysses_pg) + ulysses_rank = dist.get_rank(self.ulysses_pg) + attn_heads_per_ulysses_rank = ( + joint_tensor_key.shape[-2] // ulysses_world_size + ) + joint_tensor_key = joint_tensor_key[ + ..., + attn_heads_per_ulysses_rank + * ulysses_rank : attn_heads_per_ulysses_rank + * (ulysses_rank + 1), + :, + ] + joint_tensor_value = joint_tensor_value[ + ..., + attn_heads_per_ulysses_rank + * ulysses_rank : attn_heads_per_ulysses_rank + * (ulysses_rank + 1), + :, + ] + + parallel_method = 1 + if not self.divisable: + parallel_method = 2 + + if parallel_method == 1: + if self.world_size < 16: + # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) + # scatter 2, gather 1 + query_layer = SeqAllToAll4D.apply( + self.ulysses_pg, query, self.scatter_idx, self.gather_idx + ) + key_layer = SeqAllToAll4D.apply( + self.ulysses_pg, key, self.scatter_idx, self.gather_idx + ) + value_layer = SeqAllToAll4D.apply( + self.ulysses_pg, value, self.scatter_idx, self.gather_idx + ) + + key_layer = torch.cat([key_layer, joint_tensor_key], dim=1) + value_layer = torch.cat([value_layer, joint_tensor_value], dim=1) + + if self.algo == 1: + out = attention_ATB(query_layer, key_layer, value_layer, self.self_attention) + elif self.algo == 0: + out = attention_FAScore(query_layer, key_layer, value_layer, scale) + elif self.algo == 2: + out = attention_LA(query_layer, key_layer, value_layer, scale) + elif self.world_size == 16: + # 3 X (bs, seq_len/N, head_cnt, head_size) -> 3 X (bs, seq_len, head_cnt/N, head_size) + # scatter 2, gather 1 + query_layer = SeqAllToAll4D.apply( + self.ulysses_pg, query, self.scatter_idx, self.gather_idx + ) + key_value_layer = SeqAllToAll4D.apply( + self.ulysses_pg, torch.cat((key, value), dim=0), self.scatter_idx, self.gather_idx + ) + joint_tensor_key_value = torch.cat((joint_tensor_key, joint_tensor_value), dim=0) + + b, s, n, d = key_value_layer.shape + kv_full = torch.empty([2, b, s, n, d], dtype=query_layer.dtype, device=query_layer.device) + dist.all_gather_into_tensor(kv_full, key_value_layer, group=self.ring_pg) + + if self.algo == 1: + joint_tensor_key_value = joint_tensor_key_value.transpose(1, 2) + kv_full = kv_full.permute(1, 3, 0, 2, 4).reshape(b, n, -1, d) + kv_full = torch.cat((kv_full, joint_tensor_key_value), dim=2) + + query_layer = query_layer.transpose(1, 2) + key_layer, value_layer = kv_full.chunk(2, dim=0) + seqlen = torch.tensor([[query_layer.shape[2]], [key_layer.shape[2]]], dtype=torch.int32) + intensors = [query_layer, key_layer, value_layer, seqlen] + out = self.self_attention.forward(intensors)[0] + out = out.transpose(1, 2) + elif self.algo == 0: + joint_tensor_key_value = joint_tensor_key_value.transpose(1, 2) + kv_full = kv_full.permute(1, 3, 0, 2, 4).reshape(b, n, -1, d) + kv_full = torch.cat((kv_full, joint_tensor_key_value), dim=2) + + query_layer = query_layer.transpose(1, 2) + key_layer, value_layer = kv_full.chunk(2, dim=0) + out = torch_npu.npu_fusion_attention( + query_layer, + key_layer, + value_layer, + head_num=query_layer.shape[1], + input_layout="BNSD", + scale=scale, + pre_tockens=MAX_TOKEN, + next_tockens=MAX_TOKEN + )[0] + out = out.transpose(1, 2) + elif self.algo == 2: + kv_full = kv_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d) + kv_full = torch.cat((kv_full, joint_tensor_key_value), dim=1) + key_layer, value_layer = kv_full.chunk(2, dim=0) + out = attention_forward( + query_layer, + key_layer, + value_layer, + scale=scale, + opt_mode="manual", + op_type="ascend_laser_attention" + ) + + if isinstance(out, tuple): + context_layer, _, _ = out + else: + context_layer = out + + # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) + # scatter 1, gather 2 + output = SeqAllToAll4D.apply( + self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx + ) + else: + if self.world_size < 16: + output = all_to_all_v1( + query, + key, + value, + joint_tensor_key, + joint_tensor_value, + 2, + 1, + scale=scale, + algo = self.algo, + self_attention = self.self_attention) + elif self.world_size == 16: + output = all_to_all_v2( + query, + key, + value, + joint_tensor_key, + joint_tensor_value, + self.ulysses_pg, + self.ring_pg, + 2, + 1, + scale=scale, + algo = self.algo, + self_attention = self.self_attention) + # out e.g., [s/p::h] + return output \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/embed_layers.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/embed_layers.py new file mode 100644 index 0000000000..c095ea4d3f --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/embed_layers.py @@ -0,0 +1,157 @@ +import math +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from ..utils.helpers import to_2tuple + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.flatten = flatten + + self.proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + **factory_kwargs + ) + nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1)) + if bias: + nn.init.zeros_(self.proj.bias) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class TextProjection(nn.Module): + """ + Projects text embeddings. Also handles dropout for classifier-free guidance. + + Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py + """ + + def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.linear_1 = nn.Linear( + in_features=in_channels, + out_features=hidden_size, + bias=True, + **factory_kwargs + ) + self.act_1 = act_layer() + self.linear_2 = nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias=True, + **factory_kwargs + ) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + dim (int): the dimension of the output. + max_period (int): controls the minimum frequency of the embeddings. + + Returns: + embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. + + .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer, + frequency_embedding_size=256, + max_period=10000, + out_size=None, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, hidden_size, bias=True, **factory_kwargs + ), + act_layer(), + nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), + ) + nn.init.normal_(self.mlp[0].weight, std=0.02) + nn.init.normal_(self.mlp[2].weight, std=0.02) + + def forward(self, t): + t_freq = timestep_embedding( + t, self.frequency_embedding_size, self.max_period + ).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/fa.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/fa.py new file mode 100644 index 0000000000..510ba2cfc8 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/fa.py @@ -0,0 +1,72 @@ +import torch +import torch_npu +import os +ALGO = int(os.getenv('ALGO')) +if ALGO == 2: + from mindiesd import attention_forward + +MAX_TOKEN = 2147483647 + +def attention_ATB(query_layer, key_layer, value_layer, self_attention): + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + query_layer_list = query_layer.split(1, dim=1) + key_layer_list = key_layer.split(1, dim=1) + value_layer_list = value_layer.split(1, dim=1) + output = [] + for_loop = query_layer.shape[1] + for i in range(for_loop): + seqlen = torch.tensor([[query_layer.shape[2]], [key_layer.shape[2]]], dtype=torch.int32) + intensors = [query_layer_list[i], key_layer_list[i], value_layer_list[i], seqlen] + out = self_attention.forward(intensors)[0] + output.append(out) + out_concat = torch.cat(output, dim=1) + out = out_concat.transpose(1, 2) + return out + +def attention_FAScore(query_layer, key_layer, value_layer, scale): + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + query_layer_list = query_layer.split(1, dim=1) + key_layer_list = key_layer.split(1, dim=1) + value_layer_list = value_layer.split(1, dim=1) + output = [] + for_loop = query_layer.shape[1] + for i in range(for_loop): + out = torch_npu.npu_fusion_attention( + query_layer_list[i], + key_layer_list[i], + value_layer_list[i], + head_num=1, + input_layout="BNSD", + scale=scale, + pre_tockens=MAX_TOKEN, + next_tockens=MAX_TOKEN + )[0] + output.append(out) + out_concat = torch.cat(output, dim=1) + out = out_concat.transpose(1, 2) + return out + +def attention_LA(query_layer, key_layer, value_layer, scale): + # BSND + query_layer_list = query_layer.split(1, dim=2) + key_layer_list = key_layer.split(1, dim=2) + value_layer_list = value_layer.split(1, dim=2) + output = [] + for_loop = query_layer.shape[2] + for i in range(for_loop): + out = attention_forward( + query_layer_list[i], + key_layer_list[i], + value_layer_list[i], + scale=scale, + opt_mode="manual", + op_type="ascend_laser_attention" + ) + output.append(out) + out_concat = torch.cat(output, dim=2) + return out_concat + diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/fp8_optimization.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/fp8_optimization.py new file mode 100644 index 0000000000..f02d87a091 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/fp8_optimization.py @@ -0,0 +1,101 @@ +import os + +import torch +import torch.nn as nn +from torch.nn import functional as F + +def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1): + _bits = torch.tensor(bits) + _mantissa_bit = torch.tensor(mantissa_bit) + _sign_bits = torch.tensor(sign_bits) + M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits) + E = _bits - _sign_bits - M + bias = 2 ** (E - 1) - 1 + mantissa = 1 + for i in range(mantissa_bit - 1): + mantissa += 1 / (2 ** (i+1)) + maxval = mantissa * 2 ** (2**E - 1 - bias) + return maxval + +def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1): + """ + Default is E4M3. + """ + bits = torch.tensor(bits) + mantissa_bit = torch.tensor(mantissa_bit) + sign_bits = torch.tensor(sign_bits) + M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits) + E = bits - sign_bits - M + bias = 2 ** (E - 1) - 1 + mantissa = 1 + for i in range(mantissa_bit - 1): + mantissa += 1 / (2 ** (i+1)) + maxval = mantissa * 2 ** (2**E - 1 - bias) + minval = - maxval + minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval) + input_clamp = torch.min(torch.max(x, minval), maxval) + log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0) + log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype)) + # dequant + qdq_out = torch.round(input_clamp / log_scales) * log_scales + return qdq_out, log_scales + +def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1): + for i in range(len(x.shape) - 1): + scale = scale.unsqueeze(-1) + new_x = x / scale + quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits) + return quant_dequant_x, scale, log_scales + +def fp8_activation_dequant(qdq_out, scale, dtype): + qdq_out = qdq_out.type(dtype) + quant_dequant_x = qdq_out * scale.to(dtype) + return quant_dequant_x + +def fp8_linear_forward(cls, original_dtype, input): + weight_dtype = cls.weight.dtype + ##### + if cls.weight.dtype != torch.float8_e4m3fn: + maxval = get_fp_maxval() + scale = torch.max(torch.abs(cls.weight.flatten())) / maxval + linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale) + linear_weight = linear_weight.to(torch.float8_e4m3fn) + weight_dtype = linear_weight.dtype + else: + scale = cls.fp8_scale.to(cls.weight.device) + linear_weight = cls.weight + ##### + + if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0: + if True or len(input.shape) == 3: + cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype) + if cls.bias != None: + output = F.linear(input, cls_dequant, cls.bias) + else: + output = F.linear(input, cls_dequant) + return output + else: + return cls.original_forward(input.to(original_dtype)) + else: + return cls.original_forward(input) + +def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}): + setattr(module, "fp8_matmul_enabled", True) + + # loading fp8 mapping file + fp8_map_path = dit_weight_path.replace('.pt', '_map.pt') + if os.path.exists(fp8_map_path): + fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage) + else: + raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.") + + fp8_layers = [] + for key, layer in module.named_modules(): + if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key): + fp8_layers.append(key) + original_forward = layer.forward + layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn)) + setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype)) + setattr(layer, "original_forward", original_forward) + setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input)) + diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/mlp_layers.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/mlp_layers.py new file mode 100644 index 0000000000..23e35db3fb --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/mlp_layers.py @@ -0,0 +1,118 @@ +# Modified from timm library: +# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13 + +from functools import partial + +import torch +import torch.nn as nn + +from .modulate_layers import modulate +from ..utils.helpers import to_2tuple + + +class MLP(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer( + in_channels, hidden_channels, bias=bias[0], **factory_kwargs + ) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_channels, **factory_kwargs) + if norm_layer is not None + else nn.Identity() + ) + self.fc2 = linear_layer( + hidden_channels, out_features, bias=bias[1], **factory_kwargs + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +# +class MLPEmbedder(nn.Module): + """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py""" + def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class FinalLayer(nn.Module): + """The final layer of DiT.""" + + def __init__( + self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + # Just use LayerNorm for the final layer + self.norm_final = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + if isinstance(patch_size, int): + self.linear = nn.Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + **factory_kwargs + ) + else: + self.linear = nn.Linear( + hidden_size, + patch_size[0] * patch_size[1] * patch_size[2] * out_channels, + bias=True, + ) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + # Here we don't distinguish between the modulate types. Just use the simple one. + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/models.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/models.py new file mode 100644 index 0000000000..af6e7e61bb --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/models.py @@ -0,0 +1,871 @@ +from typing import Any, List, Tuple, Optional, Union, Dict +from einops import rearrange + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.models import ModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config + +from .activation_layers import get_activation_layer +from .norm_layers import get_norm_layer +from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection +from .attention import attention, parallel_attention, get_cu_seqlens +from .posemb_layers import apply_rotary_emb +from .mlp_layers import MLP, MLPEmbedder, FinalLayer +from .modulate_layers import ModulateDiT, modulate, apply_gate +from .token_refiner import SingleTokenRefiner + + +class MMDoubleStreamBlock(nn.Module): + """ + A multimodal dit block with seperate modulation for + text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206 + (Flux.1): https://github.com/black-forest-labs/flux + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qkv_bias: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + self.scale = head_dim ** -0.5 + + self.img_mod = ModulateDiT( + hidden_size, + factor=6, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.img_norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.img_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.img_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.img_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.img_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.img_norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + self.img_mlp = MLP( + hidden_size, + mlp_hidden_dim, + act_layer=get_activation_layer(mlp_act_type), + bias=True, + **factory_kwargs, + ) + + self.txt_mod = ModulateDiT( + hidden_size, + factor=6, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.txt_norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.txt_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs + ) + self.txt_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.txt_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.txt_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.txt_norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + self.txt_mlp = MLP( + hidden_size, + mlp_hidden_dim, + act_layer=get_activation_layer(mlp_act_type), + bias=True, + **factory_kwargs, + ) + self.hybrid_seq_parallel_attn = None + + # ATTENTION_CACHE PARAMETER + self.cache = None + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + def double_forward( + self, + img, txt, + img_mod1_shift, + img_mod1_scale, + txt_mod1_shift, + txt_mod1_scale, + freqs_cis, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv + ): + # Prepare image for attention. + img_modulated = self.img_norm1(img) + img_modulated = modulate( + img_modulated, shift=img_mod1_shift, scale=img_mod1_scale + ) + img_qkv = self.img_attn_qkv(img_modulated) + img_q, img_k, img_v = rearrange( + img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num + ) + # Apply QK-Norm if needed + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + + # Apply RoPE if needed. + if freqs_cis is not None: + img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_qq.shape == img_q.shape and img_kk.shape == img_k.shape + ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" + img_q, img_k = img_qq, img_kk + + # Prepare txt for attention. + txt_modulated = self.txt_norm1(txt) + txt_modulated = modulate( + txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale + ) + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange( + txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num + ) + # Apply QK-Norm if needed. + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + + # Run actual attention. + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + v = torch.cat((img_v, txt_v), dim=1) + assert ( + cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1 + ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}" + + # attention computation start + if not self.hybrid_seq_parallel_attn: + attn = attention( + q, + k, + v, + mode="torch", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + batch_size=img_k.shape[0], + ) + else: + attn = parallel_attention( + self.hybrid_seq_parallel_attn, + q, + k, + v, + img_q_len=img_q.shape[1], + img_kv_len=img_k.shape[1], + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + scale=self.scale + ) + + return attn + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + freqs_cis: tuple = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(vec).chunk(6, dim=-1) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(vec).chunk(6, dim=-1) + + attn = self.cache.apply(self.double_forward, + img=img, txt=txt, + img_mod1_shift=img_mod1_shift, + img_mod1_scale=img_mod1_scale, + txt_mod1_shift=txt_mod1_shift, + txt_mod1_scale=txt_mod1_scale, + freqs_cis=freqs_cis, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv) + + # attention computation end + + img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] + + # Calculate the img bloks. + img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + img = img + apply_gate( + self.img_mlp( + modulate( + self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale + ) + ), + gate=img_mod2_gate, + ) + + # Calculate the txt bloks. + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + txt = txt + apply_gate( + self.txt_mlp( + modulate( + self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale + ) + ), + gate=txt_mod2_gate, + ) + + return img, txt + + +class MMSingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + Also refer to (SD3): https://arxiv.org/abs/2403.03206 + (Flux.1): https://github.com/black-forest-labs/flux + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + self.scale = qk_scale or head_dim ** -0.5 + + # qkv and mlp_in + self.linear1 = nn.Linear( + hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs + ) + # proj and mlp_out + self.linear2 = nn.Linear( + hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs + ) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + + self.pre_norm = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs + ) + + self.mlp_act = get_activation_layer(mlp_act_type)() + self.modulation = ModulateDiT( + hidden_size, + factor=3, + act_layer=get_activation_layer("silu"), + **factory_kwargs, + ) + self.hybrid_seq_parallel_attn = None + + # ATTENTION_CACHE PARAMETER + self.cache = None + + def enable_deterministic(self): + self.deterministic = True + + def disable_deterministic(self): + self.deterministic = False + + def single_forward( + self, + qkv, + freqs_cis, + txt_len, + x, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv + ): + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + # Apply RoPE if needed. + if freqs_cis is not None: + img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_qq.shape == img_q.shape and img_kk.shape == img_k.shape + ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}" + img_q, img_k = img_qq, img_kk + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + + # Compute attention. + assert ( + cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1 + ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}" + + # attention computation start + if not self.hybrid_seq_parallel_attn: + attn = attention( + q, + k, + v, + mode="torch", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + batch_size=x.shape[0], + ) + else: + attn = parallel_attention( + self.hybrid_seq_parallel_attn, + q, + k, + v, + img_q_len=img_q.shape[1], + img_kv_len=img_k.shape[1], + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + scale=self.scale + ) + # attention computation end + return attn + + def forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + txt_len: int, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) + x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) + qkv, mlp = torch.split( + self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + attn = self.cache.apply(self.single_forward, + qkv=qkv, + freqs_cis=freqs_cis, + txt_len=txt_len, + x=x, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv + ) + + # Compute activation in mlp stream, cat again and run second linear layer. + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + apply_gate(output, gate=mod_gate) + + +class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): + """ + HunyuanVideo Transformer backbone + + Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline. + + Reference: + [1] Flux.1: https://github.com/black-forest-labs/flux + [2] MMDiT: http://arxiv.org/abs/2403.03206 + + Parameters + ---------- + args: argparse.Namespace + The arguments parsed by argparse. + patch_size: list + The size of the patch. + in_channels: int + The number of input channels. + out_channels: int + The number of output channels. + hidden_size: int + The hidden size of the transformer backbone. + heads_num: int + The number of attention heads. + mlp_width_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + mlp_act_type: str + The activation function of the MLP in the transformer block. + depth_double_blocks: int + The number of transformer blocks in the double blocks. + depth_single_blocks: int + The number of transformer blocks in the single blocks. + rope_dim_list: list + The dimension of the rotary embedding for t, h, w. + qkv_bias: bool + Whether to use bias in the qkv linear layer. + qk_norm: bool + Whether to use qk norm. + qk_norm_type: str + The type of qk norm. + guidance_embed: bool + Whether to use guidance embedding for distillation. + text_projection: str + The type of the text projection, default is single_refiner. + use_attention_mask: bool + Whether to use attention mask for text encoder. + dtype: torch.dtype + The dtype of the model. + device: torch.device + The device of the model. + """ + + @register_to_config + def __init__( + self, + args: Any, + patch_size: list = [1, 2, 2], + in_channels: int = 4, # Should be VAE.config.latent_channels. + out_channels: int = None, + hidden_size: int = 3072, + heads_num: int = 24, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + mm_double_blocks_depth: int = 20, + mm_single_blocks_depth: int = 40, + rope_dim_list: List[int] = [16, 56, 56], + qkv_bias: bool = True, + qk_norm: bool = True, + qk_norm_type: str = "rms", + guidance_embed: bool = False, # For modulation. + text_projection: str = "single_refiner", + use_attention_mask: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.unpatchify_channels = self.out_channels + self.guidance_embed = guidance_embed + self.rope_dim_list = rope_dim_list + + # Text projection. Default to linear projection. + # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831 + self.use_attention_mask = use_attention_mask + self.text_projection = text_projection + + self.text_states_dim = args.text_states_dim + self.text_states_dim_2 = args.text_states_dim_2 + + if hidden_size % heads_num != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}" + ) + pe_dim = hidden_size // heads_num + if sum(rope_dim_list) != pe_dim: + raise ValueError( + f"Got {rope_dim_list} but expected positional dim {pe_dim}" + ) + self.hidden_size = hidden_size + self.heads_num = heads_num + + # image projection + self.img_in = PatchEmbed( + self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs + ) + + # text projection + if self.text_projection == "linear": + self.txt_in = TextProjection( + self.text_states_dim, + self.hidden_size, + get_activation_layer("silu"), + **factory_kwargs, + ) + elif self.text_projection == "single_refiner": + self.txt_in = SingleTokenRefiner( + self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs + ) + else: + raise NotImplementedError( + f"Unsupported text_projection: {self.text_projection}" + ) + + # time modulation + self.time_in = TimestepEmbedder( + self.hidden_size, get_activation_layer("silu"), **factory_kwargs + ) + + # text modulation + self.vector_in = MLPEmbedder( + self.text_states_dim_2, self.hidden_size, **factory_kwargs + ) + + # guidance modulation + self.guidance_in = ( + TimestepEmbedder( + self.hidden_size, get_activation_layer("silu"), **factory_kwargs + ) + if guidance_embed + else None + ) + + # double blocks + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) + for _ in range(mm_double_blocks_depth) + ] + ) + + # single blocks + self.single_blocks = nn.ModuleList( + [ + MMSingleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(mm_single_blocks_depth) + ] + ) + + self.final_layer = FinalLayer( + self.hidden_size, + self.patch_size, + self.out_channels, + get_activation_layer("silu"), + **factory_kwargs, + ) + + self.use_cache = args.use_cache + self.use_cache_double = args.use_cache_double + self.cache_single = None + self.cache_dual = None + + def enable_deterministic(self): + for block in self.double_blocks: + block.enable_deterministic() + for block in self.single_blocks: + block.enable_deterministic() + + def disable_deterministic(self): + for block in self.double_blocks: + block.disable_deterministic() + for block in self.single_blocks: + block.disable_deterministic() + + def seq_forward( + self, + img, + txt, + text_mask, + freqs_cos, + freqs_sin, + vec, + t_idx + ): + txt_seq_len = txt.shape[1] + img_seq_len = img.shape[1] + + # Compute cu_squlens and max_seqlen for flash attention + cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q + + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + # --------------------- Pass through DiT blocks ------------------------ + for i, block in enumerate(self.double_blocks): + double_block_args = [ + img, + txt, + vec, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + freqs_cis + ] + if not self.use_cache_double: + img, txt = block(*double_block_args) + else: + img, txt = self.cache_dual.apply(block, + hidden_states=img, + encoder_hidden_states=txt, + vec=vec, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + freqs_cis=freqs_cis, + ) + + # Merge txt and img to pass through single stream blocks. + x = torch.cat((img, txt), 1) + if len(self.single_blocks) > 0: + for i, block in enumerate(self.single_blocks): + single_block_args = [ + x, + vec, + txt_seq_len, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + (freqs_cos, freqs_sin) + ] + if not self.use_cache: + x = block(*single_block_args) + else: + x = self.cache_single.apply(block, + hidden_states=x, + vec=vec, + txt_len=txt_seq_len, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + freqs_cis=freqs_cis, + ) + + img = x[:, :img_seq_len, ...] + + # ---------------------------- Final layer ------------------------------ + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + + return img + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, # Should be in range(0, 1000). + text_states: torch.Tensor = None, + text_mask: torch.Tensor = None, # Now we don't use it. + text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. + freqs_cos: Optional[torch.Tensor] = None, + freqs_sin: Optional[torch.Tensor] = None, + guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. + return_dict: bool = True, + t_idx: torch.Tensor = 0 + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + out = {} + img = x + txt = text_states + _, _, ot, oh, ow = x.shape + tt, th, tw = ( + ot // self.patch_size[0], + oh // self.patch_size[1], + ow // self.patch_size[2], + ) + + # Prepare modulation vectors. + vec = self.time_in(t) + + # text modulation + vec = vec + self.vector_in(text_states_2) + + # guidance modulation + if self.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + + # our timestep_embedding is merged into guidance_in(TimestepEmbedder) + vec = vec + self.guidance_in(guidance) + + # Embed image and text. + img = self.img_in(img) + if self.text_projection == "linear": + txt = self.txt_in(txt) + elif self.text_projection == "single_refiner": + txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) + else: + raise NotImplementedError( + f"Unsupported text_projection: {self.text_projection}" + ) + + img = self.seq_forward( + img, + txt, + text_mask, + freqs_cos, + freqs_sin, + vec, + t_idx + ) + + img = self.unpatchify(img, tt, th, tw) + if return_dict: + out["x"] = img + return out + return img + + def unpatchify(self, x, t, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.unpatchify_channels + pt, ph, pw = self.patch_size + assert t * h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) + x = torch.einsum("nthwcopq->nctohpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + + return imgs + + def params_count(self): + counts = { + "double": sum( + [ + sum(p.numel() for p in block.img_attn_qkv.parameters()) + + sum(p.numel() for p in block.img_attn_proj.parameters()) + + sum(p.numel() for p in block.img_mlp.parameters()) + + sum(p.numel() for p in block.txt_attn_qkv.parameters()) + + sum(p.numel() for p in block.txt_attn_proj.parameters()) + + sum(p.numel() for p in block.txt_mlp.parameters()) + for block in self.double_blocks + ] + ), + "single": sum( + [ + sum(p.numel() for p in block.linear1.parameters()) + + sum(p.numel() for p in block.linear2.parameters()) + for block in self.single_blocks + ] + ), + "total": sum(p.numel() for p in self.parameters()), + } + counts["attn+mlp"] = counts["double"] + counts["single"] + return counts + + +################################################################################# +# HunyuanVideo Configs # +################################################################################# + +HUNYUAN_VIDEO_CONFIG = { + "HYVideo-T/2": { + "mm_double_blocks_depth": 20, + "mm_single_blocks_depth": 40, + "rope_dim_list": [16, 56, 56], + "hidden_size": 3072, + "heads_num": 24, + "mlp_width_ratio": 4, + }, + "HYVideo-T/2-cfgdistill": { + "mm_double_blocks_depth": 20, + "mm_single_blocks_depth": 40, + "rope_dim_list": [16, 56, 56], + "hidden_size": 3072, + "heads_num": 24, + "mlp_width_ratio": 4, + "guidance_embed": True, + }, +} \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/modulate_layers.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/modulate_layers.py new file mode 100644 index 0000000000..65f8c0d939 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/modulate_layers.py @@ -0,0 +1,76 @@ +from typing import Callable + +import torch +import torch.nn as nn + + +class ModulateDiT(nn.Module): + """Modulation layer for DiT.""" + def __init__( + self, + hidden_size: int, + factor: int, + act_layer: Callable, + dtype=None, + device=None, + ): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.act = act_layer() + self.linear = nn.Linear( + hidden_size, factor * hidden_size, bias=True, **factory_kwargs + ) + # Zero-initialize the modulation + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.act(x)) + + +def modulate(x, shift=None, scale=None): + """modulate by shift and scale + + Args: + x (torch.Tensor): input tensor. + shift (torch.Tensor, optional): shift tensor. Defaults to None. + scale (torch.Tensor, optional): scale tensor. Defaults to None. + + Returns: + torch.Tensor: the output tensor after modulate. + """ + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) + elif scale is None: + return x + shift.unsqueeze(1) + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False): + """AI is creating summary for apply_gate + + Args: + x (torch.Tensor): input tensor. + gate (torch.Tensor, optional): gate tensor. Defaults to None. + tanh (bool, optional): whether to use tanh function. Defaults to False. + + Returns: + torch.Tensor: the output tensor after apply gate. + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +def ckpt_wrapper(module): + def ckpt_forward(*inputs): + outputs = module(*inputs) + return outputs + + return ckpt_forward \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/new_parallel.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/new_parallel.py new file mode 100644 index 0000000000..96f5c794af --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/new_parallel.py @@ -0,0 +1,219 @@ +#!/bin/bash +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.distributed as dist +import os +algo = int(os.getenv('ALGO')) +if algo == 2: + from mindiesd import attention_forward +from .fa import attention_ATB, attention_FAScore, attention_LA + +MAX_TOKEN = 2147483647 + +def all_to_all_v1( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + joint_tensor_key: torch.Tensor, + joint_tensor_value: torch.Tensor, + scatter_dim: int, + gather_dim: int, + **kwargs +): + scale = kwargs.get("scale", 1.0) + algo = kwargs.get("algo", 0) + self_attention = kwargs.get("self_attention", None) + world_size = dist.get_world_size() + rank = dist.get_rank() + len_joint = joint_tensor_key.shape[1] + global SEQ + b, s, n, d = k.shape + each_n = int(n // world_size) + + output_q_list = [torch.empty([b, s_i+len_joint, each_n, d], device=q.device, dtype=q.dtype) for s_i in SEQ] + output_k_list = [torch.empty([b, s_i, each_n, d], device=k.device, dtype=k.dtype) for s_i in SEQ] + output_v_list = [torch.empty([b, s_i, each_n, d], device=v.device, dtype=v.dtype) for s_i in SEQ] + q_list = [t.contiguous() for t in torch.tensor_split(q, world_size, scatter_dim)] + k_list = [t.contiguous() for t in torch.tensor_split(k, world_size, scatter_dim)] + v_list = [t.contiguous() for t in torch.tensor_split(v, world_size, scatter_dim)] + dist.all_to_all(output_q_list, q_list) + dist.all_to_all(output_k_list, k_list) + dist.all_to_all(output_v_list, v_list) + + query = torch.cat(output_q_list, dim=gather_dim).contiguous() + key = torch.cat(output_k_list, dim=gather_dim).contiguous() + value = torch.cat(output_v_list, dim=gather_dim).contiguous() + key = torch.cat([key, joint_tensor_key], dim=1) + value = torch.cat([value, joint_tensor_value], dim=1) + + if algo == 0: + output = attention_FAScore(query, key, value, scale) + elif algo == 1: + output = attention_ATB(query, key, value, self_attention) + elif algo == 2: + output = attention_LA(query, key, value, scale) + + output_shape = [b, SEQ[0]+len_joint, each_n, d] if rank < world_size - 1 else [b, SEQ[-1]+len_joint, each_n, d] + output_list = [torch.empty(output_shape, device=output.device, dtype=output.dtype) for _ in SEQ] + + SEQ_joint = [i + len_joint for i in SEQ] + output_con = [chunk.contiguous() for chunk in torch.split(output, SEQ_joint, dim=gather_dim)] + + dist.all_to_all(output_list, output_con) + output = torch.cat(output_list, dim=scatter_dim) + + return output + +def all_to_all_v2( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + joint_tensor_key: torch.Tensor, + joint_tensor_value: torch.Tensor, + ulysses_pg: None, + ring_pg: None, + scatter_dim: int, + gather_dim: int, + **kwargs +): + ulysses_ranks_even = [0, 2, 4, 6, 8, 10, 12, 14] + ulysses_ranks_odd = [1, 3, 5, 7, 9, 11, 13, 15] + scale = kwargs.get("scale", 1.0) + algo = kwargs.get("algo", 0) + self_attention = kwargs.get("self_attention", None) + ulysses_world_size = dist.get_world_size(ulysses_pg) + rank = dist.get_rank() + + b, s, n, d = k.shape + each_n = int(n // ulysses_world_size) + len_joint = joint_tensor_key.shape[1] + + target_ranks = ulysses_ranks_even if rank in ulysses_ranks_even else ulysses_ranks_odd + + output_q_list = [torch.empty([b, SEQ[rank_idx]+len_joint, each_n, d], device=q.device, dtype=q.dtype) for rank_idx in target_ranks] + output_k_list = [torch.empty([b, SEQ[rank_idx], each_n, d], device=k.device, dtype=k.dtype) for rank_idx in target_ranks] + output_v_list = [torch.empty([b, SEQ[rank_idx], each_n, d], device=v.device, dtype=v.dtype) for rank_idx in target_ranks] + q_list = [t.contiguous() for t in torch.tensor_split(q, ulysses_world_size, scatter_dim)] + k_list = [t.contiguous() for t in torch.tensor_split(k, ulysses_world_size, scatter_dim)] + v_list = [t.contiguous() for t in torch.tensor_split(v, ulysses_world_size, scatter_dim)] + dist.all_to_all(output_q_list, q_list, group=ulysses_pg) + dist.all_to_all(output_k_list, k_list, group=ulysses_pg) + dist.all_to_all(output_v_list, v_list, group=ulysses_pg) + + query_layer = torch.cat(output_q_list, dim=gather_dim).contiguous() + key = torch.cat(output_k_list, dim=gather_dim).contiguous() + value = torch.cat(output_v_list, dim=gather_dim).contiguous() + + if rank in ulysses_ranks_odd: + b, s, n, d = key.shape + pad_s = SEQ[0] * 8 - s + padding = torch.zeros([b, pad_s, n, d], dtype=key.dtype, device=key.device) + key = torch.cat([key, padding], dim=gather_dim) + value = torch.cat([value, padding], dim=gather_dim) + b, s, n, d = key.shape + k_full = torch.empty([2, b, s, n, d], dtype=key.dtype, device=key.device) + v_full = torch.empty([2, b, s, n, d], dtype=value.dtype, device=value.device) + dist.all_gather_into_tensor(k_full, key, group=ring_pg) + dist.all_gather_into_tensor(v_full, value, group=ring_pg) + k_full = k_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d) + v_full = v_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d) + key = k_full[:, :sum(SEQ), :, :] + value = v_full[:, :sum(SEQ), :, :] + + key_layer = torch.cat([key, joint_tensor_key], dim=gather_dim) + value_layer = torch.cat([value, joint_tensor_value], dim=gather_dim) + if algo == 1: + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + seqlen = torch.tensor([[query_layer.shape[2]], [key_layer.shape[2]]], dtype=torch.int32) + intensors = [query_layer, key_layer, value_layer, seqlen] + out = self_attention.forward(intensors)[0] + out = out.transpose(1, 2) + elif algo == 0: + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + out = torch_npu.npu_fusion_attention( + query_layer, + key_layer, + value_layer, + head_num=query_layer.shape[1], + input_layout="BNSD", + scale=scale, + pre_tockens=MAX_TOKEN, + next_tockens=MAX_TOKEN + )[0] + out = out.transpose(1, 2) + elif algo == 2: + out = attention_forward( + query_layer, + key_layer, + value_layer, + scale=scale, + opt_mode="manual", + op_type="ascend_laser_attention" + ) + output_shape = [b, SEQ[rank]+len_joint, each_n, d] + output_list = [torch.empty(output_shape, device=out.device, dtype=out.dtype) for _ in ulysses_ranks_even] + + if rank in ulysses_ranks_even: + SEQ_joint = [SEQ[rank]+len_joint for rank in ulysses_ranks_even] + else: + SEQ_joint = [SEQ[rank]+len_joint for rank in ulysses_ranks_odd] + + output_con = [chunk.contiguous() for chunk in torch.split(out, SEQ_joint, dim=gather_dim)] + dist.all_to_all(output_list, output_con, group=ulysses_pg) + output = torch.cat(output_list, dim=scatter_dim) + + return output + +SEQ = None + +def split_sequence(input_, dim=1): + world_size = dist.get_world_size() + rank = dist.get_rank() + if world_size == 1: + return input_ + + tensor_list = torch.chunk(input_, world_size, dim=dim) + global SEQ + if not SEQ and input_.shape[dim] % world_size != 0: + SEQ = [None] * world_size + for i in range(world_size): + SEQ[i] = tensor_list[i].shape[1] + output = tensor_list[rank].contiguous() + return output + +def gather_sequence(input_, dim=1): + input_ = input_.contiguous() + world_size = dist.get_world_size() + if world_size == 1: + return input_ + + global SEQ + if not SEQ: + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + b, s, d = input_.shape + tensor_list = [torch.empty([b, s_i, d], device=input_.device, dtype=input_.dtype) for s_i in SEQ] + dist.all_gather(tensor_list, input_) + + output = torch.cat(tensor_list, dim=dim) + SEQ = None + + return output \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/norm_layers.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/norm_layers.py new file mode 100644 index 0000000000..8fd19b42d4 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/norm_layers.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +import torch_npu + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0] + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/posemb_layers.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/posemb_layers.py new file mode 100644 index 0000000000..cd88548189 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/posemb_layers.py @@ -0,0 +1,310 @@ +import torch +from typing import Union, Tuple, List +from mindiesd import rotary_position_embedding + + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80 + + +def reshape_for_broadcast( + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + x: torch.Tensor, + head_first=False, +): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Notes: + When using FlashMHAModified, head_first should be False. + When using Attention, head_first should be True. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [ + d if i == ndim - 2 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + else: + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == ( + x.shape[-2], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [ + d if i == ndim - 2 or i == ndim - 1 else 1 + for i, d in enumerate(x.shape) + ] + else: + assert freqs_cis.shape == ( + x.shape[1], + x.shape[-1], + ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = ( + x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + ) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + # real * cos - imag * sin + # imag * cos + real * sin + xq_out = rotary_position_embedding(xq, cos, sin, rotated_mode="rotated_interleaved", head_first=head_first) + xk_out = rotary_position_embedding(xk, cos, sin, rotated_mode="rotated_interleaved", head_first=head_first) + else: + # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex) + xq_ = torch.view_as_complex( + xq.float().reshape(*xq.shape[:-1], -1, 2) + ) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to( + xq.device + ) # [S, D//2] --> [1, S, 1, D//2] + # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin) + # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + xk_ = torch.view_as_complex( + xk.float().reshape(*xk.shape[:-1], -1, 2) + ) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + theta_rescale_factor: Union[float, List[float]] = 1.0, + interpolation_factor: Union[float, List[float]] = 1.0, +): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd( + start, *args, dim=len(rope_dim_list) + ) # [3, W, H, D] / [2, W, H] + + if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + assert len(theta_rescale_factor) == len( + rope_dim_list + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + assert len(interpolation_factor) == len( + rope_dim_list + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # complex64 # [S, D/2] + return freqs_cis \ No newline at end of file diff --git a/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/token_refiner.py b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/token_refiner.py new file mode 100644 index 0000000000..72749316f8 --- /dev/null +++ b/MindIE/MultiModal/HunyuanVideo/hyvideo/modules/token_refiner.py @@ -0,0 +1,235 @@ +from typing import Optional + +from einops import rearrange +import torch +import torch.nn as nn + +from .activation_layers import get_activation_layer +from .attention import attention +from .norm_layers import get_norm_layer +from .embed_layers import TimestepEmbedder, TextProjection +from .mlp_layers import MLP +from .modulate_layers import modulate, apply_gate + + +class IndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + hidden_size, + heads_num, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + self.self_attn_qkv = nn.Linear( + hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs + ) + qk_norm_layer = get_norm_layer(qk_norm_type) + self.self_attn_q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + if qk_norm + else nn.Identity() + ) + self.self_attn_proj = nn.Linear( + hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs + ) + + self.norm2 = nn.LayerNorm( + hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs + ) + act_layer = get_activation_layer(act_type) + self.mlp = MLP( + in_channels=hidden_size, + hidden_channels=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop_rate, + **factory_kwargs, + ) + + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), + ) + # Zero-initialize the modulation + nn.init.zeros_(self.adaLN_modulation[1].weight) + nn.init.zeros_(self.adaLN_modulation[1].bias) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # timestep_aware_representations + context_aware_representations + attn_mask: torch.Tensor = None, + ): + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + # Apply QK-Norm if needed + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + + # Self-Attention + attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + + # FFN Layer + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + + return x + + +class IndividualTokenRefiner(nn.Module): + def __init__( + self, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) + for _ in range(depth) + ] + ) + + def forward( + self, + x: torch.Tensor, + c: torch.LongTensor, + mask: Optional[torch.Tensor] = None, + ): + self_attn_mask = None + if mask is not None: + batch_size = mask.shape[0] + seq_len = mask.shape[1] + mask = mask.to(x.device) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( + 1, 1, seq_len, 1 + ) + # batch_size x 1 x seq_len x seq_len + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + # avoids self-attention weight being NaN for padding tokens + self_attn_mask[:, :, :, 0] = True + + for block in self.blocks: + x = block(x, c, self_attn_mask) + return x + + +class SingleTokenRefiner(nn.Module): + """ + A single token refiner block for llm text embedding refine. + """ + def __init__( + self, + in_channels, + hidden_size, + heads_num, + depth, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + attn_mode: str = "torch", + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_mode = attn_mode + assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." + + self.input_embedder = nn.Linear( + in_channels, hidden_size, bias=True, **factory_kwargs + ) + + act_layer = get_activation_layer(act_type) + # Build timestep embedding layer + self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) + # Build context embedding layer + self.c_embedder = TextProjection( + in_channels, hidden_size, act_layer, **factory_kwargs + ) + + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + **factory_kwargs, + ) + + def forward( + self, + x: torch.Tensor, + t: torch.LongTensor, + mask: Optional[torch.LongTensor] = None, + ): + timestep_aware_representations = self.t_embedder(t) + + if mask is None: + context_aware_representations = x.mean(dim=1) + else: + mask_float = mask.float().unsqueeze(-1) # [b, s1, 1] + context_aware_representations = (x * mask_float).sum( + dim=1 + ) / mask_float.sum(dim=1) + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + + x = self.input_embedder(x) + + x = self.individual_token_refiner(x, c, mask) + + return x \ No newline at end of file -- Gitee