diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6a67ad177d33bf925e64d1f4d91835f2a99d919c..929a38857345c87cddc676e72956f5d223b15ab5 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -56,6 +56,5 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) if (WITH_MKLDNN) - cc_test(test_conv_bias_mkldnn_fuse_pass SRCS conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) endif () diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc index d0bd09a4f69acac35cf8917d16f1391647723812..ebb217a70bd0fe69acf89cc4b58ae3b668ec0c0f 100644 --- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc @@ -11,24 +11,48 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h" +#include #include #include +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/enforce.h" + namespace paddle { namespace framework { namespace ir { + +template +LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, + BinaryOperation f) { + PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims()); + LoDTensor vec_y; + vec_y.Resize(vec_a.dims()); + const float* a = vec_a.data(); + const float* b = vec_b.data(); + float* y = vec_y.mutable_data(platform::CPUPlace()); + for (int i = 0; i < vec_a.numel(); i++) { + y[i] = f(a[i], b[i]); + } + return vec_y; +} + std::unique_ptr ConvBiasFusePass::ApplyImpl( std::unique_ptr graph) const { PADDLE_ENFORCE(graph.get()); - FusePassBase::Init("conv_bias_mkldnn_fuse", graph.get()); + FusePassBase::Init(name_scope_, graph.get()); + + auto* scope = param_scope(); + PADDLE_ENFORCE(scope); + GraphPatternDetector gpd; - auto* conv_input = gpd.mutable_pattern() - ->NewNode("conv_bias_mkldnn_fuse/conv_input") - ->AsInput() - ->assert_is_op_input("conv2d", "Input"); - patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), - "conv_bias_mkldnn_fuse"); + auto* conv_input = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); conv_bias_pattern(conv_input); int found_conv_bias_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -44,27 +68,55 @@ std::unique_ptr ConvBiasFusePass::ApplyImpl( GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern); // elementwise_add op GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern); - // Create an ConvBias Node. - OpDesc desc; - std::string conv_bias_i_in = subgraph.at(conv_input)->Name(); - std::string conv_bias_w_in = conv_weight->Name(); - std::string conv_bias_b_in = eltwise_bias->Name(); - std::string conv_bias_out = eltwise_out->Name(); - desc.SetInput("Input", std::vector({conv_bias_i_in})); - desc.SetInput("Filter", std::vector({conv_bias_w_in})); - desc.SetInput("Bias", std::vector({conv_bias_b_in})); - desc.SetOutput("Output", std::vector({conv_bias_out})); - desc.SetType("conv2d"); - for (auto& attr : conv->Op()->GetAttrMap()) { - desc.SetAttr(attr.first, attr.second); - } - auto conv_bias_node = g->CreateOpNode(&desc); // OpDesc will be copied. - GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out}); + PADDLE_ENFORCE(subgraph.count(conv_input)); - IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node); - IR_NODE_LINK_TO(conv_weight, conv_bias_node); - IR_NODE_LINK_TO(eltwise_bias, conv_bias_node); - IR_NODE_LINK_TO(conv_bias_node, eltwise_out); + + auto* eltwise_bias_tensor = + scope->FindVar(eltwise_bias->Name())->GetMutable(); + + auto input_names = conv->Op()->InputNames(); + bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") != + input_names.end(); + if (has_bias && conv->Op()->Input("Bias").size() > 0) { + auto conv_bias_names = conv->Op()->Input("Bias"); + // add eltwise bias to existing conv bias + PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1); + auto* conv_bias_var = scope->FindVar(conv_bias_names[0]); + auto* conv_bias_tensor = conv_bias_var->GetMutable(); + PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims()); + *conv_bias_tensor = tensor_apply_eltwise( + *conv_bias_tensor, *eltwise_bias_tensor, std::plus()); + + conv->Op()->SetOutput("Output", + std::vector({eltwise_out->Name()})); + + GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out}); + + IR_NODE_LINK_TO(conv, eltwise_out); + } else { + // take eltwise bias as conv bias + OpDesc desc; + + desc.SetInput( + "Input", std::vector({subgraph.at(conv_input)->Name()})); + desc.SetInput("Filter", std::vector({conv_weight->Name()})); + desc.SetInput("Bias", std::vector({eltwise_bias->Name()})); + desc.SetOutput("Output", std::vector({eltwise_out->Name()})); + desc.SetType("conv2d"); + + for (auto& attr : conv->Op()->GetAttrMap()) { + desc.SetAttr(attr.first, attr.second); + } + auto conv_bias_node = g->CreateOpNode(&desc); + + IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node); + IR_NODE_LINK_TO(conv_weight, conv_bias_node); + IR_NODE_LINK_TO(eltwise_bias, conv_bias_node); + IR_NODE_LINK_TO(conv_bias_node, eltwise_out); + + GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out}); + } + found_conv_bias_count++; }; gpd(graph.get(), handler); diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h index 187453b2a697c503b1c4988fb7c3a5e122afcab8..5775b83b88730ec298c421a15f5c0b83c27b0750 100644 --- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" @@ -28,6 +29,7 @@ class ConvBiasFusePass : public FusePassBase { protected: std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"conv_bias_mkldnn_fuse"}; }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass_tester.cc deleted file mode 100644 index 50fc62c1734c641adb124496fedf487c31c1fb38..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass_tester.cc +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h" - -#include - -namespace paddle { -namespace framework { -namespace ir { - -void SetOp(ProgramDesc* prog, const std::string& type, - const std::vector& inputs, - const std::vector& outputs) { - auto* op = prog->MutableBlock(0)->AppendOp(); - op->SetType(type); - if (type == "conv2d") { - op->SetAttr("use_mkldnn", true); - op->SetInput("Input", {inputs[0]}); - op->SetInput("Filter", {inputs[1]}); - } else if (type == "elementwise_add") { - op->SetInput("X", {inputs[0]}); - op->SetInput("Y", {inputs[1]}); - } - op->SetOutput("Out", outputs); -} - -// a->OP0->b -// b->OP1->c -// (c, weights)->conv->f -// (f, bias)->elementwise_add->g -ProgramDesc BuildProgramDesc() { - ProgramDesc prog; - for (auto& v : - std::vector({"a", "b", "c", "weights", "bias", "f", "g"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::SELECTED_ROWS); - if (v == "weights" || v == "bias") { - var->SetPersistable(true); - } - } - - SetOp(&prog, "OP0", std::vector({"a"}), - std::vector({"b"})); - SetOp(&prog, "OP1", std::vector({"b"}), - std::vector({"c"})); - SetOp(&prog, "conv2d", std::vector({"c", "weights"}), - std::vector({"f"})); - SetOp(&prog, "elementwise_add", std::vector({"f", "bias"}), - std::vector({"g"})); - - return prog; -} - -TEST(ConvBiasFusePass, basic) { - auto prog = BuildProgramDesc(); - - std::unique_ptr graph(new ir::Graph(prog)); - - auto pass = PassRegistry::Instance().Get("conv_bias_mkldnn_fuse_pass"); - - int original_nodes_num = graph->Nodes().size(); - - graph = pass->Apply(std::move(graph)); - - int current_nodes_num = graph->Nodes().size(); - - // Remove 3 Nodes: conv, elementwise_add, conv_out - // Add 1 Node: ConvBias - EXPECT_EQ(original_nodes_num - 2, current_nodes_num); - - // Assert conv_bias op in newly generated graph - int conv_bias_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - if (node->Op()->HasAttr("use_mkldnn")) { - bool use_mkldnn = boost::get(node->Op()->GetAttr("use_mkldnn")); - if (use_mkldnn) { - auto names = node->Op()->InputNames(); - if (std::find(names.begin(), names.end(), "Bias") != names.end()) { - conv_bias_count++; - } - } - } - } - } - EXPECT_EQ(conv_bias_count, 1); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -USE_PASS(conv_bias_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 83838253330f1aac3f61490dfd75f3c91bf1795a..f28dfe40a2a7c5b514190940d373e8777d234dba 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -987,6 +987,7 @@ PDNode *patterns::ConvBias::operator()( // Bias stored in elementwise_add auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) ->AsInput() + ->assert_is_persistable_var() ->assert_is_op_input("elementwise_add", "Y"); // output auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr())