From 4954e22b9687322b90cae051c5d3e02f71c45445 Mon Sep 17 00:00:00 2001 From: guohuanliang Date: Fri, 15 Dec 2023 15:09:30 +0800 Subject: [PATCH 1/2] improve bertbase chinese --- .../Bert_Base_Chinese_for_Pytorch/README.md | 31 ++-- .../export_trace_model.py | 4 +- .../Bert_Base_Chinese_for_Pytorch/model.py | 2 +- .../run_torch_aie.py | 143 ++++++++++-------- 4 files changed, 101 insertions(+), 79 deletions(-) diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/README.md b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/README.md index 25a8d1f9e3..9904496fd7 100644 --- a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/README.md +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/README.md @@ -55,17 +55,6 @@ Bert 模型本身的输出是以下表格信息: 模型输出代表的是输入文本,大小为`batchsize x seq_len`,经过bert处理后得到的对应特征向量,每个字对应一个长度为`vocab_size`大小的特征向量,所以总的输出是`batch_size x seq_len x vocab_size`特征矩阵。 在Bert模型的输出后面接一个`argmax`算子,求出每个字对应一个长度为`vocab_size`大小的特征向量中最大值的下标,通过该下标去词表当中查询得到对应的字,即模型预测的概率最大的字 与 真实的标签进行对比,以此计算模型的精度。 -所以在trace的时候,模型定义如下: - ``` - class RefineModel(torch.nn.Module): - def __init__(self, tokenizer, model_path, config_path, device="cpu"): - super(RefineModel, self).__init__() - self._base_model = build_base_model(tokenizer, model_path, config_path, device) - - def forward(self, input_ids, attention_mask, token_type_ids): - x = self._base_model(input_ids, attention_mask, token_type_ids) - return x[0].argmax(dim=-1) - ``` 所以当前模型的输出信息为: - 输出数据 @@ -427,7 +416,11 @@ special_tokens_mask : [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 3. 使用torch-aie接口进行模型推理,测试性能与精度 ``` - python3 run_torch_aie.py + # 使用FastGelu + export ASCENDIE_FASTER_MODE=1 + + # 运行推理 + python3 run_torch_aie.py --batch_size=1 ``` @@ -437,14 +430,18 @@ special_tokens_mask : [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, | 模型 | Pth精度 | om离线推理精度 | torch-aie 推理精度 | | :---------------: | :--------: | :------------: | :----------------: | -| Bert-Base-Chinese | Acc:77.96% | Acc:77.96% | Acc: 77.90% | +| Bert-Base-Chinese | Acc:77.96% | Acc:77.96% | Acc: 77.88% | 性能: -| 模型 | BatchSize | om离线推理性能(经过onnx优化)纯模型推理 | torch-aie性能(11.13号版本)端到端推理 | torch-aie性能(11.13号版本)纯模型推理 | -| :---------------: | :-------: | :------------------------------------: | :----------------------------------: | :----------------------------------: | -| Bert-Base-Chinese | 1 | 5.168ms(193.47 it/s) | 13.96ms(71.6 it/s) | 12.90ms(77.2 it/s) | - +| 芯片型号 | Batch Size | 数据集 | 精度指标(Accuracy)| 性能 | +| :------: | :--------: | :----: | :--: | :--: | +| 310P3 | 1 | zhwiki | 77.88 | | +| 310P3 | 4 | zhwiki | 77.88 | | +| 310P3 | 8 | zhwiki | 77.88 | | +| 310P3 | 16 | zhwiki | 77.88 | | +| 310P3 | 32 | zhwiki | 77.88 | | +| 310P3 | 64 | zhwiki | 77.88 | | # 公网地址说明 代码涉及公网地址参考 public_address_statement.md diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py index 1195021bc5..8083b8ca03 100644 --- a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py @@ -27,7 +27,7 @@ def generate_random_data(shape, dtype, low=0, high=2): raise NotImplementedError("Not supported dtype: {}".format(dtype)) -def export_onnx(model_dir, save_path, seq_len=384, batch_size=1): +def export_ts(model_dir, save_path, seq_len=384, batch_size=1): # build tokenizer tokenizer = build_tokenizer(tokenizer_name=model_dir) @@ -51,4 +51,4 @@ if __name__ == '__main__': model_dir = sys.argv[1] save_path = sys.argv[2] seq_len = int(sys.argv[3]) - export_onnx(model_dir, save_path, seq_len) + export_ts(model_dir, save_path, seq_len) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/model.py b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/model.py index 1425cee781..4b73275db2 100644 --- a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/model.py +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/model.py @@ -53,4 +53,4 @@ class RefineModel(torch.nn.Module): def forward(self, input_ids, attention_mask, token_type_ids): x = self._base_model(input_ids, attention_mask, token_type_ids) - return x[0].argmax(dim=-1) + return x[0].to(torch.half).argmax(dim=-1).reshape(-1, 384) diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/run_torch_aie.py b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/run_torch_aie.py index e4c22e60bc..2272f0616f 100644 --- a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/run_torch_aie.py +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/run_torch_aie.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os, argparse +import os +import argparse import time import torch +import torch_aie import numpy as np from tqdm import tqdm from datasets import load_metric -import torch_aie def compute_metrics(preds, labels, eval_metric_path="./accuracy.py"): @@ -26,7 +27,7 @@ def compute_metrics(preds, labels, eval_metric_path="./accuracy.py"): labels = labels.reshape(-1) preds = preds.reshape(-1) mask = labels != -100 - labels = labels[mask] # 返回labels当中值不等于-100的所有元素组成的列表 + labels = labels[mask] # 返回labels当中值不等于-100的所有元素组成的列表 preds = preds[mask] return metric.compute(predictions=preds, references=labels) @@ -44,51 +45,69 @@ def load_datasets(input_dir, gt_dir, batch_size, seq_length): num_datas = len(os.listdir(gt_dir)) for step in tqdm(range(num_datas)): - input_bin_path = os.path.join(input_dir, "input_ids", "{}.bin".format(step)) - input_ids_vec = np.fromfile(input_bin_path, dtype=np.int64).reshape((batch_size, seq_length)) + input_bin_path = os.path.join( + input_dir, "input_ids", "{}.bin".format(step)) + input_ids_vec = np.fromfile( + input_bin_path, dtype=np.int64).reshape((1, seq_length)) input_ids_vec = torch.tensor(input_ids_vec, dtype=torch.int64) input_ids_vec_list.append(input_ids_vec) - - mask_bin_path = os.path.join(input_dir, "attention_mask", "{}.bin".format(step)) - attention_mask = np.fromfile(mask_bin_path, dtype=np.int64).reshape((batch_size, seq_length)) + + mask_bin_path = os.path.join( + input_dir, "attention_mask", "{}.bin".format(step)) + attention_mask = np.fromfile( + mask_bin_path, dtype=np.int64).reshape((1, seq_length)) attention_mask = torch.tensor(attention_mask, dtype=torch.int64) attention_mask_list.append(attention_mask) - - token_type_path = os.path.join(input_dir, "token_type_ids", "{}.bin".format(step)) - token_type_ids = np.fromfile(token_type_path, dtype=np.int64).reshape((batch_size, seq_length)) + + token_type_path = os.path.join( + input_dir, "token_type_ids", "{}.bin".format(step)) + token_type_ids = np.fromfile( + token_type_path, dtype=np.int64).reshape((1, seq_length)) token_type_ids = torch.tensor(token_type_ids, dtype=torch.int64) token_type_ids_list.append(token_type_ids) label_path = os.path.join(gt_dir, "{}.bin".format(step)) - labels = load_data(label_path, [batch_size, seq_length]) + labels = load_data(label_path, [1, seq_length]) all_labels_list.append(labels) datasets = {'input_ids_vec_list': input_ids_vec_list, - 'attention_mask_list': attention_mask_list, - 'token_type_ids_list': token_type_ids_list, - 'all_labels_list': all_labels_list, - 'num_datas': num_datas} + 'attention_mask_list': attention_mask_list, + 'token_type_ids_list': token_type_ids_list, + 'all_labels_list': all_labels_list, + 'num_datas': num_datas} return datasets -def inference(torchaie_model, datasets): +def inference(torchaie_model, datasets, batch_size): input_ids_vec_list = datasets['input_ids_vec_list'] attention_mask_list = datasets['attention_mask_list'] token_type_ids_list = datasets['token_type_ids_list'] - num_datas = datasets['num_datas'] all_logits_list = [] inference_time = [] count = 0 - for step in tqdm(range(num_datas)): - input_ids_vec_npu = input_ids_vec_list[step].to('npu:0') - attention_mask_npu = attention_mask_list[step].to('npu:0') - token_type_ids_npu = token_type_ids_list[step].to('npu:0') - stream = torch_aie.npu.Stream('npu:0') + # concat bs + input_ids_merged_list = [torch.cat(input_ids_vec_list[i:i+batch_size]) + for i in range(0, len(input_ids_vec_list), batch_size)] + + attention_mask_merged_list = [torch.cat(attention_mask_list[i:i+batch_size]) + for i in range(0, len(attention_mask_list), batch_size)] + + token_type_ids_merged_list = [torch.cat(token_type_ids_list[i:i+batch_size]) + for i in range(0, len(token_type_ids_list), batch_size)] + + # inference + stream = torch_aie.npu.Stream('npu:0') + for step in tqdm(range(len(input_ids_merged_list))): + input_ids_vec_npu = input_ids_merged_list[step].to('npu:0') + attention_mask_npu = attention_mask_merged_list[step].to('npu:0') + token_type_ids_npu = token_type_ids_merged_list[step].to('npu:0') with torch_aie.npu.stream(stream): + stream.synchronize() inf_start = time.time() - logits = torchaie_model(input_ids_vec_npu, attention_mask_npu, token_type_ids_npu) + logits = torchaie_model( + input_ids_vec_npu, attention_mask_npu, token_type_ids_npu) stream.synchronize() inf_end = time.time() inf_time = inf_end - inf_start @@ -96,24 +115,22 @@ def inference(torchaie_model, datasets): if count > 5: inference_time.append(inf_time) logits = logits.to('cpu') - all_logits_list.append(logits) + for i in range(batch_size): + logits_bs1 = logits[i:i+1] + all_logits_list.append(logits_bs1) average_inference_time = np.mean(inference_time) - print('pure model inference performance=', average_inference_time * 1000, 'ms') + print('BatchSize=', batch_size) + print('pure model inference latency=', average_inference_time * 1000, 'ms') + print('QPS=', batch_size / average_inference_time) + return all_logits_list - def run(torchaie_model, input_dir, gt_dir, batch_size, seq_length): datasets = load_datasets(input_dir, gt_dir, batch_size, seq_length) - num_datas = datasets['num_datas'] + all_labels_list = datasets['all_labels_list'] - - start_time = time.time() - all_logits_list = inference(torchaie_model, datasets) - end_time = time.time() - total_time = end_time - start_time - average_time = total_time / num_datas - print('average e2e model inference performance=', average_time * 1000) + all_logits_list = inference(torchaie_model, datasets, batch_size) all_logits_array = np.concatenate(all_logits_list, axis=0) all_labels_array = np.concatenate(all_labels_list, axis=0) @@ -124,36 +141,44 @@ def run(torchaie_model, input_dir, gt_dir, batch_size, seq_length): if __name__ == '__main__': parser = argparse.ArgumentParser(description="Description of your script") parser.add_argument("--input_dir", type=str, default="./input_data") - parser.add_argument("--gt_dir", type=str, help="Path to the model directory", default="./input_data/labels") - parser.add_argument("--seq_length", type=int, help="Sequence length as an integer", default=384) + parser.add_argument( + "--gt_dir", type=str, help="Path to the model directory", default="./input_data/labels") + parser.add_argument("--seq_length", type=int, + help="Sequence length as an integer", default=384) + parser.add_argument("--batch_size", type=int, + help="batch size as an integer", default=1) parser.add_argument('--ts_model_path', type=str, default='bert_base_chinese.pt', - help='Original TorchScript model path') + help='Original TorchScript model path') args = parser.parse_args() - + input_dir = args.input_dir gt_dir = args.gt_dir seq_length = args.seq_length + batch_size = args.batch_size ts_model = torch.jit.load(args.ts_model_path) - + torch_aie.set_device(0) - input_info = [torch_aie.Input((1, 384), dtype=torch.int64), - torch_aie.Input((1, 384), dtype=torch.int64), - torch_aie.Input((1, 384), dtype=torch.int64)] - - torchaie_model = torch_aie.compile( - ts_model, - inputs=input_info, - precision_policy=torch_aie.PrecisionPolicy.PREF_FP32, - truncate_long_and_double=True, - require_full_compilation=True, - allow_tensor_replace_int=True, - torch_executed_ops=[], - soc_version="Ascend310P3", - optimization_level=0, - ) - - torchaie_model.eval() - - batch_size = 1 + input_info = [torch_aie.Input((batch_size, seq_length), dtype=torch.int64), + torch_aie.Input((batch_size, seq_length), dtype=torch.int64), + torch_aie.Input((batch_size, seq_length), dtype=torch.int64)] + + modelpath = 'bertbase_aie_bs' + str(batch_size) + ".pt" + if os.path.exists(modelpath): + torchaie_model = torch.jit.load(modelpath).eval() + else: + torchaie_model = torch_aie.compile( + ts_model, + inputs=input_info, + precision_policy=torch_aie.PrecisionPolicy.PREF_FP32, + truncate_long_and_double=True, + require_full_compilation=True, + allow_tensor_replace_int=True, + torch_executed_ops=[], + soc_version="Ascend310P3", + optimization_level=0, + ) + torchaie_model.eval() + torchaie_model.save(modelpath) + run(torchaie_model, input_dir, gt_dir, batch_size, seq_length) -- Gitee From ded3a088b0af45e2fafd70867f57de57c5b40b14 Mon Sep 17 00:00:00 2001 From: guohuanliang Date: Fri, 15 Dec 2023 15:18:47 +0800 Subject: [PATCH 2/2] improve bertbase chinese --- .../nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py index 8083b8ca03..1561807dbf 100644 --- a/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py +++ b/AscendIE/TorchAIE/built-in/nlp/Bert_Base_Chinese_for_Pytorch/export_trace_model.py @@ -32,7 +32,7 @@ def export_ts(model_dir, save_path, seq_len=384, batch_size=1): tokenizer = build_tokenizer(tokenizer_name=model_dir) # build model - model_path = os.path.join(model_dir, "bert-base-chinese") + model_path = os.path.join(model_dir, "pytorch_model.bin") config_path = os.path.join(model_dir, "config.json") model = RefineModel(tokenizer, model_path, config_path) -- Gitee