diff --git a/AscendIE/TorchAIE/built-in/nlp/Albert/README.md b/AscendIE/TorchAIE/built-in/nlp/Albert/README.md
index 0fb595c2f76da8d6021843498d96d31680f8c67d..599e52c04680ded0016ddde0e1c0f4dc5de793b3 100644
--- a/AscendIE/TorchAIE/built-in/nlp/Albert/README.md
+++ b/AscendIE/TorchAIE/built-in/nlp/Albert/README.md
@@ -55,9 +55,9 @@ ALBERT是BERT 的“改进版”,主要通过通过Factorized embedding parame
| ------------------------------------------------------------ | ------- |
| 固件与驱动 | 23.0.RC1 |
| CANN | 7.0.RC1.alpha003 |
-| Python | 3.9.11 |
-| PyTorch | 2.0.1 |
-| Torch_AIE | 6.3.rc2 |
+| Python | 3.10.0 |
+| PyTorch | 2.1.0 |
+| MindIETorch | 1.0.RC1 |
| 芯片类型 | Ascend310P3 |
# 快速上手
@@ -85,7 +85,6 @@ ALBERT是BERT 的“改进版”,主要通过通过Factorized embedding parame
git clone https://github.com/lonePatient/albert_pytorch.git
cd albert_pytorch
git checkout 46de9ec
- patch -p1 < ../albert.patch
cd ../
```
@@ -132,14 +131,40 @@ ALBERT是BERT 的“改进版”,主要通过通过Factorized embedding parame
## 模型推理
1. 编译模型。
+
+ mindietorch现在支持两种模型编译方式,一种是将原始PyTorch模型通过torch.jit.trace转换为torchscript模型后进行编译,一种是对原始PyTorch模型通过torch.export导出的fx模型进行编译。下面分别介绍两种编译方式,用户可自行选择其中一种。
+
+ 1.1. torchscript模型编译
- 使用PyTorch将原始模型首先trace成torchscript模型,然后使用torch_aie编译成.pt文件。
+ 使用PyTorch将原始模型首先trace成torchscript模型,然后使用mindietorch编译成.pt文件。
```
- # 以bs32,seq128为例
- python3.9 export_albert_aie.py --batch_size=32 --max_seq_length=128 --compare_cpu
+ # 为原始模型定义代码打补丁
+ cd ./albert_pytorch
+ git checkout 46de9ec -f
+ patch -p1 < ../albert_ts.patch
+ cd ../
+ # 以bs32,seq128的静态模型为例
+ python3 export_albert_ts.py --batch_size=32 --max_seq_length=128 --compare_cpu
```
- - 执行以上命令将在会把编译好的模型存储值当前目录下的albert_seq128_bs32.pt文件,使用--compare_cpu参数则脚本还会验证编译后的模型与原始torch模型的输出是否一致。
+ - 执行以上命令将在会把编译好的静态模型存储到当前目录下的albert_ts_seq128_bs32.pt文件,添加--compare_cpu参数则脚本还会验证编译后的模型与原始torch模型的输出是否一致。
+\
+ 1.2. torch.export导出模型编译
+
+ 使用PyTorch将原始模型导出为fx模型,然后将该fx模型编译成.pt文件。
+ ```
+ # 为原始模型定义代码打补丁
+ cd ./albert_pytorch
+ git checkout 46de9ec -f
+ patch -p1 < ../albert_fx.patch
+ cd ../
+ # 以bs32,seq128的静态模型为例
+ python3 export_albert_fx.py --batch_size=32 --max_seq_length=128 --compare_cpu
+ # 也可以选择导出并编译动态模型
+ bash albert_dyn.sh
+ ```
+
+ - 编译静态模型时,编译后的模型会存储到当前目录下的albert_fx_seq128_bs32.pt文件(以bs32, seq128为例)。编译动态模型时,编译后的模型会存储至albert_fx_seq128_range.pt文件,静态的模型支持输入的batch维度范围为1~64。
2. 推理验证。
@@ -153,32 +178,40 @@ ALBERT是BERT 的“改进版”,主要通过通过Factorized embedding parame
执行推理,验证模型的精度和吞吐量。
```
- # 以推理bs32,seq128模型为例
- python3.9 ./run_aie_eval.py --aie_model_dir=../albert_seq128_bs32.pt --model_type=SST --model_name_or_path="./prev_trained_model/albert_base_v2" --task_name="SST-2" --data_dir="./dataset/SST-2" --spm_model_file="./prev_trained_model/albert_base_v2/30k-clean.model" --output_dir="./outputs/SST-2/" --do_lower_case --max_seq_length=128 --batch_size=32
+ # 当使用torchscript模型编译方式时,以bs32, seq128的静态模型为例
+ python3 ./run_aie_eval.py --aie_model_dir=../albert_ts_seq128_bs32.pt --model_type=SST --model_name_or_path="./prev_trained_model/albert_base_v2" --task_name="SST-2" --data_dir="./dataset/SST-2" --spm_model_file="./prev_trained_model/albert_base_v2/30k-clean.model" --output_dir="./outputs/SST-2/" --do_lower_case --max_seq_length=128 --batch_size=32
+
+ # 当使用torch.export导出模型编译方式时,以bs32, seq128的静态模型为例
+ python3 ./run_aie_eval.py --aie_model_dir=../albert_fx_seq128_bs32.pt --model_type=SST --model_name_or_path="./prev_trained_model/albert_base_v2" --task_name="SST-2" --data_dir="./dataset/SST-2" --spm_model_file="./prev_trained_model/albert_base_v2/30k-clean.model" --output_dir="./outputs/SST-2/" --do_lower_case --max_seq_length=128 --batch_size=32
+
+ # 当使用torch.export导出模型编译方式时,以seq128的动态模型为例
+ python3 ./run_aie_eval.py --aie_model_dir=../albert_fx_seq128_range.pt --model_type=SST --model_name_or_path="./prev_trained_model/albert_base_v2" --task_name="SST-2" --data_dir="./dataset/SST-2" --spm_model_file="./prev_trained_model/albert_base_v2/30k-clean.model" --output_dir="./outputs/SST-2/" --do_lower_case --max_seq_length=128 --batch_size=16
```
- - 推理完成后会输出模型在数据集上的分类准确率以及单位时间内推理的样本数量(吞吐率)。若要推理不同bs和seq配置的模型,只需要更改--aie_model_dir、--max_seq_length和--batch_size三个参数即可。
+ - 推理完成后会输出模型在数据集上的分类准确率以及单位时间内推理的样本数量(吞吐率)。对于静态模型,若要推理不同bs和seq配置的模型,需要重复编译模型步骤并设置对应的bs和seq_len,然后再执行推理验证步骤,并更改--aie_model_dir、--max_seq_length和--batch_size三个参数。对于torch.export导出模型编译的方式且模型为动态时,则可以不用重复编译模型步骤,多次执行推理验证步骤使用不同的--batch_size参数(bs的动态范围是1~64)。
# 模型推理性能&精度
-bs32seq128对应的精度性能如下:
+通过mindietorch编译后的模型精度性能如下:
精度:
-| 输入类型 | 芯片型号 | ACC(bs32seq128) |
+| 模型 | 芯片型号 | ACC(bs1seq128) |
| --------- | -------- | ------------- |
-| 静态 | Ascend310P3 | 92.82% |
-
-静态模型性能:
-
-| 模型 | batch size | 310P3性能 |
-| :------: | :------: | :------: |
-| Albert base v2 | 1 | 532.61 |
-| Albert base v2 | 4 | 841.79 |
-| Albert base v2 | 8 | 1020.77 |
-| Albert base v2 | 16 | 982.26 |
-| Albert base v2 | 32 | 988.63 |
-| Albert base v2 | 64 | 891.06 |
+| 原始Pytorch模型 | CPU | 92.7% |
+| torchscript路线编译 | Ascend310P3 | 92.7% |
+| torch.export路线编译 | Ascend310P3 | 92.7% |
+
+静态模型性能(seq128):
+
+| 模型 | batch size | torchscript路线编译 | torch.export路线编译 |
+| :------: | :------: | :------: | :------: |
+| Albert base v2 | 1 | 532.61 | 373.29 |
+| Albert base v2 | 4 | 841.79 | 354.49 |
+| Albert base v2 | 8 | 1020.77 | 415.50 |
+| Albert base v2 | 16 | 982.26 | 496.13 |
+| Albert base v2 | 32 | 988.63 | 441.28 |
+| Albert base v2 | 64 | 891.06 | 415.09 |
diff --git a/AscendIE/TorchAIE/built-in/nlp/Albert/albert_dyn.sh b/AscendIE/TorchAIE/built-in/nlp/Albert/albert_dyn.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a1aa8a8f81591e32e384d72d8b10ba8bb8bb1a69
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/nlp/Albert/albert_dyn.sh
@@ -0,0 +1,3 @@
+export ASCEND_GLOBAL_LOG_LEVEL=0
+python3 export_albert_fx.py --compare_cpu --is_range
+unset ASCEND_GLOBAL_LOG_LEVEL
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/nlp/Albert/albert_fx.patch b/AscendIE/TorchAIE/built-in/nlp/Albert/albert_fx.patch
new file mode 100644
index 0000000000000000000000000000000000000000..2e355daae8f4d7ac4e121bec45c01fa5607fb59f
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/nlp/Albert/albert_fx.patch
@@ -0,0 +1,104 @@
+diff --git a/model/modeling_albert.py b/model/modeling_albert.py
+index 899e6e6..26a837f 100644
+--- a/model/modeling_albert.py
++++ b/model/modeling_albert.py
+@@ -1,6 +1,5 @@
+ """PyTorch ALBERT model. """
+ from __future__ import absolute_import, division, print_function, unicode_literals
+-import logging
+ import math
+ import os
+ import sys
+@@ -10,7 +9,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
+ from .modeling_utils import PreTrainedModel, prune_linear_layer
+ from .configuration_albert import AlbertConfig
+ from .file_utils import add_start_docstrings
+-logger = logging.getLogger(__name__)
++from tools.common import logger # get args via logger
+
+ ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
+ 'albert-base': "",
+@@ -127,7 +126,7 @@ class AlbertEmbeddings(nn.Module):
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
+ seq_length = input_ids.size(1)
+ if position_ids is None:
+- position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
++ position_ids = torch.arange(seq_length, dtype=torch.int32, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+@@ -453,7 +452,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
+
+ ALBERT_START_DOCSTRING = r""" The ALBERT model was proposed in
+ `ALBERT: A Lite BERT for Self-supervised Learning of Language Representations`_
+- by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
++ by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
+ This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
+ refer to the PyTorch documentation for all matter related to general usage and behavior.
+ .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
+@@ -461,7 +460,7 @@ ALBERT_START_DOCSTRING = r""" The ALBERT model was proposed in
+ .. _`torch.nn.Module`:
+ https://pytorch.org/docs/stable/nn.html#module
+ Parameters:
+- config (:class:`~transformers.ALbertConfig`): Model configuration class with all the parameters of the model.
++ config (:class:`~transformers.ALbertConfig`): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
+ """
+diff --git a/model/modeling_utils.py b/model/modeling_utils.py
+index 56a52e9..4a49178 100644
+--- a/model/modeling_utils.py
++++ b/model/modeling_utils.py
+@@ -12,6 +12,7 @@ from torch.nn import CrossEntropyLoss
+ from torch.nn import functional as F
+
+ from model.configuration_utils import PretrainedConfig
++from model.configuration_albert import AlbertConfig
+ from model.file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
+
+ logger = logging.getLogger(__name__)
+@@ -54,7 +55,7 @@ class PreTrainedModel(nn.Module):
+
+ def __init__(self, config, *inputs, **kwargs):
+ super(PreTrainedModel, self).__init__()
+- if not isinstance(config, PretrainedConfig):
++ if not 'AlbertConfig' in str(type(config)) : # modify via infer changes root
+ raise ValueError(
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
+ "To create a model from a pretrained model use "
+@@ -123,7 +124,7 @@ class PreTrainedModel(nn.Module):
+ Arguments:
+
+ new_num_tokens: (`optional`) int:
+- New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
++ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
+ If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
+
+ Return: ``torch.nn.Embeddings``
+diff --git a/processors/glue.py b/processors/glue.py
+index 6628226..8836416 100644
+--- a/processors/glue.py
++++ b/processors/glue.py
+@@ -14,10 +14,6 @@ def collate_fn(batch):
+ Returns a padded tensor of sequences sorted from longest to shortest,
+ """
+ all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch))
+- max_len = max(all_lens).item()
+- all_input_ids = all_input_ids[:, :max_len]
+- all_attention_mask = all_attention_mask[:, :max_len]
+- all_token_type_ids = all_token_type_ids[:, :max_len]
+ return all_input_ids, all_attention_mask, all_token_type_ids, all_labels
+
+
+@@ -266,6 +262,11 @@ class Sst2Processor(DataProcessor):
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
+
++ def get_test_examples(self, data_dir):
++ """See base class."""
++ return self._create_examples(
++ self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
++
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
diff --git a/AscendIE/TorchAIE/built-in/nlp/Albert/albert.patch b/AscendIE/TorchAIE/built-in/nlp/Albert/albert_ts.patch
old mode 100644
new mode 100755
similarity index 100%
rename from AscendIE/TorchAIE/built-in/nlp/Albert/albert.patch
rename to AscendIE/TorchAIE/built-in/nlp/Albert/albert_ts.patch
diff --git a/AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_fx.py b/AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_fx.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7f0443c7468b55bc0baad94e4b276245727eb01
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_fx.py
@@ -0,0 +1,141 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import torch
+from torch._export import export, dynamic_dim
+import mindietorch
+from mindietorch import _enums
+import parse
+
+
+COSINE_THRESHOLD = 0.999
+
+
+def cosine_similarity(gt_tensor, pred_tensor):
+ gt_tensor = gt_tensor.flatten().to(torch.float32)
+ pred_tensor = pred_tensor.flatten().to(torch.float32)
+ if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
+ if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
+ return 1.0
+ res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
+ res = res.cpu().detach().item()
+ return res
+
+
+def get_torch_model(ar):
+ _, model, _ = parse.load_data_model(ar)
+ return model
+
+
+def aie_compile(torch_model, ar):
+ torch_model.eval()
+ input_shape = (ar.batch_size, ar.max_seq_length)
+ input_ids = torch.randint(high = 1, size = input_shape, dtype = torch.int32)
+ att_mask = torch.randint(high = 3, size = input_shape, dtype = torch.int32)
+ token_ids = torch.randint(high = 1, size = input_shape, dtype = torch.int32)
+ input_data = [ input_ids, att_mask, token_ids ]
+ input_npu_data = [ input_ids.to("npu:0"), att_mask.to("npu:0"), token_ids.to("npu:0") ]
+
+ print("mindietorch compile start.")
+ mindietorch.set_device(0)
+ if ar.is_range:
+ print("compiling shape range model")
+ min_shape = (1, ar.max_seq_length)
+ max_shape = (64, ar.max_seq_length)
+ compile_inputs = [mindietorch.Input(min_shape = min_shape, max_shape = max_shape, dtype = torch.int32),
+ mindietorch.Input(min_shape = min_shape, max_shape = max_shape, dtype = torch.int32),
+ mindietorch.Input(min_shape = min_shape, max_shape = max_shape, dtype = torch.int32)]
+
+ constraint = [
+ 1 <= dynamic_dim(input_ids, 0),
+ dynamic_dim(input_ids, 0) <= 64,
+ dynamic_dim(input_ids, 0) == dynamic_dim(att_mask, 0),
+ dynamic_dim(input_ids, 0) == dynamic_dim(token_ids, 0),
+ ]
+ torch_model = export(torch_model, args=(input_ids, att_mask, token_ids), constraints=constraint)
+ else:
+ print("compiling static model")
+ compile_inputs = [mindietorch.Input(shape = input_shape, dtype = torch.int32),
+ mindietorch.Input(shape = input_shape, dtype = torch.int32),
+ mindietorch.Input(shape = input_shape, dtype = torch.int32)]
+ compiled_model = mindietorch.compile(
+ torch_model,
+ inputs = compile_inputs,
+ precision_policy = _enums.PrecisionPolicy.FP16,
+ soc_version = "Ascend310P3",
+ ir = "dynamo"
+ )
+ print("mindietorch compile done !")
+ traced_model = torch.jit.trace(compiled_model, input_npu_data)
+ traced_model.save(ar.pt_dir)
+
+
+ if ar.compare_cpu:
+ print("start to check the percision of npu model.")
+ com_res = True
+ mrt_model = torch.jit.load(ar.pt_dir)
+ mrt_res = mrt_model(*input_npu_data)
+ print("mindie infer done !")
+ ref_res = torch_model(*input_data)
+ print("torch infer done !")
+
+ for j, a in zip(mrt_res, ref_res):
+ res = cosine_similarity(j.to("cpu"), a)
+ print(res)
+ if res < COSINE_THRESHOLD:
+ com_res = False
+
+ if com_res:
+ print("Compare success ! NPU model have the same output with CPU model !")
+ else:
+ print("Compare failed ! Outputs of NPU model are not the same with CPU model !")
+ return compiled_model
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--batch_size", type=int, default=32,
+ help="seq length for input data.")
+ parser.add_argument("--prefix_dir", type=str, default='./albert_pytorch',
+ help="prefix dir for ori model code")
+ parser.add_argument("--pth_dir", type=str, default='./albert_pytorch/outputs/SST-2',
+ help="dir of pth, load args.bin and model.bin")
+ parser.add_argument("--data_dir", type=str, default='./albert_pytorch/dataset/SST-2',
+ help="dir of dataset")
+ parser.add_argument("--max_seq_length", type=int, default=128,
+ help="seq length for input data.")
+ parser.add_argument("--save_dir", type=str, default='./',
+ help="save dir of model compiled by mindietorch")
+ parser.add_argument("--compare_cpu", action='store_true',
+ help="Whether to check the percision of npu model.")
+ parser.add_argument("--is_range", action='store_true',
+ help="Whether to compile shape range model.")
+
+ ar = parser.parse_args()
+
+ ar.pth_arg_path = os.path.join(ar.pth_dir, "training_args.bin")
+ ar.data_type = 'dev'
+ if ar.is_range:
+ ar.model_name = "albert_fx_seq{}_range".format(ar.max_seq_length)
+ else:
+ ar.model_name = "albert_fx_seq{}_bs{}".format(ar.max_seq_length, ar.batch_size)
+ ar.pt_dir = ar.save_dir + ar.model_name + ".pt"
+ torch_model = get_torch_model(ar)
+ aie_compile(torch_model, ar)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_aie.py b/AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_ts.py
similarity index 80%
rename from AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_aie.py
rename to AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_ts.py
index a59038a0be2ad02d2e264bcdef133f0c26ea6265..6072a5ca183e8e6d73d82d76fbe5115c109dc95a 100644
--- a/AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_aie.py
+++ b/AscendIE/TorchAIE/built-in/nlp/Albert/export_albert_ts.py
@@ -15,8 +15,8 @@
import os
import argparse
import torch
-import torch_aie
-from torch_aie import _enums
+import mindietorch
+from mindietorch import _enums
import parse
@@ -45,25 +45,25 @@ def aie_compile(torch_model, ar):
input_ids = torch.randint(high = 1, size = input_shape, dtype = torch.int32)
att_mask = torch.randint(high = 3, size = input_shape, dtype = torch.int32)
token_ids = torch.randint(high = 1, size = input_shape, dtype = torch.int32)
- input_data = [ input_ids, att_mask, token_ids]
+ input_data = [ input_ids, att_mask, token_ids ]
print("start to trace model.")
traced_model = torch.jit.trace(torch_model, input_data)
traced_model.eval()
print("trace done !")
- print("torch_aie compile start.")
- torch_aie.set_device(0)
- compile_inputs = [torch_aie.Input(shape = input_shape, dtype = torch.int32, format = torch_aie.TensorFormat.ND),
- torch_aie.Input(shape = input_shape, dtype = torch.int32, format = torch_aie.TensorFormat.ND),
- torch_aie.Input(shape = input_shape, dtype = torch.int32, format = torch_aie.TensorFormat.ND)]
- compiled_model = torch_aie.compile(
+ print("mindietorch compile start.")
+ mindietorch.set_device(0)
+ compile_inputs = [mindietorch.Input(shape = input_shape, dtype = torch.int32, format = mindietorch.TensorFormat.ND),
+ mindietorch.Input(shape = input_shape, dtype = torch.int32, format = mindietorch.TensorFormat.ND),
+ mindietorch.Input(shape = input_shape, dtype = torch.int32, format = mindietorch.TensorFormat.ND)]
+ compiled_model = mindietorch.compile(
traced_model,
inputs = compile_inputs,
precision_policy = _enums.PrecisionPolicy.FP16,
soc_version = "Ascend310P3",
optimization_level = 0
)
- print("torch_aie compile done !")
+ print("mindietorch compile done !")
compiled_model.save(ar.pt_dir)
@@ -73,11 +73,11 @@ def aie_compile(torch_model, ar):
compiled_model = torch.jit.load(ar.pt_dir)
jit_res = traced_model(input_ids, att_mask, token_ids)
print("jit infer done !")
- aie_res = compiled_model(input_ids, att_mask, token_ids)
+ aie_res = compiled_model(input_ids.to("npu:0"), att_mask.to("npu:0"), token_ids.to("npu:0"))
print("aie infer done !")
for j, a in zip(jit_res, aie_res):
- res = cosine_similarity(j, a)
+ res = cosine_similarity(j, a.to("cpu"))
print(res)
if res < COSINE_THRESHOLD:
com_res = False
@@ -102,14 +102,14 @@ def main():
parser.add_argument("--max_seq_length", type=int, default=128,
help="seq length for input data.")
parser.add_argument("--save_dir", type=str, default='./',
- help="save dir of model compiled by torch_aie")
+ help="save dir of model compiled by mindietorch")
parser.add_argument("--compare_cpu", action='store_true',
help="Whether to check the percision of npu model.")
ar = parser.parse_args()
ar.pth_arg_path = os.path.join(ar.pth_dir, "training_args.bin")
ar.data_type = 'dev'
- ar.model_name = "albert_seq{}_bs{}".format(ar.max_seq_length, ar.batch_size)
+ ar.model_name = "albert_ts_seq{}_bs{}".format(ar.max_seq_length, ar.batch_size)
ar.pt_dir = ar.save_dir + ar.model_name + ".pt"
torch_model = get_torch_model(ar)
aie_compile(torch_model, ar)
diff --git a/AscendIE/TorchAIE/built-in/nlp/Albert/requirements.txt b/AscendIE/TorchAIE/built-in/nlp/Albert/requirements.txt
index e7a234a9f887de906fbc5011efb67b4db9841e99..2ad896e75baf9dcbcddca84829e46eab8fb9b317 100644
--- a/AscendIE/TorchAIE/built-in/nlp/Albert/requirements.txt
+++ b/AscendIE/TorchAIE/built-in/nlp/Albert/requirements.txt
@@ -1,7 +1,7 @@
ptvsd==4.3.2
six==1.16.0
sentencepiece==0.1.96
-torch==2.0.1
+torch==2.1.0
boto3==1.17.97
botocore==1.20.97
requests==2.25.1
diff --git a/AscendIE/TorchAIE/built-in/nlp/Albert/run_aie_eval.py b/AscendIE/TorchAIE/built-in/nlp/Albert/run_aie_eval.py
index 3d0828f3e59eeb56939a069cba84c169a18b594d..c4e6aed6de41cd5f266920f98b6da961b137aa8b 100644
--- a/AscendIE/TorchAIE/built-in/nlp/Albert/run_aie_eval.py
+++ b/AscendIE/TorchAIE/built-in/nlp/Albert/run_aie_eval.py
@@ -41,7 +41,7 @@ from tools.common import seed_everything
from tools.common import init_logger, logger
from callback.progressbar import ProgressBar
-import torch_aie
+import mindietorch
def evaluate(args, model, tokenizer, prefix=""):
@@ -62,8 +62,8 @@ def evaluate(args, model, tokenizer, prefix=""):
collate_fn=collate_fn, drop_last=True)
# Eval!
- infer_stream = torch_aie.npu.Stream("npu:0")
- h2d = torch_aie.npu.Stream("npu:0")
+ infer_stream = mindietorch.npu.Stream("npu:0")
+ h2d = mindietorch.npu.Stream("npu:0")
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
@@ -77,7 +77,7 @@ def evaluate(args, model, tokenizer, prefix=""):
for step, batch in enumerate(eval_dataloader):
model.eval()
# prepare data on npu
- with torch_aie.npu.stream(h2d):
+ with mindietorch.npu.stream(h2d):
b = []
for t in batch:
t_npu = t.to(args.device)
@@ -90,7 +90,7 @@ def evaluate(args, model, tokenizer, prefix=""):
# forward
with torch.no_grad():
inf_s = time.time()
- with torch_aie.npu.stream(infer_stream):
+ with mindietorch.npu.stream(infer_stream):
outputs = model(b[0], b[1], b[2])
infer_stream.synchronize()
inf_e = time.time()
@@ -220,7 +220,7 @@ def main():
init_logger(log_file=args.output_dir + '/{}-{}.log'.format(args.model_type, args.task_name))
# setup npu device
- torch_aie.set_device(0)
+ mindietorch.set_device(0)
device = torch.device("npu:0")
args.n_gpu = 1
args.device = device