From aedd81b33241397abd1e6a9153512d17921df8e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=84=A2=E5=97=A3=E8=8A=BE?= Date: Fri, 13 Jun 2025 10:10:36 +0800 Subject: [PATCH 1/3] fix bug --- MindIE/MultiModal/StepVideo_TI2V/run_parallel.py | 14 +++++++++----- MindIE/MultiModal/StepVideo_TI2V/setup.py | 5 +++-- .../stepvideo/diffusion/video_pipeline.py | 5 ++--- .../StepVideo_TI2V/stepvideo/modules/model.py | 13 +++++-------- .../StepVideo_TI2V/stepvideo/modules/rope.py | 5 ++--- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/MindIE/MultiModal/StepVideo_TI2V/run_parallel.py b/MindIE/MultiModal/StepVideo_TI2V/run_parallel.py index cbdbad1493..879655293c 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/run_parallel.py +++ b/MindIE/MultiModal/StepVideo_TI2V/run_parallel.py @@ -1,3 +1,4 @@ +import os from stepvideo.diffusion.video_pipeline import StepVideoPipeline import torch.distributed as dist import torch @@ -30,14 +31,16 @@ if __name__ == "__main__": pipeline.transformer = pipeline.transformer.to(device) if args.use_dit_cache: - from mindiesd.layers.cache_mgr import CacheManager, DitCacheConfig - config = DitCacheConfig(step_start=6, step_interval=2, block_start=11, num_blocks=31) - cache = CacheManager(config) + from mindiesd import CacheAgent, CacheConfig + config = CacheConfig(method='dit_block_cache', steps_count=10, blocks_count=10, step_start=6, step_interval=2, block_start=7) + cache = CacheAgent(config) pipeline.transformer.cache = cache def patch_encode_prompt(): enable_llm_tensor_model_parallel() - caption_pipeline = CaptionPipeline(llm_dir="/home/data/stepvideo-ti2v/step_llm", clip_dir="/home/data/stepvideo-ti2v/hunyuan_clip", device='cpu') + llm_path = os.path.join(args.model_dir, "step_llm") + clip_path = os.path.join(args.model_dir, "hunyuan_clip") + caption_pipeline = CaptionPipeline(llm_dir=llm_path, clip_dir=clip_path, device='cpu') if args.tensor_parallel_degree > 1: llm_tp_applicator = TensorParallelApplicator(get_llm_tensor_model_parallel_world_size(), get_llm_tensor_model_parallel_rank(), tp_group=get_llm_tensor_model_parallel_group()) llm_tp_applicator.apply_to_llm_model(caption_pipeline.text_encoder) @@ -63,7 +66,8 @@ if __name__ == "__main__": pipeline.encode_prompt = encode_prompt def patch_vae(): - vae_pipeline = StepVaePipeline(vae_dir="/home/data/stepvideo-ti2v/vae", device='cpu') + vae_path = os.path.join(args.model_dir, "vae") + vae_pipeline = StepVaePipeline(vae_dir=vae_path, device='cpu') vae_pipeline.vae = vae_pipeline.vae.to(device) def encode_vae(img): diff --git a/MindIE/MultiModal/StepVideo_TI2V/setup.py b/MindIE/MultiModal/StepVideo_TI2V/setup.py index e66d19f6d1..f30434a11e 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/setup.py +++ b/MindIE/MultiModal/StepVideo_TI2V/setup.py @@ -19,16 +19,17 @@ if __name__ == "__main__": "diffusers>=0.31.0", "sentencepiece>=0.1.99", "imageio>=2.37.0", + "imageio-ffmpeg", "optimus==2.1", "numpy", + "ninja", "einops", "aiohttp", - "asyncio", "flask", "flask_restful", "ffmpeg-python", "requests", - "yunchang", + "yunchang==0.6.0", ], url="", description="A 30B DiT based text to video and image generation model", diff --git a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py index 2210507510..cb6583184f 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py +++ b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py @@ -340,8 +340,7 @@ class StepVideoPipeline(DiffusionPipeline): encoder_hidden_states_2=prompt_embeds_2, condition_hidden_states=condition_hidden_states, motion_score=motion_score, - return_dict=False, - step_id=i, + return_dict=False ) # perform guidance if do_classifier_free_guidance: @@ -363,7 +362,7 @@ class StepVideoPipeline(DiffusionPipeline): video = self.decode_vae(latents) torch.npu.synchronize() print(f"VAE time: {time.time() - start_time1}s") - if num_inference_steps > 1 and (not torch.distributed.is_initialized() or int(torch.distributed.get_rank()) == 0): + if num_inference_steps > 2 and (not torch.distributed.is_initialized() or int(torch.distributed.get_rank()) == 0): video = self.video_processor.postprocess_video(video, output_file_name=output_file_name, output_type=output_type) else: diff --git a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py index bfd2052c15..7f839dc37e 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py +++ b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py @@ -126,15 +126,14 @@ class StepVideoModel(ModelMixin, ConfigMixin): timestep=None, rope_positions=None, attn_mask=None, - parallel=True, - step_id=0 + parallel=True ): for i, block in enumerate(self.transformer_blocks): if self.cache: - hidden_states = self.cache(block, step_id, i, - hidden_states, + hidden_states = self.cache(block, encoder_hidden_states, + hidden_states=hidden_states, timestep=timestep, attn_mask=attn_mask, rope_positions=rope_positions @@ -163,8 +162,7 @@ class StepVideoModel(ModelMixin, ConfigMixin): fps: torch.Tensor = None, condition_hidden_states: torch.Tensor = None, motion_score: torch.Tensor = None, - return_dict: bool = True, - step_id: int = 0, + return_dict: bool = True ): assert hidden_states.ndim == 5; # "hidden_states's shape should be (bsz, f, ch, h ,w)" @@ -200,8 +198,7 @@ class StepVideoModel(ModelMixin, ConfigMixin): timestep=timestep, rope_positions=[frame, height, width], attn_mask=attn_mask, - parallel=self.parallel, - step_id=step_id + parallel=self.parallel ) hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame) diff --git a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py index 717f081aa2..81d1432ba5 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py +++ b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py @@ -1,7 +1,6 @@ import torch from stepvideo.parallel import get_sequence_parallel_world_size, get_sequence_parallel_rank -import mindiesd -import mindiesd.layers.embedding +from mindiesd import rotary_position_embedding class RoPE1D: @@ -25,7 +24,7 @@ class RoPE1D: def apply_rope1d(self, tokens, pos1d, cos, sin): cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] - return torch.ops.mindie.rope_mindie_sd(tokens, cos, sin, mode=0) + return rotary_position_embedding(tokens, cos, sin, rotated_mode="rotated_half") def __call__(self, tokens, positions): """ -- Gitee From 52314d99f64ed43d3cd960dc38d159e705c7d1f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=84=A2=E5=97=A3=E8=8A=BE?= Date: Fri, 13 Jun 2025 10:21:15 +0800 Subject: [PATCH 2/3] fix bug --- MindIE/MultiModal/StepVideo_TI2V/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MultiModal/StepVideo_TI2V/README.md b/MindIE/MultiModal/StepVideo_TI2V/README.md index b38db51bfc..a7ee906c30 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/README.md +++ b/MindIE/MultiModal/StepVideo_TI2V/README.md @@ -102,7 +102,7 @@ ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node 8 run_parall --ulysses_degree 1 \ --tensor_parallel_degree 8 \ --prompt="带翅膀的小老鼠先用爪子挠了挠脑袋,随后扑扇着翅膀飞了起来。" \ ---first_image_path './benchmark/Step-Video-TI2V-Eval/ti2v_eval_real/S/061.png' \ +--first_image_path './benchmark/Step-Video-TI2V-Eval/ti2v_eval_real/S/061.png' \ # 修改图片路径 --save_path './results' ``` @@ -143,7 +143,7 @@ ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node 8 run_parall --ulysses_degree 1 \ --tensor_parallel_degree 8 \ --prompt="带翅膀的小老鼠先用爪子挠了挠脑袋,随后扑扇着翅膀飞了起来。" \ ---first_image_path './benchmark/Step-Video-TI2V-Eval/ti2v_eval_real/S/061.png' \ +--first_image_path './benchmark/Step-Video-TI2V-Eval/ti2v_eval_real/S/061.png' \ # 修改图片路径 --save_path './results' \ --use_dit_cache ``` -- Gitee From 66332738d392412ea9d6cf5c0e6638bb0b78fae3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=84=A2=E5=97=A3=E8=8A=BE?= Date: Fri, 13 Jun 2025 11:03:31 +0800 Subject: [PATCH 3/3] fix cleancode --- .../StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py | 2 +- MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py index cb6583184f..d26f8d00ef 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py +++ b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/diffusion/video_pipeline.py @@ -326,7 +326,7 @@ class StepVideoPipeline(DiffusionPipeline): # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(self.scheduler.timesteps): + for _, t in enumerate(self.scheduler.timesteps): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = latent_model_input.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML diff --git a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py index 7f839dc37e..6929cc6c36 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py +++ b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/model.py @@ -129,7 +129,7 @@ class StepVideoModel(ModelMixin, ConfigMixin): parallel=True ): - for i, block in enumerate(self.transformer_blocks): + for _, block in enumerate(self.transformer_blocks): if self.cache: hidden_states = self.cache(block, encoder_hidden_states, -- Gitee