未验证 提交 945e0847 编写于 作者: 王明冬 提交者: GitHub

enhance compatiable condition for fc fuse pass. test=develop (#33452)

上级 003b4616
......@@ -14,7 +14,6 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -38,6 +37,7 @@ FCFusePass::FCFusePass() {
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
......@@ -51,6 +51,7 @@ FCFusePass::FCFusePass() {
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(1)
.End();
AddOpCompat(OpCompat("relu"))
......
......@@ -58,12 +58,12 @@ TEST(FCFusePass, basic) {
auto* weights_0 = layers.data("weights_0", {}, true);
auto* mul_out_0 = layers.mul(relu_out_0, weights_0);
auto* bias_1 = layers.data("bias_1", {}, true);
auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1);
auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1);
auto* relu_out_1 = layers.relu(add_out_0);
auto* weights_1 = layers.data("weights_1", {}, true);
auto* mul_out_1 = layers.mul(relu_out_1, weights_1);
auto* bias_2 = layers.data("bias_2", {}, true);
auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2);
auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1);
VLOG(4) << add_out_1;
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
......
......@@ -194,14 +194,18 @@ struct Layers {
}
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
int x_num_col_dims = 1) {
int x_num_col_dims = 1, int y_num_col_dims = 1) {
AttributeMap attrs;
attrs["x_num_col_dims"] = 1;
attrs["x_num_col_dims"] = x_num_col_dims;
attrs["y_num_col_dims"] = y_num_col_dims;
return binary_op("mul", x, y, out, &attrs);
}
VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
return binary_op("elementwise_add", x, y, out);
VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
int axis = -1) {
AttributeMap attrs;
attrs["axis"] = axis;
return binary_op("elementwise_add", x, y, out, &attrs);
}
VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册