提交 97d1db18 编写于 作者: A Adam 提交者: Tao Luo

Add generalized Conv+Activation MKLDNN fuse pass creation Part2 (#19237)

* Add generalized Conv+Activation MKLDNN fuse pass creation Part2
test=develop

* Undefined behaviour of GetAttrIfExists<> FIX
test=develop
上级 37428952
......@@ -800,60 +800,6 @@ PDNode *patterns::ConvActivation::operator()(
return activation_out_var;
}
PDNode *patterns::ConvReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu");
// output
auto *relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var});
return relu_out_var;
}
PDNode *patterns::ConvBReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu6");
// output
auto *brelu_out_var = pattern->NewNode(brelu_out_repr())
->AsOutput()
->assert_is_op_output("relu6");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var});
return brelu_out_var;
}
PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
......
......@@ -453,48 +453,6 @@ struct ConvActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out);
};
// CONV with ReLU
// op: conv + relu
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// relu_out, relu
struct ConvReLU : public PatternBase {
ConvReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_relu") {}
PDNode* operator()(PDNode* conv_input);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(relu);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(relu_out);
};
// CONV with ReLU6
// op: conv + relu6
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// relu6_out, relu6
struct ConvBReLU : public PatternBase {
ConvBReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bounded_relu") {}
PDNode* operator()(PDNode* conv_input);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(brelu);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(brelu_out);
};
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
......
......@@ -62,14 +62,10 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
boost::get<float>(activation->Op()->GetAttr("threshold")));
} else {
desc->SetAttr("fuse_alpha",
activation->Op()->HasAttr("alpha")
? boost::get<float>(activation->Op()->GetAttr("alpha"))
: 0.0f);
activation->Op()->GetAttrIfExists<float>("alpha"));
}
desc->SetAttr("fuse_beta",
activation->Op()->HasAttr("beta")
? boost::get<float>(activation->Op()->GetAttr("beta"))
: 0.0f);
activation->Op()->GetAttrIfExists<float>("beta"));
GraphSafeRemoveNodes(graph, {activation, conv_out});
......
......@@ -114,11 +114,7 @@ void MainTest(std::string activation) {
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
auto op_name = boost::get<std::string>(op->GetAttr("name"));
std::string fuse_activation =
op->HasAttr("fuse_activation")
? boost::get<std::string>(op->GetAttr("fuse_activation"))
: "";
if (fuse_activation == activation) {
if (op->GetAttrIfExists<std::string>("fuse_activation") == activation) {
++conv_activation_count;
}
// check if only "conv1" convolution is fused
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(),
"conv_bounded_relu_mkldnn_fuse");
conv_brelu_pattern(conv_input);
int found_conv_brelu_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBoundedReLUFusePass fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_brelu_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op
// Transform Conv node into ConvBReLU node.
OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({brelu_out->Name()}));
desc->SetAttr("fuse_brelu", true);
desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold"));
GraphSafeRemoveNodes(graph, {brelu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, brelu_out);
found_conv_brelu_count++;
};
gpd(graph, handler);
AddStatis(found_conv_brelu_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_brelu_mkldnn_fuse_pass,
paddle::framework::ir::ConvBReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the CONV and ReLU6 to a ConvReLU6Op.
*/
class ConvBReLUFusePass : public FusePassBase {
public:
virtual ~ConvBReLUFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu6") {
op->SetAttr("use_mkldnn", use_mkldnn);
if (use_mkldnn) {
op->SetAttr("threshold", 6.0f);
}
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
// b->OP1->c
// (c, weights, bias)->conv->f
// (f)->brelu->g
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
// conv+brelu, both with MKL-DNN
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), true);
SetOp(&prog, "relu6", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+brelu, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu6", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}));
return prog;
}
TEST(ConvBReLUFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_brelu_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: CONV, BRELU, conv_out
// Add 1 Node: ConvBReLU
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_brelu op in newly generated graph
int conv_brelu_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
// check if only "conv1" convolution is fused
auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "conv1") {
ASSERT_TRUE(op->HasAttr("fuse_brelu"));
ASSERT_TRUE(op->HasAttr("fuse_brelu_threshold"));
bool fuse_brelu = boost::get<bool>(op->GetAttr("fuse_brelu"));
if (fuse_brelu) {
++conv_brelu_count;
float fuse_brelu_threshold =
boost::get<float>(op->GetAttr("fuse_brelu_threshold"));
EXPECT_EQ(fuse_brelu_threshold, 6.0f);
}
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_brelu"));
}
}
}
EXPECT_EQ(conv_brelu_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_brelu_mkldnn_fuse_pass);
......@@ -109,11 +109,7 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
std::string fuse_activation =
conv_op->Op()->HasAttr("fuse_activation")
? boost::get<std::string>(conv_op->Op()->GetAttr("fuse_activation"))
: "";
if (fuse_activation == "relu" || fuse_activation == "relu6") return;
if (HasFusedActivation(conv_op)) return;
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
......@@ -182,12 +178,7 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
return;
}
std::string fuse_activation =
residual_conv_op->Op()->HasAttr("fuse_activation")
? boost::get<std::string>(
residual_conv_op->Op()->GetAttr("fuse_activation"))
: "";
if (fuse_activation == "relu" || fuse_activation == "relu6") return;
if (HasFusedActivation(residual_conv_op)) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
......
......@@ -126,6 +126,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
protected:
void ApplyImpl(graph_ptr graph) const;
static bool HasFusedActivation(Node* conv_node) {
return !(conv_node->Op()
->GetAttrIfExists<std::string>("fuse_activation")
.empty());
}
const std::string name_scope_{"residual_connection_fuse_pass"};
};
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("conv_relu_mkldnn_fuse", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_relu_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(),
"conv_relu_mkldnn_fuse");
conv_relu_pattern(conv_input);
int found_conv_relu_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvReLU fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_relu_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
FuseOptions fuse_option = FindFuseOption(*conv, *relu);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+relu fuse";
return;
}
// Transform Conv node into ConvReLU node.
OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
desc->SetAttr("fuse_relu", true);
GraphSafeRemoveNodes(graph, {relu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, relu_out);
found_conv_relu_count++;
};
gpd(graph, handler);
AddStatis(found_conv_relu_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_relu_mkldnn_fuse_pass,
paddle::framework::ir::ConvReLUFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the CONV and ReLU to a ConvReLUOp.
*/
class ConvReLUFusePass : public FusePassBase {
public:
virtual ~ConvReLUFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
// b->OP1->c
// (c, weights, bias)->conv->f
// (f)->relu->g
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
// conv+relu, both with MKL-DNN
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), true);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+relu, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}));
return prog;
}
TEST(ConvReLUFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: CONV, RELU, conv_out
// Add 1 Node: ConvReLU
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_relu_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
// check if only "conv1" convolution is fused
auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "conv1") {
ASSERT_TRUE(op->HasAttr("fuse_relu"));
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
if (fuse_relu) {
++conv_relu_count;
}
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_relu"));
}
}
}
EXPECT_EQ(conv_relu_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_relu_mkldnn_fuse_pass);
......@@ -209,12 +209,11 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
is_output_unsigned, "Scale_out");
// change threshold in bounded ReLu
if (conv_op->Op()->HasAttr("fuse_brelu") &&
boost::get<bool>(conv_op->Op()->GetAttr("fuse_brelu"))) {
if (conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation") ==
"relu6") {
float scale_out = boost::get<float>(conv_op->Op()->GetAttr("Scale_out"));
float threshold =
boost::get<float>(conv_op->Op()->GetAttr("fuse_brelu_threshold"));
conv_op->Op()->SetAttr("fuse_brelu_threshold", scale_out * threshold);
float threshold = boost::get<float>(conv_op->Op()->GetAttr("fuse_alpha"));
conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold);
}
++quantize_conv_count;
......
......@@ -80,6 +80,15 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const;
template <typename T>
T GetAttrIfExists(const std::string &name) const {
T result{};
if (HasAttr(name)) {
result = boost::get<T>(GetAttr(name));
}
return result;
}
const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const;
......
......@@ -69,10 +69,7 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() {
if (op->Type() == "conv2d") {
// output of conv2d with relu must be unsigned
std::string fuse_activation =
op->HasAttr("fuse_activation")
? boost::get<std::string>(
op->GetAttr("fuse_activation"))
: "";
op->GetAttrIfExists<std::string>("fuse_activation");
is_unsigned =
(fuse_activation == "relu" || fuse_activation == "relu6");
} else if (op->Type() == "relu") {
......
......@@ -482,14 +482,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(
new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key));
// create a conv primitive descriptor and save it for usage in backward
// TODO(grygielski) if INT8 brelu post-op will be available, just delete
// whole if statement
if (fuse_activation == "relu6") {
fuse_activation = "relu";
fuse_alpha = 0.0f;
}
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册