diff --git a/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/bert_base_get_info.py b/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/bert_base_get_info.py index ce1b597ce2293a39d5e12d66838f7571e6f648a6..3d9146da31a0d9c87be5c6bbb7a4d3478ddb75f9 100644 --- a/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/bert_base_get_info.py +++ b/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/bert_base_get_info.py @@ -13,12 +13,20 @@ # limitations under the License. # ============================================================================ import os +import argparse + -test_num = 10649 base_path = './bert_bin/' +# three inputs in each eval +test_num = len(os.listdir(base_path)) // 3 +parser = argparse.ArgumentParser(description='manual to this script') +parser.add_argument('--batchsize', type=int, default=8) +args = parser.parse_args() +batchsize = args.batchsize +real_test = test_num // batchsize * batchsize with open('./bert_base_uncased.info', 'w') as f: - for i in range(test_num): + for i in range(real_test): ids_name = base_path + 'input_ids_{}.bin'.format(i) segment_name = base_path + 'segment_ids_{}.bin'.format(i) mask_name = base_path + 'input_mask_{}.bin'.format(i) diff --git a/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/bert_postprocess_data.py b/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/bert_postprocess_data.py index 11e0ada13c7998fa4fa3ed11dffa9be3d8c54b3b..ebf52e30e9e945dcd24efda6565854009114287c 100644 --- a/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/bert_postprocess_data.py +++ b/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/bert_postprocess_data.py @@ -395,8 +395,8 @@ def main(): required=True, help="NPU benchmark infer result path") args = parser.parse_args() - test_num = 10649 npu_path = args.npu_result + test_num = len(os.listdir(npu_path)) // 2 tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large eval_examples = read_squad_examples( input_file=args.predict_file, is_training=False) @@ -432,7 +432,6 @@ def main(): all_results.append(RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) - print(" [INFO] i == ", i) time_to_infer = time.time() - infer_start output_prediction_file = os.path.join("./", "predictions.json") output_nbest_file = os.path.join("./", "nbest_predictions.json") @@ -442,7 +441,8 @@ def main(): f.write(json.dumps(answers, indent=4) + "\n") with open(output_nbest_file, "w") as f: f.write(json.dumps(nbest_answers, indent=4) + "\n") - + print("Completed") + if __name__ == '__main__': main() \ No newline at end of file