提交 4be45af1 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: skip connection attribute renamed....

MKLDNN conv + elementwise_add fusion: skip connection attribute renamed. Comments about patterns added.

test=develop
上级 9a335e02
...@@ -111,7 +111,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -111,7 +111,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
op_desc.SetOutput("Output", {conv_output->Name()}); op_desc.SetOutput("Output", {conv_output->Name()});
op_desc.SetAttr("use_mkldnn", true); op_desc.SetAttr("use_mkldnn", true);
op_desc.SetAttr("fuse_eltwise", true); op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = g->CreateOpNode(&op_desc); auto fused_conv_op = g->CreateOpNode(&op_desc);
......
...@@ -600,6 +600,15 @@ struct ConvBias : public PatternBase { ...@@ -600,6 +600,15 @@ struct ConvBias : public PatternBase {
PATTERN_DECL_NODE(eltwise_out); PATTERN_DECL_NODE(eltwise_out);
}; };
// Convolution op
// Forward pass for convolution.
// conv_input, conv_bias and conv_filter are inputs.
// conv_output is a result of the operator.
// residual_data is data used by skip connection.
// If residual connection fusion is on, the formula is:
// conv_output = conv_op(conv_filter, conv_input, conv_bias)
// + conv_residual_data
// If the fusion is off, conv_residual_data is not added.
struct Conv : public PatternBase { struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope) Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "convolution") {} : PatternBase(pattern, name_scope, "convolution") {}
...@@ -614,6 +623,10 @@ struct Conv : public PatternBase { ...@@ -614,6 +623,10 @@ struct Conv : public PatternBase {
PATTERN_DECL_NODE(conv_output); PATTERN_DECL_NODE(conv_output);
}; };
// ElementwiseAdd used in residual connections.
// y_var is used and convolution output.
// The operator is removed, when residual
// connection fusion is on.
struct ElementwiseAdd : public PatternBase { struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope) ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {} : PatternBase(pattern, name_scope, "elementwise_add") {}
......
...@@ -300,7 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -300,7 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
...@@ -369,11 +369,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -369,11 +369,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
fuse_relu, fuse_eltwise); fuse_relu, fuse_residual_conn);
} else { } else {
conv_pd = conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_eltwise); mkldnn_engine, fuse_relu, fuse_residual_conn);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -388,7 +388,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -388,7 +388,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
T* output_data = nullptr; T* output_data = nullptr;
if (fuse_eltwise) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>(); auto residual_param_data = residual_param->data<T>();
...@@ -442,14 +442,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -442,14 +442,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private: private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_eltwise) const { bool fuse_residual_conn) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with // Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_eltwise is true, the // the scale parameter. It is assumed that when fuse_residual_connection is
// Output tensor contains the data coming from residual connection. The // true, the output tensor contains the data coming from residual
// result of this post_op is: Output = scale * Output + Conv_Out. // connection. The result of this post_op is:
if (fuse_eltwise) { // Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(1.0f); post_operations.append_sum(1.0f);
} }
// Fusion with ReLU layer is executed through the PostOps feature. Create a // Fusion with ReLU layer is executed through the PostOps feature. Create a
...@@ -470,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -470,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const { const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -479,7 +480,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -479,7 +480,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst, stride_dims, padding_dims, padding_dims, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -494,7 +496,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -494,7 +496,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_eltwise) const { const bool fuse_residual_conn) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
...@@ -503,7 +505,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -503,7 +505,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias, dst, stride_dims, padding_dims, padding_dims, bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
......
...@@ -135,7 +135,7 @@ void Conv2DOpMaker::Make() { ...@@ -135,7 +135,7 @@ void Conv2DOpMaker::Make() {
AddInput("ResidualData", AddInput("ResidualData",
"(Tensor) Tensor with residual data " "(Tensor) Tensor with residual data "
"to which convolution output will be added." "to which convolution output will be added."
"Used on with fuse_eltwise fusion.") "Used with fuse_residual_connection fusion.")
.AsDispensable(); .AsDispensable();
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
...@@ -169,10 +169,10 @@ void Conv2DOpMaker::Make() { ...@@ -169,10 +169,10 @@ void Conv2DOpMaker::Make() {
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel") AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_eltwise", AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used " "(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is connected via skip connection " "whenever convolution output is as an input to residual "
"to a previous layer.") "connection.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
......
...@@ -74,7 +74,7 @@ class InferenceTranspiler(object): ...@@ -74,7 +74,7 @@ class InferenceTranspiler(object):
''' '''
Transpile the program fusing elementwise_add into conv for MKLDNN Transpile the program fusing elementwise_add into conv for MKLDNN
program. Elementwise add following convolution OP can be fused by adding program. Elementwise add following convolution OP can be fused by adding
'fuse_eltwise' attribute to convolution OP and replacing its output 'fuse_residual_connection' attribute to convolution OP and replacing its output
Tensor with second parameter of elementwise_add. Tensor with second parameter of elementwise_add.
The result of fuse is: The result of fuse is:
- before: - before:
...@@ -465,7 +465,7 @@ class InferenceTranspiler(object): ...@@ -465,7 +465,7 @@ class InferenceTranspiler(object):
in_var = self.block.vars[conv_op.input("Input")[0]] in_var = self.block.vars[conv_op.input("Input")[0]]
bias_var = self.block.vars[conv_op.input("Bias")[0]] bias_var = self.block.vars[conv_op.input("Bias")[0]]
conv_op.set_attr("fuse_eltwise", True) conv_op.set_attr("fuse_residual_connection", True)
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names} attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
self.block._insert_op( self.block._insert_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册