diff --git a/MindIE/MultiModal/CogVideoX/README.md b/MindIE/MultiModal/CogVideoX/README.md index bf9d6697ba0aaa672864917b6fcc6bc3344babc3..ff948b45ee58114d509f050e4f17e20a2f769f99 100644 --- a/MindIE/MultiModal/CogVideoX/README.md +++ b/MindIE/MultiModal/CogVideoX/README.md @@ -169,7 +169,8 @@ TASK_QUEUE_ENABLE=2 ASCEND_RT_VISIBLE_DEVICES=0 torchrun --master_port=2002 --np - num_inference_steps:推理迭代步数,默认值为50。 - dtype:数据类型,默认值为bfloat16。CogVideoX-2b推荐设置为float16,需要在命令前加INF_NAN_MODE_FORCE_DISABLE=1,开启饱和模式避免数值溢出。 - seed: 设置随机种子,默认值为42。 -- enable_skip:是否使用采样优化。 +- enable_skip:是否使用采样优化,注意是有损的加速算法。 + 推理结束后会在当前路径下生成result.json,用于记录文本提示和生成视频的对应关系,便于测试视频精度。 diff --git a/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_state.py b/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_state.py index 3ca14c2089e17f3ee97331b76f660b0ee7290958..a7fa9f9f86b561c92610acbe713177682b995824 100644 --- a/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_state.py +++ b/MindIE/MultiModal/CogVideoX/cogvideox_5b/utils/parallel_state.py @@ -265,7 +265,7 @@ def set_parallel(pipe): image_embeds = output[:, text_len:, :].reshape(batch, num_frames, -1, output.shape[-1]) text_embeds = split_tensor(text_embeds, -2, get_sp_world_size(), get_sp_group()) - image_embeds = split_tensor(image_embeds, -2, get_sp_world_size(), get_sp_group(), scale=2) + image_embeds = split_tensor(image_embeds, -2, get_sp_world_size(), get_sp_group()) image_embeds = image_embeds.reshape(batch, -1, image_embeds.shape[-1]) return torch.cat([text_embeds, image_embeds], dim=1) diff --git a/MindIE/MultiModal/CogVideoX/inference.py b/MindIE/MultiModal/CogVideoX/inference.py index 3571d9bb8c64fce8fce58939bf7fe7ab1a8f2498..cadc94c75877c96b5cd3637dc113f2b0e77bc271 100644 --- a/MindIE/MultiModal/CogVideoX/inference.py +++ b/MindIE/MultiModal/CogVideoX/inference.py @@ -101,7 +101,7 @@ def generate_video( export_to_video(video_generate, video_path, fps=fps) result[os.path.abspath(video_path)] = prompt - with open('result_2b_46.json', 'w', encoding='utf-8') as json_file: + with open('result.json', 'w', encoding='utf-8') as json_file: json.dump(result, json_file, ensure_ascii=False, indent=4) print(f"Result saved to result.json.")