From eb5291285f98125b73bb06bb00ea8c70c3ceba05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Fri, 25 Jun 2021 14:04:40 +0800 Subject: [PATCH] fix the attributes error in transpose.pbtxt,test=develop. (#33770) --- .../ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc | 8 +++++--- .../mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc | 1 - paddle/fluid/operators/compat/transpose.pbtxt | 4 ++-- paddle/fluid/operators/compat/transpose2.pbtxt | 4 ++-- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc index 1f17a741f19..e5bdb08fe4a 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -34,10 +34,13 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { .IsTensor() .End() .AddAttr("alpha") // unconstrained. can be any float value. + .IsType() .End() .AddAttr("transpose_X") // unconstrained. can be any bool value. + .IsType() .End() .AddAttr("transpose_Y") // unconstrained. can be any bool value. + .IsType() .End(); AddOpCompat(OpCompat("transpose2")) @@ -51,9 +54,7 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { .IsTensor() .End() .AddAttr("axis") // ints - .End() - .AddAttr("data_format") - .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) + .IsType>() .End(); AddOpCompat(OpCompat("reshape2")) @@ -75,6 +76,7 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { .IsTensor() .End() .AddAttr("shape") // ints + .IsType>() .End(); } void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc index ac4e6c383da..d98d640e100 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc @@ -28,7 +28,6 @@ void SetOp(ProgramDesc *prog, const std::string &type, op->SetOutput("Out", {outputs[0]}); if (type == "transpose2") { op->SetAttr("axis", std::vector({0, 2, 1, 3})); - op->SetAttr("data_format", std::string("NCHW")); op->SetOutput("XShape", {outputs[1]}); } if (type == "reshape2") { diff --git a/paddle/fluid/operators/compat/transpose.pbtxt b/paddle/fluid/operators/compat/transpose.pbtxt index 97081e0afc2..1cd04a4da4a 100644 --- a/paddle/fluid/operators/compat/transpose.pbtxt +++ b/paddle/fluid/operators/compat/transpose.pbtxt @@ -10,12 +10,12 @@ def { name: "axis" type: INTS } +} +extra { attrs { name: "data_format" type: STRING } -} -extra { attrs { name: "use_mkldnn" type: BOOLEAN diff --git a/paddle/fluid/operators/compat/transpose2.pbtxt b/paddle/fluid/operators/compat/transpose2.pbtxt index 19d991a6414..31aecd24bc9 100644 --- a/paddle/fluid/operators/compat/transpose2.pbtxt +++ b/paddle/fluid/operators/compat/transpose2.pbtxt @@ -13,12 +13,12 @@ def { name: "axis" type: INTS } +} +extra { attrs { name: "data_format" type: STRING } -} -extra { attrs { name: "use_mkldnn" type: BOOLEAN -- GitLab