diff --git a/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/ReadMe.md b/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/ReadMe.md index 18ca26c9de04c940e0f11e288b5bcc4b47aa1f5a..865a190861676cb9ce559a47d9002bbe6ad2dd58 100644 --- a/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/ReadMe.md +++ b/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/ReadMe.md @@ -158,11 +158,11 @@ ![img](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/turing/resourcecenter/img/public_sys-resources/notice_3.0-zh-cn.png) - + 使用ATC工具将.onnx文件转换为.om文件,需要.onnx算子版本需为11。在bert_base_pth2onnx.py脚本中torch.onnx.export方法中的输入参数opset_version的值需为11,请勿修改。 - - 3. 此步可选,根据onnx图里是否存在(0,2,3,1)的transpose进行优化,若存在,根据这类Transpose和紧跟它的MatMul的name更新add_attr_transB.py,也就是更新里面的transpose_nodes和bmm_nodes两个list,然后运行下面命令。 - + + 3. 此步可选,根据onnx图里是否存在(0,2,3,1)的transpose进行优化,若存在,运行下面命令。 + ``` python3 add_attr_trans_b.py bert_base_batch_8.onnx bert_base_batch_8.onnx ``` @@ -180,7 +180,7 @@ export ASCEND_OPP_PATH=${install_path}/opp # 使用二进制输入时,执行如下命令 - atc --input_format=ND --framework=5 --model=bert_base_batch_8.onnx --input_shape="input_ids:8,512;token_type_ids:8,512;attention_mask:8,512" --output=bert_base_batch_8_auto --log=info --soc_version=Ascend710 + atc --input_format=ND --framework=5 --model=bert_base_batch_8.onnx --input_shape="input_ids:8,512;token_type_ids:8,512;attention_mask:8,512" --output=bert_base_batch_8_auto --log=info --soc_version=Ascend710 --optypelist_for_implmode="Gelu" --op_select_implmode=high_performance --input_fp16_nodes="attention_mask" ``` ![img](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/turing/resourcecenter/img/public_sys-resources/note_3.0-zh-cn.png) diff --git a/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/add_attr_trans_b.py b/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/add_attr_trans_b.py index 45b2063c18bf1bd62bb74e81cd27d73efbe21fbb..b2b94bbb35b75f6a311e7567bd0e0dc9c70a5e97 100644 --- a/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/add_attr_trans_b.py +++ b/ACL_PyTorch/built-in/nlp/Bert_Base_Uncased_for_Pytorch/add_attr_trans_b.py @@ -16,43 +16,22 @@ from gener_core.mod_modify.onnx_graph import OXGraph from gener_core.mod_modify.interface import AttrType as AT +FUS_NODE_TRANS = "Transpose" +FUS_NODE_BMM = "MatMul" input_model = sys.argv[1] output_model = sys.argv[2] mod = OXGraph(input_model) - -transpose_nodes = ['Transpose_60', - 'Transpose_154', - 'Transpose_248', - 'Transpose_342', - 'Transpose_436', - 'Transpose_530', - 'Transpose_624', - 'Transpose_718', - 'Transpose_812', - 'Transpose_906', - 'Transpose_1000', - 'Transpose_1094'] -bmm_nodes = ['MatMul_72', - 'MatMul_166', - 'MatMul_260', - 'MatMul_354', - 'MatMul_448', - 'MatMul_542', - 'MatMul_636', - 'MatMul_730', - 'MatMul_824', - 'MatMul_918', - 'MatMul_1012', - 'MatMul_1106'] io_map = mod.get_net_in_out_map() +trans_nodes = mod.get_nodes_by_optype(FUS_NODE_TRANS) -for transpose_node in transpose_nodes: - now_trans = mod.get_node(transpose_node) - now_trans.set_attr({"perm": (AT.LIST_INT, [0, 2, 1, 3])}) -for bmm in bmm_nodes: - now_bmm = mod.get_node(bmm) - now_bmm.set_attr({"transB": (AT.INT, 1)}) +for trans_node in trans_nodes: + if trans_node.get_attr("perm", AT.LIST_INT) == [0, 2, 3, 1]: + trans_node.set_attr({"perm": (AT.LIST_INT, [0, 2, 1, 3])}) + bmm = io_map.get(trans_node.name) + if FUS_NODE_BMM in bmm[0]: + new_bmm = mod.get_node(bmm[0]) + new_bmm.set_attr({"transB": (AT.INT, 1)}) mod.save_new_model(output_model) print("OK") \ No newline at end of file