diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 656d453d4030439f0229492a7c2ab2ee46481950..0bb2782b3737ee3130e2d7bee68fd932c3b87932 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -14,7 +14,6 @@ #include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include -#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -38,6 +37,7 @@ FCFusePass::FCFusePass() { .IsNumGE(1) .End() .AddAttr("y_num_col_dims") + .IsNumEQ(1) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -51,6 +51,7 @@ FCFusePass::FCFusePass() { .IsTensor() .End() .AddAttr("axis") + .IsNumGE(1) .End(); AddOpCompat(OpCompat("relu")) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index cf35c1ac772da079159cb4ced2edc234d7325b1e..5046911036818c902844a35220101836b6404478 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -58,12 +58,12 @@ TEST(FCFusePass, basic) { auto* weights_0 = layers.data("weights_0", {}, true); auto* mul_out_0 = layers.mul(relu_out_0, weights_0); auto* bias_1 = layers.data("bias_1", {}, true); - auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1); + auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1); auto* relu_out_1 = layers.relu(add_out_0); auto* weights_1 = layers.data("weights_1", {}, true); auto* mul_out_1 = layers.mul(relu_out_1, weights_1); auto* bias_2 = layers.data("bias_2", {}, true); - auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2); + auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1); VLOG(4) << add_out_1; std::unique_ptr graph(new ir::Graph(layers.main_program())); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 6b187e538d1c082dec47144ed144a746794767b9..850d3dca6d0e10dd2f93a2149bef268042de339b 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -194,14 +194,18 @@ struct Layers { } VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, - int x_num_col_dims = 1) { + int x_num_col_dims = 1, int y_num_col_dims = 1) { AttributeMap attrs; - attrs["x_num_col_dims"] = 1; + attrs["x_num_col_dims"] = x_num_col_dims; + attrs["y_num_col_dims"] = y_num_col_dims; return binary_op("mul", x, y, out, &attrs); } - VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { - return binary_op("elementwise_add", x, y, out); + VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, + int axis = -1) { + AttributeMap attrs; + attrs["axis"] = axis; + return binary_op("elementwise_add", x, y, out, &attrs); } VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,