From 14c4a08a8cd5b294aab8f2003aa98a2c6f45fe7a Mon Sep 17 00:00:00 2001 From: alanzhang Date: Thu, 12 Dec 2024 11:49:14 +0800 Subject: [PATCH] =?UTF-8?q?BERT=E6=A8=A1=E5=9E=8B=E6=8E=A8=E7=90=86?= =?UTF-8?q?=E9=83=A8=E7=BD=B2part3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/nlp/BERT/run_aie_eval.py | 325 ++++++++++++++++++ 1 file changed, 325 insertions(+) create mode 100644 MindIE/MindIE-Torch/built-in/nlp/BERT/run_aie_eval.py diff --git a/MindIE/MindIE-Torch/built-in/nlp/BERT/run_aie_eval.py b/MindIE/MindIE-Torch/built-in/nlp/BERT/run_aie_eval.py new file mode 100644 index 0000000000..65f565c99a --- /dev/null +++ b/MindIE/MindIE-Torch/built-in/nlp/BERT/run_aie_eval.py @@ -0,0 +1,325 @@ +# coding=utf-8 +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved. +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run BERT compiled with mindietorch on SQuAD.""" + +from __future__ import absolute_import, division, print_function + +import argparse +import collections +import json +import logging +import math +import os +import random +import sys +import time +import subprocess +from io import open + +import torch +import mindietorch +from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, + TensorDataset) +from tqdm import tqdm, trange +from run_squad import convert_examples_to_features, get_answers + +sys.path.append("./DeepLearningExamples/PyTorch/LanguageModeling/BERT/") +from tokenization import (BasicTokenizer, BertTokenizer, whitespace_tokenize) + + +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) +logger = logging.getLogger(__name__) + + +class SquadExample(object): + """ + A single training/test example for the Squad dataset. + For examples without an answer, the start and end position are -1. + """ + def __init__(self, + qas_id, + question_text, + doc_tokens, + orig_answer_text=None, + start_position=None, + end_position=None, + is_impossible=None): + self.qas_id = qas_id + self.question_text = question_text + self.doc_tokens = doc_tokens + self.orig_answer_text = orig_answer_text + self.start_position = start_position + self.end_position = end_position + self.is_impossible = is_impossible + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = "" + s += "qas_id: %s" % (self.qas_id) + s += ", question_text: %s" % ( + self.question_text) + s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) + if self.start_position: + s += ", start_position: %d" % (self.start_position) + if self.end_position: + s += ", end_position: %d" % (self.end_position) + if self.is_impossible: + s += ", is_impossible: %r" % (self.is_impossible) + return s + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + unique_id, + example_index, + doc_span_index, + tokens, + token_to_orig_map, + token_is_max_context, + input_ids, + input_mask, + segment_ids, + start_position=None, + end_position=None, + is_impossible=None): + self.unique_id = unique_id + self.example_index = example_index + self.doc_span_index = doc_span_index + self.tokens = tokens + self.token_to_orig_map = token_to_orig_map + self.token_is_max_context = token_is_max_context + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + self.start_position = start_position + self.end_position = end_position + self.is_impossible = is_impossible + + +def read_squad_examples(input_file, is_training): + """Read a SQuAD json file into a list of SquadExample.""" + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader)["data"] + + def is_whitespace(c): + whitespace_chars = {" ", "\t", "\r", "\n"} + return c in whitespace_chars or ord(c) == 0x202F + + examples = [] + for entry in input_data: + for paragraph in entry["paragraphs"]: + paragraph_text = paragraph["context"] + doc_tokens = [] + char_to_word_offset = [] + prev_is_whitespace = True + for c in paragraph_text: + if is_whitespace(c): + prev_is_whitespace = True + else: + if prev_is_whitespace: + doc_tokens.append(c) + else: + doc_tokens[-1] += c + prev_is_whitespace = False + char_to_word_offset.append(len(doc_tokens) - 1) + + for qa in paragraph["qas"]: + qas_id = qa["id"] + question_text = qa["question"] + start_position = None + end_position = None + orig_answer_text = None + is_impossible = False + example = SquadExample( + qas_id=qas_id, + question_text=question_text, + doc_tokens=doc_tokens, + orig_answer_text=orig_answer_text, + start_position=start_position, + end_position=end_position, + is_impossible=is_impossible) + examples.append(example) + return examples + + +RawResult = collections.namedtuple("RawResult", + ["unique_id", "start_logits", "end_logits"]) + + +def main(): + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument("--aie_model", default=None, type=str, required=True, + help="Path to bert-base model compiled with mindietorch. ") + parser.add_argument("--predict_file", default=None, type=str, required=True, + help="SQuAD json for predictions and evaluation. E.g., dev-v1.1.json or test-v1.1.json") + parser.add_argument('--vocab_file', + type=str, default=None, required=True, + help="Vocabulary mapping/file BERT was pretrainined on") + parser.add_argument("--predict_batch_size", default=8, type=int, required=True, + help="Batch size for predictions, must match the compile spec of the mindietorch model.") + + ## Other parameters + parser.add_argument("--output_dir", default="./output_predictions", type=str, + help="The output directory where the predictions will be written.") + parser.add_argument("--max_seq_length", default=512, type=int, + help="The maximum total input sequence length after WordPiece tokenization. Sequences " + "longer than this will be truncated, and sequences shorter than this will be padded.") + parser.add_argument("--doc_stride", default=128, type=int, + help="When splitting up a long document into chunks, how much stride to take between chunks.") + parser.add_argument("--max_query_length", default=64, type=int, + help="The maximum number of tokens for the question. Questions longer than this will " + "be truncated to this length.") + parser.add_argument("--n_best_size", default=20, type=int, + help="The total number of n-best predictions to generate in the nbest_predictions.json " + "output file.") + parser.add_argument("--max_answer_length", default=30, type=int, + help="The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another.") + parser.add_argument("--verbose_logging", action='store_true', + help="If true, all of the warnings related to data processing will be printed. " + "A number of warnings are expected for a normal SQuAD evaluation.") + parser.add_argument('--seed', + type=int, + default=42, + help="random seed for initialization") + parser.add_argument("--do_lower_case", + action='store_true', + help="Whether to lower case the input text. True for uncased models, False for cased models.") + parser.add_argument('--version_2_with_negative', + action='store_true', + help='If true, the SQuAD examples contain some that do not have an answer.') + parser.add_argument("--eval_script", + help="Script to evaluate squad predictions", + default="./evaluate_data.py", + type=str) + args = parser.parse_args() + mindietorch.set_device(0) + logger.info("PARAMETERs are %s", [str(args)]) + + random.seed(args.seed) + torch.manual_seed(args.seed) + + if not args.predict_file: + raise ValueError( + "The `predict_file` must be specified.") + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large + + eval_examples = read_squad_examples( + input_file=args.predict_file, is_training=False) + eval_features = convert_examples_to_features( + examples=eval_examples, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + doc_stride=args.doc_stride, + max_query_length=args.max_query_length, + is_training=False) + + all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.int32) + all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.int32) + all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.int32) + all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.int32) + eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) + # Run prediction for full data + eval_sampler = SequentialSampler(eval_data) + eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size, drop_last=True) + logger.info("Prepare dataset for evaluation done.") + + model = torch.jit.load(args.aie_model) + model.eval() + inf_stream = mindietorch.npu.Stream("npu:0") + fps = [] + all_results = [] + logger.info("Prepare aie model for evaluation done, ready to predict.") + for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"): + input_ids = input_ids.to("npu:0") + input_mask = input_mask.to("npu:0") + segment_ids = segment_ids.to("npu:0") + + with torch.no_grad(): + inf_s = time.time() + with mindietorch.npu.stream(inf_stream): + outputs = model(input_ids, segment_ids, input_mask) + inf_stream.synchronize() + inf_e = time.time() + fps.append(inf_e - inf_s) + # print(len(fps)) + + batch_start_logits, batch_end_logits = outputs + batch_start_logits = batch_start_logits.to("cpu") + batch_end_logits = batch_end_logits.to("cpu") + for i, example_index in enumerate(example_indices): + start_logits = batch_start_logits[i].detach().cpu().tolist() + end_logits = batch_end_logits[i].detach().cpu().tolist() + eval_feature = eval_features[example_index.item()] + unique_id = int(eval_feature.unique_id) + all_results.append(RawResult(unique_id=unique_id, + start_logits=start_logits, + end_logits=end_logits)) + + logger.info("Predict on the dataset finished.") + output_prediction_file = os.path.join(args.output_dir, "predictions.json") + output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") + + answers, nbest_answers = get_answers(eval_examples, eval_features, all_results, args) + with open(output_prediction_file, "w") as f: + f.write(json.dumps(answers, indent=4) + "\n") + with open(output_nbest_file, "w") as f: + f.write(json.dumps(nbest_answers, indent=4) + "\n") + logger.info("Predictions are written to json file.") + + logger.info("Calculating the f1 score...") + eval_out = subprocess.check_output([sys.executable, args.eval_script, args.predict_file, os.path.join(args.output_dir, "predictions.json")]) + scores = str(eval_out).strip() + exact_match = float(scores.split(":")[1].split(",")[0]) + f1 = float(scores.split(":")[2].split("}")[0]) + logger.info("The predictions have %s exact match and the f1 score on squadv1.1 is %s", exact_match, f1) + + avg_inf_time = sum(fps[5:]) / len(fps[5:]) + throughput = args.predict_batch_size / avg_inf_time + logger.info("The average inference time is %s and the throughput is %s", avg_inf_time, throughput) + + +if __name__ == "__main__": + main() -- Gitee