提交 2281ebf0 编写于 作者: G guomingz 提交者: Tao Luo

Enable the convolution/relu6(bounded_relu) fusion for FP32 on Intel platform. (#17130)

* Relu6 is the bottleneck op for Mobilenet-v2. As the mkldnn supports the conv/relu6 fusion, we implement it fusion via cpass way. Due to the int8 enabling for this fusion will be supported in MKLDNN v0.20, so this PR is focused on the fp32 optimization.

Below table shows the benchmark(FPS) which measured on skx-8180(28 cores)
Batch size | with fusion | without fusion
-- | -- | --
1 | 214.7 | 53.4
50 | 1219.727 | 137.280

test=develop

* Fix the format issue

test=develop

* Add the missing nolint comments.

test=develop

* Fix the typos.

test=develop

* Register the conv_brelu_mkldnn_fuse_pass for the MKLDNN engine.

test=develop

* Adjust the indentation.

test=develop

* Add the test_conv_brelu_mkldnn_fuse_pass case.

test=develop

* Slightly update the code per Baidu comments.
Let the parameter definition embedded into the code.
That's will make the code easy to understand.

test=develop
上级 3398f996
...@@ -85,6 +85,7 @@ if(WITH_MKLDNN) ...@@ -85,6 +85,7 @@ if(WITH_MKLDNN)
pass_library(depthwise_conv_mkldnn_pass base mkldnn) pass_library(depthwise_conv_mkldnn_pass base mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn) pass_library(cpu_quantize_pass inference mkldnn)
...@@ -114,6 +115,7 @@ if (WITH_MKLDNN) ...@@ -114,6 +115,7 @@ if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
......
...@@ -785,6 +785,33 @@ PDNode *patterns::ConvReLU::operator()( ...@@ -785,6 +785,33 @@ PDNode *patterns::ConvReLU::operator()(
return relu_out_var; return relu_out_var;
} }
PDNode *patterns::ConvBReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu6");
// output
auto *brelu_out_var = pattern->NewNode(brelu_out_repr())
->AsOutput()
->assert_is_op_output("relu6");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var});
return brelu_out_var;
}
PDNode *patterns::SeqConvEltAddRelu::operator()( PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) { paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators // Create Operators
......
...@@ -449,6 +449,27 @@ struct ConvReLU : public PatternBase { ...@@ -449,6 +449,27 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out); PATTERN_DECL_NODE(relu_out);
}; };
// CONV with ReLU6
// op: conv + relu6
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// relu6_out, relu6
struct ConvBReLU : public PatternBase {
ConvBReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bounded_relu") {}
PDNode* operator()(PDNode* conv_input);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(brelu);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(brelu_out);
};
// SEQCONV with Elementwise_Add ReLU // SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu // op: seqconv + elementwise_add + relu
// named nodes: // named nodes:
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(),
"conv_bounded_relu_mkldnn_fuse");
conv_brelu_pattern(conv_input);
int found_conv_brelu_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBoundedReLUFusePass fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_brelu_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op
// Transform Conv node into ConvBReLU node.
OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({brelu_out->Name()}));
desc->SetAttr("fuse_brelu", true);
desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold"));
GraphSafeRemoveNodes(graph, {brelu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, brelu_out);
found_conv_brelu_count++;
};
gpd(graph, handler);
AddStatis(found_conv_brelu_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_brelu_mkldnn_fuse_pass,
paddle::framework::ir::ConvBReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the CONV and ReLU6 to a ConvReLU6Op.
*/
class ConvBReLUFusePass : public FusePassBase {
public:
virtual ~ConvBReLUFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu6") {
op->SetAttr("use_mkldnn", use_mkldnn);
if (use_mkldnn) {
op->SetAttr("threshold", 6.0f);
}
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
// b->OP1->c
// (c, weights, bias)->conv->f
// (f)->brelu->g
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
// conv+brelu, both with MKL-DNN
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), true);
SetOp(&prog, "relu6", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+brelu, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu6", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}));
return prog;
}
TEST(ConvBReLUFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_brelu_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: CONV, BRELU, conv_out
// Add 1 Node: ConvBReLU
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_brelu op in newly generated graph
int conv_brelu_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
// check if only "conv1" convolution is fused
auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "conv1") {
ASSERT_TRUE(op->HasAttr("fuse_brelu"));
ASSERT_TRUE(op->HasAttr("fuse_brelu_threshold"));
bool fuse_brelu = boost::get<bool>(op->GetAttr("fuse_brelu"));
if (fuse_brelu) {
++conv_brelu_count;
float fuse_brelu_threshold =
boost::get<float>(op->GetAttr("fuse_brelu_threshold"));
EXPECT_EQ(fuse_brelu_threshold, 6.0f);
}
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_brelu"));
}
}
}
EXPECT_EQ(conv_brelu_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_brelu_mkldnn_fuse_pass);
...@@ -153,7 +153,8 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -153,7 +153,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass"})) { "conv_relu_mkldnn_fuse_pass", //
"conv_brelu_mkldnn_fuse_pass"})) {
passes_.push_back(pass); passes_.push_back(pass);
} }
} }
......
...@@ -209,6 +209,12 @@ void Conv2DOpMaker::Make() { ...@@ -209,6 +209,12 @@ void Conv2DOpMaker::Make() {
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel") AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_brelu",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<float>("fuse_brelu_threshold",
"(float, default false 6.0) Only used in mkldnn kernel")
.SetDefault(6.0f);
AddAttr<bool>("fuse_residual_connection", AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used " "(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual " "whenever convolution output is as an input to residual "
......
...@@ -119,9 +119,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -119,9 +119,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = false;
float fuse_brelu_threshold = 6.0;
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U; bool is_conv3d = strides.size() == 3U;
if (!is_conv3d) {
fuse_brelu = ctx.Attr<bool>("fuse_brelu");
fuse_brelu_threshold = ctx.Attr<float>("fuse_brelu_threshold");
}
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
PADDLE_ENFORCE( PADDLE_ENFORCE(
is_conv3d is_conv3d
...@@ -142,8 +147,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -142,8 +147,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvMKLDNNHandler::GetHash( const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups, src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations,
ctx.op().Input("Input") + ctx.op().Input("Filter")); groups, ctx.op().Input("Input") + ctx.op().Input("Filter"));
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
...@@ -194,11 +199,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -194,11 +199,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, fwd_prop_kind); fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold,
fwd_prop_kind);
} else { } else {
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings, src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, fwd_prop_kind); mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, fwd_prop_kind);
} }
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
...@@ -317,13 +324,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -317,13 +324,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = ctx.Attr<bool>("fuse_brelu");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
if (fuse_residual_conn) { if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true, PADDLE_ENFORCE(force_fp32_output != true,
"residual fusion does not support force output with fp32"); "residual fusion does not support force output with fp32");
} }
bool is_conv3d = strides.size() == 3U; bool is_conv3d = strides.size() == 3U;
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -334,6 +340,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -334,6 +340,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"dilation in convolution is not implemented yet"); "dilation in convolution is not implemented yet");
PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently"); PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently");
PADDLE_ENFORCE(fuse_brelu != true,
"int8 does not support conv/relu6 fusion currently");
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
...@@ -341,15 +349,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -341,15 +349,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> weights_tz = std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
auto dst_dt = fuse_relu ? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<uint8_t>::DataType) auto dst_dt = (fuse_relu) ? paddle::framework::ToMKLDNNDataType(
: paddle::framework::ToMKLDNNDataType( framework::DataTypeTrait<uint8_t>::DataType)
framework::DataTypeTrait<int8_t>::DataType); : paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
if (force_fp32_output) { if (force_fp32_output) {
dst_dt = paddle::framework::ToMKLDNNDataType( dst_dt = paddle::framework::ToMKLDNNDataType(
...@@ -367,8 +377,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -367,8 +377,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
key.reserve(MaxKeyLength); key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey( platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_relu, fuse_residual_conn, input->format(), fuse_relu, fuse_residual_conn, false /*fuse_brelu*/,
ctx.op().Input("Input") + ctx.op().Input("Filter")); ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
...@@ -449,22 +460,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -449,22 +460,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
memory::format::x); memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, conv_pd = ConvFwdPrimitiveDesc(
fuse_relu, fuse_residual_conn, src_md, weights_md, bias_md, dst_md, strides, paddings,
output_shift_scale, sum_scale, is_test); mkldnn_engine, fuse_relu, fuse_residual_conn, false /*fuse_brelu*/,
0.0 /*fuse_brelu_threshold*/, output_shift_scale, sum_scale,
is_test);
} else { } else {
conv_pd = conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, paddings, mkldnn_engine, fuse_relu,
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_residual_conn, false /*fuse_brelu*/,
output_shift_scale, sum_scale, is_test); 0.0 /*fuse_brelu_threshold*/,
output_shift_scale, sum_scale, is_test);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
mkldnn_engine, key)); mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
user_src_memory_p = user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
...@@ -632,11 +645,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -632,11 +645,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private: private:
mkldnn::primitive_attr CreatePostOps( mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, bool fuse_residual_conn, bool fuse_relu, bool fuse_residual_conn,
const std::vector<float> output_shift_scale, float sum_scale) const { const std::vector<float> output_shift_scale, float sum_scale,
bool fuse_brelu, float fuse_brelu_threshold) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale); conv_attr.set_output_scales(mask, output_shift_scale);
if (fuse_residual_conn) { if (fuse_residual_conn) {
post_operations.append_sum(sum_scale); post_operations.append_sum(sum_scale);
} }
...@@ -647,6 +662,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -647,6 +662,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder); negative_slope, placeholder);
} }
if (fuse_brelu) {
constexpr float scale = 1.0f;
constexpr float placeholder = 0.0f; // beta
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_brelu_threshold, placeholder);
}
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
} }
...@@ -656,7 +678,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -656,7 +678,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold,
const std::vector<float> output_shift_scale, const std::vector<float> output_shift_scale,
const float sum_scale, bool is_test) const { const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
...@@ -668,9 +691,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -668,9 +691,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto conv_desc = mkldnn::convolution_forward::desc( auto conv_desc = mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
padding_dims, padding_dims, mkldnn::padding_kind::zero); padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
mkldnn::primitive_attr conv_attr = CreatePostOps( CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale); sum_scale, fuse_brelu, fuse_brelu_threshold);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -685,7 +708,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -685,7 +708,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold,
const std::vector<float> output_shift_scale, const std::vector<float> output_shift_scale,
const float sum_scale, bool is_test) const { const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
...@@ -698,8 +722,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -698,8 +722,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation, mkldnn::convolution_direct, src, weights, bias, dst, propagation, mkldnn::convolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps( mkldnn::primitive_attr conv_attr =
fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale); CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
sum_scale, fuse_brelu, fuse_brelu_threshold);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine); conv_desc, conv_attr, engine);
...@@ -762,7 +787,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -762,7 +787,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = std::vector<int> dst_tz =
paddle::framework::vectorize2int(output_grad->dims()); paddle::framework::vectorize2int(output_grad->dims());
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_brelu = false;
if (!is_conv3d) {
fuse_brelu = ctx.Attr<bool>("fuse_brelu");
}
auto src_format = input->format(); auto src_format = input->format();
mkldnn::memory::format weights_format = mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d); GetWeightsFormat(filter->format(), g, is_conv3d);
...@@ -771,8 +800,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -771,8 +800,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// as well as attributes of primitive to be created // as well as attributes of primitive to be created
// This name will be used as key when saving info into device context // This name will be used as key when saving info into device context
const std::string key = platform::ConvMKLDNNHandler::GetHash( const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups, src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations,
ctx.op().Input("Input") + ctx.op().Input("Filter")); groups, ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
......
...@@ -166,11 +166,11 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -166,11 +166,11 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), mkldnn::memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), mkldnn::memory::format::x);
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, false, fwd_prop_kind); fuse_relu, false, false, 0.0, fwd_prop_kind);
} else { } else {
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings, src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, false, fwd_prop_kind); mkldnn_engine, fuse_relu, false, false, 0.0, fwd_prop_kind);
} }
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
......
...@@ -212,25 +212,29 @@ class MKLDNNHandler { ...@@ -212,25 +212,29 @@ class MKLDNNHandler {
dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<T>(output_data))); dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<T>(output_data)));
} }
static void AppendKey(std::string* key, static void AppendKey(
const mkldnn::memory::dims& input_dims, std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& strides, const std::vector<int>& paddings, const std::vector<int>& dilations,
const std::vector<int>& paddings, const int& groups, const mkldnn::memory::data_type& srcdt,
const std::vector<int>& dilations, const int& groups, const mkldnn::memory::format& format, const bool& relu,
const mkldnn::memory::data_type& srcdt, const bool& residual, const bool& brelu, const std::string& suffix) {
const mkldnn::memory::format& format, const bool& relu,
const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims); AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims); AppendKeyDims(key, weights_dims);
AppendKeyVec(key, strides); AppendKeyVec(key, strides);
AppendKeyVec(key, paddings); AppendKeyVec(key, paddings);
AppendKeyVec(key, dilations); AppendKeyVec(key, dilations);
AppendKey(key, std::to_string(groups)); AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt)); AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format)); AppendKey(key, std::to_string(format));
AppendKey(key, std::to_string(relu)); AppendKey(key, std::to_string(relu));
AppendKey(key, std::to_string(residual)); AppendKey(key, std::to_string(residual));
AppendKey(key, std::to_string(brelu));
AppendKey(key, suffix); AppendKey(key, suffix);
} }
...@@ -562,8 +566,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -562,8 +566,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
scale_data, mask); scale_data, mask);
} }
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,
bool fuse_residual_conn = false) const { bool fuse_brelu,
float fuse_brelu_threshold) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with // Fusion with Elementwise layer relies on adding a sum post-operation with
...@@ -583,6 +588,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -583,6 +588,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder); negative_slope, placeholder);
} }
if (fuse_brelu) {
constexpr float scale = 1.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_brelu_threshold, placeholder);
}
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
} }
...@@ -594,6 +607,7 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -594,6 +607,7 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
const mkldnn::memory::desc& dst, const std::vector<int>& strides, const mkldnn::memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const mkldnn::engine& engine, const std::vector<int>& paddings, const mkldnn::engine& engine,
const bool fuse_relu, const bool fuse_residual_conn, const bool fuse_relu, const bool fuse_residual_conn,
const bool fuse_brelu, const float fuse_brelu_threshold,
mkldnn::prop_kind fwd_prop_kind) { mkldnn::prop_kind fwd_prop_kind) {
const std::string key_conv_pd = key_ + "@conv_pd"; const std::string key_conv_pd = key_ + "@conv_pd";
...@@ -614,8 +628,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -614,8 +628,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
weights, dst, stride_dims, padding_dims, padding_dims, weights, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = mkldnn::primitive_attr conv_attr = CreatePostOps(
CreatePostOps(fuse_relu, fuse_residual_conn); fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold);
conv_pd_.reset( conv_pd_.reset(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine)); new typename forward_t::primitive_desc(conv_desc, conv_attr, engine));
...@@ -714,6 +728,22 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -714,6 +728,22 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
return conv_bwd_data_p; return conv_bwd_data_p;
} }
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
mkldnn::memory::dims& weights_dims, // NOLINT
const bool& fuse_relu, // NOLINT
const bool& fuse_brelu, // NOLINT
std::vector<int>& strides, // NOLINT
std::vector<int>& paddings, // NOLINT
std::vector<int>& dilations, // NOLINT
int groups, const std::string& suffix) {
return dims2str(input_dims) + dims2str(weights_dims) +
std::to_string(fuse_relu) + std::to_string(fuse_brelu) +
dims2str(strides) + dims2str(paddings) + dims2str(dilations) +
std::to_string(groups) + suffix;
}
// Generate keys for storing/retriving primitives for this operator // Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial // TODO(jczaja): Make hashing function more optimial
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
......
...@@ -57,6 +57,8 @@ class TestConv2dMKLDNNOp(TestConv2dOp): ...@@ -57,6 +57,8 @@ class TestConv2dMKLDNNOp(TestConv2dOp):
self.fuse_bias = False self.fuse_bias = False
self.bias_size = None self.bias_size = None
self.fuse_relu = False self.fuse_relu = False
self.fuse_brelu = False
self.fuse_brelu_threshold = 6.0
self.fuse_residual_connection = False self.fuse_residual_connection = False
self.input_residual_size = None self.input_residual_size = None
TestConv2dOp.setUp(self) TestConv2dOp.setUp(self)
...@@ -84,15 +86,38 @@ class TestConv2dMKLDNNOp(TestConv2dOp): ...@@ -84,15 +86,38 @@ class TestConv2dMKLDNNOp(TestConv2dOp):
if self.fuse_relu: if self.fuse_relu:
output = np.maximum(output, 0).astype(self.dsttype) output = np.maximum(output, 0).astype(self.dsttype)
if self.fuse_brelu:
output = np.minimum(
np.maximum(output, 0),
self.fuse_brelu_threshold).astype(self.dsttype)
output = output.astype(self.dtype) output = output.astype(self.dtype)
self.attrs['fuse_bias'] = self.fuse_bias self.attrs['fuse_bias'] = self.fuse_bias
self.attrs['fuse_relu'] = self.fuse_relu self.attrs['fuse_relu'] = self.fuse_relu
self.attrs['fuse_brelu'] = self.fuse_brelu
self.attrs['fuse_brelu_threshold'] = self.fuse_brelu_threshold
self.attrs['fuse_residual_connection'] = self.fuse_residual_connection self.attrs['fuse_residual_connection'] = self.fuse_residual_connection
self.outputs['Output'] = output self.outputs['Output'] = output
class TestWithbreluFusion(TestConv2dMKLDNNOp):
def init_test_case(self):
TestConv2dMKLDNNOp.init_test_case(self)
self.fuse_brelu = True
self.fuse_brelu_threshold = 6.0
self.dsttype = np.float32
def test_check_grad(self):
pass
def test_check_grad_no_filter(self):
pass
def test_check_grad_no_input(self):
pass
class TestWithFuse(TestConv2dMKLDNNOp): class TestWithFuse(TestConv2dMKLDNNOp):
def init_test_case(self): def init_test_case(self):
TestConv2dMKLDNNOp.init_test_case(self) TestConv2dMKLDNNOp.init_test_case(self)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册