From 353103fb0ef28526a6e2539b1ceab344e975f399 Mon Sep 17 00:00:00 2001 From: sunyi001 <1659275352@qq.com> Date: Fri, 6 Jun 2025 14:16:53 +0800 Subject: [PATCH] fix generation method of cos and sin --- PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py index 9d05a5f980..2c233898d3 100644 --- a/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py +++ b/PyTorch/built-in/rl/VeRL_for_PyTorch/verl/utils/npu_patch.py @@ -30,6 +30,10 @@ logger = logging.get_logger(__name__) def apply_rotary_pos_emb_flashatt_npu( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + cos = cos.repeat(1, 2) + sin = sin.repeat(1, 2) q_embed = apply_rotary_emb(q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()).type_as(q) k_embed = apply_rotary_emb(k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()).type_as(k) return q_embed, k_embed -- Gitee