From 7faa3e95555ff76f86b747cc987aa212bd27af9f Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Sat, 12 Oct 2019 12:58:18 +0200 Subject: [PATCH] Add ConvTranspose + BatchNorm fuse pass (#20161) * Add ConvTranspose + BatchNorm fuse pass test=develop * Add tests for conv+bn and conv_transpose+bn passes test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../fluid/framework/ir/conv_bn_fuse_pass.cc | 48 +++++++---- paddle/fluid/framework/ir/conv_bn_fuse_pass.h | 12 +++ .../framework/ir/conv_bn_fuse_pass_tester.cc | 86 +++++++++++++++++++ .../framework/ir/graph_pattern_detector.cc | 9 +- .../framework/ir/graph_pattern_detector.h | 3 +- .../fluid/framework/ir/pass_tester_helper.h | 40 +++++++++ .../inference/api/paddle_pass_builder.cc | 28 +++--- 8 files changed, 194 insertions(+), 33 deletions(-) create mode 100644 paddle/fluid/framework/ir/conv_bn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 88acc284702..6db8487d67e 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -126,6 +126,7 @@ cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.c cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass) cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass) +cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) if(WITH_GPU) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) endif() diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 4fe3fb4f3dc..372087f0d5f 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -51,7 +51,7 @@ void recompute_bias_and_weights(const Scope* scope, const ir::Node& bn_mean, // const ir::Node& bn_variance, // LoDTensor* eltwise_y_in_tensor, // - float epsilon) { + float epsilon, const std::string& conv_type) { using EigenVectorArrayMap = Eigen::Map>; using ConstEigenVectorArrayMap = @@ -92,13 +92,26 @@ void recompute_bias_and_weights(const Scope* scope, // Re-compute weight of conv2d from BN auto* weights = scope->FindVar(conv_weight->Name())->GetMutable(); auto weights_shape = weights->dims(); - auto weights_shape_2d = flatten_to_2d(weights_shape, 1); + auto weights_data = weights->mutable_data(platform::CPUPlace()); + + // ConvTranspose weights are in IOHW format + if (conv_type == "conv2d_transpose") { + int kernel_size = weights_shape[2] * weights_shape[3]; + for (int i = 0; i < weights->numel();) { + for (int j = 0; j < weights_shape[1]; ++j) { + for (int k = 0; k < kernel_size; ++k, ++i) { + weights_data[i] *= variance_array[j]; + } + } + } + } else { + auto weights_shape_2d = flatten_to_2d(weights_shape, 1); - EigenMatrixArrayMap weights_array_2d( - weights->mutable_data(platform::CPUPlace()), weights_shape_2d[0], - weights_shape_2d[1]); + EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0], + weights_shape_2d[1]); - weights_array_2d.colwise() *= variance_array; + weights_array_2d.colwise() *= variance_array; + } } void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { @@ -113,14 +126,14 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->AsInput() - ->assert_is_op_input("conv2d", "Input"); + ->assert_is_op_input(conv_type(), "Input"); patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_); - conv_bn_pattern(conv_input, false /*with_eltwise_add*/); + conv_bn_pattern(conv_input, conv_type(), false /*with_eltwise_add*/); int found_conv_bn_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "handle ConvBN fuse"; + VLOG(4) << "handle " + conv_type() + "BN fuse"; // conv, batch_norm, // conv_weight, conv_out, @@ -132,7 +145,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { // check if fuse can be done and if MKL-DNN should be used FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm); if (fuse_option == DO_NOT_FUSE) { - VLOG(3) << "do not perform conv+bn fuse"; + VLOG(3) << "do not perform " + conv_type() + " bn fuse"; return; } @@ -160,7 +173,7 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { float epsilon = boost::get(batch_norm->Op()->GetAttr("epsilon")); recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, *bn_mean, *bn_variance, eltwise_y_in_tensor, - epsilon); + epsilon, conv_type()); // with MKL-DNN fuse conv+bn into conv with bias // without MKL-DNN fuse conv+bn into conv+elementwise_add @@ -187,7 +200,6 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { } conv->Op()->SetOutput("Output", std::vector({bn_out->Name()})); - GraphSafeRemoveNodes( graph, {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, @@ -233,14 +245,14 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->AsInput() - ->assert_is_op_input("conv2d", "Input"); + ->assert_is_op_input(conv_type(), "Input"); patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_); - conv_bn_pattern(conv_input, true /*with_eltwise_add*/); + conv_bn_pattern(conv_input, conv_type(), true /*with_eltwise_add*/); int found_conv_bn_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "handle ConvBN fuse"; + VLOG(4) << "handle " + conv_type() + "BN fuse"; // conv, batch_norm, // conv_weight, conv_out, @@ -266,7 +278,7 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { float epsilon = boost::get(batch_norm->Op()->GetAttr("epsilon")); recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, *bn_mean, *bn_variance, eltwise_y_in_tensor, - epsilon); + epsilon, conv_type()); // Update the elementwise_add node eltwise->Op()->SetAttr("axis", 1); @@ -294,3 +306,7 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass); REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass, paddle::framework::ir::ConvEltwiseAddBNFusePass); +REGISTER_PASS(conv_transpose_bn_fuse_pass, + paddle::framework::ir::ConvTransposeBNFusePass); +REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass, + paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h index 837a48ed730..fcdbcf299c5 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h @@ -29,6 +29,7 @@ namespace ir { class ConvBNFusePass : public FusePassBase { public: virtual ~ConvBNFusePass() {} + virtual std::string conv_type() const { return "conv2d"; } protected: void ApplyImpl(ir::Graph* graph) const override; @@ -38,12 +39,23 @@ class ConvBNFusePass : public FusePassBase { class ConvEltwiseAddBNFusePass : public FusePassBase { public: virtual ~ConvEltwiseAddBNFusePass() {} + virtual std::string conv_type() const { return "conv2d"; } protected: void ApplyImpl(ir::Graph* graph) const override; const std::string name_scope_{"conv_eltwiseadd_bn_fuse"}; }; +class ConvTransposeBNFusePass : public ConvBNFusePass { + public: + std::string conv_type() const { return "conv2d_transpose"; } +}; + +class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { + public: + std::string conv_type() const { return "conv2d_transpose"; } +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass_tester.cc new file mode 100644 index 00000000000..168d0afb26d --- /dev/null +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass_tester.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "bias_1", {3}); + AddVarToScope(param_scope, "scale", {3}); + AddVarToScope(param_scope, "mean", {3}); + AddVarToScope(param_scope, "variance", {3}); + AddVarToScope(param_scope, "filters", {3, 3, 2, 2}); + return param_scope; +} + +void TestMain(const std::string& conv_type) { + // inputs operator output + // ------------------------------------------------------------------ + // (in, filters, bias_0) conv -> conv_out + // (conv_out, scale, + // bias_1, mean, varaince) batch_norm -> (...) + Layers layers; + auto* in = layers.data("in", {1, 3, 20, 20}); + auto* filters = layers.data("filters", {3, 3, 2, 2}, true); + auto* bias_0 = layers.data("bias_0", {3}, true); + VarDesc* conv_out; + if (conv_type == "conv_transpose") { + conv_out = layers.conv2d_transpose(in, filters, bias_0); + } else { + conv_out = layers.conv2d(in, filters, bias_0); + } + conv_out->SetShape({1, 3, 20, 20}); + auto* scale = layers.data("scale", {3}, true); + auto* bias_1 = layers.data("bias_1", {3}, true); + auto* mean = layers.data("mean", {3}, true); + auto* variance = layers.data("variance", {3}, true); + layers.batch_norm(conv_out, scale, bias_1, mean, variance); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + auto pass = PassRegistry::Instance().Get(conv_type + "_bn_fuse_pass"); + int num_bn_nodes_before = GetNumOpNodes(graph, "batch_norm"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_bn_nodes_after = GetNumOpNodes(graph, "batch_norm"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_bn_nodes_before, 1); + PADDLE_ENFORCE_EQ(num_bn_nodes_after, 0); +} + +TEST(ConvBNFusePass, conv2d) { TestMain("conv"); } + +TEST(ConvBNFusePass, conv2d_tranpose) { TestMain("conv_transpose"); } + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_bn_fuse_pass); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index bbb2ee2f56a..b628ccc8684 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -666,10 +666,11 @@ bool VarLinksFromOp(Node *node, const std::string &op_type) { } PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, + const std::string &conv_type, bool with_eltwise_add) { // Create Operators - conv_input->assert_is_op_input("conv2d", "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + conv_input->assert_is_op_input(conv_type, "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type); PDNode *eltwise_op = nullptr; if (with_eltwise_add) { @@ -683,11 +684,11 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) ->AsInput() ->assert_is_persistable_var() - ->assert_is_op_input("conv2d", "Filter"); + ->assert_is_op_input(conv_type, "Filter"); auto *conv_out_var = pattern->NewNode(conv_out_repr()) ->AsIntermediate() - ->assert_is_only_output_of_op("conv2d"); + ->assert_is_only_output_of_op(conv_type); PDNode *eltwise_y_in_var = nullptr; PDNode *eltwise_out_var = nullptr; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 0d7d56cabf3..5fea6523657 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -404,7 +404,8 @@ struct ConvBN : public PatternBase { ConvBN(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "conv_bn") {} - PDNode* operator()(PDNode* conv_input, bool with_eltwise_add); + PDNode* operator()(PDNode* conv_input, const std::string& conv_type, + bool with_eltwise_add); // declare operator node's name PATTERN_DECL_NODE(conv); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 8df292b483b..970bd2d58d5 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -48,6 +48,19 @@ struct Layers { return out; } + VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("conv2d_transpose"); + op->SetInput("Input", {input->Name()}); + op->SetInput("Filter", {filter->Name()}); + op->SetInput("Bias", {bias->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + return out; + } + VarDesc* depthwise_conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias, bool use_cudnn) { VarDesc* out = lod_tensor(unique_name()); @@ -162,6 +175,33 @@ struct Layers { return outs; } + std::vector batch_norm(VarDesc* x, VarDesc* scale, VarDesc* bias, + VarDesc* mean, VarDesc* variance) { + VarDesc* y = lod_tensor(unique_name()); + VarDesc* mean_out = lod_tensor(unique_name()); + VarDesc* variance_out = lod_tensor(unique_name()); + VarDesc* saved_mean = lod_tensor(unique_name()); + VarDesc* saved_variance = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("batch_norm"); + op->SetInput("X", {x->Name()}); + op->SetInput("Scale", {scale->Name()}); + op->SetInput("Bias", {bias->Name()}); + op->SetInput("Mean", {mean->Name()}); + op->SetInput("Variance", {variance->Name()}); + op->SetOutput("Y", {y->Name()}); + op->SetOutput("MeanOut", {mean_out->Name()}); + op->SetOutput("VarianceOut", {variance_out->Name()}); + op->SetOutput("SavedMean", {saved_mean->Name()}); + op->SetOutput("SavedVariance", {saved_variance->Name()}); + op->SetAttr("epsilon", static_cast(1e-5)); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + std::vector outs = {y, mean_out, variance_out, saved_mean, + saved_variance}; + return outs; + } + private: VarDesc* lod_tensor(std::string name, std::vector shape = {}, bool is_persistable = false) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e81a842814a..e436367872b 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -155,17 +155,19 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // "seqpool_concat_fuse_pass", // "seqpool_cvm_concat_fuse_pass", // // "embedding_fc_lstm_fuse_pass", // - "fc_lstm_fuse_pass", // - "mul_lstm_fuse_pass", // - "fc_gru_fuse_pass", // - "mul_gru_fuse_pass", // - "seq_concat_fc_fuse_pass", // - "fc_fuse_pass", // - "repeated_fc_relu_fuse_pass", // - "squared_mat_sub_fuse_pass", // - "conv_bn_fuse_pass", // - "conv_eltwiseadd_bn_fuse_pass", // - "is_test_pass", // + "fc_lstm_fuse_pass", // + "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "fc_fuse_pass", // + "repeated_fc_relu_fuse_pass", // + "squared_mat_sub_fuse_pass", // + "conv_bn_fuse_pass", // + "conv_eltwiseadd_bn_fuse_pass", // + "conv_transpose_bn_fuse_pass", // + "conv_transpose_eltwiseadd_bn_fuse_pass", // + "is_test_pass", // // following pass should be located in the last, since // it will work on all fused ops. "runtime_context_cache_pass"}); @@ -185,7 +187,9 @@ void CpuPassStrategy::EnableMKLDNN() { "depthwise_conv_mkldnn_pass", // "conv_bn_fuse_pass", // Execute BN passes again to "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order - "conv_bias_mkldnn_fuse_pass", // + "conv_transpose_bn_fuse_pass", // + "conv_transpose_eltwiseadd_bn_fuse_pass", // + "conv_bias_mkldnn_fuse_pass", // "conv_transpose_bias_mkldnn_fuse_pass", "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", -- GitLab