提交 347bf904 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: bias is also handled

上级 bf95ac36
......@@ -49,6 +49,7 @@ struct Pattern : public PatternBase {
struct Conv {
std::string op_name() const { return "conv2d"; }
std::string input_name() const { return "Input"; }
std::string bias_name() const { return "Bias"; }
std::string filter_name() const { return "Filter"; }
std::string residual_data_name() const { return "ResidualData"; }
std::string output_name() const { return "Output"; }
......@@ -60,13 +61,16 @@ struct Conv {
auto input_var = pattern->new_node(input_name())
->assert_is_op_input(op_name(), input_name());
auto bias_var = pattern->new_node(bias_name())
->assert_is_op_input(op_name(), bias_name());
auto filter_var = pattern->new_node(filter_name())
->assert_is_op_input(op_name(), filter_name());
auto output_var = pattern->new_node(output_name())
->assert_is_op_output(op_name(), output_name());
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksFrom({input_var, bias_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
......@@ -178,13 +182,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate();
auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input,
auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input, Node* conv_bias,
Node* conv_filter, Node* conv_output,
Node* elementwise_add_x) {
OpDesc op_desc;
op_desc.SetType(conv_pattern.op_name());
op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()});
op_desc.SetInput(conv_pattern.bias_name(), {conv_bias->Name()});
op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()});
op_desc.SetInput(conv_pattern.residual_data_name(),
{elementwise_add_x->Name()});
......@@ -196,6 +201,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
auto fused_conv_op = g->CreateOpNode(&op_desc);
patterns::LinkNodes(conv_input, fused_conv_op);
patterns::LinkNodes(conv_bias, fused_conv_op);
patterns::LinkNodes(conv_filter, fused_conv_op);
patterns::LinkNodes(elementwise_add_x, fused_conv_op);
patterns::LinkNodes(fused_conv_op, conv_output);
......@@ -208,6 +214,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_pattern.op_name());
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.input_name());
auto conv_bias = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.bias_name());
auto conv_filter = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_output = patterns::GetNodeFromSubgraph(
......@@ -220,7 +228,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
auto elementwise_add_out = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x);
fuse_conv(g, conv_input, conv_bias, conv_filter, conv_output,
elementwise_add_x);
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
};
......
......@@ -34,7 +34,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
if (type == "conv2d") {
op->SetAttr("use_mkldnn", true);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[1]});
op->SetInput("Filter", {inputs[2]});
op->SetOutput("Output", outputs);
} else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
......@@ -98,8 +99,8 @@ struct IsReachable {
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "weights", "c", "d", "e"})) {
for (auto& v : std::vector<std::string>(
{"a", "b", "bias", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights") {
......@@ -107,7 +108,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
}
}
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"});
......@@ -150,7 +151,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "weights"})) {
for (auto& v : std::vector<std::string>({"a", "b", "bias", "weights"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
......@@ -158,7 +159,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
}
}
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
return prog;
......@@ -199,8 +200,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f"})) {
for (auto& v : std::vector<std::string>(
{"a", "b", "bias", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v.find("weights")) {
......@@ -209,7 +210,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
}
SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册