From f8a76117da405348cebeac62646155182d5dfb3c Mon Sep 17 00:00:00 2001 From: yu Date: Mon, 3 Apr 2023 12:50:48 +0000 Subject: [PATCH] =?UTF-8?q?update=20PyTorch/built-in/nlp/GPT-2=5Ffor=5FPyT?= =?UTF-8?q?orch/gpt=5Fpatch/gpt=5Fpatch.py.=20=E4=BF=AE=E5=A4=8D=E4=BA=86?= =?UTF-8?q?=E4=B8=80=E4=B8=AAbug=EF=BC=8C=E6=AD=A4bug=E4=BC=9A=E5=9C=A8?= =?UTF-8?q?=E4=BB=85=E5=AD=98=E5=9C=A8scale=E7=9A=84=E6=83=85=E5=86=B5?= =?UTF-8?q?=E4=B8=8B=E8=A7=A6=E5=8F=91=EF=BC=8C=E6=AD=A4=E6=83=85=E5=86=B5?= =?UTF-8?q?=E4=B8=8B=E9=9C=80=E8=A6=81=E4=BD=BF=E7=94=A8=20`self.scale`=20?= =?UTF-8?q?=E8=80=8C=E9=9D=9E=20`scale`=20=E6=8C=87=E4=BB=A3=E7=BC=A9?= =?UTF-8?q?=E6=94=BE=E7=B3=BB=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: yu --- PyTorch/built-in/nlp/GPT-2_for_PyTorch/gpt_patch/gpt_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorch/built-in/nlp/GPT-2_for_PyTorch/gpt_patch/gpt_patch.py b/PyTorch/built-in/nlp/GPT-2_for_PyTorch/gpt_patch/gpt_patch.py index ef205768dd..1377d7c655 100644 --- a/PyTorch/built-in/nlp/GPT-2_for_PyTorch/gpt_patch/gpt_patch.py +++ b/PyTorch/built-in/nlp/GPT-2_for_PyTorch/gpt_patch/gpt_patch.py @@ -849,7 +849,7 @@ def FusedScaleMaskSoftmaxForward(self, input, mask, norm_factor): input = input.float() if self.scale is not None: - input = input * (scale * 1.0 / norm_factor) + input = input * (self.scale * 1.0 / norm_factor) if self.attn_mask_type == AttnMaskType.causal: if self.mask_tri is None: -- Gitee