未验证 提交 945e0847 编写于 作者: 王明冬 提交者: GitHub

enhance compatiable condition for fc fuse pass. test=develop (#33452)

上级 003b4616
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -38,6 +37,7 @@ FCFusePass::FCFusePass() { ...@@ -38,6 +37,7 @@ FCFusePass::FCFusePass() {
.IsNumGE(1) .IsNumGE(1)
.End() .End()
.AddAttr("y_num_col_dims") .AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
...@@ -51,6 +51,7 @@ FCFusePass::FCFusePass() { ...@@ -51,6 +51,7 @@ FCFusePass::FCFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("axis") .AddAttr("axis")
.IsNumGE(1)
.End(); .End();
AddOpCompat(OpCompat("relu")) AddOpCompat(OpCompat("relu"))
......
...@@ -58,12 +58,12 @@ TEST(FCFusePass, basic) { ...@@ -58,12 +58,12 @@ TEST(FCFusePass, basic) {
auto* weights_0 = layers.data("weights_0", {}, true); auto* weights_0 = layers.data("weights_0", {}, true);
auto* mul_out_0 = layers.mul(relu_out_0, weights_0); auto* mul_out_0 = layers.mul(relu_out_0, weights_0);
auto* bias_1 = layers.data("bias_1", {}, true); auto* bias_1 = layers.data("bias_1", {}, true);
auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1); auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1);
auto* relu_out_1 = layers.relu(add_out_0); auto* relu_out_1 = layers.relu(add_out_0);
auto* weights_1 = layers.data("weights_1", {}, true); auto* weights_1 = layers.data("weights_1", {}, true);
auto* mul_out_1 = layers.mul(relu_out_1, weights_1); auto* mul_out_1 = layers.mul(relu_out_1, weights_1);
auto* bias_2 = layers.data("bias_2", {}, true); auto* bias_2 = layers.data("bias_2", {}, true);
auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2); auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1);
VLOG(4) << add_out_1; VLOG(4) << add_out_1;
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
......
...@@ -194,14 +194,18 @@ struct Layers { ...@@ -194,14 +194,18 @@ struct Layers {
} }
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
int x_num_col_dims = 1) { int x_num_col_dims = 1, int y_num_col_dims = 1) {
AttributeMap attrs; AttributeMap attrs;
attrs["x_num_col_dims"] = 1; attrs["x_num_col_dims"] = x_num_col_dims;
attrs["y_num_col_dims"] = y_num_col_dims;
return binary_op("mul", x, y, out, &attrs); return binary_op("mul", x, y, out, &attrs);
} }
VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
return binary_op("elementwise_add", x, y, out); int axis = -1) {
AttributeMap attrs;
attrs["axis"] = axis;
return binary_op("elementwise_add", x, y, out, &attrs);
} }
VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册