From 7d3840ed1df5545dc197ed8368b00e0f61dd14a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E7=89=B9=E9=A9=B9?= Date: Tue, 10 Jun 2025 15:49:32 +0800 Subject: [PATCH] =?UTF-8?q?=E9=92=88=E5=AF=B9PyTorch2.6=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../classification/Resnet50_Cifar_for_PyTorch/tools/train.py | 3 +++ PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py | 3 +++ .../cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py | 3 ++- PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py | 3 +++ PyTorch/contrib/nlp/roberta_for_PyTorch/train.py | 4 ++++ 5 files changed, 15 insertions(+), 1 deletion(-) diff --git a/PyTorch/built-in/cv/classification/Resnet50_Cifar_for_PyTorch/tools/train.py b/PyTorch/built-in/cv/classification/Resnet50_Cifar_for_PyTorch/tools/train.py index 599aba4de9..32c0c593e7 100644 --- a/PyTorch/built-in/cv/classification/Resnet50_Cifar_for_PyTorch/tools/train.py +++ b/PyTorch/built-in/cv/classification/Resnet50_Cifar_for_PyTorch/tools/train.py @@ -24,6 +24,9 @@ if torch.__version__ >= "1.8": import torch_npu torch.npu.config.allow_internal_format = True +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + import torch.distributed as dist import torch_npu from torch_npu.contrib import transfer_to_npu diff --git a/PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py b/PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py index 9fc4355e33..67b89f370a 100644 --- a/PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py +++ b/PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py @@ -16,6 +16,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + from omegaconf import DictConfig, OmegaConf # We need to setup root logger before importing any fairseq libraries. diff --git a/PyTorch/contrib/cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py b/PyTorch/contrib/cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py index 2465095d90..064e29eb47 100644 --- a/PyTorch/contrib/cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py +++ b/PyTorch/contrib/cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py @@ -30,7 +30,8 @@ if torch.__version__ >= "1.8": import torch_npu torch.npu.config.allow_internal_format = True - +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) import torch.nn.parallel import torch.backends.cudnn as cudnn diff --git a/PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py b/PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py index 9aa8d42fb2..4eb11ce164 100644 --- a/PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py +++ b/PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py @@ -37,6 +37,9 @@ import json import time import torch +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + import torch.nn as nn import torch_npu from torch_npu.contrib import transfer_to_npu diff --git a/PyTorch/contrib/nlp/roberta_for_PyTorch/train.py b/PyTorch/contrib/nlp/roberta_for_PyTorch/train.py index 321de3d9b5..744a6f251e 100644 --- a/PyTorch/contrib/nlp/roberta_for_PyTorch/train.py +++ b/PyTorch/contrib/nlp/roberta_for_PyTorch/train.py @@ -7,6 +7,10 @@ Legacy entry point. Use fairseq_cli/train.py or fairseq-train instead. """ +import torch +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + from fairseq_cli.train import cli_main -- Gitee