提交 b837689e 编写于 作者: A Adam 提交者: Tao Luo

Add generalized Conv+Activation MKLDNN fuse pass creation (#19072)

test=develop
上级 50b1cab1
......@@ -86,8 +86,7 @@ if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base mkldnn)
pass_library(depthwise_conv_mkldnn_pass base mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn)
......@@ -127,8 +126,7 @@ endif()
if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
......
......@@ -771,6 +771,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
return bn_out_var;
}
PDNode *patterns::ConvActivation::operator()(
paddle::framework::ir::PDNode *conv_input, std::string conv_type,
std::string activation_type) {
// Create Operators
conv_input->assert_is_op_input(conv_type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type);
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(conv_type, "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op(conv_type)
->assert_is_op_input(activation_type);
// output
auto *activation_out_var = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
activation_op->LinksFrom({conv_out_var}).LinksTo({activation_out_var});
return activation_out_var;
}
PDNode *patterns::ConvReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
......
......@@ -431,6 +431,28 @@ struct ConvBN : public PatternBase {
PATTERN_DECL_NODE(bn_saved_variance);
};
// Conv with Activation
// op: conv + activation
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// activation_out, activation
struct ConvActivation : public PatternBase {
ConvActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_activation") {}
PDNode* operator()(PDNode* conv_input, std::string conv_type = "conv2d",
std::string activation_type = "relu");
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(activation);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(activation_out);
};
// CONV with ReLU
// op: conv + relu
// named nodes:
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph, "graph cannot be nullptr.");
FusePassBase::Init("conv_activation_mkldnn_fuse", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_activation_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input(conv_type(), "Input");
patterns::ConvActivation conv_activation_pattern(
gpd.mutable_pattern(), "conv_activation_mkldnn_fuse");
conv_activation_pattern(conv_input, conv_type(), activation_type());
int found_conv_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_activation_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out,
conv_activation_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_activation_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out,
conv_activation_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(activation, activation,
conv_activation_pattern); // Activation op
// Transform Conv node into ConvActivation node.
OpDesc* desc = conv->Op();
desc->SetOutput("Output",
std::vector<std::string>({activation_out->Name()}));
desc->SetAttr("fuse_activation", activation_type());
// MKLDNN ops use alpha and beta as activation parameters but paddle ops are
// not generalized
if (activation_type() == "relu6") {
desc->SetAttr("fuse_alpha",
boost::get<float>(activation->Op()->GetAttr("threshold")));
} else {
desc->SetAttr("fuse_alpha",
activation->Op()->HasAttr("alpha")
? boost::get<float>(activation->Op()->GetAttr("alpha"))
: 0.0f);
}
desc->SetAttr("fuse_beta",
activation->Op()->HasAttr("beta")
? boost::get<float>(activation->Op()->GetAttr("beta"))
: 0.0f);
GraphSafeRemoveNodes(graph, {activation, conv_out});
PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL,
"subgraph has to contain conv_input node.");
IR_NODE_LINK_TO(conv, activation_out);
found_conv_activation_count++;
};
gpd(graph, handler);
AddStatis(found_conv_activation_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_activation_mkldnn_fuse_pass,
paddle::framework::ir::ConvActivationFusePass);
REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
paddle::framework::ir::ConvActivationFusePass);
REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DLeakyReLUFusePass);
REGISTER_PASS(conv_relu6_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DReLU6FusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Conv and Activation base class.
*/
class ConvActivationFusePass : public FusePassBase {
public:
virtual ~ConvActivationFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
virtual std::string activation_type() const { return "relu"; }
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"conv_activation_mkldnn_fuse"};
};
/*
* Fuse Conv and LeakyReLU class
*/
class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
public:
std::string activation_type() const { return "leaky_relu"; }
};
/*
* Fuse Conv and BoundedReLU class
*/
class Conv2DReLU6FusePass : public ConvActivationFusePass {
public:
std::string activation_type() const { return "relu6"; }
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool is_activation = false,
bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (is_activation) {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs);
if (type == "leaky_relu") {
op->SetAttr("alpha", 0.02f);
} else if (type == "relu6") {
op->SetAttr("threshold", 6.0f);
}
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
// b->OP1->c
// (c, weights, bias)->conv->f
// (f)->activation->g
ProgramDesc BuildProgramDesc(std::string activation) {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l", "m"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias" || v == "weights2" || v == "bias2") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
// conv+activation, both with MKL-DNN
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), false, true);
SetOp(&prog, activation, "activation1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true, true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+activation, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), false, true);
SetOp(&prog, "activation", "activation2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}), true, false);
SetOp(&prog, "OP4", "op4", std::vector<std::string>({"l"}),
std::vector<std::string>({"m"}));
return prog;
}
void MainTest(std::string activation) {
auto prog = BuildProgramDesc(activation);
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass =
PassRegistry::Instance().Get("conv_" + activation + "_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: CONV, activation, conv_out
// Add 1 Node: ConvActivation
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_activation op in newly generated graph
int conv_activation_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
auto op_name = boost::get<std::string>(op->GetAttr("name"));
std::string fuse_activation =
op->HasAttr("fuse_activation")
? boost::get<std::string>(op->GetAttr("fuse_activation"))
: "";
if (fuse_activation == activation) {
++conv_activation_count;
}
// check if only "conv1" convolution is fused
if (op_name == "conv1") {
ASSERT_TRUE(op->HasAttr("fuse_activation"));
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_activation"));
}
}
}
EXPECT_EQ(conv_activation_count, 1);
}
TEST(ConvActivationFusePass, conv_relu_fuse_pass) { MainTest("relu"); }
TEST(ConvActivationFusePass, conv_leaky_relu_fuse_pass) {
MainTest("leaky_relu");
}
TEST(ConvActivationFusePass, conv_relu6_fuse_pass) { MainTest("relu6"); }
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_activation_mkldnn_fuse_pass);
......@@ -83,7 +83,7 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU(
// Transform Conv node into ConvReLU node.
OpDesc* conv_desc = conv_op->Op();
conv_desc->SetAttr("fuse_relu", true);
conv_desc->SetAttr("fuse_activation", std::string("relu"));
// Remove ReLU when all Convs were transformed.
auto number_of_unfused_convs_left =
......
......@@ -28,7 +28,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("fuse_relu", false);
op->SetAttr("fuse_activation", std::string(""));
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) {
......@@ -109,8 +109,9 @@ void MainTest(const ProgramDesc& prog, bool fuse_relu) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "conv2d") {
ASSERT_TRUE(op->HasAttr("fuse_relu"));
bool fuse_relu_attr = boost::get<bool>(op->GetAttr("fuse_relu"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
bool fuse_relu_attr =
(boost::get<std::string>(op->GetAttr("fuse_activation")) == "relu");
EXPECT_EQ(fuse_relu, fuse_relu_attr);
} else if (op->Type() == "relu") {
relu_count++;
......
......@@ -109,8 +109,11 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
auto fuse_relu = HasAttribute<bool>(*conv_op, "fuse_relu");
if (fuse_relu && *fuse_relu) return;
std::string fuse_activation =
conv_op->Op()->HasAttr("fuse_activation")
? boost::get<std::string>(conv_op->Op()->GetAttr("fuse_activation"))
: "";
if (fuse_activation == "relu" || fuse_activation == "relu6") return;
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
......@@ -179,8 +182,12 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
return;
}
auto fuse_relu = HasAttribute<bool>(*residual_conv_op, "fuse_relu");
if (fuse_relu && *fuse_relu) return;
std::string fuse_activation =
residual_conv_op->Op()->HasAttr("fuse_activation")
? boost::get<std::string>(
residual_conv_op->Op()->GetAttr("fuse_activation"))
: "";
if (fuse_activation == "relu" || fuse_activation == "relu6") return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
......
......@@ -68,10 +68,13 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() {
if (is_output) {
if (op->Type() == "conv2d") {
// output of conv2d with relu must be unsigned
is_unsigned = (op->HasAttr("fuse_relu") &&
boost::get<bool>(op->GetAttr("fuse_relu"))) ||
(op->HasAttr("fuse_brelu") &&
boost::get<bool>(op->GetAttr("fuse_brelu")));
std::string fuse_activation =
op->HasAttr("fuse_activation")
? boost::get<std::string>(
op->GetAttr("fuse_activation"))
: "";
is_unsigned =
(fuse_activation == "relu" || fuse_activation == "relu6");
} else if (op->Type() == "relu") {
is_unsigned = true;
} else if (op->Type() == "transpose2" ||
......
......@@ -180,8 +180,9 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", //
"conv_brelu_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_leaky_relu_mkldnn_fuse_pass", //
"conv_relu6_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass"
})) {
......
......@@ -215,6 +215,14 @@ void Conv2DOpMaker::Make() {
AddAttr<float>("fuse_brelu_threshold",
"(float, default false 6.0) Only used in mkldnn kernel")
.SetDefault(6.0f);
AddAttr<std::string>("fuse_activation",
"(string, default \"\") Only used in mkldnn kernel")
.SetDefault("");
AddAttr<float>("fuse_alpha",
"(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f);
AddAttr<float>("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f);
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual "
......@@ -352,6 +360,14 @@ void Conv3DOpMaker::Make() {
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>("fuse_activation",
"(string, default \"\") Only used in mkldnn kernel")
.SetDefault("");
AddAttr<float>("fuse_alpha",
"(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f);
AddAttr<float>("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f);
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual "
......
......@@ -170,6 +170,14 @@ void Conv2DTransposeOpMaker::Make() {
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>("fuse_activation",
"(string, default \"\") Only used in mkldnn kernel")
.SetDefault("");
AddAttr<float>("fuse_alpha",
"(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f);
AddAttr<float>("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel")
.SetDefault(0.0f);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
......
......@@ -71,13 +71,14 @@ inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format,
static mkldnn::memory::data_type GetDstType(bool is_int8,
bool force_fp32_output,
bool fuse_relu, bool fuse_brelu,
std::string fuse_activation,
bool fuse_residual_conn,
const Tensor* residual_param) {
auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float
if (is_int8) {
dst_dt = (fuse_relu || fuse_brelu) ? mkldnn::memory::data_type::u8
: mkldnn::memory::data_type::s8;
dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6")
? mkldnn::memory::data_type::u8
: mkldnn::memory::data_type::s8;
if (force_fp32_output) {
dst_dt = mkldnn::memory::data_type::f32;
}
......@@ -100,12 +101,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (!is_INT8) {
ComputeFP32(ctx);
} else {
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = ctx.Attr<bool>("fuse_brelu");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto dst_dt = GetDstType(true, force_fp32_output, fuse_relu, fuse_brelu,
auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeINT8<float>(ctx);
......@@ -150,16 +150,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = false;
float fuse_brelu_threshold = 6.0;
int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
if (!is_conv3d) {
fuse_brelu = ctx.Attr<bool>("fuse_brelu");
fuse_brelu_threshold = ctx.Attr<float>("fuse_brelu_threshold");
}
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
is_conv3d
......@@ -180,7 +177,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations,
src_tz, weights_tz, fuse_activation, strides, paddings, dilations,
groups, ctx.op().Input("Input") + ctx.op().Input("Filter"));
std::vector<primitive> pipeline;
......@@ -232,13 +229,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold,
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
fwd_prop_kind);
} else {
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, fwd_prop_kind);
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, fwd_prop_kind);
}
// create mkldnn memory from input tensors (data/weights)
......@@ -355,12 +352,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = ctx.Attr<bool>("fuse_brelu");
float fuse_brelu_threshold = ctx.Attr<float>("fuse_brelu_threshold");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool unsigned_output = fuse_relu || fuse_brelu;
bool unsigned_output =
(fuse_activation == "relu" || fuse_activation == "relu6");
PADDLE_ENFORCE(!fuse_residual_conn || !force_fp32_output,
"residual fusion does not support force output with fp32");
......@@ -394,7 +392,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_relu, fuse_residual_conn, fuse_brelu,
input->format(), fuse_activation, fuse_residual_conn,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
......@@ -484,6 +482,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(
new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key));
// create a conv primitive descriptor and save it for usage in backward
// TODO(grygielski) if INT8 brelu post-op will be available, just delete
// whole if statement
if (fuse_activation == "relu6") {
fuse_activation = "relu";
fuse_alpha = 0.0f;
}
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
......@@ -493,13 +499,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::format::x);
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, propagation, output_shift_scale, sum_scale);
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, propagation, output_shift_scale, sum_scale);
} else {
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, propagation, output_shift_scale, sum_scale);
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, propagation, output_shift_scale, sum_scale);
}
// create mkldnn memory from input tensors (data/weights)
......@@ -681,11 +687,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz =
paddle::framework::vectorize2int(output_grad->dims());
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_brelu = false;
if (!is_conv3d) {
fuse_brelu = ctx.Attr<bool>("fuse_brelu");
}
auto src_format = input->format();
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
......@@ -694,8 +695,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations,
groups, ctx.op().Input("Input") + ctx.op().Input("Filter"));
src_tz, weights_tz, "", strides, paddings, dilations, groups,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline;
......
......@@ -142,7 +142,9 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta");
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......@@ -166,11 +168,12 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), mkldnn::memory::format::x);
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, false, false, 0.0, fwd_prop_kind);
fuse_activation, fuse_alpha, fuse_beta, false, fwd_prop_kind);
} else {
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, false, false, 0.0, fwd_prop_kind);
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, false,
fwd_prop_kind);
}
// create mkldnn memory from input tensors (data/weights)
......
......@@ -217,8 +217,8 @@ class MKLDNNHandler {
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& srcdt,
const mkldnn::memory::format& format, const bool& relu,
const bool& residual, const bool& brelu, const std::string& suffix) {
const mkldnn::memory::format& format, const std::string& fuse_activation,
const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims);
......@@ -232,9 +232,8 @@ class MKLDNNHandler {
AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format));
AppendKey(key, std::to_string(relu));
AppendKey(key, fuse_activation);
AppendKey(key, std::to_string(residual));
AppendKey(key, std::to_string(brelu));
AppendKey(key, suffix);
}
......@@ -1179,9 +1178,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
}
mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, bool fuse_residual_conn, bool fuse_brelu,
float fuse_brelu_threshold,
const std::vector<float> output_shift_scale = {},
std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
......@@ -1199,20 +1197,17 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
fuse_alpha, fuse_beta);
}
if (fuse_brelu) {
if (fuse_activation == "relu6") {
constexpr float scale = 1.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_brelu_threshold, placeholder);
fuse_alpha, fuse_beta);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
......@@ -1224,9 +1219,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
boost::optional<const mkldnn::memory::desc&> bias,
const mkldnn::memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const mkldnn::engine& engine,
const bool fuse_relu, const bool fuse_residual_conn,
const bool fuse_brelu, const float fuse_brelu_threshold,
mkldnn::prop_kind fwd_prop_kind,
const std::string& fuse_activation, float fuse_alpha, float fuse_beta,
const bool fuse_residual_conn, mkldnn::prop_kind fwd_prop_kind,
const std::vector<float> output_shift_scale = {},
const float sum_scale = 1.0f) {
// Conv PD has to be passed to Grad op that
......@@ -1259,8 +1253,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, output_shift_scale, sum_scale);
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, output_shift_scale, sum_scale);
conv_pd_.reset(new typename forward_t::primitive_desc(
conv_desc, conv_attr, engine));
......@@ -1343,14 +1337,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
mkldnn::memory::dims& weights_dims, // NOLINT
const bool& fuse_relu, // NOLINT
const bool& fuse_brelu, // NOLINT
const std::string& fuse_activation, // NOLINT
std::vector<int>& strides, // NOLINT
std::vector<int>& paddings, // NOLINT
std::vector<int>& dilations, // NOLINT
int groups, const std::string& suffix) {
return dims2str(input_dims) + dims2str(weights_dims) +
std::to_string(fuse_relu) + std::to_string(fuse_brelu) +
return dims2str(input_dims) + dims2str(weights_dims) + fuse_activation +
dims2str(strides) + dims2str(paddings) + dims2str(dilations) +
std::to_string(groups) + suffix;
}
......
......@@ -83,12 +83,12 @@ class TestConv2dInt8Op(TestConv2dOp):
input_residual, self.input_residual_size).astype(
self.srctype) * (self.scale_out / self.scale_in_eltwise
))
if self.fuse_relu:
if self.fuse_activation == "relu":
output = np.maximum(output_tmp, 0).astype(self.dsttype)
else:
output = output_tmp.astype(self.dsttype)
else:
if self.fuse_relu:
if self.fuse_activation == "relu":
output = np.maximum(np.round(output1 - output2),
0).astype(self.dsttype)
else:
......@@ -113,12 +113,12 @@ class TestConv2dInt8Op(TestConv2dOp):
input_residual, self.input_residual_size).astype(
np.int32) * (self.scale_out / self.scale_in_eltwise
))
if self.fuse_relu:
if self.fuse_activation == "relu":
output = np.maximum(output_tmp_res, 0).astype(self.dsttype)
else:
output = output_tmp_res.astype(self.dsttype)
else:
if self.fuse_relu:
if self.fuse_activation == "relu":
output = np.maximum(output1_tmp, 0).astype(self.dsttype)
else:
output = output1_tmp.astype(self.dsttype)
......@@ -145,7 +145,7 @@ class TestConv2dInt8Op(TestConv2dOp):
'Scale_out': self.scale_out,
'Scale_weights': self.scale_weights,
'Scale_in_eltwise': self.scale_in_eltwise,
'fuse_relu': self.fuse_relu,
'fuse_activation': self.fuse_activation,
'fuse_residual_connection': self.fuse_residual
}
self.outputs = {'Output': output}
......@@ -178,7 +178,7 @@ class TestConv2dInt8Op(TestConv2dOp):
self.dsttype = np.int8
def init_fuse_relu(self):
self.fuse_relu = True
self.fuse_activation = "relu"
def init_fuse_residual(self):
self.fuse_residual = True
......@@ -262,11 +262,11 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op):
self.groups = 3
def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual):
def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual):
self.srctype = input_dt
self.dsttype = np.uint8 if fuse_relu else np.int8
self.dsttype = np.uint8 if fuse_activation == "relu" else np.int8
self.fuse_relu = fuse_relu
self.fuse_activation = fuse_activation
self.fuse_residual = fuse_residual
......@@ -277,43 +277,43 @@ def create_test_int8_class(parent):
class TestS8U8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, True, False)
init_data_type_with_fusion(self, np.int8, "relu", False)
#--------------------test conv2d s8 in and s8 out--------------------
class TestS8S8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, False, False)
init_data_type_with_fusion(self, np.int8, "", False)
#--------------------test conv2d u8 in and s8 out--------------------
class TestU8S8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, False, False)
init_data_type_with_fusion(self, np.uint8, "", False)
#--------------------test conv2d u8 in and u8 out without residual fuse--------------------
class TestU8U8Case(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, True, False)
init_data_type_with_fusion(self, np.uint8, "relu", False)
#--------------------test conv2d s8 in and u8 out with residual fuse--------------------
class TestS8U8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, True, True)
init_data_type_with_fusion(self, np.int8, "relu", True)
#--------------------test conv2d s8 in and s8 out with residual fuse--------------------
class TestS8S8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.int8, False, True)
init_data_type_with_fusion(self, np.int8, "", True)
#--------------------test conv2d u8 in and s8 out with residual fuse--------------------
class TestU8S8ResCase(parent):
def init_data_type(self):
init_data_type_with_fusion(self, np.uint8, False, True)
init_data_type_with_fusion(self, np.uint8, "", True)
cls_name_s8u8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "1")
cls_name_s8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0")
......
......@@ -56,8 +56,9 @@ class TestConv2dMKLDNNOp(TestConv2dOp):
def setUp(self):
self.fuse_bias = False
self.bias_size = None
self.fuse_relu = False
self.fuse_brelu = False
self.fuse_activation = ""
self.fuse_alpha = 0
self.fuse_beta = 0
self.fuse_brelu_threshold = 6.0
self.fuse_residual_connection = False
self.input_residual_size = None
......@@ -83,18 +84,18 @@ class TestConv2dMKLDNNOp(TestConv2dOp):
self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype(
input_residual)
if self.fuse_relu:
if self.fuse_activation == "relu":
output = np.maximum(output, 0).astype(self.dsttype)
if self.fuse_brelu:
output = np.minimum(
np.maximum(output, 0),
self.fuse_brelu_threshold).astype(self.dsttype)
if self.fuse_activation == "relu6":
output = np.minimum(np.maximum(output, 0),
self.fuse_alpha).astype(self.dsttype)
output = output.astype(self.dtype)
self.attrs['fuse_bias'] = self.fuse_bias
self.attrs['fuse_relu'] = self.fuse_relu
self.attrs['fuse_brelu'] = self.fuse_brelu
self.attrs['fuse_activation'] = self.fuse_activation
self.attrs['fuse_alpha'] = self.fuse_alpha
self.attrs['fuse_beta'] = self.fuse_beta
self.attrs['fuse_brelu_threshold'] = self.fuse_brelu_threshold
self.attrs['fuse_residual_connection'] = self.fuse_residual_connection
......@@ -104,8 +105,8 @@ class TestConv2dMKLDNNOp(TestConv2dOp):
class TestWithbreluFusion(TestConv2dMKLDNNOp):
def init_test_case(self):
TestConv2dMKLDNNOp.init_test_case(self)
self.fuse_brelu = True
self.fuse_brelu_threshold = 6.0
self.fuse_activation = "relu6"
self.fuse_alpha = 6.0
self.dsttype = np.float32
def test_check_grad(self):
......
......@@ -51,7 +51,9 @@ class TestConv2dTransposeMKLDNNOp(TestConv2dTransposeOp):
self.pad = [0, 0]
self.fuse_bias = False
self.bias_size = None
self.fuse_relu = False
self.fuse_activation = ""
self.fuse_alpha = 0.0
self.fuse_beta = 0.0
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
......@@ -71,11 +73,13 @@ class TestConv2dTransposeMKLDNNOp(TestConv2dTransposeOp):
self.attrs['fuse_bias'] = self.fuse_bias
self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias)
if self.fuse_relu:
if self.fuse_activation == "relu":
output = np.maximum(output, 0).astype(self.dtype)
output = output.astype(self.dtype)
self.attrs['fuse_bias'] = self.fuse_bias
self.attrs['fuse_relu'] = self.fuse_relu
self.attrs['fuse_activation'] = self.fuse_activation
self.attrs['fuse_alpha'] = self.fuse_alpha
self.attrs['fuse_beta'] = self.fuse_beta
self.outputs['Output'] = output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册