提交 2281ebf0 编写于 作者: G guomingz 提交者: Tao Luo

Enable the convolution/relu6(bounded_relu) fusion for FP32 on Intel platform. (#17130)

* Relu6 is the bottleneck op for Mobilenet-v2. As the mkldnn supports the conv/relu6 fusion, we implement it fusion via cpass way. Due to the int8 enabling for this fusion will be supported in MKLDNN v0.20, so this PR is focused on the fp32 optimization.

Below table shows the benchmark(FPS) which measured on skx-8180(28 cores)
Batch size | with fusion | without fusion
-- | -- | --
1 | 214.7 | 53.4
50 | 1219.727 | 137.280

test=develop

* Fix the format issue

test=develop

* Add the missing nolint comments.

test=develop

* Fix the typos.

test=develop

* Register the conv_brelu_mkldnn_fuse_pass for the MKLDNN engine.

test=develop

* Adjust the indentation.

test=develop

* Add the test_conv_brelu_mkldnn_fuse_pass case.

test=develop

* Slightly update the code per Baidu comments.
Let the parameter definition embedded into the code.
That's will make the code easy to understand.

test=develop
上级 3398f996
develop 2.0.1-rocm-post Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease cherry_undefined_var compile_windows delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_paddle_tiny_install delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_dataloader_memory_leak fix_imperative_dygraph_error fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/123malin/netifaces github/fork/123malin/tdm_abacus github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LielinJiang/visual-dl-cb github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/fix-example-code-for-hapi-Model github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/fix_ps_profiler github/fork/MrChengmo/update_ps_heter github/fork/PWhiddy/patch-1 github/fork/Shixiaowei02/dev/save_load_upgrade github/fork/TCChenlong/fix_hapi github/fork/TCChenlong/fix_inden github/fork/Thunderbrook/xpu_slice github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2 github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/XieYunshen/timeout_20S_ut github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chalsliu/set_timeout github/fork/chen-zhiyu/develop github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads github/fork/chenwhql/saveload/add_get_inference_program github/fork/chenwhql/saveload/remove_save_load_config github/fork/cryoco/pass-compatibility-trt github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/hbwx24/error_msg/cuda_kernel_error_msg github/fork/heavengate/cherry_yolo_box github/fork/heavengate/update_yolo_box github/fork/iclementine/rnn_fix github/fork/iducn/testestse github/fork/jczaja/prv-25537-fix github/fork/jeff41404/release/1.8 github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/juncaipeng/fix_doc_1 github/fork/lfchener/sample_code github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/luotao1/profiler_ut github/fork/mapingshuo/add_wait github/fork/mapingshuo/doc_2.0 github/fork/mapingshuo/zero-0.5 github/fork/miraiwk/dev github/fork/pangyoki/add-Categorical-class-branch github/fork/pangyoki/add-multinomial-op-branch github/fork/pangyoki/fix-test_distritbution-CI github/fork/qjing666/doublegrad github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/sandyhouse/pipeline_exe_run github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_errors_fix github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0924 github/fork/smallv0221/yxp0925 github/fork/swtkiwi/del-matplotlib github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/bert_fuse github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_coverage_build_sh github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/momentum_op github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wojtuss/wojtuss/fusion_gru_quantization github/fork/wojtuss/wojtuss/quantization-with-shift github/fork/wzzju/fix_err_info github/fork/wzzju/pure_fp16 github/fork/xiemoyuan/op_error_message github/fork/xiemoyuan/optimize_error_message github/fork/yaoxuefeng6/fix_doc github/fork/yaoxuefeng6/mod_dataset_v2 github/fork/yongqiangma/lod github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/conv_filter_grad github/fork/zhangting2020/is_compile_with_cuda github/fork/zhangting2020/place_doc github/fork/zhangting2020/program github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhhsplendid/refine_api2_test github/fork/zhhsplendid/refine_api_test_ptb_lm github/fork/zhhsplendid/refine_api_test_resnet github/fork/zhhsplendid/refine_api_test_simnet github/fork/zhiqiu/dev/refine_initializer github/fork/zhiqiu/dev/remove_inplace_argument github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11 improve_sccache incubate/infrt inplace_addto make_flag_adding_easier move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc numel paddle_tiny_install paralleltest preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/1.5 release/1.6 release/1.7 release/1.8 release/2.0 release/2.0-alpha release/2.0-beta release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 revert-24981-add_device_attr_for_regulization revert-26856-strategy_example2 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment rocm_dev_0217 support_weight_transpose test_benchmark_ci test_feature_precision_test_c test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0 v2.0.0-beta0 v2.0.0-alpha0 v1.8.5 v1.8.4 v1.8.3 v1.8.2 v1.8.1 v1.8.0 v1.7.2 v1.7.1 v1.7.0 v1.6.3 v1.6.2 v1.6.1 v1.6.0 v1.6.0-rc0 v1.5.2 v1.5.1 v1.5.0
无相关合并请求
......@@ -85,6 +85,7 @@ if(WITH_MKLDNN)
pass_library(depthwise_conv_mkldnn_pass base mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn)
......@@ -114,6 +115,7 @@ if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
......
......@@ -785,6 +785,33 @@ PDNode *patterns::ConvReLU::operator()(
return relu_out_var;
}
PDNode *patterns::ConvBReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("relu6");
// output
auto *brelu_out_var = pattern->NewNode(brelu_out_repr())
->AsOutput()
->assert_is_op_output("relu6");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var});
return brelu_out_var;
}
PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
......
......@@ -449,6 +449,27 @@ struct ConvReLU : public PatternBase {
PATTERN_DECL_NODE(relu_out);
};
// CONV with ReLU6
// op: conv + relu6
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// relu6_out, relu6
struct ConvBReLU : public PatternBase {
ConvBReLU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bounded_relu") {}
PDNode* operator()(PDNode* conv_input);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(brelu);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(brelu_out);
};
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph);
GraphPatternDetector gpd;
auto* conv_input = gpd.mutable_pattern()
->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input")
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(),
"conv_bounded_relu_mkldnn_fuse");
conv_brelu_pattern(conv_input);
int found_conv_brelu_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBoundedReLUFusePass fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_brelu_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op
GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op
// Transform Conv node into ConvBReLU node.
OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({brelu_out->Name()}));
desc->SetAttr("fuse_brelu", true);
desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold"));
GraphSafeRemoveNodes(graph, {brelu, conv_out});
PADDLE_ENFORCE(subgraph.count(conv_input));
IR_NODE_LINK_TO(conv, brelu_out);
found_conv_brelu_count++;
};
gpd(graph, handler);
AddStatis(found_conv_brelu_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_brelu_mkldnn_fuse_pass,
paddle::framework::ir::ConvBReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the CONV and ReLU6 to a ConvReLU6Op.
*/
class ConvBReLUFusePass : public FusePassBase {
public:
virtual ~ConvBReLUFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu6") {
op->SetAttr("use_mkldnn", use_mkldnn);
if (use_mkldnn) {
op->SetAttr("threshold", 6.0f);
}
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
// b->OP1->c
// (c, weights, bias)->conv->f
// (f)->brelu->g
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
// conv+brelu, both with MKL-DNN
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), true);
SetOp(&prog, "relu6", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+brelu, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu6", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}));
return prog;
}
TEST(ConvBReLUFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_brelu_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
// Remove 3 Nodes: CONV, BRELU, conv_out
// Add 1 Node: ConvBReLU
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_brelu op in newly generated graph
int conv_brelu_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
// check if only "conv1" convolution is fused
auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "conv1") {
ASSERT_TRUE(op->HasAttr("fuse_brelu"));
ASSERT_TRUE(op->HasAttr("fuse_brelu_threshold"));
bool fuse_brelu = boost::get<bool>(op->GetAttr("fuse_brelu"));
if (fuse_brelu) {
++conv_brelu_count;
float fuse_brelu_threshold =
boost::get<float>(op->GetAttr("fuse_brelu_threshold"));
EXPECT_EQ(fuse_brelu_threshold, 6.0f);
}
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_brelu"));
}
}
}
EXPECT_EQ(conv_brelu_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_brelu_mkldnn_fuse_pass);
......@@ -153,7 +153,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass"})) {
"conv_relu_mkldnn_fuse_pass", //
"conv_brelu_mkldnn_fuse_pass"})) {
passes_.push_back(pass);
}
}
......
......@@ -209,6 +209,12 @@ void Conv2DOpMaker::Make() {
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_brelu",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<float>("fuse_brelu_threshold",
"(float, default false 6.0) Only used in mkldnn kernel")
.SetDefault(6.0f);
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual "
......
......@@ -119,9 +119,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = false;
float fuse_brelu_threshold = 6.0;
int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
if (!is_conv3d) {
fuse_brelu = ctx.Attr<bool>("fuse_brelu");
fuse_brelu_threshold = ctx.Attr<float>("fuse_brelu_threshold");
}
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
is_conv3d
......@@ -142,8 +147,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations,
groups, ctx.op().Input("Input") + ctx.op().Input("Filter"));
std::vector<primitive> pipeline;
......@@ -194,11 +199,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, fwd_prop_kind);
fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold,
fwd_prop_kind);
} else {
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, fwd_prop_kind);
mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu,
fuse_brelu_threshold, fwd_prop_kind);
}
// create mkldnn memory from input tensors (data/weights)
......@@ -317,13 +324,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups");
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool fuse_brelu = ctx.Attr<bool>("fuse_brelu");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
if (fuse_residual_conn) {
PADDLE_ENFORCE(force_fp32_output != true,
"residual fusion does not support force output with fp32");
}
bool is_conv3d = strides.size() == 3U;
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
......@@ -334,6 +340,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"dilation in convolution is not implemented yet");
PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently");
PADDLE_ENFORCE(fuse_brelu != true,
"int8 does not support conv/relu6 fusion currently");
const T* input_data = input->data<T>();
......@@ -341,15 +349,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
auto dst_dt = fuse_relu ? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<uint8_t>::DataType)
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
auto dst_dt = (fuse_relu) ? paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<uint8_t>::DataType)
: paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<int8_t>::DataType);
if (force_fp32_output) {
dst_dt = paddle::framework::ToMKLDNNDataType(
......@@ -367,8 +377,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_relu, fuse_residual_conn,
input->format(), fuse_relu, fuse_residual_conn, false /*fuse_brelu*/,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
......@@ -449,22 +460,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale, is_test);
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn, false /*fuse_brelu*/,
0.0 /*fuse_brelu_threshold*/, output_shift_scale, sum_scale,
is_test);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale, is_test);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine, fuse_relu,
fuse_residual_conn, false /*fuse_brelu*/,
0.0 /*fuse_brelu_threshold*/,
output_shift_scale, sum_scale, is_test);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights)
user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
......@@ -632,11 +645,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private:
mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, bool fuse_residual_conn,
const std::vector<float> output_shift_scale, float sum_scale) const {
const std::vector<float> output_shift_scale, float sum_scale,
bool fuse_brelu, float fuse_brelu_threshold) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
......@@ -647,6 +662,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
if (fuse_brelu) {
constexpr float scale = 1.0f;
constexpr float placeholder = 0.0f; // beta
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_brelu_threshold, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
......@@ -656,7 +678,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold,
const std::vector<float> output_shift_scale,
const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
......@@ -668,9 +691,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto conv_desc = mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
sum_scale, fuse_brelu, fuse_brelu_threshold);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
......@@ -685,7 +708,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold,
const std::vector<float> output_shift_scale,
const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
......@@ -698,8 +722,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
propagation, mkldnn::convolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
sum_scale, fuse_brelu, fuse_brelu_threshold);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
......@@ -762,7 +787,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz =
paddle::framework::vectorize2int(output_grad->dims());
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_brelu = false;
if (!is_conv3d) {
fuse_brelu = ctx.Attr<bool>("fuse_brelu");
}
auto src_format = input->format();
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
......@@ -771,8 +800,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations,
groups, ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline;
......
......@@ -166,11 +166,11 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), mkldnn::memory::format::x);
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, false, fwd_prop_kind);
fuse_relu, false, false, 0.0, fwd_prop_kind);
} else {
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, false, fwd_prop_kind);
mkldnn_engine, fuse_relu, false, false, 0.0, fwd_prop_kind);
}
// create mkldnn memory from input tensors (data/weights)
......
......@@ -212,25 +212,29 @@ class MKLDNNHandler {
dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<T>(output_data)));
}
static void AppendKey(std::string* key,
const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations, const int& groups,
const mkldnn::memory::data_type& srcdt,
const mkldnn::memory::format& format, const bool& relu,
const bool& residual, const std::string& suffix) {
static void AppendKey(
std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& srcdt,
const mkldnn::memory::format& format, const bool& relu,
const bool& residual, const bool& brelu, const std::string& suffix) {
AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims);
AppendKeyVec(key, strides);
AppendKeyVec(key, paddings);
AppendKeyVec(key, dilations);
AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format));
AppendKey(key, std::to_string(relu));
AppendKey(key, std::to_string(residual));
AppendKey(key, std::to_string(brelu));
AppendKey(key, suffix);
}
......@@ -562,8 +566,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
scale_data, mask);
}
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_residual_conn = false) const {
mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn,
bool fuse_brelu,
float fuse_brelu_threshold) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with
......@@ -583,6 +588,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
if (fuse_brelu) {
constexpr float scale = 1.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_brelu_threshold, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
......@@ -594,6 +607,7 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
const mkldnn::memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const mkldnn::engine& engine,
const bool fuse_relu, const bool fuse_residual_conn,
const bool fuse_brelu, const float fuse_brelu_threshold,
mkldnn::prop_kind fwd_prop_kind) {
const std::string key_conv_pd = key_ + "@conv_pd";
......@@ -614,8 +628,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
weights, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold);
conv_pd_.reset(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine));
......@@ -714,6 +728,22 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
return conv_bwd_data_p;
}
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
mkldnn::memory::dims& weights_dims, // NOLINT
const bool& fuse_relu, // NOLINT
const bool& fuse_brelu, // NOLINT
std::vector<int>& strides, // NOLINT
std::vector<int>& paddings, // NOLINT
std::vector<int>& dilations, // NOLINT
int groups, const std::string& suffix) {
return dims2str(input_dims) + dims2str(weights_dims) +
std::to_string(fuse_relu) + std::to_string(fuse_brelu) +
dims2str(strides) + dims2str(paddings) + dims2str(dilations) +
std::to_string(groups) + suffix;
}
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT
......
......@@ -57,6 +57,8 @@ class TestConv2dMKLDNNOp(TestConv2dOp):
self.fuse_bias = False
self.bias_size = None
self.fuse_relu = False
self.fuse_brelu = False
self.fuse_brelu_threshold = 6.0
self.fuse_residual_connection = False
self.input_residual_size = None
TestConv2dOp.setUp(self)
......@@ -84,15 +86,38 @@ class TestConv2dMKLDNNOp(TestConv2dOp):
if self.fuse_relu:
output = np.maximum(output, 0).astype(self.dsttype)
if self.fuse_brelu:
output = np.minimum(
np.maximum(output, 0),
self.fuse_brelu_threshold).astype(self.dsttype)
output = output.astype(self.dtype)
self.attrs['fuse_bias'] = self.fuse_bias
self.attrs['fuse_relu'] = self.fuse_relu
self.attrs['fuse_brelu'] = self.fuse_brelu
self.attrs['fuse_brelu_threshold'] = self.fuse_brelu_threshold
self.attrs['fuse_residual_connection'] = self.fuse_residual_connection
self.outputs['Output'] = output
class TestWithbreluFusion(TestConv2dMKLDNNOp):
def init_test_case(self):
TestConv2dMKLDNNOp.init_test_case(self)
self.fuse_brelu = True
self.fuse_brelu_threshold = 6.0
self.dsttype = np.float32
def test_check_grad(self):
pass
def test_check_grad_no_filter(self):
pass
def test_check_grad_no_input(self):
pass
class TestWithFuse(TestConv2dMKLDNNOp):
def init_test_case(self):
TestConv2dMKLDNNOp.init_test_case(self)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册