From d72742bd7b8aa7f50609a557a24fc32c03a9bf93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=84=A2=E5=97=A3=E8=8A=BE?= Date: Thu, 24 Apr 2025 15:37:02 +0800 Subject: [PATCH] modify interface --- MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py index 717f081aa2..81d1432ba5 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py +++ b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py @@ -1,7 +1,6 @@ import torch from stepvideo.parallel import get_sequence_parallel_world_size, get_sequence_parallel_rank -import mindiesd -import mindiesd.layers.embedding +from mindiesd import rotary_position_embedding class RoPE1D: @@ -25,7 +24,7 @@ class RoPE1D: def apply_rope1d(self, tokens, pos1d, cos, sin): cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] - return torch.ops.mindie.rope_mindie_sd(tokens, cos, sin, mode=0) + return rotary_position_embedding(tokens, cos, sin, rotated_mode="rotated_half") def __call__(self, tokens, positions): """ -- Gitee