未验证 提交 42aa1d40 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #13485 from tpatejko/tpatejko/capi-resnet-conv-elementwise-fusion

MKLDNN conv+elementwise_add fusion for residual connections in Resnet
......@@ -45,6 +45,9 @@ if(WITH_MKLDNN)
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
if(WITH_MKLDNN)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif()
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......@@ -58,4 +61,5 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
if (WITH_MKLDNN)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
endif ()
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include <utility>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
// The function keeps the graph consistent by replacing
// a node 'from' in the set of inputs nodes
// of the visited node by a node 'to'.
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
auto from_in_inputs =
std::find(std::begin(node.inputs), std::end(node.inputs), from);
if (from_in_inputs != std::end(node.inputs)) {
IR_NODE_LINK_TO(to, (&node));
auto inputs = node.Op()->Inputs();
using input_type = VariableNameMap::value_type;
std::for_each(std::begin(inputs), std::end(inputs),
[from, to, &node](const input_type& i) -> void {
auto param_names = i.second;
auto pi = std::find(std::begin(param_names),
std::end(param_names), from->Name());
if (pi != std::end(param_names)) {
node.Op()->SetInput(i.first, {to->Name()});
}
});
}
}
}
} // namespace
using graph_ptr = std::unique_ptr<ir::Graph>;
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get());
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope_};
auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
elementwise_add_pattern(conv_output);
conv_output->AsIntermediate();
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
auto bias_input_names = conv_op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias");
if (bias_it != std::end(bias_input_names)) {
bool has_bias = !bias_it->second.empty();
if (has_bias) {
auto conv_bias_names = bias_it->second;
auto conv_bias_names_it =
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
[&conv_bias_names](Node* n) -> bool {
return n->Name() == conv_bias_names[0];
});
return std::make_pair(has_bias, *conv_bias_names_it);
}
}
return std::make_pair(false, nullptr);
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = g->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
}
CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"residual_connections_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
constexpr int nodes_removed = 3;
constexpr int nodes_added = 1;
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::pair<std::string, std::string>>& inputs,
const std::pair<std::string, std::string>& output) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", true);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
op->SetOutput(output.first, {output.second});
}
struct IsReachable {
using func = std::function<bool(const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
auto find_node = [](const std::unique_ptr<ir::Graph>& graph,
const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(*graph)) {
if (name == node.Name()) {
return &node;
}
}
return nullptr;
};
return [&](std::string from, const std::string to) -> bool {
if (from == to) return true;
std::map<std::string, bool> visited;
for (auto& node : GraphTraits::DFS(*graph)) {
visited[node.Name()] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr) return false;
for (auto n : cur->outputs) {
if (n->Name() == to) return true;
if (!visited[n->Name()]) {
visited[n->Name()] = true;
queue.push_back(n->Name());
}
}
}
return false;
};
}
};
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return var;
};
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
} // namespace
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionWithElementwiseAddReluNoBias) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "d"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_FALSE(is_reachable(graph)("a", "d"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "f"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "f"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_elementwise_add_mkldnn_fuse_pass);
......@@ -1044,6 +1044,46 @@ PDNode *patterns::ConvBias::operator()(
return eltwise_out_var;
}
PDNode *patterns::Conv::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto input_var = pattern->NewNode(conv_input_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
auto filter_var = pattern->NewNode(conv_filter_repr())
->AsInput()
->assert_is_op_input("conv2d", "Filter");
auto output_var = pattern->NewNode(conv_output_repr())
->AsOutput()
->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
}
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
x_var->assert_is_op_input("elementwise_add", "X");
auto y_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
elementwise_add_op->LinksFrom({x_var, y_var});
elementwise_add_op->LinksTo({out_var});
return out_var;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -633,6 +633,44 @@ struct ConvBias : public PatternBase {
PATTERN_DECL_NODE(eltwise_bias);
PATTERN_DECL_NODE(eltwise_out);
};
// Convolution op
// Forward pass for convolution.
// conv_input, conv_bias and conv_filter are inputs.
// conv_output is a result of the operator.
// residual_data is data used by skip connection.
// If residual connection fusion is on, the formula is:
// conv_output = conv_op(conv_filter, conv_input, conv_bias)
// + conv_residual_data
// If the fusion is off, conv_residual_data is not added.
struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "convolution") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_residual_data);
PATTERN_DECL_NODE(conv_output);
};
// ElementwiseAdd used in residual connections.
// y_var is used and convolution output.
// The operator is removed, when residual
// connection fusion is on.
struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* x_var);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x);
PATTERN_DECL_NODE(elementwise_add_y);
PATTERN_DECL_NODE(elementwise_add_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
......@@ -80,8 +80,9 @@ class Analyzer : public OrderedRegistry<PassManager> {
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
#ifdef PADDLE_WITH_MKLDNN
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", //
#endif
}};
......
......@@ -300,10 +300,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
// TODO: add support for dilation
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
......@@ -369,11 +369,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_eltwise);
fuse_relu, fuse_residual_conn);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_eltwise);
mkldnn_engine, fuse_relu, fuse_residual_conn);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -386,8 +386,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
T* output_data = nullptr;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(
residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
"Output and elementwise parameter need to have the "
"same dimension sizes");
output->ShareDataWith(*residual_param);
output_data = output->mutable_data<T>(ctx.GetPlace());
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
}
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
......@@ -424,14 +442,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_eltwise) const {
bool fuse_residual_conn) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_eltwise is true, the
// Output tensor contains the data coming from residual connection. The
// result of this post_op is: Output = scale * Output + Conv_Out.
if (fuse_eltwise) {
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(1.0f);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
......@@ -452,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const {
const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
......@@ -461,7 +480,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
......@@ -476,7 +496,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const {
const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
......@@ -485,7 +505,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
......
......@@ -132,6 +132,11 @@ void Conv2DOpMaker::Make() {
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.")
.Reuse("Input");
AddInput("ResidualData",
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used with fuse_residual_connection fusion.")
.AsDispensable();
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
......@@ -164,10 +169,10 @@ void Conv2DOpMaker::Make() {
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_eltwise",
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is connected via skip connection "
"to a previous layer.")
"whenever convolution output is as an input to residual "
"connection.")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
......
......@@ -74,7 +74,7 @@ class InferenceTranspiler(object):
'''
Transpile the program fusing elementwise_add into conv for MKLDNN
program. Elementwise add following convolution OP can be fused by adding
'fuse_eltwise' attribute to convolution OP and replacing its output
'fuse_residual_connection' attribute to convolution OP and replacing its output
Tensor with second parameter of elementwise_add.
The result of fuse is:
- before:
......@@ -92,7 +92,8 @@ class InferenceTranspiler(object):
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'elementwise_add':
self._fuse_conv_eltwise(current_op, next_op)
self._fuse_conv_eltwise(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
self._adjust_input()
......@@ -444,7 +445,7 @@ class InferenceTranspiler(object):
outputs={"Output": out_var},
attrs=attrs)
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
def _fuse_conv_eltwise(self, index, conv_op, eltwise_op):
'''
fuse the conv op with elementwise_add
......@@ -454,9 +455,30 @@ class InferenceTranspiler(object):
:type eltwise_op: Operator
'''
conv_op._set_attr("fuse_eltwise", True)
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
eltwise_input = "X"
if eltwise_op.input("X")[0] == conv_op.output("Output")[0]:
eltwise_input = "Y"
residual_var = self.block.vars[eltwise_op.input(eltwise_input)[0]]
out_var = self.block.vars[eltwise_op.output("Out")[0]]
filter_var = self.block.vars[conv_op.input("Filter")[0]]
in_var = self.block.vars[conv_op.input("Input")[0]]
bias_var = self.block.vars[conv_op.input("Bias")[0]]
conv_op._set_attr("fuse_residual_connection", True)
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
self.block._insert_op(
index,
type="conv2d",
inputs={
"Input": in_var,
"Filter": filter_var,
"Bias": bias_var,
"ResidualData": residual_var
},
outputs={"Output": out_var},
attrs=attrs)
def _adjust_input(self):
for i in range(len(self.block.ops)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册