diff --git a/PyTorch/built-in/cv/classification/Resnet50_Cifar_for_PyTorch/tools/train.py b/PyTorch/built-in/cv/classification/Resnet50_Cifar_for_PyTorch/tools/train.py index 599aba4de93db59beb39593a40ae72c2d1c599b0..32c0c593e727cd31402ebd45198fbf8e1e385ca4 100644 --- a/PyTorch/built-in/cv/classification/Resnet50_Cifar_for_PyTorch/tools/train.py +++ b/PyTorch/built-in/cv/classification/Resnet50_Cifar_for_PyTorch/tools/train.py @@ -24,6 +24,9 @@ if torch.__version__ >= "1.8": import torch_npu torch.npu.config.allow_internal_format = True +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + import torch.distributed as dist import torch_npu from torch_npu.contrib import transfer_to_npu diff --git a/PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py b/PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py index 9fc4355e33127875d665f1d219964a06880d3c00..67b89f370a4b05c57e37c6c5f99e47f33474d62f 100644 --- a/PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py +++ b/PyTorch/contrib/audio/wav2vec2.0/fairseq_cli/train.py @@ -16,6 +16,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + from omegaconf import DictConfig, OmegaConf # We need to setup root logger before importing any fairseq libraries. diff --git a/PyTorch/contrib/cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py b/PyTorch/contrib/cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py index 2465095d9069064eb6617b4fd590f9dddf54d809..064e29eb473bf231725100e7d6075c07f9019ce9 100644 --- a/PyTorch/contrib/cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py +++ b/PyTorch/contrib/cv/classification/HRNet_ID1780_for_PyTorch/tools/train.py @@ -30,7 +30,8 @@ if torch.__version__ >= "1.8": import torch_npu torch.npu.config.allow_internal_format = True - +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) import torch.nn.parallel import torch.backends.cudnn as cudnn diff --git a/PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py b/PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py index 9aa8d42fb219d92b4f09cf5afe7a8bbc7262057c..4eb11ce1646f34433ee5516fc4609d711d472df2 100644 --- a/PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py +++ b/PyTorch/contrib/nlp/BERT-NER-Pytorch/run_ner_crf.py @@ -37,6 +37,9 @@ import json import time import torch +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + import torch.nn as nn import torch_npu from torch_npu.contrib import transfer_to_npu diff --git a/PyTorch/contrib/nlp/roberta_for_PyTorch/train.py b/PyTorch/contrib/nlp/roberta_for_PyTorch/train.py index 321de3d9b53f8194b58c26f5cb2c03281afc2bb1..744a6f251e2daf6f58f038e8c743d50780561e8a 100644 --- a/PyTorch/contrib/nlp/roberta_for_PyTorch/train.py +++ b/PyTorch/contrib/nlp/roberta_for_PyTorch/train.py @@ -7,6 +7,10 @@ Legacy entry point. Use fairseq_cli/train.py or fairseq-train instead. """ +import torch +if torch.__version__ >= "2.6": + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + from fairseq_cli.train import cli_main