diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc index 1f17a741f190941d352e9ad6346dfdbeca671b50..e5bdb08fe4ab4825aef1d3d3ccd7d3a7f352574e 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -34,10 +34,13 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { .IsTensor() .End() .AddAttr("alpha") // unconstrained. can be any float value. + .IsType() .End() .AddAttr("transpose_X") // unconstrained. can be any bool value. + .IsType() .End() .AddAttr("transpose_Y") // unconstrained. can be any bool value. + .IsType() .End(); AddOpCompat(OpCompat("transpose2")) @@ -51,9 +54,7 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { .IsTensor() .End() .AddAttr("axis") // ints - .End() - .AddAttr("data_format") - .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) + .IsType>() .End(); AddOpCompat(OpCompat("reshape2")) @@ -75,6 +76,7 @@ MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() { .IsTensor() .End() .AddAttr("shape") // ints + .IsType>() .End(); } void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const { diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc index ac4e6c383dad9d5cc11e5bbce5f24093f9d60d24..d98d640e1002b1ff97e9d03a44a866987e3a2af8 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc @@ -28,7 +28,6 @@ void SetOp(ProgramDesc *prog, const std::string &type, op->SetOutput("Out", {outputs[0]}); if (type == "transpose2") { op->SetAttr("axis", std::vector({0, 2, 1, 3})); - op->SetAttr("data_format", std::string("NCHW")); op->SetOutput("XShape", {outputs[1]}); } if (type == "reshape2") { diff --git a/paddle/fluid/operators/compat/transpose.pbtxt b/paddle/fluid/operators/compat/transpose.pbtxt index 97081e0afc29a823d2dc95f2be31311020da8203..1cd04a4da4a174808f81f3b1d5c4f6093b5126ee 100644 --- a/paddle/fluid/operators/compat/transpose.pbtxt +++ b/paddle/fluid/operators/compat/transpose.pbtxt @@ -10,12 +10,12 @@ def { name: "axis" type: INTS } +} +extra { attrs { name: "data_format" type: STRING } -} -extra { attrs { name: "use_mkldnn" type: BOOLEAN diff --git a/paddle/fluid/operators/compat/transpose2.pbtxt b/paddle/fluid/operators/compat/transpose2.pbtxt index 19d991a6414d131c4833d5b919e9372b38168864..31aecd24bc911b446b43f351885549be9d84533a 100644 --- a/paddle/fluid/operators/compat/transpose2.pbtxt +++ b/paddle/fluid/operators/compat/transpose2.pbtxt @@ -13,12 +13,12 @@ def { name: "axis" type: INTS } +} +extra { attrs { name: "data_format" type: STRING } -} -extra { attrs { name: "use_mkldnn" type: BOOLEAN