提交 5b2a3c4b 编写于 作者: S Sylwester Fraczek 提交者: Tao Luo

Conv concat relu quantization (#17466)

* add conv_concat_relu fuse

test=develop

* add test code

test=develop

* added missing include with unordered_map

test=develop

* review fixes for wojtuss

test=develop

* remove 'should (not) be fused' comment statements

one of them was invalid anyway

test=develop
上级 bccb0ba4
...@@ -86,6 +86,7 @@ if(WITH_MKLDNN) ...@@ -86,6 +86,7 @@ if(WITH_MKLDNN)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn) pass_library(cpu_quantize_pass inference mkldnn)
...@@ -116,6 +117,7 @@ if (WITH_MKLDNN) ...@@ -116,6 +117,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass) cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
......
...@@ -1184,6 +1184,46 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { ...@@ -1184,6 +1184,46 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
return out_var; return out_var;
} }
PDNode *patterns::ConcatReLU::operator()() {
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
auto concat_out =
pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
concat_op->LinksTo({concat_out});
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
return relu_out;
}
PDNode *patterns::ConvConcatReLU::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat");
auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu");
auto conv_out = pattern->NewNode(conv_out_repr())
->assert_is_op_output("conv2d", "Output");
auto concat_out = pattern->NewNode(concat_out_repr())
->assert_is_op_output("concat", "Out")
->assert_is_op_input("relu", "X");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
conv_op->LinksTo({conv_out});
concat_op->LinksFrom({conv_out}).LinksTo({concat_out});
relu_op->LinksFrom({concat_out}).LinksTo({relu_out});
return relu_out;
}
std::unordered_set<std::string> conv_act_set({"identity", "relu"}); std::unordered_set<std::string> conv_act_set({"identity", "relu"});
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
......
...@@ -728,6 +728,39 @@ struct ElementwiseAdd : public PatternBase { ...@@ -728,6 +728,39 @@ struct ElementwiseAdd : public PatternBase {
PATTERN_DECL_NODE(elementwise_add_out); PATTERN_DECL_NODE(elementwise_add_out);
}; };
// Concat + ReLU
// named nodes:
// concat_op, concat_out, relu_op, relu_out
struct ConcatReLU : public PatternBase {
ConcatReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "concat_relu") {}
PDNode* operator()();
PATTERN_DECL_NODE(concat_op);
PATTERN_DECL_NODE(concat_out);
PATTERN_DECL_NODE(relu_op);
PATTERN_DECL_NODE(relu_out);
};
// Conv + Concat + ReLU
// named nodes:
// conv_op, conv_out
// concat_op, concat_out, relu_op, relu_out
struct ConvConcatReLU : public PatternBase {
ConvConcatReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_concat_relu") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(concat_op);
PATTERN_DECL_NODE(concat_out);
PATTERN_DECL_NODE(relu_op);
PATTERN_DECL_NODE(relu_out);
};
// Conv + ElementwiseAdd + an activation // Conv + ElementwiseAdd + an activation
// This pattern can futher fuse the conv related ops after the conv+bn fusion. // This pattern can futher fuse the conv related ops after the conv+bn fusion.
struct ConvElementwiseaddAct : public PatternBase { struct ConvElementwiseaddAct : public PatternBase {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h"
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void ConvConcatReLUFusePass::FindConcatWithConvs(
ir::Graph* graph,
std::unordered_map<const Node*, int>* concat_with_convs_counter) const {
GraphPatternDetector gpd;
patterns::ConcatReLU concat_relu_pattern{gpd.mutable_pattern(),
"concat_relu"};
concat_relu_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Find Concats with Convs";
GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_relu_pattern);
GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, concat_relu_pattern);
auto concat_inputs = concat_op->inputs;
for (auto node : concat_inputs) {
auto prev_op_node = node->inputs;
PADDLE_ENFORCE_EQ(prev_op_node.size(), 1);
auto* conv_op = prev_op_node[0];
if (conv_op->Op()->Type() != "conv2d") return;
FuseOptions fuse_option = FindFuseOption(*conv_op, *relu_op);
if (fuse_option == DO_NOT_FUSE) {
return;
}
}
(*concat_with_convs_counter)[concat_op] = concat_inputs.size();
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
void ConvConcatReLUFusePass::FuseConvConcatReLU(
ir::Graph* graph,
std::unordered_map<const Node*, int>* concat_with_convs_counter) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::ConvConcatReLU conv_concat_relu(pattern, name_scope_);
conv_concat_relu();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvConcatReLU fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, conv_concat_relu);
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_concat_relu);
if (!concat_with_convs_counter->count(concat_op)) {
VLOG(4) << "this concat has input from non-conv2d operator";
return;
}
// Transform Conv node into ConvReLU node.
OpDesc* conv_desc = conv_op->Op();
conv_desc->SetAttr("fuse_relu", true);
// Remove ReLU when all Convs were transformed.
auto number_of_unfused_convs_left =
--(*concat_with_convs_counter)[concat_op];
if (number_of_unfused_convs_left == 0) {
OpDesc* concat_desc = concat_op->Op();
concat_desc->SetOutput("Out",
std::vector<std::string>({relu_out->Name()}));
GraphSafeRemoveNodes(graph, {relu_op, concat_out});
IR_NODE_LINK_TO(concat_op, relu_out);
}
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init(name_scope_, graph);
std::unordered_map<const Node*, int> concat_with_convs_counter;
FindConcatWithConvs(graph, &concat_with_convs_counter);
FuseConvConcatReLU(graph, &concat_with_convs_counter);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass,
paddle::framework::ir::ConvConcatReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the (multi conv) -> Concat -> ReLU -> next_op
* to a:
* (multi ConvReLU) -> Concat -> next_op.
*/
class ConvConcatReLUFusePass : public FusePassBase {
public:
virtual ~ConvConcatReLUFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
void FindConcatWithConvs(
Graph* graph,
std::unordered_map<const Node*, int>* concat_with_convs_counter) const;
void FuseConvConcatReLU(
Graph* graph,
std::unordered_map<const Node*, int>* concat_with_convs_counter) const;
const std::string name_scope_{"conv_concat_relu_mkldnn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn = true) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("fuse_relu", false);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) {
op->SetInput("Bias", {inputs[2]});
}
op->SetOutput("Output", outputs);
} else if (type == "relu") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
} else if (type == "pool2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
} else if (type == "concat") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// (a1,w1)->conv1->c1
// (a2,w2,b2)->conv2->c2
// if put_only_convs_before_concat=true
// (a3,w3)->conv3->c3
// else
// a3->pool1->c3
//
// (c1,c2,c3)->concat1->d
// d->relu1->e
ProgramDesc BuildProgramDesc(bool put_only_convs_before_concat,
bool all_convs_use_mkldnn) {
ProgramDesc prog;
for (auto& v :
std::initializer_list<std::string>({"a1", "w1", "c1", "a2", "w2", "b2",
"c2", "a3", "w3", "c3", "d", "e"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v.find("w") == 0 || v.find("b") == 0) {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", {"a1", "w1", "b1"}, {"c1"}, all_convs_use_mkldnn);
SetOp(&prog, "conv2d", {"a2", "w2", "b2"}, {"c2"});
if (put_only_convs_before_concat) {
SetOp(&prog, "conv2d", {"a3", "w3", "b3"}, {"c3"});
} else {
SetOp(&prog, "pool2d", {"a3"}, {"c3"});
}
SetOp(&prog, "concat", {"c1", "c2", "c3"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"});
return prog;
}
void MainTest(const ProgramDesc& prog, bool fuse_relu) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get("conv_concat_relu_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
if (fuse_relu) {
// Remove 2 nodes: concat_out, relu
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
} else {
EXPECT_EQ(original_nodes_num, current_nodes_num);
}
int relu_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "conv2d") {
ASSERT_TRUE(op->HasAttr("fuse_relu"));
bool fuse_relu_attr = boost::get<bool>(op->GetAttr("fuse_relu"));
EXPECT_EQ(fuse_relu, fuse_relu_attr);
} else if (op->Type() == "relu") {
relu_count++;
}
}
}
EXPECT_EQ(relu_count, fuse_relu ? 0 : 1);
}
TEST(ConvConcatReLUFusePass, only_convs_before_concat) {
bool all_convs_use_mkldnn = true;
bool put_only_convs_before_concat = true;
auto prog =
BuildProgramDesc(put_only_convs_before_concat, all_convs_use_mkldnn);
bool expect_relu_fuse = true;
MainTest(prog, expect_relu_fuse);
}
TEST(ConvConcatReLUFusePass, only_convs_before_concat_but_one_non_mkldnn) {
bool all_convs_use_mkldnn = false;
bool put_only_convs_before_concat = true;
auto prog =
BuildProgramDesc(put_only_convs_before_concat, all_convs_use_mkldnn);
bool expect_relu_fuse = false;
MainTest(prog, expect_relu_fuse);
}
TEST(ConvConcatReLUFusePass, convs_and_pool_before_concat) {
bool all_convs_use_mkldnn = true;
bool put_only_convs_before_concat = false;
auto prog =
BuildProgramDesc(put_only_convs_before_concat, all_convs_use_mkldnn);
bool expect_relu_fuse = false;
MainTest(prog, expect_relu_fuse);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_concat_relu_mkldnn_fuse_pass);
...@@ -153,6 +153,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -153,6 +153,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", // "conv_relu_mkldnn_fuse_pass", //
"conv_brelu_mkldnn_fuse_pass"})) { "conv_brelu_mkldnn_fuse_pass"})) {
passes_.push_back(pass); passes_.push_back(pass);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册