From 945e0847bc4b2588aef4b8813856f883028e5502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Thu, 10 Jun 2021 16:48:31 +0800 Subject: [PATCH] enhance compatiable condition for fc fuse pass. test=develop (#33452) --- paddle/fluid/framework/ir/fc_fuse_pass.cc | 3 ++- paddle/fluid/framework/ir/fc_fuse_pass_tester.cc | 4 ++-- paddle/fluid/framework/ir/pass_tester_helper.h | 12 ++++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 656d453d403..0bb2782b373 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -14,7 +14,6 @@ #include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include -#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -38,6 +37,7 @@ FCFusePass::FCFusePass() { .IsNumGE(1) .End() .AddAttr("y_num_col_dims") + .IsNumEQ(1) .End(); AddOpCompat(OpCompat("elementwise_add")) @@ -51,6 +51,7 @@ FCFusePass::FCFusePass() { .IsTensor() .End() .AddAttr("axis") + .IsNumGE(1) .End(); AddOpCompat(OpCompat("relu")) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index cf35c1ac772..50469110368 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -58,12 +58,12 @@ TEST(FCFusePass, basic) { auto* weights_0 = layers.data("weights_0", {}, true); auto* mul_out_0 = layers.mul(relu_out_0, weights_0); auto* bias_1 = layers.data("bias_1", {}, true); - auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1); + auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1); auto* relu_out_1 = layers.relu(add_out_0); auto* weights_1 = layers.data("weights_1", {}, true); auto* mul_out_1 = layers.mul(relu_out_1, weights_1); auto* bias_2 = layers.data("bias_2", {}, true); - auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2); + auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1); VLOG(4) << add_out_1; std::unique_ptr graph(new ir::Graph(layers.main_program())); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 6b187e538d1..850d3dca6d0 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -194,14 +194,18 @@ struct Layers { } VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, - int x_num_col_dims = 1) { + int x_num_col_dims = 1, int y_num_col_dims = 1) { AttributeMap attrs; - attrs["x_num_col_dims"] = 1; + attrs["x_num_col_dims"] = x_num_col_dims; + attrs["y_num_col_dims"] = y_num_col_dims; return binary_op("mul", x, y, out, &attrs); } - VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { - return binary_op("elementwise_add", x, y, out); + VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, + int axis = -1) { + AttributeMap attrs; + attrs["axis"] = axis; + return binary_op("elementwise_add", x, y, out, &attrs); } VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, -- GitLab