From fcd0ef511e8b850b5f1190c1feeaa5b0e90ea6f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=87=8C?= Date: Fri, 10 Feb 2023 06:40:49 +0000 Subject: [PATCH] update PyTorch/dev/nlp/FairSeq_Transformer_ID0496_for_PyTorch/train.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 王凌 --- .../dev/nlp/FairSeq_Transformer_ID0496_for_PyTorch/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/PyTorch/dev/nlp/FairSeq_Transformer_ID0496_for_PyTorch/train.py b/PyTorch/dev/nlp/FairSeq_Transformer_ID0496_for_PyTorch/train.py index 488c633579..ed3e324cb0 100644 --- a/PyTorch/dev/nlp/FairSeq_Transformer_ID0496_for_PyTorch/train.py +++ b/PyTorch/dev/nlp/FairSeq_Transformer_ID0496_for_PyTorch/train.py @@ -56,6 +56,7 @@ from fairseq import checkpoint_utils, distributed_utils, options, progress_bar, from fairseq.data import iterators from fairseq.trainer import Trainer from fairseq.meters import AverageMeter, StopwatchMeter +from ptdbg_ascend import * #hook def hook_func(name, save_dict, module): @@ -92,6 +93,7 @@ def hook_func(name, save_dict, module): def main(args, init_distributed=False): # args.device_id = 5 + seed_all() if not torch.cuda.is_available(): a = torch.randn([8]).npu().fill_(2) float_status = torch.npu_alloc_float_status(a) @@ -127,6 +129,8 @@ def main(args, init_distributed=False): # Build model and criterion model = task.build_model(args) + set_dump_switch("ON") + register_hook(model, overflow_check, overflow_nums=2) ''' npu_dict = {} for name, module in model.named_modules(): -- Gitee