From fdc06f21587b6d65196a65ff9edacc09442296fb Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Tue, 27 Oct 2020 11:21:33 +0800 Subject: [PATCH] add Fuse bn add act pass (#28196) * add fuse_bn_add_act pass --- paddle/fluid/framework/details/CMakeLists.txt | 2 +- .../fluid/framework/details/build_strategy.cc | 8 + .../fluid/framework/details/build_strategy.h | 1 + paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/fuse_bn_add_act_pass.cc | 365 ++++++++++++++++++ .../fluid/framework/ir/fuse_bn_add_act_pass.h | 75 ++++ .../framework/ir/graph_pattern_detector.cc | 186 +++++++++ .../framework/ir/graph_pattern_detector.h | 72 ++++ .../fused/fused_bn_add_activation_op.cc | 2 - .../fused/fused_bn_add_activation_op.cu | 1 - .../fused/fused_bn_add_activation_op.h | 1 - paddle/fluid/pybind/pybind.cc | 25 ++ .../fluid/tests/unittests/CMakeLists.txt | 2 + ...dd_act.py => test_fuse_bn_add_act_pass.py} | 46 ++- 14 files changed, 771 insertions(+), 16 deletions(-) create mode 100644 paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc create mode 100644 paddle/fluid/framework/ir/fuse_bn_add_act_pass.h rename python/paddle/fluid/tests/unittests/{test_fused_bn_add_act.py => test_fuse_bn_add_act_pass.py} (85%) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 8281ec21438..29db49a47cf 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -107,7 +107,7 @@ cc_test(exception_holder_test SRCS exception_holder_test.cc ) set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass - fuse_elewise_add_act_pass fuse_bn_act_pass + fuse_elewise_add_act_pass fuse_bn_act_pass fuse_bn_add_act_pass multi_batch_merge_pass fuse_relu_depthwise_conv_pass lock_free_optimize_pass diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 962f968c84e..678946fbc51 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -164,6 +164,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, "fuse_relu_depthwise_conv_pass"); AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass"); + AppendPassWithCheck(strategy_.fuse_bn_add_act_ops_, "fuse_bn_add_act_pass"); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); #else @@ -390,6 +391,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, "GPU, skipped."; continue; } + } else if (pass->Type() == "fuse_bn_add_act_pass") { + if (!use_cuda) { + LOG(WARNING) << "fuse_bn_add_act_pass is only supported on " + "GPU, skipped."; + continue; + } } else if (pass->Type() == "mkldnn_placement_pass") { pass->Set("mkldnn_enabled_op_types", new std::unordered_set(mkldnn_enabled_op_types_)); @@ -416,6 +423,7 @@ USE_PASS(sync_batch_norm_pass); USE_PASS(fuse_relu_depthwise_conv_pass); USE_PASS(fuse_elewise_add_act_pass); USE_PASS(fuse_bn_act_pass); +USE_PASS(fuse_bn_add_act_pass); USE_PASS(graph_viz_pass); USE_PASS(multi_batch_merge_pass); USE_PASS(reduce_mode_multi_devices_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 87b27eaa440..bc275cb8f3b 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -100,6 +100,7 @@ struct BuildStrategy { // TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have // cycle. bool fuse_bn_act_ops_{false}; + bool fuse_bn_add_act_ops_{true}; bool fuse_elewise_add_act_ops_{false}; bool enable_auto_fusion_{false}; // Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 9415fe6e61e..f9ab60c5c74 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -114,6 +114,7 @@ if(WITH_MKLDNN) endif() cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) +cc_library(fuse_bn_add_act_pass SRCS fuse_bn_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector ) diff --git a/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc new file mode 100644 index 00000000000..774f655c7bb --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc @@ -0,0 +1,365 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fuse_bn_add_act_pass.h" +#include +#include +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif + +namespace paddle { +namespace framework { +namespace ir { + +void FuseBatchNormAddActPass::ApplyImpl(ir::Graph *graph) const { +#ifdef PADDLE_WITH_CUDA +#if CUDNN_VERSION_MIN(7, 4, 1) + // forward + std::unordered_set act_types = {"relu"}; + graph = FuseBatchNormAddAct(graph, act_types); + // backward + std::unordered_set act_grad_types = {"relu_grad"}; + graph = FuseBatchNormAddActGrad(graph, act_grad_types); +#endif +#endif +} + +// act(bn(x) + z) +ir::Graph *FuseBatchNormAddActPass::FuseBatchNormAddAct( + ir::Graph *graph, const std::unordered_set &act_types) const { + PADDLE_ENFORCE_NE( + graph, nullptr, + platform::errors::InvalidArgument( + "The input graph of FuseBatchNormAddAct should not be nullptr.")); + FusePassBase::Init("bn_add_act", graph); + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("bn_add_act/x") + ->AsInput() + ->assert_is_op_input("batch_norm", "X") + ->assert_var_dtype(proto::VarType::FP16); + patterns::BatchNormAddAct bn_add_act_pattern(gpd.mutable_pattern(), + "bn_add_act"); + + bn_add_act_pattern(x, act_types); + + int found_bn_add_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle FuseBatchNormAddAct fuse"; + // BN inputs + GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_add_act_pattern); + // BN outputs + GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, bn_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, + bn_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, + bn_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, bn_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space, + bn_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_add_act_pattern); + // Add outputs + GET_IR_NODE_FROM_SUBGRAPH(elewise_add_in, elewise_add_in, + bn_add_act_pattern); + // Add outputs + GET_IR_NODE_FROM_SUBGRAPH(elewise_add_out, elewise_add_out, + bn_add_act_pattern); + // ACT output + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_add_act_pattern); + // ops + GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, bn_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elewise_add, elewise_add, bn_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_add_act_pattern); + + std::string bn_x_n = subgraph.at(x)->Name(); + std::string elewise_add_in_n = elewise_add_in->Name(); + std::string bn_scale_n = bn_scale->Name(); + std::string bn_bias_n = bn_bias->Name(); + std::string bn_mean_out_n = bn_mean_out->Name(); + std::string bn_variance_out_n = bn_variance_out->Name(); + std::string bn_saved_variance_n = bn_saved_variance->Name(); + std::string bn_saved_mean_n = bn_saved_mean->Name(); + std::string bn_reserve_space_n = bn_reserve_space->Name(); + std::string bn_out_n = bn_out->Name(); + std::string elewise_add_out_n = elewise_add_out->Name(); + std::string act_out_n = act_out->Name(); + + Node *fused_bn_add_act_node = CreateFusedBatchNormAddActNode( + g, act, elewise_add, batch_norm, bn_x_n, elewise_add_in_n, bn_scale_n, + bn_bias_n, bn_mean_out_n, bn_variance_out_n, bn_saved_variance_n, + bn_saved_mean_n, bn_reserve_space_n, act_out_n); + + VLOG(4) << "\n\t " << bn_x_n << ", " << bn_scale_n << ", " << bn_bias_n + << " -> " << batch_norm->Name() << " -> " << bn_mean_out_n << ", " + << bn_variance_out_n << ", " << bn_saved_variance_n << ", " + << bn_saved_mean_n << ", " << bn_reserve_space_n << " and " + << bn_out_n << "\n" + << "\t " << bn_out_n << " and " << elewise_add_in_n << " -> " + << elewise_add->Name() << " -> " << elewise_add_out_n << "\n" + << "\t " << elewise_add_out_n << " -> " << act->Name() << " -> " + << act_out_n; + + ReLinkNodes(g, batch_norm, elewise_add, act, fused_bn_add_act_node); + found_bn_add_act_count++; + }; + + gpd(graph, handler); + + AddStatis(found_bn_add_act_count); + return graph; +} + +Node *FuseBatchNormAddActPass::CreateFusedBatchNormAddActNode( + Graph *g, const Node *act, const Node *elewise_add, const Node *bn, + const std::string &bn_x_n, const std::string &elewise_add_in_n, + const std::string &bn_scale_n, const std::string &bn_bias_n, + const std::string &bn_mean_out_n, const std::string &bn_variance_out_n, + const std::string &bn_saved_variance_n, const std::string &bn_saved_mean_n, + const std::string &bn_reserve_space_n, const std::string &act_out_n) const { + OpDesc desc; + desc.SetInput("X", std::vector({bn_x_n})); + desc.SetInput("Z", std::vector({elewise_add_in_n})); + desc.SetInput("Scale", std::vector({bn_scale_n})); + desc.SetInput("Bias", std::vector({bn_bias_n})); + + desc.SetOutput("Y", std::vector({act_out_n})); + desc.SetOutput("MeanOut", std::vector({bn_mean_out_n})); + desc.SetOutput("VarianceOut", std::vector({bn_variance_out_n})); + desc.SetOutput("SavedMean", std::vector({bn_saved_mean_n})); + desc.SetOutput("SavedVariance", + std::vector({bn_saved_variance_n})); + desc.SetOutput("ReserveSpace", + std::vector({bn_reserve_space_n})); + desc.SetType("fused_bn_add_activation"); + + desc.SetAttr("act_type", act->Name()); + // Set attrs + for (auto &n : {act->Op(), elewise_add->Op(), bn->Op()}) { + for (auto &m : n->GetAttrMap()) { + desc.SetAttr(m.first, m.second); + } + } + + auto fused_bn_add_act_node = g->CreateOpNode(&desc); + return fused_bn_add_act_node; +} + +// the backward of act(bn(x) + z) +ir::Graph *FuseBatchNormAddActPass::FuseBatchNormAddActGrad( + ir::Graph *graph, + const std::unordered_set &act_grad_types) const { + PADDLE_ENFORCE_NE( + graph, nullptr, + platform::errors::InvalidArgument( + "The input graph of FuseBatchNormAddActGrad should not be nullptr.")); + FusePassBase::Init("bn_add_act_grad", graph); + + GraphPatternDetector gpd; + auto *d_act_out = + gpd.mutable_pattern() + ->NewNode("bn_add_act_grad/x") + ->AsInput() + ->assert_is_ops_input(act_grad_types, GradVarName("Out")) + ->assert_var_dtype(proto::VarType::FP16); + patterns::BatchNormAddActGrad bn_add_act_grad_pattern(gpd.mutable_pattern(), + "bn_add_act_grad"); + bn_add_act_grad_pattern(d_act_out, act_grad_types); + + int found_bn_add_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle FuseBatchNormAddActGrad fuse"; + GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elewise_add_grad, elewise_add_grad, + bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(batch_norm_grad, batch_norm_grad, + bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_act_x, d_act_x, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_bn_out, d_bn_out, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_x, bn_x, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, + bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, + bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space, + bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_bn_x, d_bn_x, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_bn_scale, d_bn_scale, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_bn_bias, d_bn_bias, bn_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_elewise_add_in, d_elewise_add_in, + bn_add_act_grad_pattern); + + std::string d_act_out_n = subgraph.at(d_act_out)->Name(); // Y@GRAD + std::string act_out_n = act_out->Name(); // Y + std::string d_act_x_n = d_act_x->Name(); + std::string bn_x_n = bn_x->Name(); + std::string bn_scale_n = bn_scale->Name(); + std::string bn_bias_n = bn_bias->Name(); + std::string bn_saved_mean_n = bn_saved_mean->Name(); + std::string bn_saved_variance_n = bn_saved_variance->Name(); + std::string bn_reserve_space_n = bn_reserve_space->Name(); + std::string d_bn_out_n = d_bn_out->Name(); + std::string d_bn_x_n = d_bn_x->Name(); + std::string d_bn_scale_n = d_bn_scale->Name(); + std::string d_bn_bias_n = d_bn_bias->Name(); + std::string d_elewise_add_in_n = d_elewise_add_in->Name(); + + OpDesc desc; + desc.SetType("fused_bn_add_activation_grad"); + desc.SetInput("X", {bn_x_n}); + desc.SetInput("Y", std::vector({act_out_n})); + desc.SetInput(GradVarName("Y"), std::vector({d_act_out_n})); + desc.SetInput("Scale", std::vector({bn_scale_n})); + desc.SetInput("Bias", std::vector({bn_bias_n})); + desc.SetInput("SavedMean", std::vector({bn_saved_mean_n})); + desc.SetInput("SavedVariance", + std::vector({bn_saved_variance_n})); + desc.SetInput("ReserveSpace", + std::vector({bn_reserve_space_n})); + desc.SetOutput(GradVarName("X"), std::vector({d_bn_x_n})); + desc.SetOutput(GradVarName("Z"), + std::vector({d_elewise_add_in_n})); + desc.SetOutput(GradVarName("Scale"), + std::vector({d_bn_scale_n})); + desc.SetOutput(GradVarName("Bias"), + std::vector({d_bn_bias_n})); + std::string act = act_grad->Name(); + act = act.substr(0, act.length() - 5); // remove "_grad" + desc.SetAttr("act_type", act); + + for (auto &n : + {act_grad->Op(), elewise_add_grad->Op(), batch_norm_grad->Op()}) { + for (auto &m : n->GetAttrMap()) { + desc.SetAttr(m.first, m.second); + } + } + + auto fused_node = g->CreateOpNode(&desc); + + VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> " + << act_grad->Name() << " -> " << d_act_x_n << "\n\t "; + VLOG(4) << d_act_x_n << " -> " << elewise_add_grad->Name() << " -> " + << d_elewise_add_in_n << "," << d_bn_out_n << "\n\t "; + VLOG(4) << bn_x_n << ", " << d_bn_out_n << ", " << bn_scale_n << ", " + << bn_bias_n << ", " << bn_saved_mean_n << ", " + << bn_saved_variance_n << " and " << bn_reserve_space_n << " -> " + << batch_norm_grad->Name() << " -> " << d_bn_x_n << ", " + << d_bn_scale_n << " and " << d_bn_bias_n; + + ReLinkNodes(g, act_grad, elewise_add_grad, batch_norm_grad, fused_node); + found_bn_add_act_count++; + }; + + gpd(graph, handler); + + AddStatis(found_bn_add_act_count); + return graph; +} + +void FuseBatchNormAddActPass::ReLinkNodes(Graph *graph, Node *op_1, Node *op_2, + Node *op_3, + Node *fused_op) const { // delete act + // link inputs of op_1 to fused_op + for (auto &in : op_1->inputs) { + fused_op->inputs.emplace_back(in); + in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs); + } + + std::unordered_set nodes2delete; + + LinkOutputsToFuseOp(op_1, op_2, fused_op, &nodes2delete); + LinkOutputsToFuseOp(op_2, op_3, fused_op, &nodes2delete); + LinkInputsToFuseOp(op_2, fused_op, &nodes2delete); + LinkInputsToFuseOp(op_3, fused_op, &nodes2delete); + + for (auto &out : op_3->outputs) { + IR_OP_VAR_LINK(fused_op, out); + } + + nodes2delete.insert(std::move(op_1)); + nodes2delete.insert(std::move(op_2)); + nodes2delete.insert(std::move(op_3)); + + GraphSafeRemoveNodes(graph, nodes2delete); +} + +void FuseBatchNormAddActPass::LinkOutputsToFuseOp( + Node *op_1, Node *op_2, Node *fused_op, + std::unordered_set *nodes2delete) const { + // if the outputs of op_1 are inputs of op_2, add the outputs to nodes2delete + // otherwise link the outputs to fused_op + for (auto &out : op_1->outputs) { + auto result_iter = + std::find_if(op_2->inputs.begin(), op_2->inputs.end(), + [&out](const Node *node) -> bool { return node == out; }); + + if (result_iter == op_2->inputs.end()) { + IR_OP_VAR_LINK(fused_op, out); + } else { + nodes2delete->emplace(out); + } + } +} + +void FuseBatchNormAddActPass::LinkInputsToFuseOp( + Node *op, Node *fused_op, + std::unordered_set *nodes2delete) const { + // if the inputs of the op are outputs of previous op, which means + // these inputs have been added to nodes2delete before, skip the inputs, + // otherwise link the inputs of the op to fused_op + for (auto &in : op->inputs) { + if (nodes2delete->count(in)) { + continue; + } + fused_op->inputs.emplace_back(in); + in->outputs = this->ReplaceNode(op, fused_op, in->outputs); + } +} + +std::vector FuseBatchNormAddActPass::ReplaceNode( + Node *cur_node, Node *new_node, const std::vector &nodes) const { + std::vector new_list(nodes.size()); + bool has_replaced = false; + std::transform(nodes.begin(), nodes.end(), new_list.begin(), + [&](Node *node) -> Node * { + if (node == cur_node) { + has_replaced = true; + return new_node; + } + return node; + }); + PADDLE_ENFORCE_EQ(has_replaced, true, + platform::errors::NotFound("Not found %s in the node list.", + cur_node->Name())); + return new_list; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fuse_bn_add_act_pass, + paddle::framework::ir::FuseBatchNormAddActPass); diff --git a/paddle/fluid/framework/ir/fuse_bn_add_act_pass.h b/paddle/fluid/framework/ir/fuse_bn_add_act_pass.h new file mode 100644 index 00000000000..243a5b1b8df --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_bn_add_act_pass.h @@ -0,0 +1,75 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the BatchNorm, add and activation. + */ +class Graph; +class Node; + +class FuseBatchNormAddActPass : public FusePassBase { + public: + virtual ~FuseBatchNormAddActPass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + + ir::Graph *FuseBatchNormAddAct( + ir::Graph *graph, const std::unordered_set &act_types) const; + + ir::Graph *FuseBatchNormAddActGrad( + ir::Graph *graph, + const std::unordered_set &act_grad_types) const; + + void LinkOutputsToFuseOp( + Node *op_1, Node *op_2, Node *fused_op, + std::unordered_set *nodes2delete) const; + + void LinkInputsToFuseOp(Node *op, Node *fused_op, + std::unordered_set *nodes2delete) const; + + std::vector ReplaceNode(Node *cur_node, Node *new_node, + const std::vector &nodes) const; + + void ReLinkNodes(Graph *graph, Node *op_1, Node *op_2, Node *op_3, + Node *fused_op) const; + Node *CreateFusedBatchNormAddActNode( + Graph *g, const Node *act, const Node *add, const Node *bn, + const std::string &bn_x_n, const std::string &add_y_n, + const std::string &bn_scale_n, const std::string &bn_bias_n, + const std::string &bn_mean_out_n, const std::string &bn_variance_out_n, + const std::string &bn_saved_variance_n, + const std::string &bn_saved_mean_n, const std::string &bn_reserve_space_n, + const std::string &act_out_n) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 3127a3fd8a7..5ffaf28fe92 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -93,6 +93,7 @@ void GraphPatternDetector::operator()(Graph *graph, auto subgraphs = DetectPatterns(); UniquePatterns(&subgraphs); + SortSubgraphs(&subgraphs); RemoveOverlappedMatch(&subgraphs); ValidateByNodeRole(&subgraphs); @@ -302,6 +303,46 @@ void GraphPatternDetector::UniquePatterns( *subgraphs = result; } +void GraphPatternDetector::SortSubgraphs( + std::vector *subgraphs) { + if (subgraphs->empty()) return; + bool has_bn_add_act = false; + for (auto &subgraph : *subgraphs) { + for (auto &item : subgraph) { + if (item.first->name().find("bn_add_act") != std::string::npos) { + has_bn_add_act = true; + break; + } + } + } + if (!has_bn_add_act) { + return; + } + + std::sort( + subgraphs->begin(), subgraphs->end(), + [](const GraphPatternDetector::subgraph_t &a, + const GraphPatternDetector::subgraph_t &b) { + for (auto &item : a) { + if (item.first->name().find("bn_add_act") != std::string::npos && + item.first->name().find("bn_reserve_space") != + std::string::npos) { + auto it_b = b.find(item.first); + if (it_b != b.end()) { + if (item.second->Name() != it_b->second->Name()) { + return item.second->Name() < it_b->second->Name(); + } else { + return false; + } + } else { + return false; + } + } + } + return false; + }); +} + void GraphPatternDetector::RemoveOverlappedMatch( std::vector *subgraphs) { std::vector result; @@ -1208,6 +1249,151 @@ PDNode *patterns::BatchNormActOneDNN::operator()(const std::string &act_type) { return act_out; } +PDNode *patterns::BatchNormAddAct::operator()( + paddle::framework::ir::PDNode *bn_x_var, + std::unordered_set act_types) { + bn_x_var->assert_is_op_input("batch_norm", "X") + ->assert_var_dtype(proto::VarType::FP16); + auto *bn_scale_var = pattern->NewNode(bn_scale_repr()) + ->assert_is_op_input("batch_norm", "Scale"); + auto *bn_bias_var = pattern->NewNode(bn_bias_repr()) + ->assert_is_op_input("batch_norm", "Bias"); + + auto *bn = pattern->NewNode(batch_norm_repr()) + ->assert_is_op("batch_norm") + ->assert_is_not_op_input("MomentumTensor") + ->assert_op_attr("is_test", false) + ->assert_op_attr("use_global_stats", false) + ->assert_op_attr("data_layout", "NHWC"); + + auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr()) + ->assert_is_op_output("batch_norm", "MeanOut"); + auto *bn_variance_out_var = + pattern->NewNode(bn_variance_out_repr()) + ->assert_is_op_output("batch_norm", "VarianceOut"); + auto *bn_saved_variance_var = + pattern->NewNode(bn_saved_variance_repr()) + ->assert_is_op_output("batch_norm", "SavedVariance"); + auto *bn_saved_mean_var = + pattern->NewNode(bn_saved_mean_repr()) + ->assert_is_op_output("batch_norm", "SavedMean"); + auto *bn_reserve_space = + pattern->NewNode(bn_reserve_space_repr()) + ->assert_is_op_output("batch_norm", "ReserveSpace"); + auto *bn_out_var = pattern->NewNode(bn_out_repr()) + ->assert_is_op_output("batch_norm", "Y") + ->assert_var_dtype(proto::VarType::FP16); + + bn_out_var->assert_is_op_input("elementwise_add"); + + auto *elewise_add = + pattern->NewNode(elewise_add_repr())->assert_is_op("elementwise_add"); + + auto *elewise_add_in_var = pattern->NewNode(elewise_add_in_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_input("elementwise_add") + ->assert_var_dtype(proto::VarType::FP16); + + auto *elewise_add_out_var = + pattern->NewNode(elewise_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_has_n_outputs(1); + + elewise_add_out_var->AsIntermediate()->assert_is_ops_input(act_types); + + auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types); + + auto *act_out_var = + pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out"); + + bn->LinksFrom({bn_x_var, bn_scale_var, bn_bias_var}) + .LinksTo({bn_mean_out_var, bn_variance_out_var, bn_saved_variance_var, + bn_saved_mean_var, bn_reserve_space, bn_out_var}); + elewise_add->LinksFrom({elewise_add_in_var, bn_out_var}) + .LinksTo({elewise_add_out_var}); + act->LinksFrom({elewise_add_out_var}).LinksTo({act_out_var}); + + return act_out_var; +} + +PDNode *patterns::BatchNormAddActGrad::operator()( + paddle::framework::ir::PDNode *d_act_out_var, + std::unordered_set act_grad_types) { + auto *act_grad = + pattern->NewNode(act_grad_repr())->assert_is_ops(act_grad_types); + auto *elewise_add_grad = pattern->NewNode(elewise_add_grad_repr()) + ->assert_is_op("elementwise_add_grad"); + auto *bn_grad = pattern->NewNode(batch_norm_grad_repr()) + ->assert_is_op("batch_norm_grad") + ->assert_op_attr("use_global_stats", false) + ->assert_op_attr("data_layout", "NHWC"); + + auto *act_out_var = pattern->NewNode(act_out_repr()) + ->assert_is_ops_input(act_grad_types, "Out"); + auto *d_act_x_var = + pattern->NewNode(d_act_x_repr()) + ->assert_is_ops_output(act_grad_types, GradVarName("X")) + ->assert_has_n_outputs(1); // d_act_x + + d_act_x_var->AsIntermediate()->assert_is_op_input("elementwise_add_grad"); + + auto *d_elewise_add_in_var = + pattern->NewNode(d_elewise_add_in_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_output("elementwise_add_grad") + ->assert_var_dtype(proto::VarType::FP16); // d_add_in_1 + auto *d_bn_out_var = + pattern->NewNode(d_bn_out_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_output("elementwise_add_grad") + ->assert_var_dtype(proto::VarType::FP16); // d_add_in_2 + + d_bn_out_var->assert_is_op_input("batch_norm_grad", GradVarName("Y")); + + auto *bn_x_var = pattern->NewNode(bn_x_repr()) + ->assert_is_op_input("batch_norm_grad", "X") + ->assert_var_dtype(proto::VarType::FP16); + auto *bn_scale_var = pattern->NewNode(bn_scale_repr()) + ->assert_is_op_input("batch_norm_grad", "Scale"); + auto *bn_bias_var = pattern->NewNode(bn_bias_repr()) + ->assert_is_op_input("batch_norm_grad", "Bias"); + auto *bn_saved_mean_var = + pattern->NewNode(bn_saved_mean_repr()) + ->assert_is_op_input("batch_norm_grad", "SavedMean"); + auto *bn_saved_variance_var = + pattern->NewNode(bn_saved_variance_repr()) + ->assert_is_op_input("batch_norm_grad", "SavedVariance"); + + auto *bn_reserve_space = + pattern->NewNode(bn_reserve_space_repr()) + ->assert_is_op_input("batch_norm_grad", "ReserveSpace"); + auto *d_bn_x_var = + pattern->NewNode(d_bn_x_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_output("batch_norm_grad", GradVarName("X")) + ->assert_var_dtype(proto::VarType::FP16); + auto *d_bn_scale_var = + pattern->NewNode(d_bn_scale_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_output("batch_norm_grad", GradVarName("Scale")); + auto *d_bn_bias_var = + pattern->NewNode(d_bn_bias_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_output("batch_norm_grad", GradVarName("Bias")); + + act_grad->LinksFrom({d_act_out_var, act_out_var}).LinksTo({d_act_x_var}); + + elewise_add_grad->LinksFrom({d_act_x_var}) + .LinksTo({d_elewise_add_in_var, d_bn_out_var}); + + bn_grad + ->LinksFrom({bn_x_var, d_bn_out_var, bn_scale_var, bn_bias_var, + bn_saved_mean_var, bn_saved_variance_var, bn_reserve_space}) + .LinksTo({d_bn_x_var, d_bn_scale_var, d_bn_bias_var}); + + return bn_grad; +} + PDNode *patterns::ElewiseAddAct::operator()( paddle::framework::ir::PDNode *ele_x_var, std::unordered_set act_types) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index c44c7b4059e..77a1b034074 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -294,6 +294,12 @@ class GraphPatternDetector { // Remove duplicate patterns. void UniquePatterns(std::vector* subgraphs); + // Sort subgraphs, sort subgraphs by the specified node so that + // the removed forward and backward subgraphs are corresponding + // when two subgraphs are overlapped. Note: this function is + // currently only used for bn_add_act, refer to PR28196 for details. + void SortSubgraphs(std::vector* subgraphs); + // Remove overlapped match subgraphs, when overlapped, keep the previous one. // The intermediate PDNodes will be removed, so can't shared by multiple // patterns. @@ -685,6 +691,72 @@ struct BatchNormActOneDNN : public PatternBase { PATTERN_DECL_NODE(act_out); }; +// The following pattern is used to fuse batch_norm, elewise_add, and act +// formula: act(bn(x) + z) +// op: batch_norm + elewise_add + act +struct BatchNormAddAct : public PatternBase { + BatchNormAddAct(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "bn_add_act") {} + + PDNode* operator()(PDNode* x, std::unordered_set acts); + + // declare operator node's name + PATTERN_DECL_NODE(batch_norm); + PATTERN_DECL_NODE(elewise_add); + PATTERN_DECL_NODE(act); + // declare variable node's name + // BN inputs + PATTERN_DECL_NODE(bn_scale); + PATTERN_DECL_NODE(bn_bias); + // BN outputs + PATTERN_DECL_NODE(bn_mean_out); + PATTERN_DECL_NODE(bn_variance_out); + PATTERN_DECL_NODE(bn_saved_variance); + PATTERN_DECL_NODE(bn_saved_mean); + PATTERN_DECL_NODE(bn_reserve_space); + PATTERN_DECL_NODE(bn_out); + // Elewise_Add input + PATTERN_DECL_NODE(elewise_add_in); + // Elewise_Add output + PATTERN_DECL_NODE(elewise_add_out); + // ACT output + PATTERN_DECL_NODE(act_out); +}; + +// the backward of act(bn(x) + z) +// op: batch_norm_grad + elewise_add_grad + act_grad +struct BatchNormAddActGrad : public PatternBase { + BatchNormAddActGrad(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "bn_add_act_grad") {} + + // act_grad: in["Out", "Out@GRAD"], out["X@GRAD"] + // elewise_add_grad: in["Out@GRAD"], out["X@GRAD", "Y@GRAD"] + // bn_grad: in["X", "Z", "Y@GRAD", "Scale", "Bias", "SavedMean", + // "SavedVariance", + // "ReserveSpace"], + // out["X@GRAD", "Z@GRAD", "Scale@GRAD", "Bias@GRAD"] + PDNode* operator()(PDNode* x, std::unordered_set act_grad_types); + + // declare operator node's name + PATTERN_DECL_NODE(act_grad); + PATTERN_DECL_NODE(elewise_add_grad); + PATTERN_DECL_NODE(batch_norm_grad); + // declare variable node's name + PATTERN_DECL_NODE(act_out); + PATTERN_DECL_NODE(d_act_x); + PATTERN_DECL_NODE(d_elewise_add_in); + PATTERN_DECL_NODE(d_bn_out); + PATTERN_DECL_NODE(bn_x); + PATTERN_DECL_NODE(bn_scale); + PATTERN_DECL_NODE(bn_bias); + PATTERN_DECL_NODE(bn_saved_mean); + PATTERN_DECL_NODE(bn_saved_variance); + PATTERN_DECL_NODE(bn_reserve_space); + PATTERN_DECL_NODE(d_bn_x); + PATTERN_DECL_NODE(d_bn_scale); + PATTERN_DECL_NODE(d_bn_bias); +}; + // The following patterns are used to fuse elewise_add and act // formula: act(ele_add(x, y)) // op: elementwise_add + act diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc index 5b3ed03bb64..9f446b48b47 100644 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cc @@ -186,8 +186,6 @@ void FusedBatchNormAddActGradOp::InferShape( // check input OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedBatchNormAddActGradOp"); - OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", - "FusedBatchNormAddActGradOp"); OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "FusedBatchNormAddActGradOp"); OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cu b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cu index 7f1d297cda3..c92b13b5f58 100644 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.cu +++ b/paddle/fluid/operators/fused/fused_bn_add_activation_op.cu @@ -188,7 +188,6 @@ class FusedBatchNormAddActGradKernel std::string act_type = ctx.Attr("act_type"); const auto *x = ctx.Input("X"); - const auto *z = ctx.Input("Z"); const auto *y = ctx.Input("Y"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto *scale = ctx.Input("Scale"); diff --git a/paddle/fluid/operators/fused/fused_bn_add_activation_op.h b/paddle/fluid/operators/fused/fused_bn_add_activation_op.h index 5c7df96e60d..d5e5ae9bda6 100644 --- a/paddle/fluid/operators/fused/fused_bn_add_activation_op.h +++ b/paddle/fluid/operators/fused/fused_bn_add_activation_op.h @@ -61,7 +61,6 @@ class FusedBatchNormAddActGradOpMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType(this->ForwardOpType() + "_grad"); op->SetInput("X", this->Input("X")); - op->SetInput("Z", this->Input("Z")); op->SetInput("Y", this->Output("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8ff7e900653..736669fa4ef 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -2500,6 +2500,31 @@ All parameter, weight, gradient are variables in Paddle. build_strategy = static.BuildStrategy() build_strategy.fuse_bn_act_ops = True )DOC") + .def_property( + "fuse_bn_add_act_ops", + [](const BuildStrategy &self) { return self.fuse_bn_add_act_ops_; }, + [](BuildStrategy &self, bool b) { + PADDLE_ENFORCE_NE(self.IsFinalized(), true, + platform::errors::PreconditionNotMet( + "BuildStrategy has been finlaized, cannot be " + "configured again.")); + self.fuse_bn_add_act_ops_ = b; + }, + R"DOC((bool, optional): fuse_bn_add_act_ops indicate whether + to fuse batch_norm, elementwise_add and activation_op, + it may make the execution faster. Default is True + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + + paddle.enable_static() + + build_strategy = static.BuildStrategy() + build_strategy.fuse_bn_add_act_ops = True + )DOC") .def_property( "enable_auto_fusion", [](const BuildStrategy &self) { return self.enable_auto_fusion_; }, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 101242808b2..4cd9d9e530d 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -331,6 +331,7 @@ list(REMOVE_ITEM TEST_OPS test_basic_gru_unit_op) list(REMOVE_ITEM TEST_OPS test_basic_lstm_api) list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) list(REMOVE_ITEM TEST_OPS test_fuse_bn_act_pass) +list(REMOVE_ITEM TEST_OPS test_fuse_bn_add_act_pass) list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_mnist) list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while) list(REMOVE_ITEM TEST_OPS test_conv3d_transpose_op) @@ -515,6 +516,7 @@ py_test_modules(test_parallel_executor_transformer_auto_growth MODULES test_para py_test_modules(test_data_norm_op MODULES test_data_norm_op) py_test_modules(test_fuse_bn_act_pass MODULES test_fuse_bn_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000) +py_test_modules(test_fuse_bn_add_act_pass MODULES test_fuse_bn_add_act_pass ENVS FLAGS_cudnn_deterministic=1 FLAGS_cudnn_batchnorm_spatial_persistent=1 FLAGS_conv_workspace_size_limit=1000) # NOTE: These unittests will appear NaN steadily in windows CI. After analysis, # it is found that windows CI will run all the training unittests with the ON_INFER option turned on, diff --git a/python/paddle/fluid/tests/unittests/test_fused_bn_add_act.py b/python/paddle/fluid/tests/unittests/test_fuse_bn_add_act_pass.py similarity index 85% rename from python/paddle/fluid/tests/unittests/test_fused_bn_add_act.py rename to python/paddle/fluid/tests/unittests/test_fuse_bn_add_act_pass.py index 45c27552743..316c40971aa 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_bn_add_act.py +++ b/python/paddle/fluid/tests/unittests/test_fuse_bn_add_act_pass.py @@ -21,6 +21,8 @@ import paddle import paddle.fluid as fluid from paddle.fluid import core +paddle.enable_static() + @unittest.skipIf(not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA") @@ -163,12 +165,16 @@ class TestFusedBnAddActAPI(unittest.TestCase): iters = 5 batch_size = 16 - # build_fused_program + # build_fused_program: turn on fuse_bn_add_act_ops main_program = fluid.Program() startup_program = fluid.Program() - x, y, loss = self.build_fused_program(main_program, startup_program, - use_cuda) + x, y, loss = self.build_origin_program(main_program, startup_program, + use_cuda) feeder = fluid.DataFeeder(feed_list=[x, y], place=place) + build_strategy_fused = fluid.BuildStrategy() + build_strategy_fused.fuse_bn_add_act_ops = True + binary_fused = fluid.CompiledProgram(main_program).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy_fused) train_reader = paddle.batch( paddle.dataset.mnist.train(), batch_size=batch_size) exe = fluid.Executor(place) @@ -178,17 +184,16 @@ class TestFusedBnAddActAPI(unittest.TestCase): exe.run(startup_program) for _ in range(iters): data = next(train_reader()) - loss_v = exe.run(main_program, + loss_v = exe.run(binary_fused, feed=feeder.feed(data), fetch_list=[loss]) loss_vals_fused.append(loss_v[0][0]) - # build_origin_program - main_program = fluid.Program() - startup_program = fluid.Program() - x, y, loss = self.build_origin_program(main_program, startup_program, - use_cuda) - feeder = fluid.DataFeeder(feed_list=[x, y], place=place) + # build_origin_program: turn off fused_bn_act_ops + build_strategy = fluid.BuildStrategy() + build_strategy.fuse_bn_add_act_ops = False + binary = fluid.CompiledProgram(main_program).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) train_reader = paddle.batch( paddle.dataset.mnist.train(), batch_size=batch_size) loss_vals = [] @@ -197,7 +202,7 @@ class TestFusedBnAddActAPI(unittest.TestCase): exe.run(startup_program) for _ in range(iters): data = next(train_reader()) - loss_v = exe.run(main_program, + loss_v = exe.run(binary, feed=feeder.feed(data), fetch_list=[loss]) loss_vals.append(loss_v[0][0]) @@ -210,6 +215,25 @@ class TestFusedBnAddActAPI(unittest.TestCase): place = fluid.CUDAPlace(0) self.check(place, use_cuda=True) + def test_fuse_bn_add_act_API(self): + # build_fused_program: use fused_bn_add_act python API + main_program = fluid.Program() + startup_program = fluid.Program() + place = fluid.CUDAPlace(0) + x, y, loss = self.build_fused_program( + main_program, startup_program, use_cuda=True) + feeder = fluid.DataFeeder(feed_list=[x, y], place=place) + train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16) + exe = fluid.Executor(place) + scope = fluid.Scope() + with fluid.scope_guard(scope): + exe.run(startup_program) + for _ in range(5): + data = next(train_reader()) + loss_v = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[loss]) + if __name__ == '__main__': unittest.main() -- GitLab