diff --git a/ACL_PyTorch/built-in/nlp/Uie_for_Pytorch/fix_onnx.py b/ACL_PyTorch/built-in/nlp/Uie_for_Pytorch/fix_onnx.py index e7e5292219effb4bad8e33997620c606156b32d7..356216382b7a71c870222eaea9597aa7ac659208 100644 --- a/ACL_PyTorch/built-in/nlp/Uie_for_Pytorch/fix_onnx.py +++ b/ACL_PyTorch/built-in/nlp/Uie_for_Pytorch/fix_onnx.py @@ -104,7 +104,7 @@ if __name__ == '__main__': seq_len_ = sys.argv[4] onnx_graph = OnnxGraph.parse(input_path) fix_mul(onnx_graph) - fix_add_shape(onnx_graph, bs_) + fix_add_shape(onnx_graph) fix_transpose(onnx_graph) fix_reshape(onnx_graph, bs_, seq_len_) onnx_graph.save(save_path)