提交 d7509d63 编写于 作者: M Michal Gallus

Conv+Bias: Support non-null bias

test=develop
上级 91e8fbac
...@@ -56,6 +56,5 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph ...@@ -56,6 +56,5 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
if (WITH_MKLDNN) if (WITH_MKLDNN)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
endif () endif ()
...@@ -11,24 +11,48 @@ ...@@ -11,24 +11,48 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
#include <functional>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
LoDTensor vec_y;
vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>();
const float* b = vec_b.data<float>();
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < vec_a.numel(); i++) {
y[i] = f(a[i], b[i]);
}
return vec_y;
}
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph.get());
FusePassBase::Init("conv_bias_mkldnn_fuse", graph.get()); FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern() auto* conv_input =
->NewNode("conv_bias_mkldnn_fuse/conv_input") gpd.mutable_pattern()
->AsInput() ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->assert_is_op_input("conv2d", "Input"); ->AsInput()
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), ->assert_is_op_input("conv2d", "Input");
"conv_bias_mkldnn_fuse"); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input); conv_bias_pattern(conv_input);
int found_conv_bias_count = 0; int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -44,27 +68,55 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -44,27 +68,55 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
// elementwise_add op // elementwise_add op
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
// Create an ConvBias Node.
OpDesc desc;
std::string conv_bias_i_in = subgraph.at(conv_input)->Name();
std::string conv_bias_w_in = conv_weight->Name();
std::string conv_bias_b_in = eltwise_bias->Name();
std::string conv_bias_out = eltwise_out->Name();
desc.SetInput("Input", std::vector<std::string>({conv_bias_i_in}));
desc.SetInput("Filter", std::vector<std::string>({conv_bias_w_in}));
desc.SetInput("Bias", std::vector<std::string>({conv_bias_b_in}));
desc.SetOutput("Output", std::vector<std::string>({conv_bias_out}));
desc.SetType("conv2d");
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
auto conv_bias_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input)); PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
IR_NODE_LINK_TO(conv_weight, conv_bias_node); auto* eltwise_bias_tensor =
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node); scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end();
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
auto conv_bias_names = conv->Op()->Input("Bias");
// add eltwise bias to existing conv bias
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
*conv_bias_tensor = tensor_apply_eltwise(
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out);
} else {
// take eltwise bias as conv bias
OpDesc desc;
desc.SetInput(
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType("conv2d");
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
auto conv_bias_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
}
found_conv_bias_count++; found_conv_bias_count++;
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -28,6 +29,7 @@ class ConvBiasFusePass : public FusePassBase { ...@@ -28,6 +29,7 @@ class ConvBiasFusePass : public FusePassBase {
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", true);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
} else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
}
op->SetOutput("Out", outputs);
}
// a->OP0->b
// b->OP1->c
// (c, weights)->conv->f
// (f, bias)->elementwise_add->g
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights"}),
std::vector<std::string>({"f"}));
SetOp(&prog, "elementwise_add", std::vector<std::string>({"f", "bias"}),
std::vector<std::string>({"g"}));
return prog;
}
TEST(ConvBiasFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_bias_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: conv, elementwise_add, conv_out
// Add 1 Node: ConvBias
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_bias op in newly generated graph
int conv_bias_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
if (node->Op()->HasAttr("use_mkldnn")) {
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
if (use_mkldnn) {
auto names = node->Op()->InputNames();
if (std::find(names.begin(), names.end(), "Bias") != names.end()) {
conv_bias_count++;
}
}
}
}
}
EXPECT_EQ(conv_bias_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_bias_mkldnn_fuse_pass);
...@@ -987,6 +987,7 @@ PDNode *patterns::ConvBias::operator()( ...@@ -987,6 +987,7 @@ PDNode *patterns::ConvBias::operator()(
// Bias stored in elementwise_add // Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("elementwise_add", "Y"); ->assert_is_op_input("elementwise_add", "Y");
// output // output
auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr()) auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册