diff --git a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py index 717f081aa2961bc2afc232022d3285a4829396ad..81d1432ba5d83828f54934fff0ea4edb0d5e98a3 100644 --- a/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py +++ b/MindIE/MultiModal/StepVideo_TI2V/stepvideo/modules/rope.py @@ -1,7 +1,6 @@ import torch from stepvideo.parallel import get_sequence_parallel_world_size, get_sequence_parallel_rank -import mindiesd -import mindiesd.layers.embedding +from mindiesd import rotary_position_embedding class RoPE1D: @@ -25,7 +24,7 @@ class RoPE1D: def apply_rope1d(self, tokens, pos1d, cos, sin): cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :] sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :] - return torch.ops.mindie.rope_mindie_sd(tokens, cos, sin, mode=0) + return rotary_position_embedding(tokens, cos, sin, rotated_mode="rotated_half") def __call__(self, tokens, positions): """