From 7cc7da3daf44fe99b559028b0d8b9207ef19ff92 Mon Sep 17 00:00:00 2001 From: mueuler Date: Thu, 5 Jun 2025 20:42:38 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E2=80=9C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E4=B8=8D=E5=90=8C=E6=A8=A1=E5=9E=8B=E5=90=8E=E7=AB=AF=EF=BC=8C?= =?UTF-8?q?=E9=80=82=E9=85=8D=E5=A4=9A=E6=A8=A1=E5=9E=8B=E7=BC=96=E6=8E=92?= =?UTF-8?q?=E7=89=B9=E6=80=A7=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../OpenSora-1.2/opensora/__init__.py | 7 +- .../opensora/pipeline/compile_pipe.py | 17 +- .../opensora/pipeline/open_sora_pipeline.py | 512 +++++++++++++++++- .../opensora/pipeline/pipeline_utils.py | 115 +++- .../OpenSora-1.2/opensora/stdit3/stdit3.py | 4 +- .../OpenSora-1.2/opensora/utils/__init__.py | 4 +- .../OpenSora-1.2/opensora/utils/utils.py | 443 ++++++++++++++- .../opensora/vae/VideoAutoencoder.py | 16 + 8 files changed, 1084 insertions(+), 34 deletions(-) diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/__init__.py b/MindIE/MultiModal/OpenSora-1.2/opensora/__init__.py index b0dcd39a69..16108376cd 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/__init__.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/__init__.py @@ -18,8 +18,11 @@ from .vae import (VideoAutoencoder, VideoAutoencoderConfig) from .pipeline import (OpenSoraPipeline12, compile_pipe) from .schedulers import RFlowScheduler from .stdit3 import (STDiT3Config, STDiT3) -from .utils import (set_random_seed, append_score_to_prompts, extract_prompts_loop, merge_prompt, prepare_multi_resolution_info, - split_prompt, is_npu_available, exists, default, Patchify, Depatchify) +from .utils import (set_random_seed, append_score_to_prompts, extract_prompts_loop, merge_prompt, prepare_multi_resolution_info, + split_prompt, is_npu_available, exists, default, append_generated, apply_mask_strategy, + video_encode, video_decode, is_main_process, is_distributed, refine_prompts_by_mindie, + create_shm, get_data_from_shm, get_shm_info, pad_and_flatten, + Patchify, Depatchify) from .layer import (approx_gelu, CaptionEmbedder, PatchEmbed3D, PositionEmbedding2D, SizeEmbedder, TimestepEmbedder, RotaryEmbedding, Mlp, AdaLayerNorm, PatchGroupNorm3d, GroupNorm3dAdapter, all_to_all_with_pad, get_spatial_pad, get_temporal_pad, set_spatial_pad, set_temporal_pad, split_sequence, gather_sequence, set_parallel_manager, get_sequence_parallel_group, diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/compile_pipe.py b/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/compile_pipe.py index 493af6ffd8..8a6dc28ab4 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/compile_pipe.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/compile_pipe.py @@ -17,17 +17,32 @@ import torch.nn as nn from ..utils import is_npu_available +CFG_MAX_STEP = 10000 -def compile_pipe(pipe, cfg=None): +def compile_pipe(pipe, cache_manager=None, + cfg_last_step: int = CFG_MAX_STEP): + if not isinstance(cfg_last_step, int): + raise TypeError(f"Expected int for cfg_last_step, but got {type(cfg_last_step).__name__}") + if is_npu_available(): device = 'npu' if hasattr(pipe, "text_encoder") and isinstance(pipe.text_encoder, nn.Module): pipe.text_encoder.to(device) + else: + raise TypeError("Please input valid pipeline") if hasattr(pipe, "transformer") and isinstance(pipe.transformer, nn.Module): pipe.transformer.to(device) if hasattr(pipe, "vae") and isinstance(pipe.vae, nn.Module): pipe.vae.to(device) + + if cache_manager is not None: + if not hasattr(cache_manager, "use_cache"): + raise TypeError("Please input valid cache_manager") + pipe.transformer.cache_manager = cache_manager + + if cfg_last_step != CFG_MAX_STEP: + pipe.cfg_last_step = cfg_last_step return pipe else: raise RuntimeError("NPU is not available.") \ No newline at end of file diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/open_sora_pipeline.py b/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/open_sora_pipeline.py index d9decdd169..f1a42ef52b 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/open_sora_pipeline.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/open_sora_pipeline.py @@ -14,8 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from multiprocessing import shared_memory from typing import Tuple, List - +import logging +from concurrent.futures.thread import ThreadPoolExecutor +import torch.distributed as dist +import numpy as np import torch import torch_npu from tqdm import tqdm @@ -23,21 +28,32 @@ from torch import Tensor from transformers import AutoTokenizer, T5EncoderModel -from .pipeline_utils import OpenSoraPipeline +from .pipeline_utils import OpenSoraPipeline, SampleParams from ..utils import ( set_random_seed, append_score_to_prompts, extract_prompts_loop, - merge_prompt, prepare_multi_resolution_info, split_prompt, is_npu_available) + merge_prompt, prepare_multi_resolution_info, split_prompt, is_npu_available, append_generated, apply_mask_strategy, + is_main_process, is_distributed, refine_prompts_by_mindie, video_decode, create_shm, get_data_from_shm, + get_shm_info, pad_and_flatten) from ..stdit3 import STDiT3 from ..vae import VideoAutoencoder from ..schedulers import RFlowScheduler +from ..utils.utils import AppendGeneratedParams + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) torch_npu.npu.config.allow_internal_format = False NUM_FRAMES = 'num_frames' +TEXT_ENCODER_Y_SIZE = [2, 1, 300, 4096] +TEXT_ENCODER_MASK_SIZE = [1, 300] +NULL_Y_SIZE = [1, 1, 300, 4096] +SEED = 42 -target_image_size = [(720, 1280), (512, 512)] +target_image_size = [(480, 640), (1080, 1920), (720, 1280), (512, 512)] target_num_frames = [32, 128] -target_fps = [8] +target_fps = [8, 24, 30, 60] +target_compressed_num_frame = [9, 38] target_output_type = ["latent", "thwc"] target_dtype = [torch.bfloat16, torch.float16] MAX_PROMPT_LENGTH = 1024 # the limits of open-sora1.2 @@ -46,13 +62,15 @@ MAX_PROMPT_LENGTH = 1024 # the limits of open-sora1.2 class OpenSoraPipeline12(OpenSoraPipeline): def __init__(self, text_encoder: T5EncoderModel, tokenizer: AutoTokenizer, transformer: STDiT3, - vae: VideoAutoencoder, scheduler: RFlowScheduler, + vae: VideoAutoencoder, scheduler: RFlowScheduler, model_type='SDModel', num_frames: int = 32, image_size: Tuple[int, int] = (720, 1280), fps: int = 8, + num_sampling_steps=30, refine_server_ip=None, refine_server_port=None, dtype: torch.dtype = torch.bfloat16): super().__init__() torch.set_grad_enabled(False) + self.model_type = model_type self.text_encoder = text_encoder self.tokenizer = tokenizer self.transformer = transformer @@ -61,22 +79,88 @@ class OpenSoraPipeline12(OpenSoraPipeline): self.num_frames = num_frames self.image_size = image_size self.fps = fps + self.num_sampling_steps = num_sampling_steps + self.refine_server_ip = refine_server_ip + self.refine_server_port = refine_server_port + self.num_timesteps = 1000 + + for idx, target_num_frame in enumerate(target_num_frames): + if self.num_frames == target_num_frame: + self.compressed_num_frame = target_compressed_num_frame[idx] if is_npu_available(): self.device = 'npu' else: self.device = 'cpu' + self.parallel_refine = os.environ.get("PARALLEL_REFINE", 'true').lower() == 'true' self.dtype = dtype self.text_encoder.to(self.dtype) + if self.text_encoder is not None: + self.text_encoder.to(self.dtype) + if self.model_type == 'SDModel': + self.check_init_input() + else: + self.check_single_model_init_input(self.model_type) @torch.no_grad() - def __call__(self, prompts: List[str], seed: int = 42, output_type: str = "latent"): + def __call__(self, prompts=None, seed=42, output_type="latent", input_video_path=None, prompt_refine=False): + batched_prompt_segment_list = [] + executor = ThreadPoolExecutor() + future = None + if prompt_refine: + if self.parallel_refine: + if is_main_process(): + params = (prompts, self.refine_server_ip, self.refine_server_port) + future = executor.submit(refine_prompts_by_mindie, *params) + else: + batched_prompt_segment_list = refine_prompts_by_mindie( + prompts, self.refine_server_ip, self.refine_server_port) + else: + batched_prompt_segment_list = prompts set_random_seed(seed=seed) + input_video = None + if input_video_path is not None: + if not os.path.exists(input_video_path): + raise FileNotFoundError(f"input_video_path {input_video_path} not found") + input_video = video_decode(video_paths=[input_video_path], return_vals=True) + input_video = input_video.permute(3, 0, 1, 2).unsqueeze(0).to(self.device).to(self.dtype) + + condition_frame_length = 5 + condition_frame_edit = 0.0 + refs = [[]] + ms = [""] * len(prompts) + if input_video is not None: + params = AppendGeneratedParams( + vae=self.vae, + generated_video=input_video, + refs_x=refs, + mask_strategy=ms, + loop_i=1, + condition_frame_length=condition_frame_length, + condition_frame_edit=condition_frame_edit + ) + refs, ms = append_generated(params) + + if prompt_refine and self.parallel_refine and is_main_process(): + batched_prompt_segment_list = future.result() + + #broadcast refine result + if prompt_refine and self.parallel_refine and is_distributed(): + # sycn the prompt if using seq parallel + torch.distributed.barrier() + + # create a list of size equal to world size + broadcast_obj_list = [batched_prompt_segment_list] * dist.get_world_size() + dist.broadcast_object_list(broadcast_obj_list, 0) + + # recover the prompt list + all_prompts = broadcast_obj_list[0] + batched_prompt_segment_list = all_prompts # 1.0 Encode input prompt - text_encoder_res_list = self._encode_prompt(prompts, self.text_encoder) + text_encoder_res_list = self._encode_prompt(batched_prompt_segment_list, self.text_encoder) torch.npu.empty_cache() input_size = (self.num_frames, *self.image_size) @@ -100,17 +184,19 @@ class OpenSoraPipeline12(OpenSoraPipeline): # == Iter over loop generation == z = torch.randn(len(batch_prompts), self.vae.out_channels, *latent_size, device=self.device, dtype=self.dtype) - + masks = apply_mask_strategy(z, refs, ms, 0, align=None) # 2.0 Prepare timesteps timesteps = self._retrieve_timesteps(z, additional_args=model_args, ) - samples = self._sample( - self.transformer, - text_encoder_res_list[i], + params = SampleParams( + model=self.transformer, + model_args=text_encoder_res_list[i], z=z, timesteps=timesteps, additional_args=model_args, + mask=masks ) + samples = self._sample(params) del z, timesteps, text_encoder_res_list @@ -126,6 +212,349 @@ class OpenSoraPipeline12(OpenSoraPipeline): else: return all_videos + @torch.no_grad() + def single_model_execute(self, prompts=None, seed=42, + output_type="latent", input_video_path=None, prompt_refine=False): + self.check_call_input(output_type, prompts, seed) + stream = torch_npu.npu.Stream(self.device) + + with torch_npu.npu.stream(stream): + if self.text_encoder is not None and self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device) + + batched_prompt_segment_list = [] + + if prompt_refine: + batched_prompt_segment_list = refine_prompts_by_mindie( + prompts, self.refine_server_ip, self.refine_server_port) + else: + batched_prompt_segment_list = prompts + + set_random_seed(seed=seed) + input_video = None + if input_video_path is not None: + if not os.path.exists(input_video_path): + raise FileNotFoundError(f"input_video_path {input_video_path} not found") + input_video = video_decode(video_paths=[input_video_path], return_vals=True) + input_video = input_video.permute(3, 0, 1, 2).unsqueeze(0).to(self.device).to(self.dtype) + + condition_frame_length = 5 + condition_frame_edit = 0.0 + refs = [[]] + ms = [""] * len(prompts) + if input_video is not None: + if self.vae is not None and self.model_type == "SDVisualEncoder": + params = AppendGeneratedParams( + vae=self.vae, + generated_video=input_video, + refs_x=refs, + mask_strategy=ms, + loop_i=1, + condition_frame_length=condition_frame_length, + condition_frame_edit=condition_frame_edit + ) + refs, ms = append_generated(params) + logger.info(f"[VAE Backend] done vae encode !") + + if self.text_encoder is not None: + # 1.0 Encode input prompt + text_encoder_res_list = self._encode_prompt(batched_prompt_segment_list, self.text_encoder) + logger.info(f"[TEXT ENCODER Backend] done text encode !") + else: + text_encoder_res_list = [{}] + text_encoder_res_list[0]['y'] = torch.randn(size=TEXT_ENCODER_Y_SIZE, dtype=torch.bfloat16).to(self.device) + text_encoder_res_list[0]['mask'] = \ + torch.randint(0, 1, size=TEXT_ENCODER_MASK_SIZE, dtype=torch.int64).to(self.device) + + with torch_npu.npu.stream(stream): + if self.text_encoder is not None and self.text_encoder.device != torch.device("cpu"): + self.text_encoder.to('cpu') + torch.npu.empty_cache() + + input_size = (self.num_frames, *self.image_size) + if self.vae is not None: + latent_size = self.vae.get_latent_size(input_size) + else: + latent_size = [self.compressed_num_frame, self.image_size[0] // 8, self.image_size[1] // 8] + + batch_size = 1 + num_sample = 1 + + # == Iter over all samples == + all_videos = [] + for i in range(0, len(prompts), batch_size): + # == prepare batch prompts == + batch_prompts = prompts[i: i + batch_size] + + # == multi-resolution info == + model_args = prepare_multi_resolution_info( + 'STDiT2', (len(batch_prompts), self.image_size, self.num_frames, self.fps), self.device, self.dtype) + + # == Iter over number of sampling for one prompt == + for _ in range(num_sample): + # == Iter over loop generation == + if self.vae is not None: + z = torch.randn(len(batch_prompts), self.vae.out_channels, + *latent_size, device=self.device, dtype=self.dtype) + else: + z = torch.randn(len(batch_prompts), 4, *latent_size, device=self.device, dtype=self.dtype) + masks = apply_mask_strategy(z, refs, ms, 0, align=None) + # 2.0 Prepare timesteps + timesteps = self._retrieve_timesteps(z, additional_args=model_args, ) + + if self.transformer is not None: + params = SampleParams( + model=self.transformer, + model_args=text_encoder_res_list[i], + z=z, + timesteps=timesteps, + additional_args=model_args, + mask=masks + ) + samples = self._sample(params) + logger.info(f"[DIT Backend] done transformer!") + else: + samples = torch.randn(size=[batch_size, 4, self.compressed_num_frame, self.image_size[0] // 8, + self.image_size[1] // 8], dtype=torch.bfloat16).to(self.device) + del z, timesteps, text_encoder_res_list + if self.vae is not None and self.model_type == "SDVisualDecoder": + samples = self.vae.decode(samples.to(self.dtype), num_frames=self.num_frames) + logger.info(f"[VAE Backend] done vae decode !") + else: + samples = torch.randn(size=[batch_size, 3, self.num_frames, *self.image_size], dtype=torch.bfloat16) + all_videos.append(samples) + + del samples + torch.npu.empty_cache() + + stream.synchronize() + + if not output_type == "latent": + videos = self._video_write(all_videos) + return videos + else: + return all_videos + + @torch.no_grad() + def t5_execute(self, prompts=None, prompt_refine=False): + set_random_seed(SEED) + stream = torch_npu.npu.Stream(self.device) + + with torch_npu.npu.stream(stream): + if self.text_encoder is not None and self.text_encoder.device == torch.device("cpu"): + self.text_encoder.to(self.device) + + batched_prompt_segment_list = [] + if prompt_refine: + batched_prompt_segment_list = refine_prompts_by_mindie( + prompts, self.refine_server_ip, self.refine_server_port) + else: + batched_prompt_segment_list = prompts + + text_encoder_res_list = self._encode_prompt(batched_prompt_segment_list, self.text_encoder) + y_nparray = text_encoder_res_list[0]['y'].cpu().half().numpy() + encode_y_shm_name, encode_y_shape, encode_y_dtype = create_shm(y_nparray) + + mask_nparray = text_encoder_res_list[0]['mask'].cpu().numpy() + encode_mask_shm_name, encode_mask_shape, encode_mask_dtype = create_shm(mask_nparray) + + t5_result, t5_result_length = pad_and_flatten(encode_y_shm_name, encode_y_shape, encode_y_dtype, + encode_mask_shm_name, encode_mask_shape, encode_mask_dtype) + return [t5_result, t5_result_length] + + + @torch.no_grad() + def vae_encode(self, input_video_path=None): + set_random_seed(SEED) + input_video = None + if input_video_path is not None: + if not os.path.exists(input_video_path): + raise FileNotFoundError(f"input_video_path {input_video_path} not found") + input_video = video_decode(video_paths=[input_video_path], return_vals=True) + input_video = input_video.permute(3, 0, 1, 2).unsqueeze(0).to(self.device).to(self.dtype) + condition_frame_length = 5 + condition_frame_edit = 0.0 + refs = [[]] + ms = [""] + if input_video is not None: + if self.vae is not None and self.model_type == "SDVisualEncoder": + params = AppendGeneratedParams( + vae=self.vae, + generated_video=input_video, + refs_x=refs, + mask_strategy=ms, + loop_i=1, + condition_frame_length=condition_frame_length, + condition_frame_edit=condition_frame_edit + ) + refs, ms = append_generated(params) + logger.info(f"[VAE Backend] done vae encode !") + refs_array = np.array(refs[0][0].cpu().float()) + ms_array = np.array(ms) + + encode_refs_shm_name, encode_refs_shape, encode_refs_dtype = create_shm(refs_array) + encode_ms_shm_name, encode_ms_shape, encode_ms_dtype = create_shm(ms_array) + + vae_encode_result, vae_encode_result_length = pad_and_flatten( + encode_refs_shm_name, encode_refs_shape, encode_refs_dtype, + encode_ms_shm_name, encode_ms_shape, encode_ms_dtype) + return [vae_encode_result, vae_encode_result_length] + + @torch.no_grad() + def dit_execute(self, text_encoder_result=None, text_encoder_result_length=None, + video_embedding_name=None, video_embedding=None, + video_embedding_length=None): + set_random_seed(SEED) + bs = 1 + text_encoder_res = get_shm_info(text_encoder_result, text_encoder_result_length) + + text_encoder_res_list = [{}] + y_shm_name = text_encoder_res[0] + y_shm_shape = eval(text_encoder_res[1]) + y_shm_dtype = np.dtype(text_encoder_res[2]) + text_encoder_res_list[0]['y'] = ( + get_data_from_shm(y_shm_name, y_shm_shape, y_shm_dtype).to(torch.bfloat16).to('npu')) + + mask_shm_name = text_encoder_res[3] + mask_shm_shape = eval(text_encoder_res[4]) + mask_shm_dtype = np.dtype(text_encoder_res[5]) + text_encoder_res_list[0]['mask'] = get_data_from_shm(mask_shm_name, mask_shm_shape, mask_shm_dtype).to('npu') + + y_null = self._null(bs) + text_encoder_res_list[0]["y"][1] = y_null + + refs = [[]] + ms = [""] + if video_embedding is not None and video_embedding_name == "SD_VISUAL_ENCODER_RESULT": + vae_encode_res = get_shm_info(video_embedding, video_embedding_length) + refs_shm_name = vae_encode_res[0] + refs_shm_shape = eval(vae_encode_res[1]) + refs_shm_dtype = np.dtype(vae_encode_res[2]) + refs_content = get_data_from_shm(refs_shm_name, refs_shm_shape, refs_shm_dtype).to(torch.bfloat16).to('npu') + refs = [[refs_content]] + + ms_shm_name = vae_encode_res[3] + ms_shm_shape = eval(vae_encode_res[4]) + ms_shm_dtype = np.dtype(vae_encode_res[5]) + shm = shared_memory.SharedMemory(name=ms_shm_name) + ms = np.ndarray(ms_shm_shape, dtype=ms_shm_dtype, buffer=shm.buf).copy() + + latent_size = [self.compressed_num_frame, self.image_size[0] // 8, self.image_size[1] // 8] + + # == Iter over all samples == + all_videos = [] + + # == multi-resolution info == + model_args = prepare_multi_resolution_info( + 'STDiT2', (bs, self.image_size, self.num_frames, self.fps), self.device, self.dtype) + + if video_embedding_name == "SD_DIT_RESULT": + dit_res_shm_info = get_shm_info(video_embedding, video_embedding_length) + dit_shm_name = dit_res_shm_info[0] + dit_shm_shape = eval(dit_res_shm_info[1]) + dit_shm_dtype = np.dtype(dit_res_shm_info[2]) + z = get_data_from_shm(dit_shm_name, dit_shm_shape, dit_shm_dtype).to(torch.bfloat16).to('npu') + else: + z = torch.randn(bs, 4, *latent_size, device=self.device, dtype=self.dtype) + + masks = apply_mask_strategy(z, refs, ms, 0, align=None) + # 2.0 Prepare timesteps + timesteps = self._retrieve_timesteps(z, additional_args=model_args, ) + params = SampleParams( + model=self.transformer, + model_args=text_encoder_res_list[0], + z=z, + timesteps=timesteps, + additional_args=model_args, + mask=masks + ) + samples = self._sample(params) + logger.info(f"[DIT Backend] done transformer!") + samples_array = np.array(samples.cpu()) + samples_shm_name, samples_shape, samples_dtype = create_shm(samples_array) + + dit_result, dit_result_length = pad_and_flatten(samples_shm_name, samples_shape, samples_dtype) + + # 存储text encode内容 + y_nparray = text_encoder_res_list[0]['y'].cpu().half().numpy() + encode_y_shm_name, encode_y_shape, encode_y_dtype = create_shm(y_nparray) + + mask_nparray = text_encoder_res_list[0]['mask'].cpu().numpy() + encode_mask_shm_name, encode_mask_shape, encode_mask_dtype = create_shm(mask_nparray) + + t5_result, t5_result_length = pad_and_flatten(encode_y_shm_name, encode_y_shape, encode_y_dtype, + encode_mask_shm_name, encode_mask_shape, encode_mask_dtype) + torch.npu.empty_cache() + return [dit_result, dit_result_length, t5_result, t5_result_length] + + + @torch.no_grad() + def vae_decode(self, dit_result=None, dit_result_length=None, output_type="latent"): + if output_type not in target_output_type: + logger.error("output_type not in target_output_type") + raise ValueError("output_type not in target_output_type") + set_random_seed(SEED) + dit_res_shm_info = get_shm_info(dit_result, dit_result_length) + samples_shm_name = dit_res_shm_info[0] + samples_shm_shape = eval(dit_res_shm_info[1]) + samples_shm_dtype = np.dtype(dit_res_shm_info[2]) + samples = get_data_from_shm(samples_shm_name, samples_shm_shape, samples_shm_dtype).to(torch.bfloat16).to('npu') + + all_videos = [] + samples = self.vae.decode(samples, num_frames=self.num_frames) + + all_videos.append(samples) + + del samples + torch.npu.empty_cache() + + if not output_type == "latent": + videos = self._video_write(all_videos) + return videos + else: + return all_videos + def check_init_input(self): + type_checks = { + "text_encoder": T5EncoderModel, + "transformer": STDiT3, + "vae": VideoAutoencoder, + "scheduler": RFlowScheduler + } + + for attr, expected_type in type_checks.items(): + if hasattr(self, attr) and not isinstance(getattr(self, attr), expected_type): + raise TypeError(f"{attr} must be an instance of {expected_type.name}") + value_checks = { + "num_frames": (self.num_frames, target_num_frames), + "image_size": (self.image_size, target_image_size), + "fps": (self.fps, target_fps), + "dtype": (self.dtype, target_dtype) + } + for attr, (val, target_vals) in value_checks.items(): + if hasattr(self, attr) and val not in target_vals: + raise TypeError(f"{attr} must be expected target of {target_vals}") + + def check_single_model_init_input(self, model_type): + member_name, model_cls = None, None + if model_type == "SDDiT": + member_name, model_cls = "transformer", STDiT3 + elif model_type == "SDTextEncoder": + member_name, model_cls = "text_encoder", T5EncoderModel + elif model_type in ['SDVisualEncoder', 'SDVisualDecoder']: + member_name, model_cls = "vae", VideoAutoencoder + + if hasattr(self, member_name) and not isinstance(getattr(self, member_name), model_cls): + raise TypeError(f"{model_type} Backend, {member_name} must be an instance of {model_cls.name}") + value_checks = { + "num_frames": (self.num_frames, target_num_frames), + "image_size": (self.image_size, target_image_size), + "fps": (self.fps, target_fps), + "dtype": (self.dtype, target_dtype) + } + for attr, (val, target_vals) in value_checks.items(): + if hasattr(self, attr) and val not in target_vals: + raise TypeError(f"{model_type} Backend, {attr} must be expected target of {target_vals}") + def _video_write(self, x): x = [x[0][0]] x = torch.cat(x, dim=1) @@ -138,8 +567,8 @@ class OpenSoraPipeline12(OpenSoraPipeline): def _retrieve_timesteps(self, z, additional_args=None, ): # prepare timesteps - timesteps = [(1.0 - i / self.scheduler.num_sampling_steps) * self.scheduler.num_timesteps - for i in range(self.scheduler.num_sampling_steps)] + timesteps = [(1.0 - i / self.num_sampling_steps) * self.scheduler.num_timesteps + for i in range(self.num_sampling_steps)] if self.scheduler.use_discrete_timesteps: timesteps = [int(round(t)) for t in timesteps] timesteps = [torch.tensor([t] * z.shape[0], device=self.device) for t in timesteps] @@ -149,20 +578,44 @@ class OpenSoraPipeline12(OpenSoraPipeline): num_timesteps=self.scheduler.num_timesteps) for t in timesteps] return timesteps - def _sample(self, model, model_args, z, timesteps, additional_args=None, ): + def _sample(self, params: SampleParams): + model = params.model + model_args = params.model_args + z = params.z + timesteps = params.timesteps + additional_args = params.additional_args + mask = params.mask if additional_args is not None: model_args.update(additional_args) + if mask is not None: + noise_added = torch.zeros_like(mask, dtype=torch.bool) + noise_added = noise_added | (mask == 1) for i in tqdm(range(0, len(timesteps), 1)): t = timesteps[i] model_args['t_idx'] = i + # mask for adding noise + if mask is not None: + mask_t = mask * self.num_timesteps + x0 = z.clone() + x_noise = self.scheduler.add_noise(x0, torch.randn_like(x0), t) + + mask_t_upper = mask_t >= t.unsqueeze(1) + model_args["x_mask"] = mask_t_upper.repeat(2, 1) + mask_add_noise = mask_t_upper & ~noise_added + + z = torch.where(mask_add_noise[:, None, :, None, None], x_noise, x0) + noise_added = mask_t_upper # classifier-free guidance z_in = torch.cat([z, z], 0) t = torch.cat([t, t], 0) pred = model(z_in, t, **model_args).chunk(2, dim=1)[0] z = self.scheduler.step(pred, timesteps, i, z) + + if mask is not None: + z = torch.where(mask_t_upper[:, None, :, None, None], z, x0) return z def _timestep_transform(self, t, model_kwargs, num_timesteps): @@ -192,19 +645,28 @@ class OpenSoraPipeline12(OpenSoraPipeline): return dict(y=caption_embs, mask=emb_masks) def _null(self, n): - null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + if self.transformer is not None: + null_y = self.transformer.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + logger.info(f"[DIT Backend] Create null_y tensor.") + else: + null_y = torch.randn(size=NULL_Y_SIZE, dtype=torch.bfloat16).to(self.device) return null_y def _get_text_embeddings(self, texts): - text_tokens_and_mask = self.tokenizer( - texts, - max_length=300, - padding="max_length", - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) + text_tokens_and_mask = {} + if self.tokenizer is not None: + text_tokens_and_mask = self.tokenizer( + texts, + max_length=300, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + else: + text_tokens_and_mask["input_ids"] = torch.randint(1, 10, size=TEXT_ENCODER_MASK_SIZE, dtype=torch.int64) + text_tokens_and_mask["attention_mask"] = torch.randint(0, 1, size=TEXT_ENCODER_MASK_SIZE, dtype=torch.int64) input_ids = text_tokens_and_mask["input_ids"].to(self.device) attention_mask = text_tokens_and_mask["attention_mask"].to(self.device) diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/pipeline_utils.py b/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/pipeline_utils.py index 0f0c76bc77..a881e7230b 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/pipeline_utils.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/pipeline/pipeline_utils.py @@ -19,9 +19,11 @@ import os import inspect import logging import importlib +from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union import torch +from torch import Tensor from tqdm import tqdm from mindiesd import ConfigMixin @@ -45,6 +47,10 @@ FPS = 'fps' DTYPE = 'dtype' SET_PATCH_PARALLEL = 'set_patch_parallel' FROM_PRETRAINED = 'from_pretrained' +NUM_SAMPLING_STEPS = 'num_sampling_steps' +REFINE_SERVER_IP = 'refine_server_ip' +REFINE_SERVER_PORT = 'refine_server_port' +MODEL_TYPE = 'model_type' logger = logging.getLogger(__name__) # init python log @@ -80,6 +86,11 @@ class OpenSoraPipeline(ConfigMixin): enable_sequence_parallelism = kwargs.pop(ENABLE_SEQUENCE_PARALLELISM, ENABLE_SEQUENCE_PARALLELISM_DEFAULT_VALUE) set_patch_parallel = kwargs.pop(SET_PATCH_PARALLEL, False) + + num_sampling_steps = kwargs.pop(NUM_SAMPLING_STEPS, 30) + refine_server_ip = kwargs.pop(REFINE_SERVER_IP, None) + refine_server_port = kwargs.pop(REFINE_SERVER_PORT, None) + dtype = kwargs.pop("dtype", torch.bfloat16) init_dict, _ = cls.load_config(model_path, **kwargs) @@ -126,10 +137,95 @@ class OpenSoraPipeline(ConfigMixin): pipe_init_dict[NUM_FRAMES] = num_frames pipe_init_dict[IMAGE_SIZE] = image_size pipe_init_dict[FPS] = fps + pipe_init_dict[NUM_SAMPLING_STEPS] = num_sampling_steps + pipe_init_dict[REFINE_SERVER_IP] = refine_server_ip + pipe_init_dict[REFINE_SERVER_PORT] = refine_server_port pipe_init_dict[DTYPE] = dtype return cls(**pipe_init_dict) + @classmethod + def from_pretrained_single_model(cls, model_path, **kwargs): + real_path = _check_path(model_path) + num_frames = kwargs.pop(NUM_FRAMES, OPEN_SORA_DEFAULT_VIDEO_FRAME) + image_size = kwargs.pop(IMAGE_SIZE, OPEN_SORA_DEFAULT_IMAGE_SIZE) + fps = kwargs.pop(FPS, OPEN_SORA_DEFAULT_FPS) + enable_sequence_parallelism = kwargs.pop(ENABLE_SEQUENCE_PARALLELISM, + ENABLE_SEQUENCE_PARALLELISM_DEFAULT_VALUE) + set_patch_parallel = kwargs.pop(SET_PATCH_PARALLEL, False) + num_sampling_steps = kwargs.pop(NUM_SAMPLING_STEPS, 5) + refine_server_ip = kwargs.pop(REFINE_SERVER_IP, None) + refine_server_port = kwargs.pop(REFINE_SERVER_PORT, None) + model_type = kwargs.pop(MODEL_TYPE, None) + + init_list = [VAE, TEXT_ENCODER, TOKENIZER, TRANSFORMER, SCHEDULER] + pipe_init_dict = {} + + all_parameters = inspect.signature(cls.__init__).parameters + + required_param = {k: v for k, v in all_parameters.items() if v.default == inspect.Parameter.empty} + expected_modules = set(required_param.keys()) - {"self"} + # init the module from kwargs + passed_module = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + from_pretrained = kwargs.pop(FROM_PRETRAINED, real_path) + mindiesd_module = 'mindiesd' + if model_type == 'SDTextEncoder': + module = 'transformers' + for key in tqdm(init_list, desc="Loading T5EncoderModel"): + if key == TEXT_ENCODER: + cls_name = 'T5EncoderModel' + library = importlib.import_module(module) + class_obj = getattr(library, cls_name) + initializer = _get_initializers().get(key, init_default) + pipe_init_dict[key] = initializer(class_obj, real_path, from_pretrained, + set_patch_parallel, kwargs) + elif key == TOKENIZER: + cls_name = 'AutoTokenizer' + library = importlib.import_module(module) + class_obj = getattr(library, cls_name) + initializer = _get_initializers().get(key, init_default) + pipe_init_dict[key] = initializer(class_obj, real_path, from_pretrained, + set_patch_parallel, kwargs) + else: + pipe_init_dict[key] = None + elif model_type == 'SDDiT': + cls_name = 'STDiT3' + for key in tqdm(init_list, desc="Loading STDiT3"): + if key == TRANSFORMER: + library = importlib.import_module(mindiesd_module) + class_obj = getattr(library, cls_name) + pipe_init_dict[key] = class_obj.from_pretrained( + real_path, enable_sequence_parallelism=enable_sequence_parallelism, **kwargs) + else: + pipe_init_dict[key] = None + elif model_type in ['SDVisualEncoder', 'SDVisualDecoder']: + cls_name = 'VideoAutoencoder' + for key in tqdm(init_list, desc="Loading VideoAutoencoder"): + if key == VAE: + library = importlib.import_module(mindiesd_module) + class_obj = getattr(library, cls_name) + initializer = _get_initializers().get(key, init_default) + pipe_init_dict[key] = initializer(class_obj, real_path, from_pretrained, + set_patch_parallel, kwargs) + else: + pipe_init_dict[key] = None + + cls_name = 'RFlowScheduler' + library = importlib.import_module(mindiesd_module) + class_obj = getattr(library, cls_name) + pipe_init_dict[SCHEDULER] = class_obj() + + pipe_init_dict[NUM_FRAMES] = num_frames + pipe_init_dict[IMAGE_SIZE] = image_size + pipe_init_dict[FPS] = fps + pipe_init_dict[NUM_SAMPLING_STEPS] = num_sampling_steps + pipe_init_dict[REFINE_SERVER_IP] = refine_server_ip + pipe_init_dict[REFINE_SERVER_PORT] = refine_server_port + pipe_init_dict[DTYPE] = kwargs.get(DTYPE, torch.bfloat16) + pipe_init_dict[MODEL_TYPE] = model_type + + return cls(**pipe_init_dict) + def _get_initializers(): initializers = { @@ -139,6 +235,13 @@ def _get_initializers(): return initializers +def _check_path(model_path): + real_path = os.path.abspath(model_path) + if not (os.path.exists(real_path) and os.path.isdir(real_path)): + raise ValueError("model path is invalid!") + return real_path + + def _check(pipe_init_dict): if pipe_init_dict.get(VAE) is None: raise ValueError("Cannot get module 'vae' in init list!") @@ -166,4 +269,14 @@ def init_vae(class_obj, sub_folder, from_pretrained, set_patch_parallel, **kwarg def init_default(class_obj, sub_folder, from_pretrained, set_patch_parallel, **kwargs): - return class_obj.from_pretrained(sub_folder, local_files_only=True, **kwargs) \ No newline at end of file + return class_obj.from_pretrained(sub_folder, local_files_only=True, **kwargs) + + +@dataclass +class SampleParams: + model: any + model_args: dict + z: Tensor + timesteps: list + additional_args: dict = None + mask: Tensor = None \ No newline at end of file diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/stdit3/stdit3.py b/MindIE/MultiModal/OpenSora-1.2/opensora/stdit3/stdit3.py index 348a21df5c..eb0c11821f 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/stdit3/stdit3.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/stdit3/stdit3.py @@ -94,7 +94,7 @@ class T2IFinalLayer(nn.Module): x = rearrange(x, "b (t s) c -> b t s c", t=t, s=s) masked_x = rearrange(masked_x, "b (t s) c -> b t s c", t=t, s=s) x = torch.where(x_mask[:, :, None, None], x, masked_x) - x = rearrange(x, "b t s c -> b (t s) c") + x = rearrange(x, "b t s c -> b (t s) c", t=t, s=s) return x def forward(self, x, t, x_mask=None, t0=None, t_s=(None, None)): @@ -158,7 +158,7 @@ class STDiT3Block(nn.Module): x = rearrange(x, "b (t s) c -> b t s c", t=t, s=s) masked_x = rearrange(masked_x, "b (t s) c -> b t s c", t=t, s=s) x = torch.where(x_mask[:, :, None, None], x, masked_x) - x = rearrange(x, "b t s c -> b (t s) c") + x = rearrange(x, "b t s c -> b (t s) c", t=t, s=s) return x def forward( diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/utils/__init__.py b/MindIE/MultiModal/OpenSora-1.2/opensora/utils/__init__.py index ffb91659f0..10fd6be171 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/utils/__init__.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/utils/__init__.py @@ -17,7 +17,9 @@ from .utils import ( set_random_seed, append_score_to_prompts, extract_prompts_loop, - merge_prompt, prepare_multi_resolution_info, split_prompt, is_npu_available, exists, default + merge_prompt, prepare_multi_resolution_info, split_prompt, is_npu_available, exists, default, append_generated, apply_mask_strategy, + video_encode, video_decode, is_main_process, is_distributed, refine_prompts_by_mindie, + create_shm, get_data_from_shm, get_shm_info, pad_and_flatten ) from .patch_utils import ( Patchify, Depatchify diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/utils/utils.py b/MindIE/MultiModal/OpenSora-1.2/opensora/utils/utils.py index c324585f37..79c2a6f16d 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/utils/utils.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/utils/utils.py @@ -15,10 +15,81 @@ # limitations under the License. -import importlib +import json +import logging +import multiprocessing import random -import torch +import os +import importlib +import time +from dataclasses import dataclass +from multiprocessing import Manager, shared_memory +from threading import Timer + import numpy as np +import torch +import torchvision.io as io +import torch.distributed as dist +import requests + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"] +MAX_SHM_SIZE = 10**9 +OPENAI_CLIENT = None +REFINE_PROMPTS = None +REFINE_PROMPTS_TEMPLATE = """ +You need to refine user's input prompt. The user's input prompt is used for video generation task. You need to refine the user's prompt to make it more suitable for the task. Here are some examples of refined prompts: +{} + +The refined prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The refined prompt should be in English. +""" +RANDOM_PROMPTS = None +RANDOM_PROMPTS_TEMPLATE = """ +You need to generate one input prompt for video generation task. The prompt should be suitable for the task. Here are some examples of refined prompts: +{} + +The prompt should pay attention to all objects in the video. The description should be useful for AI to re-generate the video. The description should be no more than six sentences. The prompt should be in English. +""" +REFINE_EXAMPLE = [ + "a close - up shot of a woman standing in a room with a white wall and a plant on the left side." + "the woman has curly hair and is wearing a green tank top." + "she is looking to the side with a neutral expression on her face." + "the lighting in the room is soft and appears to be natural, coming from the left side of the frame." + "the focus is on the woman, with the background being out of focus." + "there are no texts or other objects in the video.the style of the video is a simple," + " candid portrait with a shallow depth of field.", + "a serene scene of a pond filled with water lilies.the water is a deep blue, " + "providing a striking contrast to the pink and white flowers that float on its surface." + "the flowers, in full bloom, are the main focus of the video." + "they are scattered across the pond, with some closer to the camera and others further away, " + "creating a sense of depth.the pond is surrounded by lush greenery, adding a touch of nature to the scene." + "the video is taken from a low angle, looking up at the flowers, " + "which gives a unique perspective and emphasizes their beauty." + "the overall composition of the video suggests a peaceful and tranquil setting, likely a garden or a park.", + "a professional setting where a woman is presenting a slide from a presentation." + "she is standing in front of a projector screen, which displays a bar chart." + "the chart is colorful, with bars of different heights, indicating some sort of data comparison." + "the woman is holding a pointer, which she uses to highlight specific parts of the chart." + "she is dressed in a white blouse and black pants, and her hair is styled in a bun." + "the room has a modern design, with a sleek black floor and a white ceiling." + "the lighting is bright, illuminating the woman and the projector screen." + "the focus of the image is on the woman and the projector screen, with the background being out of focus." + "there are no texts visible in the image." + "the relative positions of the objects suggest that the woman is the main subject of the image, " + "and the projector screen is the object of her attention." + "the image does not provide any information about the content of the presentation or the context of the meeting." +] +MAX_NEW_TOKENS = 512 +TEMPERATURE = 1.1 +TOP_P = 0.95 +TOP_K = 100 +SEED = 10 +REPETITION_PENALTY = 1.03 + +TIMEOUT_T = 600 IMG_FPS = 8 @@ -138,3 +209,371 @@ def append_score_to_prompts(prompts, aes=None, flow=None, camera_motion=None): new_prompts.append(new_prompt) return new_prompts + +def is_distributed(): + return os.environ.get("WORLD_SIZE", None) is not None + + +def is_main_process(): + return not is_distributed() or dist.get_rank() == 0 + + +@dataclass +class AppendGeneratedParams: + vae: any + generated_video: list + refs_x: list + mask_strategy: list[str] + loop_i: int + condition_frame_length: int + condition_frame_edit: float + + +def append_generated(params: AppendGeneratedParams): + vae = params.vae + generated_video = params.generated_video + refs_x = params.refs_x + mask_strategy = params.mask_strategy + loop_i = params.loop_i + condition_frame_length = params.condition_frame_length + condition_frame_edit = params.condition_frame_edit + + ref_x = vae.encode(generated_video) + for j, refs in enumerate(refs_x): + if refs is None: + refs_x[j] = [ref_x[j]] + else: + refs.append(ref_x[j]) + if mask_strategy[j] is None or mask_strategy[j] == "": + mask_strategy[j] = "" + else: + mask_strategy[j] += ";" + mask_strategy[ + j + ] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}" + return refs_x, mask_strategy + + +def parse_mask_strategy(mask_strategy): + mask_batch = [] + if mask_strategy == "" or mask_strategy is None: + return mask_batch + + mask_strategy = mask_strategy.split(";") + for mask in mask_strategy: + mask_group = mask.split(",") + num_group = len(mask_group) + if not (1 <= num_group <= 6): + raise ValueError(f"Invalid mask strategy: {mask}") + mask_group.extend(MASK_DEFAULT[num_group:]) + for i in range(5): + mask_group[i] = int(mask_group[i]) + mask_group[5] = float(mask_group[5]) + mask_batch.append(mask_group) + return mask_batch + + +def create_shm(target_array): + size = target_array.nbytes + if not isinstance(size, int) or size <= 0: + raise ValueError("Size must be a positive integer.") + if size > MAX_SHM_SIZE: + raise ValueError(f"Size exceeds the maximum allowed limit of {MAX_SHM_SIZE} bytes.") + + try: + shm = shared_memory.SharedMemory(create=True, size=size) + + except Exception as e: + raise RuntimeError(f"Failed to create shared memory: {e}") from e + + target_shared_array = np.ndarray(target_array.shape, dtype=target_array.dtype, buffer=shm.buf) + target_shared_array[:] = target_array + + encode_target_shm_name = str2int(shm.name) + encode_target_shape = str2int(str(target_shared_array.shape)) + encode_target_dtype = str2int(str(target_shared_array.dtype)) + + return encode_target_shm_name, encode_target_shape, encode_target_dtype + + +def str2int(s): + int_list = [ord(char) for char in s] + return int_list + + +def pad_and_flatten(*lists): + lengths = [len(lst) for lst in lists] + + flattened_list = [ + item + for sublist in lists + for item in sublist + ] + return flattened_list, lengths + + +def get_shm_info(info_data, info_length): + shm_info_result = [] + start_index = 0 + + for text_enco_res_length in info_length: + end_index = start_index + text_enco_res_length + shm_info_result.append(info_data[start_index:end_index]) + start_index = end_index + + string_list = [''.join(sublist) for sublist in shm_info_result] + return string_list + + +def get_data_from_shm(shm_name, shm_shape, shm_dtype, device='cpu'): + shm_values = None + shm = None + try: + shm = shared_memory.SharedMemory(name=shm_name) + shm_values = np.ndarray(shm_shape, dtype=shm_dtype, buffer=shm.buf).copy() + shm_values = torch.from_numpy(shm_values).to(device) + except Exception as e: + raise ValueError(f"Failed to get data from shared memory: {e}") from e + finally: + if shm is not None: + shm.close() + shm.unlink() + return shm_values + + +def find_nearest_point(value, point, max_value): + if point == 0: + raise ValueError("point cannot be zero") + t = value // point + if value % point > point / 2 and t < max_value // point - 1: + t += 1 + return t * point + + +def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None): + masks = [] + no_mask = True + for i, mask_strategy in enumerate(mask_strategys): + no_mask = False + mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device) + mask_strategy = parse_mask_strategy(mask_strategy) + for mst in mask_strategy: + loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst + if loop_id != loop_i: + continue + ref = refs_x[i][m_id] + + if m_ref_start < 0: + m_ref_start = ref.shape[1] + m_ref_start + if m_target_start < 0: + m_target_start = z.shape[2] + m_target_start + if align is not None: + m_ref_start = find_nearest_point(m_ref_start, align, ref.shape[1]) + m_target_start = find_nearest_point(m_target_start, align, z.shape[2]) + m_length = min(m_length, z.shape[2] - m_target_start, ref.shape[1] - m_ref_start) + z[i, :, m_target_start : m_target_start + m_length] = ref[:, m_ref_start : m_ref_start + m_length] + mask[m_target_start : m_target_start + m_length] = edit_ratio + masks.append(mask) + if no_mask: + return None + masks = torch.stack(masks) + return masks + + +def refine_prompt_by_mindie(prompt, refine_server_ip, refine_server_port): + global REFINE_PROMPTS + if REFINE_PROMPTS is None: + examples = load_prompts(REFINE_EXAMPLE) + REFINE_PROMPTS = REFINE_PROMPTS_TEMPLATE.format("\n".join(examples)) + + response = get_mindie_response(REFINE_PROMPTS, prompt, refine_server_ip, refine_server_port) + return response + + +def get_mindie_response(sys_prompt, usr_prompt, refine_server_ip, refine_server_port): + url = f'http://{refine_server_ip}:{refine_server_port}/infer' + _prompt = f"{sys_prompt} and here is the user's input prompt that needs to be refined:{usr_prompt}" + #定义请求的数据 + payload = { + "inputs": _prompt, + "stream": False, + "parameters":{ + "max_new_tokens": MAX_NEW_TOKENS, + "temperature": TEMPERATURE, + "top_p": TOP_P, + "top_k": TOP_K, + "do_sample": True, + "seed": SEED, + "repetition_penalty": REPETITION_PENALTY, + "typical_p": None, + "watermark": True + } + } + json_payload = json.dumps(payload) + + headers = { + "Accept": "application/json", + "Content-type": "application/json" + } + + response = requests.post(url, data=json_payload, headers=headers) + + if response.status_code == 200: + try: + response_data = response.json() + response_data = usr_prompt + response_data["generated_text"] + return response_data + except json.JSONDecodeError: + logger.error("[Error] Failed to parse response from Mindie") + return None + else: + logger.error(f"[Error] Failed to get response from Mindie, status code: {response.status_code}") + return None + + +def load_prompts(refine_example, start_idx=None, end_idx=None): + prompts = [line.strip() for line in refine_example] + prompts = prompts[start_idx:end_idx] + return prompts + + +def get_random_prompt_by_mindie(refine_server_ip, refine_server_port): + global RANDOM_PROMPTS + if RANDOM_PROMPTS is None: + examples = load_prompts(REFINE_EXAMPLE) + RANDOM_PROMPTS = RANDOM_PROMPTS_TEMPLATE.format("\n".join(examples)) + + response = get_mindie_response(RANDOM_PROMPTS, "Generate one example.", refine_server_ip, refine_server_port) + return response + + +def refine_prompts_by_mindie(prompts, refine_server_ip, refine_server_port): + new_prompts = [] + for prompt in prompts: + try: + if prompt.strip() == "": + new_prompt = get_random_prompt_by_mindie(refine_server_ip, refine_server_port) + else: + new_prompt = refine_prompt_by_mindie(prompt, refine_server_ip, refine_server_port) + new_prompts.append(new_prompt) + except Exception as e: + logger.warning(f"Failed to refine prompt: {prompt} due to {e}") + new_prompts.append(prompt) + return new_prompts + + +def decode_single_video(video_path, results, idx, return_vals): + try: + video_frames = io.read_video(video_path, pts_unit='sec') + if return_vals: + results[idx] = video_frames[0].numpy() + return video_frames + except Exception as e: + logger.error(f"Error decoding video {video_path}: {e}") + return None + + +def encode_single_video(video_path, video_tensor, fps, video_codec): + try: + io.write_video(video_path, video_tensor, fps=fps, video_codec=video_codec) + return True + except Exception as e: + logger.error(f"Error encoding video {video_path}: {e}") + return False + + +def timeout(process_list): + # 终止所有进程 + for p in process_list: + p.terminate() + logger.error(f"Time out, finish all sub process!") + + +def video_decode(video_paths, return_vals=False): + if not isinstance(video_paths, list): + logger.error(video_paths, list) + return None + if len(video_paths) > 128 or len(video_paths) == 0: + logger.error("the length of video_paths should in [1, 128]") + return None + for video_path in video_paths: + if not isinstance(video_path, str) or not video_path.endswith('.mp4'): + logger.error("video_path invalid") + return None + if not os.path.exists(video_path): + logger.error("video_path does not exist") + return None + + decode_val = [] + with Manager() as manager: + results = manager.list([None] * len(video_paths)) + process_list = [] + for i, video in enumerate(video_paths): + p = multiprocessing.Process(target=decode_single_video, args=(video, results, i, return_vals)) + p.start() + process_list.append(p) + + # 设置定时器 + timer = Timer(TIMEOUT_T, timeout, args=(process_list,)) + timer.start() + + for p in process_list: + p.join() + + timer.cancel() + decode_val = [(None if result is None else torch.from_numpy(result)) for result in results] + + if len(decode_val) == 1: + return decode_val[0] + return decode_val + + +def video_encode(tensor_list, video_path_list, fps=8, video_codec="h264"): + if not isinstance(tensor_list, list) or not isinstance(video_path_list, list): + logger.error("tensor_list and video_path_list should both be list type") + return None + if len(tensor_list) == 0 or len(tensor_list) > 128: + logger.error("the length of tensor_list should be in [1, 128]") + return None + if len(video_path_list) == 0 or len(video_path_list) > 128: + logger.error("the length of video_path_list should be in [1, 128]") + return None + if len(tensor_list) != len(video_path_list): + logger.error("the length of tensor_list and video_path_list should be equal") + return None + for idx, tensor in enumerate(tensor_list): + if not isinstance(tensor, torch.Tensor): + logger.error("the element in tensor_list should be torch tensor") + return None + if not isinstance(video_path_list[idx], str) or not video_path_list[idx].endswith('.mp4'): + logger.error("the element in video_path_list should be str and endswith .mp4") + return None + + start = time.time() + process_list = [] + for i, video in enumerate(tensor_list): + p = multiprocessing.Process(target=encode_single_video, args=(video_path_list[i], video, fps, video_codec)) + p.start() + process_list.append(p) + + # Timer() not work for io.write_video(), so try to terminal(or kill) it + has_alive = True + time_out = False + while has_alive: + has_alive = False + end = time.time() + if not time_out and end - start >= TIMEOUT_T: + time_out = True + for p in process_list: + if p.is_alive(): + if time_out: + logger.error('Time out for video encode(io.write), finish sub process') + p.terminate() + p.join() + has_alive = True + break + + for p in process_list: + p.join() + + return [] \ No newline at end of file diff --git a/MindIE/MultiModal/OpenSora-1.2/opensora/vae/VideoAutoencoder.py b/MindIE/MultiModal/OpenSora-1.2/opensora/vae/VideoAutoencoder.py index 4bffccdd03..682de158bb 100644 --- a/MindIE/MultiModal/OpenSora-1.2/opensora/vae/VideoAutoencoder.py +++ b/MindIE/MultiModal/OpenSora-1.2/opensora/vae/VideoAutoencoder.py @@ -206,3 +206,19 @@ class VideoAutoencoder(DiffusionModel): else: self.load_state_dict(state_dict, strict=False) return state_dict.keys() + + def encode(self, x): + x_z = self.spatial_vae.encode(x) + + if self.micro_frame_size is None: + posterior = self.temporal_vae.encode(x_z) + z = posterior.sample() + else: + z_list = [] + for i in range(0, x_z.shape[2], self.micro_frame_size): + x_z_bs = x_z[:, :, i: i + self.micro_frame_size] + posterior = self.temporal_vae.encode(x_z_bs) + z_list.append(posterior.sample()) + z = torch.cat(z_list, dim=2) + + return (z - self.shift) / self.scale -- Gitee