diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py index 9d05a5f980aaa6c8e380353195c517d817cead0c..2c233898d33252020e56a65bd7f2e2463a7cac09 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py @@ -30,6 +30,10 @@ logger = logging.get_logger(__name__) def apply_rotary_pos_emb_flashatt_npu( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + cos = cos.repeat(1, 2) + sin = sin.repeat(1, 2) q_embed = apply_rotary_emb(q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()).type_as(k) return q_embed, k_embed