diff --git a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/transformers_need/modeling_qwen2_vl.py b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/transformers_need/modeling_qwen2_vl.py index 660bca2c3af04b887a78cbebadfaf659f097d879..b6860bcf14c6e88e1c0013e3309d5fc228ad2cef 100644 --- a/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/transformers_need/modeling_qwen2_vl.py +++ b/PyTorch/built-in/rl/OpenRLHF_v0.5.7_for_PyTorch/transformers_need/modeling_qwen2_vl.py @@ -24,6 +24,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch +import torch_npu import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint @@ -178,8 +179,8 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim unsqueeze_dim ) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = torch_npu.npu_rotary_mul(q, cos ,sin) + k_embed = torch_npu.npu_rotary_mul(k, cos ,sin) return q_embed, k_embed @@ -189,9 +190,19 @@ def apply_rotary_pos_emb_vision( orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() - cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + + cos = cos.unsqueeze(1).unsqueeze(2).float() + sin = sin.unsqueeze(1).unsqueeze(2).float() + + q = q.unsqueeze(2) + k = k.unsqueeze(2) + + q_embed = torch_npu.npu_rotary_mul(q, cos ,sin) + k_embed = torch_npu.npu_rotary_mul(k, cos ,sin) + + q_embed = q_embed.squeeze(2) + k_embed = k_embed.squeeze(2) + q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) return q_embed, k_embed @@ -380,7 +391,7 @@ class VisionSdpaAttention(nn.Module): cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) - attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True q = q.transpose(0, 1) @@ -442,11 +453,7 @@ class Qwen2RMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0] def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"