From d402234ba8c74322e461611de7ba4b77c4a3688b Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 20 Sep 2018 09:36:28 +0800 Subject: [PATCH] Feature/op_fuse_pass (#12440) * Add Preface * Add demo code * Save file * Refine code * seems can work * use elementwise strategy * Use ElementwiseComputeEx * Add comments * extract functions from operator * Refine code * Follow comment * code refine * add op_fuse pass * add backward * code refine * use TopologySortOperations * follow comments * refine IsFusible * code enhance * fix op_fusion_pass * refine code * refine fuse_elemwise_act_op * adjust the input and output * refine logic * add intermediate_edge * disable inplace * follow comments * refine logic * follow comments * Remove the removable IntermediateOut * change strategy * code refine * enable fuse backward * code refine * code refine * rename unit test * follow comments --- paddle/fluid/framework/CMakeLists.txt | 12 +- .../fluid/framework/details/build_strategy.h | 2 + paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/fuse_elewise_add_act_pass.cc | 374 ++++++++++++++ .../framework/ir/fuse_elewise_add_act_pass.h | 75 +++ .../framework/ir/graph_pattern_detector.cc | 464 +++++++++++++----- .../framework/ir/graph_pattern_detector.h | 83 ++++ paddle/fluid/framework/ir/node.h | 4 + paddle/fluid/framework/parallel_executor.cc | 16 + .../fluid/operators/elementwise_op_function.h | 358 +++++++++----- .../operators/fused_elemwise_activation_op.cc | 141 +++--- .../operators/fused_elemwise_activation_op.h | 241 +++++---- .../fluid/operators/math/compound_functors.h | 125 +++-- paddle/fluid/operators/math/functors.h | 12 +- paddle/fluid/pybind/pybind.cc | 9 +- .../paddle/fluid/tests/unittests/op_test.py | 16 +- .../unittests/parallel_executor_test_base.py | 2 + .../test_fuse_elewise_add_act_pass.py | 156 ++++++ .../test_fused_elemwise_activation_op.py | 83 ++-- .../fluid/tests/unittests/test_reshape_op.py | 2 +- .../tests/unittests/test_transpose_op.py | 2 +- 21 files changed, 1651 insertions(+), 528 deletions(-) create mode 100644 paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc create mode 100644 paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h create mode 100644 python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index d998109df..6d8cbe5d9 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -148,13 +148,13 @@ if(WITH_DISTRIBUTE) else() cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) endif() - + if (NOT WIN32) -cc_library(parallel_executor SRCS parallel_executor.cc DEPS - threaded_ssa_graph_executor scope_buffered_ssa_graph_executor - graph graph_viz_pass multi_devices_graph_pass - multi_devices_graph_print_pass multi_devices_graph_check_pass - fast_threaded_ssa_graph_executor) + cc_library(parallel_executor SRCS parallel_executor.cc DEPS + threaded_ssa_graph_executor scope_buffered_ssa_graph_executor + graph graph_viz_pass multi_devices_graph_pass + multi_devices_graph_print_pass multi_devices_graph_check_pass + fast_threaded_ssa_graph_executor fuse_elewise_add_act_pass) endif() # NOT WIN32 cc_library(prune SRCS prune.cc DEPS framework_proto) diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 8714a4216..77cafa49f 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -54,6 +54,8 @@ struct BuildStrategy { std::string debug_graphviz_path_{""}; + bool fuse_elewise_add_act_ops_{false}; + bool enable_data_balance_{false}; }; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 7004f484a..4dca3ceb4 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -37,6 +37,8 @@ pass_library(fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) +cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) + set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc new file mode 100644 index 000000000..648acc4a7 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -0,0 +1,374 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr FuseElewiseAddActPass::ApplyImpl( + std::unique_ptr graph) const { + std::unordered_set act_types = {"relu", "scale"}; + graph = FuseActElewiseAdd(std::move(graph), act_types); + graph = FuseElewiseAddAct(std::move(graph), act_types); + // backward + { + std::unordered_set in_place_act_types = {"relu_grad"}; + graph = FuseElewiseAddActInplaceGrad(std::move(graph), in_place_act_types); + } + + // Remove the removable intermediate_out. + RemoveIntermediateOut(graph.get()); + + return graph; +} + +// ele_add(x, act(y)) +std::unique_ptr FuseElewiseAddActPass::FuseElewiseAddAct( + std::unique_ptr graph, + const std::unordered_set &act_types) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("elewise_add_act", graph.get()); + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("elewise_add_act/x") + ->AsInput() + ->assert_is_op_input("elementwise_add", "X"); + patterns::ElewiseAddAct elewise_add_act_pattern(gpd.mutable_pattern(), + "elementwise_add"); + + elewise_add_act_pattern(x, act_types); + + int found_elewise_add_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle FuseElewiseAddAct fuse"; + GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out, + elewise_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, act, elewise_add_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_add, ele_add, elewise_add_act_pattern); + + std::string ele_x_n = subgraph.at(x)->Name(); + std::string ele_y_n = ele_y->Name(); + std::string ele_out_n = ele_out->Name(); + std::string act_out_n = act_out->Name(); + + Node *elewise_add_act_node = CreateFuseElewiseAddActNode( + g, act, ele_add, ele_x_n, ele_y_n, ele_out_n, act_out_n); + + VLOG(4) << "\n\t " << ele_x_n << " and " << ele_y_n << " -> " + << ele_add->Name() << " -> " << ele_out_n << "\n" + << "\t " << ele_out_n << " -> " << act->Name() << " -> " + << act_out_n; + + ReLinkNodes(g, ele_out, ele_add, act, elewise_add_act_node); + found_elewise_add_act_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_elewise_add_act_count); + return graph; +} + +// act(ele_add(x,y)) +std::unique_ptr FuseElewiseAddActPass::FuseActElewiseAdd( + std::unique_ptr graph, + const std::unordered_set &act_types) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("act_elewise_add", graph.get()); + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("act_elewise_add/x") + ->AsInput() + ->assert_is_ops_input(act_types, "X"); + patterns::ActElewiseAdd act_elewise_add_pattern(gpd.mutable_pattern(), + "act_elewise_add"); + + act_elewise_add_pattern(x, act_types); + + int found_elewise_add_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle FuseElewiseAddAct fuse"; + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out, + act_elewise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, act, act_elewise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_add, ele_add, act_elewise_add_pattern); + + std::string act_i_n = subgraph.at(x)->Name(); + std::string act_o_n = act_out->Name(); + std::string elewise_add_x_n = ele_x->Name(); + std::string elewise_add_out_n = ele_out->Name(); + + Node *elewise_add_act_node = CreateFuseElewiseAddActNode( + g, ele_add, act, elewise_add_x_n, act_i_n, act_o_n, elewise_add_out_n); + + VLOG(4) << "\n\t " << act_i_n << " -> " << act->Name() << " -> " << act_o_n + << "\n\t " << act_o_n << " and " << elewise_add_x_n << " -> " + << ele_add->Name() << " -> " << elewise_add_out_n; + + ReLinkNodes(g, act_out, act, ele_add, elewise_add_act_node); + found_elewise_add_act_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_elewise_add_act_count); + return graph; +} + +// the backward of act(ele_add(x,y)) +// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"] +// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] +std::unique_ptr FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( + std::unique_ptr graph, + const std::unordered_set &act_types) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("elewise_add_act_grad", graph.get()); + + GraphPatternDetector gpd; + auto *d_act_out = gpd.mutable_pattern() + ->NewNode("elewise_add_act_grad_inplace/x") + ->AsInput() + ->assert_is_ops_input(act_types, GradVarName("Out")); + patterns::ElewiseAddActInplaceGrad elewise_add_act_grad_pattern( + gpd.mutable_pattern(), "elewise_add_act_grad_inplace"); + elewise_add_act_grad_pattern(d_act_out, act_types); + + int found_elewise_add_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle FuseElewiseAddActGrad1 fuse"; + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, elewise_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, elewise_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out, + elewise_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_y, ele_y, elewise_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_add_grad, ele_add_grad, + elewise_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_ele_x, d_ele_x, elewise_add_act_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(d_ele_y, d_ele_y, elewise_add_act_grad_pattern); + + std::string d_act_out_n = subgraph.at(d_act_out)->Name(); + std::string act_out_n = act_out->Name(); + std::string d_itermediate_out_n = d_itermediate_out->Name(); + std::string ele_y_n = ele_y->Name(); + std::string d_ele_x_n = d_ele_x->Name(); + std::string d_ele_y_n = d_ele_y->Name(); + + OpDesc desc; + desc.SetType("fused_elemwise_activation_grad"); + desc.SetInput("IntermediateOut", {}); + desc.SetInput("X", {}); + desc.SetInput("Y", std::vector({ele_y_n})); + desc.SetInput("Out", std::vector({act_out_n})); + desc.SetInput(GradVarName("Out"), std::vector({d_act_out_n})); + desc.SetOutput(GradVarName("X"), std::vector({d_ele_x_n})); + desc.SetOutput(GradVarName("Y"), std::vector({d_ele_y_n})); + desc.SetOutput(GradVarName("IntermediateOut"), + std::vector({d_itermediate_out_n})); + + desc.SetAttr("save_intermediate_out", false); + desc.SetAttr("functor_list", + std::vector( + {act_grad->Op()->Type(), ele_add_grad->Op()->Type()})); + + for (auto &n : {act_grad->Op(), ele_add_grad->Op()}) { + for (auto &m_ele : n->GetAttrMap()) { + desc.SetAttr(m_ele.first, m_ele.second); + } + } + + auto fused_node = g->CreateOpNode(&desc); + + VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> " + << act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t " + << d_itermediate_out_n << " and " << act_out_n << " -> " + << ele_add_grad->Name() << " -> " << d_itermediate_out_n; + + ReLinkNodes(g, d_itermediate_out, act_grad, ele_add_grad, fused_node); + found_elewise_add_act_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_elewise_add_act_count); + return graph; +} + +Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode( + Graph *g, const Node *op_1, const Node *op_2, const std::string &ele_x_n, + const std::string &ele_y_n, const std::string &ele_out_n, + const std::string &act_out_n) const { + OpDesc desc; + desc.SetInput("X", std::vector({ele_x_n})); + desc.SetInput("Y", std::vector({ele_y_n})); + desc.SetOutput("Out", std::vector({act_out_n})); + desc.SetOutput("IntermediateOut", std::vector({ele_out_n})); + desc.SetType("fused_elemwise_activation"); + desc.SetAttr("save_intermediate_out", true); + desc.SetAttr("functor_list", std::vector( + {op_1->Op()->Type(), op_2->Op()->Type()})); + + // Set attrs + for (auto &n : {op_1->Op(), op_2->Op()}) { + for (auto &m_ele : n->GetAttrMap()) { + desc.SetAttr(m_ele.first, m_ele.second); + } + } + + auto elewise_add_act_node = g->CreateOpNode(&desc); + return elewise_add_act_node; +} + +void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { + std::unordered_set need_removed_nodes; + for (auto &cur_node : graph->Nodes()) { + if (cur_node->IsVar()) continue; + if (cur_node->Name() == "fused_elemwise_activation") { + bool save_intermediate_out = + boost::get(cur_node->Op()->GetAttr("save_intermediate_out")); + auto intermediate_out_args = cur_node->Op()->Output("IntermediateOut"); + PADDLE_ENFORCE( + save_intermediate_out && !intermediate_out_args.empty(), + "The %s should save the intermediate_out in the fusing stage.", + cur_node->Name()); + + // If the intermediate_out's output is empty, it should be removed. + auto cur_node_outputs = cur_node->outputs; + for (auto &out : cur_node_outputs) { + if (out->Name() == intermediate_out_args[0]) { + if (out->outputs.size() == 0) { + cur_node->outputs = this->RemoveNode(out, cur_node->outputs); + need_removed_nodes.insert(std::move(out)); + cur_node->Op()->SetAttr("save_intermediate_out", false); + } + } + } + } else if (cur_node->Name() == "fused_elemwise_activation_grad") { + auto intermediate_out_grad_args = + cur_node->Op()->Output(GradVarName("IntermediateOut")); + PADDLE_ENFORCE( + !intermediate_out_grad_args.empty(), + "The %s should save the intermediate_out in the fusing stage.", + cur_node->Name()); + auto cur_node_outputs = cur_node->outputs; + // If the intermediate_out_g's output is empty, it should be removed. + for (auto &out : cur_node_outputs) { + if (out->Name() == intermediate_out_grad_args[0] && + out->outputs.empty()) { + cur_node->Op()->SetOutput(GradVarName("IntermediateOut"), {}); + cur_node->outputs = this->RemoveNode(out, cur_node->outputs); + need_removed_nodes.insert(std::move(out)); + } + } + } + } + GraphSafeRemoveNodes(graph, need_removed_nodes); +} + +void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, + const Node *intermediate_out, + Node *op_1, Node *op_2, + Node *fused_op) const { // delete act + for (auto &in : op_1->inputs) { + fused_op->inputs.emplace_back(in); + in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs); + } + + std::unordered_set nodes2delete; + for (auto &out : op_1->outputs) { + if (out->IsCtrlVar()) { + auto result_iter = std::find_if( + op_2->inputs.begin(), op_2->inputs.end(), + [&out](const Node *node) -> bool { return node == out; }); + + if (result_iter == op_2->inputs.end()) { + IR_OP_VAR_LINK(fused_op, out); + } else { + nodes2delete.emplace(out); + } + } else { + PADDLE_ENFORCE(out == intermediate_out); + IR_OP_VAR_LINK(fused_op, out); + } + } + + for (auto &in : op_2->inputs) { + if (in == intermediate_out || nodes2delete.count(in)) { + continue; + } + fused_op->inputs.emplace_back(in); + in->outputs = this->ReplaceNode(op_2, fused_op, in->outputs); + } + + for (auto &out : op_2->outputs) { + IR_OP_VAR_LINK(fused_op, out); + } + + nodes2delete.insert(std::move(op_1)); + nodes2delete.insert(std::move(op_2)); + + GraphSafeRemoveNodes(graph, nodes2delete); +} + +std::vector FuseElewiseAddActPass::ReplaceNode( + Node *cur_node, Node *new_node, const std::vector &nodes) const { + std::vector new_list(nodes.size()); + bool has_replaced = false; + std::transform(nodes.begin(), nodes.end(), new_list.begin(), + [&](Node *node) -> Node * { + if (node == cur_node) { + has_replaced = true; + return new_node; + } + return node; + }); + PADDLE_ENFORCE(has_replaced, "Not find %s in the node list.", + cur_node->Name()); + return new_list; +} + +std::vector FuseElewiseAddActPass::RemoveNode( + Node *trg_node, const std::vector &nodes) const { + std::vector new_list(nodes.size()); + auto end_iter = + std::copy_if(nodes.begin(), nodes.end(), new_list.begin(), + [&](Node *node) -> bool { return node != trg_node; }); + new_list.resize( + static_cast(std::distance(new_list.begin(), end_iter))); + return new_list; +} +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fuse_elewise_add_act_pass, + paddle::framework::ir::FuseElewiseAddActPass); diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h new file mode 100644 index 000000000..b2fecc076 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.h @@ -0,0 +1,75 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the ElewiseAdd and activation + */ +class FuseElewiseAddActPass : public FusePassBase { + public: + virtual ~FuseElewiseAddActPass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + std::unique_ptr FuseElewiseAddAct( + std::unique_ptr graph, + const std::unordered_set &act_types) const; + + std::unique_ptr FuseActElewiseAdd( + std::unique_ptr graph, + const std::unordered_set &act_types) const; + + std::unique_ptr FuseElewiseAddActInplaceGrad( + std::unique_ptr graph, + const std::unordered_set &act_types) const; + + /** + * Remove the removable intermediate_out. + * - If the intermediate_out is only used by the backward op, but the + * backward op doesn't use intermediate_out. + * - If the intermediate_out_grad is not used by any op. + */ + void RemoveIntermediateOut(Graph *graph) const; + + std::vector ReplaceNode(Node *cur_node, Node *new_node, + const std::vector &nodes) const; + + std::vector RemoveNode(Node *trg_node, + const std::vector &nodes) const; + + void ReLinkNodes(Graph *graph, const Node *intermediate_out, Node *op_1, + Node *op_2, Node *fused_op) const; + Node *CreateFuseElewiseAddActNode(Graph *g, const Node *op_1, + const Node *op_2, + const std::string &ele_x_n, + const std::string &ele_y_n, + const std::string &ele_out_n, + const std::string &act_out_n) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 11d5998aa..ef5113819 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -20,10 +20,10 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/printf.h" - namespace paddle { namespace framework { namespace ir { @@ -34,7 +34,7 @@ using string::Style; size_t PDPattern::id_ = 0UL; -PDNode* PDPattern::NewNode(const std::string& name) { +PDNode *PDPattern::NewNode(const std::string &name) { if (!name.empty()) { PADDLE_ENFORCE_EQ(node_map_.count(name), 0, "PDNode's name should be unique, get duplicate [%s]", @@ -42,12 +42,12 @@ PDNode* PDPattern::NewNode(const std::string& name) { } nodes_.emplace_back(new PDNode(this, name)); - auto* cur = nodes_.back().get(); + auto *cur = nodes_.back().get(); node_map_[name] = cur; return cur; } -PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { +PDNode *PDPattern::NewNode(PDNode::teller_t &&teller, const std::string &name) { if (!name.empty()) { PADDLE_ENFORCE_EQ(node_map_.count(name), 0, "PDNode's name should be unique, get duplicate [%s]", @@ -55,12 +55,12 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { } nodes_.emplace_back(new PDNode(std::move(teller), this, name)); - auto* cur = nodes_.back().get(); + auto *cur = nodes_.back().get(); node_map_[name] = cur; return cur; } -PDNode* PDPattern::RetrieveNode(const std::string& id) const { +PDNode *PDPattern::RetrieveNode(const std::string &id) const { auto it = node_map_.find(id); if (it == node_map_.end()) { return nullptr; @@ -69,14 +69,14 @@ PDNode* PDPattern::RetrieveNode(const std::string& id) const { return it->second; } -void PDPattern::AddEdge(PDNode* a, PDNode* b) { +void PDPattern::AddEdge(PDNode *a, PDNode *b) { PADDLE_ENFORCE(a); PADDLE_ENFORCE(b); PADDLE_ENFORCE(a != b, "can't connect to the same nodes."); edges_.emplace_back(a, b); } -void GraphPatternDetector::operator()(Graph* graph, +void GraphPatternDetector::operator()(Graph *graph, GraphPatternDetector::handle_t handler) { if (!MarkPDNodesInGraph(*graph)) { return; @@ -90,18 +90,18 @@ void GraphPatternDetector::operator()(Graph* graph, if (subgraphs.empty()) return; PrettyLogEndl(Style::detail(), "--- detect %d subgraphs", subgraphs.size()); int id = 0; - for (auto& g : subgraphs) { + for (auto &g : subgraphs) { VLOG(3) << "optimizing #" << id++ << " subgraph"; handler(g, graph); } } -bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { +bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) { VLOG(3) << "mark pdnodes in graph"; if (graph.Nodes().empty()) return false; - for (auto& node : GraphTraits::DFS(graph)) { - for (const auto& pdnode : pattern_.nodes()) { + for (auto &node : GraphTraits::DFS(graph)) { + for (const auto &pdnode : pattern_.nodes()) { if (pdnode->Tell(&node)) { VLOG(4) << "pdnode " << pdnode->name() << " marked"; pdnodes2nodes_[pdnode.get()].insert(&node); @@ -109,15 +109,15 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { } } // Check to early stop if some PDNode can't find matched Node. - for (auto& pdnode : pattern_.nodes()) { + for (auto &pdnode : pattern_.nodes()) { if (!pdnodes2nodes_.count(pdnode.get())) { VLOG(4) << pdnode->name() << " can't find matched Node, early stop"; // return false; } } - for (auto& item : pdnodes2nodes_) { - for (auto& n : item.second) { - GetMarkedNodes(const_cast(&graph)).insert(n); + for (auto &item : pdnodes2nodes_) { + for (auto &n : item.second) { + GetMarkedNodes(const_cast(&graph)).insert(n); } } VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; @@ -128,28 +128,28 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { // The intermediate Nodes can only link to the nodes inside the pattern, or this // subgraph will be droped. void GraphPatternDetector::ValidateByNodeRole( - std::vector* subgraphs) { + std::vector *subgraphs) { std::vector result; subgraphs->erase( std::remove_if( subgraphs->begin(), subgraphs->end(), - [](const GraphPatternDetector::subgraph_t& subgraph) -> bool { + [](const GraphPatternDetector::subgraph_t &subgraph) -> bool { // Collect the inputs and outputs. - std::unordered_set ios; - for (auto& item : subgraph) { + std::unordered_set ios; + for (auto &item : subgraph) { if (!item.first->IsIntermediate()) { ios.insert(item.second); } } - for (auto& item : subgraph) { + for (auto &item : subgraph) { if (item.first->IsIntermediate()) { - for (auto* x : item.second->inputs) { + for (auto *x : item.second->inputs) { if (!ios.count(x)) { return true; } } - for (auto* x : item.second->outputs) { + for (auto *x : item.second->outputs) { if (!ios.count(x)) { return true; } @@ -162,9 +162,9 @@ void GraphPatternDetector::ValidateByNodeRole( } struct HitGroup { - std::unordered_map roles; + std::unordered_map roles; - bool Match(Node* node, PDNode* pat) { + bool Match(Node *node, PDNode *pat) { if (nodes_.count(node)) { if (!roles.count(pat)) return false; return roles[pat] == node; @@ -172,18 +172,18 @@ struct HitGroup { return !roles.count(pat) || roles.at(pat) == node; } - void Register(Node* node, PDNode* pat) { + void Register(Node *node, PDNode *pat) { roles[pat] = node; nodes_.insert(node); } private: - std::unordered_set nodes_; + std::unordered_set nodes_; }; // Tell whether Node a links to b. -bool IsNodesLink(Node* a, Node* b) { - for (auto* node : a->outputs) { +bool IsNodesLink(Node *a, Node *b) { + for (auto *node : a->outputs) { if (b == node) { return true; } @@ -198,10 +198,10 @@ GraphPatternDetector::DetectPatterns() { std::vector init_groups; std::array, 2> bi_records; // PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); - auto* first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() + auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() : pattern_.edges().front().first; if (!pdnodes2nodes_.count(first_pnode)) return result; - for (auto* node : pdnodes2nodes_[first_pnode]) { + for (auto *node : pdnodes2nodes_[first_pnode]) { HitGroup group; group.roles[first_pnode] = node; init_groups.emplace_back(group); @@ -212,21 +212,21 @@ GraphPatternDetector::DetectPatterns() { // Extend a PDNode to subgraphs by deducing the connection relations defined // in edges of PDNodes. - for (const auto& edge : pattern_.edges()) { + for (const auto &edge : pattern_.edges()) { VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); // TODO(Superjomn) Fix bug here, the groups might be duplicate here. // Each role has two PDNodes, which indicates two roles. // Detect two Nodes that can match these two roles and they are connected. - auto& pre_groups = bi_records[step % 2]; - auto& cur_groups = bi_records[1 - (step++ % 2)]; + auto &pre_groups = bi_records[step % 2]; + auto &cur_groups = bi_records[1 - (step++ % 2)]; cur_groups.clear(); if (pre_groups.empty()) break; // source -> target - for (Node* source : pdnodes2nodes_[edge.first]) { - for (Node* target : pdnodes2nodes_[edge.second]) { + for (Node *source : pdnodes2nodes_[edge.first]) { + for (Node *target : pdnodes2nodes_[edge.second]) { VLOG(8) << "check " << source->id() << " -- " << target->id(); // TODO(Superjomn) add some prune strategies. - for (const auto& group : pre_groups) { + for (const auto &group : pre_groups) { HitGroup new_group = group; if (IsNodesLink(source, target) && new_group.Match(source, edge.first)) { @@ -241,17 +241,17 @@ GraphPatternDetector::DetectPatterns() { } } VLOG(3) << "step " << step << " get records: " << cur_groups.size(); - for (auto& group : cur_groups) { - for (auto& item : group.roles) { + for (auto &group : cur_groups) { + for (auto &item : group.roles) { VLOG(4) << "node " << item.second->id() << " as " << item.first->name(); } VLOG(4) << "========================================================="; } } - for (auto& group : bi_records[step % 2]) { + for (auto &group : bi_records[step % 2]) { GraphPatternDetector::subgraph_t subgraph; - for (auto& role : group.roles) { + for (auto &role : group.roles) { subgraph.emplace(role.first, role.second); } result.emplace_back(subgraph); @@ -260,16 +260,16 @@ GraphPatternDetector::DetectPatterns() { } void GraphPatternDetector::UniquePatterns( - std::vector* subgraphs) { + std::vector *subgraphs) { if (subgraphs->empty()) return; std::vector result; std::unordered_set set; - for (auto& g : *subgraphs) { + for (auto &g : *subgraphs) { size_t key = 0; - for (auto& item : g) { - key ^= std::hash{}(item.first); - key ^= std::hash{}(item.second); + for (auto &item : g) { + key ^= std::hash{}(item.first); + key ^= std::hash{}(item.second); } if (!set.count(key)) { result.emplace_back(g); @@ -280,20 +280,20 @@ void GraphPatternDetector::UniquePatterns( } void GraphPatternDetector::RemoveOverlappedMatch( - std::vector* subgraphs) { + std::vector *subgraphs) { std::vector result; - std::unordered_set node_set; + std::unordered_set node_set; - for (const auto& subgraph : *subgraphs) { + for (const auto &subgraph : *subgraphs) { bool valid = true; - for (auto& item : subgraph) { + for (auto &item : subgraph) { if (item.first->IsIntermediate() && node_set.count(item.second)) { valid = false; break; } } if (valid) { - for (auto& item : subgraph) { + for (auto &item : subgraph) { node_set.insert(item.second); } result.push_back(subgraph); @@ -307,71 +307,81 @@ std::string PDPattern::DotString() const { Dot dot; int id = 0; // Create Nodes - std::unordered_map node2dot; - for (const auto& node : nodes()) { + std::unordered_map node2dot; + for (const auto &node : nodes()) { std::string node_id = "Node" + std::to_string(id++); dot.AddNode(node_id, {}, node->name()); node2dot[node.get()] = node_id; } // Create Edges - for (const auto& edge : edges()) { + for (const auto &edge : edges()) { if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) { LOG(ERROR) << "no node " << edge.first << " " << edge.second; continue; } - auto& src = node2dot.at(edge.first); - auto& trg = node2dot.at(edge.second); + auto &src = node2dot.at(edge.first); + auto &trg = node2dot.at(edge.second); dot.AddEdge(src, trg, {}); } return dot.Build(); } -PDNode& PDNode::LinksTo(const std::vector& others) { +PDNode &PDNode::LinksTo(const std::vector &others) { // extend outlinks. - for (PDNode* x : others) { + for (PDNode *x : others) { pattern_->AddEdge(this, x); } return *this; } -PDNode& PDNode::LinksFrom(const std::vector& others) { +PDNode &PDNode::LinksFrom(const std::vector &others) { // extend outlinks. - for (PDNode* x : others) { + for (PDNode *x : others) { pattern_->AddEdge(x, this); } return *this; } -PDNode* PDNode::assert_is_op() { - asserts_.emplace_back([](Node* x) { return x && x->IsOp(); }); +PDNode *PDNode::assert_is_op() { + asserts_.emplace_back([](Node *x) { return x && x->IsOp(); }); return this; } -PDNode* PDNode::assert_is_op(const std::string& op_type) { - asserts_.emplace_back([op_type](Node* x) { + +PDNode *PDNode::assert_is_op(const std::string &op_type) { + asserts_.emplace_back([op_type](Node *x) { return x && x->IsOp() && x->Op()->Type() == op_type; }); return this; } -PDNode* PDNode::assert_is_var() { - asserts_.emplace_back([](Node* x) { return x && x->IsVar(); }); + +PDNode *PDNode::assert_is_var() { + asserts_.emplace_back([](Node *x) { return x && x->IsVar(); }); + return this; +} + +PDNode *PDNode::assert_is_not_ctrl_var() { + asserts_.emplace_back([](Node *x) { return x && !x->IsCtrlVar(); }); return this; } -PDNode* PDNode::assert_var_not_persistable() { + +PDNode *PDNode::assert_var_not_persistable() { assert_is_var(); - asserts_.emplace_back([](Node* x) { return !x->Var()->Persistable(); }); + asserts_.emplace_back([](Node *x) { return !x->Var()->Persistable(); }); return this; } -PDNode* PDNode::assert_is_persistable_var() { + +PDNode *PDNode::assert_is_persistable_var() { assert_is_var(); - asserts_.emplace_back([=](Node* x) { return x->Var()->Persistable(); }); + asserts_.emplace_back([=](Node *x) { return x->Var()->Persistable(); }); return this; } -PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type, - const std::string& argument, int nth) { + +PDNode *PDNode::assert_is_op_nth_input(const std::string &op_type, + const std::string &argument, int nth) { assert_is_var(); assert_is_op_input(op_type); - asserts_.emplace_back([=](Node* x) { - for (auto* op : x->outputs) { + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->outputs) { if (op->IsOp() && op->Op()->Type() == op_type && IsNthInput(x, op, argument, nth)) return true; @@ -380,11 +390,12 @@ PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type, }); return this; } -PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type, - const std::string& argument, int nth) { + +PDNode *PDNode::assert_is_op_nth_output(const std::string &op_type, + const std::string &argument, int nth) { assert_is_var(); - asserts_.emplace_back([=](Node* x) { - for (auto* op : x->inputs) { + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->inputs) { if (op->IsOp() && op->Op()->Type() == op_type && IsNthOutput(x, op, argument, nth)) return true; @@ -393,10 +404,11 @@ PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type, }); return this; } -PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) { + +PDNode *PDNode::assert_is_only_input_of_op(const std::string &op_type) { assert_is_var(); - asserts_.emplace_back([=](Node* x) { - for (auto* op : x->outputs) { + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->outputs) { if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type && op->inputs.size() == 1) { return true; @@ -406,10 +418,11 @@ PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) { }); return this; } -PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) { + +PDNode *PDNode::assert_is_only_output_of_op(const std::string &op_type) { assert_is_var(); - asserts_.emplace_back([=](Node* x) { - for (auto* op : x->inputs) { + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->inputs) { if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type && op->outputs.size() == 1) { return true; @@ -419,10 +432,11 @@ PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) { }); return this; } -PDNode* PDNode::assert_is_op_output(const std::string& op_type) { + +PDNode *PDNode::assert_is_op_output(const std::string &op_type) { assert_is_var(); - asserts_.emplace_back([=](Node* x) { - for (auto* op : x->inputs) { + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->inputs) { if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { return true; } @@ -431,16 +445,17 @@ PDNode* PDNode::assert_is_op_output(const std::string& op_type) { }); return this; } -PDNode* PDNode::assert_is_op_output(const std::string& op_type, - const std::string& argument) { + +PDNode *PDNode::assert_is_op_output(const std::string &op_type, + const std::string &argument) { assert_is_var(); assert_is_op_nth_output(op_type, argument, 0); return this; } -PDNode* PDNode::assert_is_op_input(const std::string& op_type) { +PDNode *PDNode::assert_is_op_input(const std::string &op_type) { assert_is_var(); - asserts_.emplace_back([=](Node* x) { - for (auto* op : x->outputs) { + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->outputs) { if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { return true; } @@ -449,72 +464,161 @@ PDNode* PDNode::assert_is_op_input(const std::string& op_type) { }); return this; } -PDNode* PDNode::assert_is_op_input(const std::string& op_type, - const std::string& argument) { + +PDNode *PDNode::assert_is_op_input(const std::string &op_type, + const std::string &argument) { assert_is_var(); assert_is_op_nth_input(op_type, argument, 0); return this; } -PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) { + +PDNode *PDNode::assert_op_has_n_inputs(const std::string &op_type, size_t n) { assert_is_op(op_type); - asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; }); + asserts_.emplace_back([=](Node *x) { return x->inputs.size() == n; }); return this; } -PDNode* PDNode::assert_op_has_n_outputs(const std::string& op_type, size_t n) { + +PDNode *PDNode::assert_op_has_n_outputs(const std::string &op_type, size_t n) { assert_is_op(op_type); - asserts_.emplace_back([=](Node* x) { return x->outputs.size() == n; }); + asserts_.emplace_back([=](Node *x) { return x->outputs.size() == n; }); return this; } -PDNode* PDNode::assert_more(PDNode::teller_t&& teller) { + +PDNode *PDNode::assert_more(PDNode::teller_t &&teller) { asserts_.emplace_back(std::move(teller)); return this; } -bool VarLinksToOp(Node* node, const std::string& op_type) { - for (auto* out : node->outputs) { +PDNode *PDNode::assert_is_ops(const std::unordered_set &op_types) { + asserts_.emplace_back([op_types](Node *x) { + return x && x->IsOp() && op_types.count(x->Op()->Type()); + }); + return this; +} + +PDNode *PDNode::assert_is_ops_nth_input( + const std::unordered_set &op_types, + const std::string &argument, int nth) { + assert_is_var(); + assert_is_ops_input(op_types); + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->outputs) { + if (op->IsOp() && op_types.count(op->Op()->Type()) && + IsNthInput(x, op, argument, nth)) + return true; + } + return false; + }); + return this; +} + +PDNode *PDNode::assert_is_ops_nth_output( + const std::unordered_set &op_types, + const std::string &argument, int nth) { + assert_is_var(); + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->inputs) { + if (op->IsOp() && op_types.count(op->Op()->Type()) && + IsNthOutput(x, op, argument, nth)) + return true; + } + return false; + }); + return this; +} +PDNode *PDNode::assert_is_ops_output( + const std::unordered_set &op_types) { + assert_is_var(); + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->inputs) { + if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type())) { + return true; + } + } + return false; + }); + return this; +} + +PDNode *PDNode::assert_is_ops_output( + const std::unordered_set &op_types, + const std::string &argument) { + assert_is_var(); + assert_is_ops_nth_output(op_types, argument, 0); + return this; +} + +PDNode *PDNode::assert_is_ops_input( + const std::unordered_set &op_types) { + assert_is_var(); + asserts_.emplace_back([=](Node *x) { + for (auto *op : x->outputs) { + if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type())) { + return true; + } + } + return false; + }); + return this; +} + +PDNode *PDNode::assert_is_ops_input( + const std::unordered_set &op_types, + const std::string &argument) { + assert_is_var(); + assert_is_ops_nth_input(op_types, argument, 0); + return this; +} + +bool VarLinksToOp(Node *node, const std::string &op_type) { + for (auto *out : node->outputs) { if (out->IsOp() && out->Op()->Type() == op_type) { return true; } } return false; } -bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth) { + +bool IsNthInput(Node *var, Node *op, const std::string &argument, size_t nth) { PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(op->IsOp()); if (op->Op()->Input(argument).size() <= nth) return false; return var->Name() == op->Op()->Input(argument)[nth]; } -bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth) { + +bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { PADDLE_ENFORCE(var->IsVar()); PADDLE_ENFORCE(op->IsOp()); if (op->Op()->Output(argument).size() <= nth) return false; return var->Name() == op->Op()->Output(argument)[nth]; } -void GraphSafeRemoveNodes(Graph* graph, - const std::unordered_set& nodes) { - for (auto* node : nodes) { - graph->RemoveNode(const_cast(node)); + +void GraphSafeRemoveNodes(Graph *graph, + const std::unordered_set &nodes) { + for (auto *node : nodes) { + graph->RemoveNode(const_cast(node)); } - for (auto* node : graph->Nodes()) { + for (auto *node : graph->Nodes()) { for (auto it = node->inputs.begin(); it != node->inputs.end();) { if (nodes.count(*it)) { - it = const_cast(node)->inputs.erase(it); + it = const_cast(node)->inputs.erase(it); } else { it++; } } for (auto it = node->outputs.begin(); it != node->outputs.end();) { if (nodes.count(*it)) { - it = const_cast(node)->outputs.erase(it); + it = const_cast(node)->outputs.erase(it); } else { it++; } } } } -bool VarLinksFromOp(Node* node, const std::string& op_type) { - for (auto* out : node->inputs) { + +bool VarLinksFromOp(Node *node, const std::string &op_type) { + for (auto *out : node->inputs) { if (out->IsOp() && out->Op()->Type() == op_type) { return true; } @@ -522,30 +626,30 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { return false; } -PDNode* patterns::ConvReLU::operator()( - paddle::framework::ir::PDNode* conv_input) { +PDNode *patterns::ConvReLU::operator()( + paddle::framework::ir::PDNode *conv_input) { // Create Operators conv_input->assert_is_op_input("conv2d", "Input"); - auto* conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); - auto* relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); // Create variables // Filter - auto* conv_weight_var = pattern->NewNode(conv_weight_repr()) + auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) ->AsInput() ->assert_is_persistable_var() ->assert_is_op_input("conv2d", "Filter"); // Bias - auto* conv_bias_var = pattern->NewNode(conv_bias_repr()) + auto *conv_bias_var = pattern->NewNode(conv_bias_repr()) ->AsInput() ->assert_is_persistable_var() ->assert_is_op_input("conv2d", "Bias"); // intermediate variable, will be removed in the IR after fuse. - auto* conv_out_var = pattern->NewNode(conv_out_repr()) + auto *conv_out_var = pattern->NewNode(conv_out_repr()) ->AsIntermediate() ->assert_is_only_output_of_op("conv2d") ->assert_is_op_input("relu"); // output - auto* relu_out_var = pattern->NewNode(relu_out_repr()) + auto *relu_out_var = pattern->NewNode(relu_out_repr()) ->AsOutput() ->assert_is_op_output("relu"); @@ -555,18 +659,18 @@ PDNode* patterns::ConvReLU::operator()( return relu_out_var; } -PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x, +PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, bool with_bias) { // Create shared nodes. x->assert_is_op_input("mul", "X"); - auto* mul = pattern->NewNode(mul_repr())->assert_is_op("mul"); + auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul"); - auto* mul_w_var = pattern->NewNode(w_repr()) + auto *mul_w_var = pattern->NewNode(w_repr()) ->AsInput() ->assert_is_persistable_var() ->assert_is_op_input("mul", "Y"); - auto* mul_out_var = + auto *mul_out_var = pattern->NewNode(mul_out_repr())->assert_is_op_output("mul"); if (!with_bias) { // not with bias @@ -577,14 +681,14 @@ PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x, } else { // with bias mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); // Create operators. - auto* elementwise_add = pattern->NewNode(elementwise_add_repr()) + auto *elementwise_add = pattern->NewNode(elementwise_add_repr()) ->assert_is_op("elementwise_add"); // Create variables. - auto* bias = pattern->NewNode(bias_repr()) + auto *bias = pattern->NewNode(bias_repr()) ->assert_is_op_input("elementwise_add") ->AsInput(); - auto* fc_out = pattern->NewNode(Out_repr()) + auto *fc_out = pattern->NewNode(Out_repr()) ->AsOutput() ->assert_is_op_output("elementwise_add"); @@ -594,11 +698,11 @@ PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x, } } -PDNode* patterns::LSTM::operator()(PDNode* x) { +PDNode *patterns::LSTM::operator()(PDNode *x) { x->assert_is_op_input("lstm", "Input"); - auto* lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm"); + auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm"); #define NEW_NODE(arg__, io__) \ - auto* arg__ = \ + auto *arg__ = \ pattern->NewNode(arg__##_repr())->assert_is_op_##io__("lstm", #arg__); // Currently, the H0 and C0 are optional @@ -619,11 +723,11 @@ PDNode* patterns::LSTM::operator()(PDNode* x) { return Hidden; } -PDNode* patterns::GRU::operator()(PDNode* x) { +PDNode *patterns::GRU::operator()(PDNode *x) { x->assert_is_op_input("gru", "Input"); - auto* gru_op = pattern->NewNode(gru_repr())->assert_is_op("gru"); + auto *gru_op = pattern->NewNode(gru_repr())->assert_is_op("gru"); #define NEW_NODE(arg__, io__) \ - auto* arg__ = \ + auto *arg__ = \ pattern->NewNode(arg__##_repr())->assert_is_op_##io__("gru", #arg__); NEW_NODE(Weight, input); @@ -648,6 +752,100 @@ PDNode* patterns::GRU::operator()(PDNode* x) { return Hidden; } +PDNode *patterns::ActElewiseAdd::operator()( + paddle::framework::ir::PDNode *in_var, + std::unordered_set act_types) { + in_var->assert_is_ops_input(act_types, "X"); + + auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types); + auto *act_out_var = pattern->NewNode(act_out_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_ops_output(act_types); + act_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + + auto *ele_x_var = pattern->NewNode(ele_x_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_input("elementwise_add") + ->AsInput(); + auto *elementwise_add = + pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add"); + + auto *elewise_add_out = pattern->NewNode(elewise_add_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add", "Out"); + + act->LinksFrom({in_var}).LinksTo({act_out_var}); + elementwise_add->LinksFrom({act_out_var, ele_x_var}) + .LinksTo({elewise_add_out}); + + return elewise_add_out; +} + +PDNode *patterns::ElewiseAddAct::operator()( + paddle::framework::ir::PDNode *ele_x_var, + std::unordered_set act_types) { + auto *ele_y_var = pattern->NewNode(ele_y_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + + auto *ele_add = + pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add"); + + auto *ele_out_var = pattern->NewNode(elewise_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + + ele_out_var->AsIntermediate()->assert_is_ops_input(act_types); + + auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types); + + auto *act_out_var = + pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out"); + + ele_add->LinksFrom({ele_x_var, ele_y_var}).LinksTo({ele_out_var}); + act->LinksFrom({ele_out_var}).LinksTo({act_out_var}); + + return act_out_var; +} + +PDNode *patterns::ElewiseAddActInplaceGrad::operator()( + paddle::framework::ir::PDNode *d_act_out_var, + std::unordered_set act_types) { + // act_grad: in["Out", "Out@GRAD"], out["X@GRAD"] + // ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] + auto *act_grad = pattern->NewNode(act_grad_repr())->assert_is_ops(act_types); + + auto *act_out_var = + pattern->NewNode(act_out_repr())->assert_is_ops_input(act_types, "Out"); + + auto *d_intermediate_var = + pattern->NewNode(d_itermediate_out_repr()) + ->assert_is_ops_output(act_types, GradVarName("X")); + + act_grad->LinksFrom({d_act_out_var, act_out_var}) + .LinksTo({d_intermediate_var}); + + auto *ele_y_var = pattern->NewNode(ele_y_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_input("elementwise_add_grad", "Y"); + + auto *ele_add_grad = pattern->NewNode(ele_add_grad_repr()) + ->assert_is_op("elementwise_add_grad"); + + auto *d_ele_x_var = + pattern->NewNode(d_ele_x_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_output("elementwise_add_grad", GradVarName("X")); + + auto *d_ele_y_var = + pattern->NewNode(d_ele_y_repr()) + ->assert_is_not_ctrl_var() + ->assert_is_op_output("elementwise_add_grad", GradVarName("Y")); + + ele_add_grad->LinksFrom({d_intermediate_var, ele_y_var}) + .LinksTo({d_ele_x_var, d_ele_y_var}); + + return ele_add_grad; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 1a8d9cefb..46950ed87 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -95,6 +95,7 @@ struct PDNode { PDNode* assert_is_op(); PDNode* assert_is_op(const std::string& op_type); PDNode* assert_is_var(); + PDNode* assert_is_not_ctrl_var(); PDNode* assert_var_not_persistable(); PDNode* assert_is_persistable_var(); PDNode* assert_is_op_output(const std::string& op_type); @@ -113,6 +114,20 @@ struct PDNode { PDNode* assert_op_has_n_outputs(const std::string& op_type, size_t n); PDNode* assert_more(teller_t&& teller); + PDNode* assert_is_ops_output(const std::unordered_set& op_types); + PDNode* assert_is_ops(const std::unordered_set& op_types); + PDNode* assert_is_ops_output(const std::unordered_set& op_types, + const std::string& argument); + PDNode* assert_is_ops_nth_input( + const std::unordered_set& op_types, + const std::string& argument, int nth); + PDNode* assert_is_ops_input(const std::unordered_set& op_types); + PDNode* assert_is_ops_input(const std::unordered_set& op_types, + const std::string& argument); + PDNode* assert_is_ops_nth_output( + const std::unordered_set& op_types, + const std::string& argument, int nth); + private: PDNode(PDPattern* pattern, const std::string& name = "", Type type = Type::kVar) @@ -447,6 +462,68 @@ struct GRU : public PatternBase { PATTERN_DECL_NODE(Hidden); }; +// The following patterns are used to fuse elewise_add and act +// formula: act(ele_add(x, y)) +// op: elementwise_add + act +// named nodes: elementwise_add, act +// ele_x, ele_y, elewise_add_out, act_out +struct ElewiseAddAct : public PatternBase { + ElewiseAddAct(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "elewise_add_act") {} + + PDNode* operator()(PDNode* x, std::unordered_set acts); + + // declare operator node's name + PATTERN_DECL_NODE(ele_add); + PATTERN_DECL_NODE(act); + // declare variable node's name + PATTERN_DECL_NODE(elewise_add_out); + PATTERN_DECL_NODE(ele_y); + PATTERN_DECL_NODE(act_out); +}; + +// formula: ele_add(x, act(y)) +// op: elementwise_add + act +// named nodes: elementwise_add, act +// act_in, act_out, ele_x, elewise_add_out +struct ActElewiseAdd : public PatternBase { + ActElewiseAdd(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "act_elewise_add") {} + + PDNode* operator()(PDNode* x, std::unordered_set acts); + + // declare operator node's name + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(ele_add); + // declare variable node's name + PATTERN_DECL_NODE(act_out); + PATTERN_DECL_NODE(ele_x); + PATTERN_DECL_NODE(elewise_add_out); +}; + +// the backward of act(ele_add(x, y)) +// the act is inplace. +// op: elementwise_add_grad + act_grad +// named nodes: elementwise_add_grad, act_grad +// act_out, act_out_g, ele_y, d_itermediate_out, d_ele_x, d_ele_y +struct ElewiseAddActInplaceGrad : public PatternBase { + ElewiseAddActInplaceGrad(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "elewise_add_act_grad1") {} + + // act_grad: in["Out", "Out@GRAD"], out["X@GRAD"] + // ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"] + PDNode* operator()(PDNode* x, std::unordered_set acts); + + // declare operator node's name + PATTERN_DECL_NODE(act_grad); + PATTERN_DECL_NODE(ele_add_grad); + // declare variable node's name + PATTERN_DECL_NODE(act_out); + PATTERN_DECL_NODE(d_itermediate_out); + PATTERN_DECL_NODE(d_ele_x); + PATTERN_DECL_NODE(d_ele_y); + PATTERN_DECL_NODE(ele_y); +}; } // namespace patterns // Link two ir::Nodes from each other. @@ -454,6 +531,12 @@ struct GRU : public PatternBase { a->outputs.push_back(b); \ b->inputs.push_back(a); +// Set the out_var as the output of the op +#define IR_OP_VAR_LINK(op, out_var) \ + op->outputs.push_back(out_var); \ + out_var->inputs.clear(); \ + out_var->inputs.push_back(op); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 82ab1f40f..5d6da9f1d 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -48,6 +48,10 @@ class Node { bool IsOp() const { return type_ == Type::kOperation; } bool IsVar() const { return type_ == Type::kVariable; } + bool IsCtrlVar() const { + return type_ == Type::kVariable && + Name().find(ir::Node::kControlDepVarName) != std::string::npos; + } std::vector inputs; std::vector outputs; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index dbc3ff865..f5a54c0f4 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -57,6 +57,21 @@ std::unique_ptr ApplyParallelExecutorPass( graph = viz_pass->Apply(std::move(graph)); } + // Apply op fusion. + if (strategy.fuse_elewise_add_act_ops_) { + auto fuse_elewise_add_act_pass = + ir::PassRegistry::Instance().Get("fuse_elewise_add_act_pass"); + graph = fuse_elewise_add_act_pass->Apply(std::move(graph)); + // Apply a graph viz pass to record a graph. + if (!strategy.debug_graphviz_path_.empty()) { + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + graph = viz_pass->Apply(std::move(graph)); + } + } + // Convert graph to run on multi-devices. auto multi_devices_pass = ir::PassRegistry::Instance().Get("multi_devices_pass"); @@ -359,6 +374,7 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle +USE_PASS(fuse_elewise_add_act_pass); USE_PASS(graph_viz_pass); USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index b1a399c22..7c84a9d81 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -987,18 +987,28 @@ void FusedElemwiseAndActComputeWithBroadcast( } // --- backward -template +template struct FusedElemwiseAndActGradNoBroadcast { HOSTDEVICE void operator()(size_t i) { if (dx_ != nullptr) { - dx_[i] = UseIntermediateOut ? dx_op_(x_[i], y_[i], intermediate_out_[i], - out_[i], dout_[i]) - : dx_op_(x_[i], y_[i], out_[i], dout_[i]); + dx_[i] = UseIntermediateOut + ? dx_op_.UseIntermediateOut( + x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i]) + : dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); } if (dy_ != nullptr) { - dy_[i] = UseIntermediateOut ? dy_op_(x_[i], y_[i], intermediate_out_[i], - out_[i], dout_[i]) - : dy_op_(x_[i], y_[i], out_[i], dout_[i]); + dy_[i] = UseIntermediateOut + ? dy_op_.UseIntermediateOut( + x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i]) + : dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); + } + if (dintermediate_ != nullptr) { + dintermediate_[i] = + UseIntermediateOut + ? dintermediate_op_.UseIntermediateOut( + x_[i], intermediate_out_[i], out_[i], dout_[i]) + : dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); } } @@ -1009,37 +1019,44 @@ struct FusedElemwiseAndActGradNoBroadcast { const T *dout_; DX_OP dx_op_; DY_OP dy_op_; + DIntermediate_OP dintermediate_op_; T *dx_; T *dy_; + T *dintermediate_; }; template + typename DIntermediate_OP, bool UseIntermediateOut> void FusedElemwiseAndActGradComputeNoBroadcast( const framework::ExecutionContext &ctx, const framework::DDim &x_dim, const framework::DDim &y_dim, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *intermediate_out, const framework::Tensor *out, const framework::Tensor *dout, int axis, - framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { + framework::Tensor *dx, framework::Tensor *dy, + framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op, + DIntermediate_OP dintermediate_op) { size_t N = static_cast(framework::product(x_dim)); platform::ForRange for_range( ctx.template device_context(), N); for_range( - FusedElemwiseAndActGradNoBroadcast{ + FusedElemwiseAndActGradNoBroadcast{ x->data(), y->data(), intermediate_out ? intermediate_out->data() : nullptr, - out->data(), dout->data(), dx_op, dy_op, + out->data(), dout->data(), dx_op, dy_op, dintermediate_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())}); + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()), + dintermediate == nullptr ? nullptr : dintermediate->mutable_data( + ctx.GetPlace())}); } -template -static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y, - const T *intermediate_out, - const T *out, const T *dout, - int h, int w, DX_OP dx_op, - DY_OP dy_op, T *dx, T *dy) { +template +static void FusedElemwiseAndActGradBroadcast1CPU( + const T *x, const T *y, const T *intermediate_out, const T *out, + const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, + DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { int64_t tmp_out_idx, x_idx, y_idx; for (int i = 0; i < h; ++i) { for (int j = 0; j < w; ++j) { @@ -1055,9 +1072,11 @@ static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y, if (dx != nullptr) { T tmp = UseIntermediateOut - ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], + dout[offset]); if (BcastY) { dx[x_idx] = tmp; @@ -1071,9 +1090,11 @@ static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y, } if (dy != nullptr) { T tmp = UseIntermediateOut - ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], + dout[offset]); if (BcastY) { if (i == 0) { dy[y_idx] = tmp; @@ -1084,18 +1105,34 @@ static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y, dy[y_idx] = tmp; } } + if (d_intermediate != nullptr) { + T tmp = UseIntermediateOut + ? dintermediate_op.UseIntermediateOut( + x[x_idx], intermediate_out[tmp_out_idx], out[offset], + dout[offset]) + : dintermediate_op.Recompute(x[x_idx], y[y_idx], + out[offset], dout[i]); + if (SameShapeOfIntermediateOutAndOut) { + d_intermediate[tmp_out_idx] = tmp; + } else { + if (i == 0) { + d_intermediate[tmp_out_idx] = tmp; + } else { + d_intermediate[tmp_out_idx] += tmp; + } + } + } } } } -template -static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y, - const T *intermediate_out, - const T *out, const T *dout, - int pre, int n, int post, - DX_OP dx_op, DY_OP dy_op, - T *dx, T *dy) { +template +static void FusedElemwiseAndActGradBroadcast2CPU( + const T *x, const T *y, const T *intermediate_out, const T *out, + const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, + DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { int64_t tmp_out_idx, x_idx, y_idx; for (int i = 0; i < pre; ++i) { for (int j = 0; j < n; ++j) { @@ -1112,9 +1149,11 @@ static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y, if (dx != nullptr) { T tmp = UseIntermediateOut - ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], + dout[offset]); if (BcastY) { dx[x_idx] = tmp; @@ -1128,9 +1167,11 @@ static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y, } if (dy != nullptr) { T tmp = UseIntermediateOut - ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], + dout[offset]); if (BcastY) { if (i == 0 && k == 0) { dy[y_idx] = tmp; @@ -1141,21 +1182,40 @@ static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y, dy[y_idx] = tmp; } } + if (d_intermediate != nullptr) { + T tmp = UseIntermediateOut + ? dintermediate_op.UseIntermediateOut( + x[x_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dintermediate_op.Recompute(x[x_idx], y[y_idx], + out[offset], dout[i]); + if (SameShapeOfIntermediateOutAndOut) { + d_intermediate[tmp_out_idx] = tmp; + } else { + if (i == 0) { + d_intermediate[tmp_out_idx] = tmp; + } else { + d_intermediate[tmp_out_idx] += tmp; + } + } + } } } } } #ifdef __NVCC__ -template +template static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( const T *x, const T *y, const T *intermediate_out, const T *out, - const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { + const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, + DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { int j = blockIdx.x; int i = threadIdx.x; int tid = threadIdx.x; - T val(0); + T val(0), inter_val(0); int64_t tmp_out_idx, x_idx, y_idx; do { @@ -1170,10 +1230,12 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( } if (dx != nullptr) { - T tmp = UseIntermediateOut - ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + T tmp = + UseIntermediateOut + ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]); if (BcastY) { dx[x_idx] = tmp; @@ -1182,23 +1244,38 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( } } if (dy != nullptr) { - T tmp = UseIntermediateOut - ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + T tmp = + UseIntermediateOut + ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]); if (BcastY) { val += tmp; } else { dy[y_idx] = tmp; } } + if (d_intermediate != nullptr) { + T tmp = UseIntermediateOut + ? dintermediate_op.UseIntermediateOut( + y[y_idx], intermediate_out[tmp_out_idx], out[offset], + dout[offset]) + : dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset], + dout[offset]); + if (SameShapeOfIntermediateOutAndOut) { + d_intermediate[tmp_out_idx] = tmp; + } else { + inter_val += tmp; + } + } i += ELEMWISE_MAX_BLOCK_DIM; } while (i < h); + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; if (BcastY) { if (dy) { - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; val = paddle::platform::reduceSum(val, tid, h); if (threadIdx.x == 0) { dy[j] = val; @@ -1206,41 +1283,49 @@ static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( } } else { if (dx) { - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; val = paddle::platform::reduceSum(val, tid, h); if (threadIdx.x == 0) { dx[j] = val; } } } + if (!SameShapeOfIntermediateOutAndOut) { + if (d_intermediate) { + inter_val = paddle::platform::reduceSum(inter_val, tid, h); + if (threadIdx.x == 0) { + d_intermediate[j] = inter_val; + } + } + } } -template -static void FusedElemwiseAndActGradBroadcast1CUDA(cudaStream_t stream, - const T *x, const T *y, - const T *intermediate_out, - const T *out, const T *dout, - int h, int w, DX_OP dx_op, - DY_OP dy_op, T *dx, T *dy) { +template +static void FusedElemwiseAndActGradBroadcast1CUDA( + cudaStream_t stream, const T *x, const T *y, const T *intermediate_out, + const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, + DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int gird_size = w; FusedElemwiseAndActGradBroadcast1CUDAKernel< - T, DX_OP, DY_OP, UseIntermediateOut, BcastY, + T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY, SameShapeOfIntermediateOutAndOut><<>>( - x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dx, dy); + x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dintermediate_op, + dx, dy, d_intermediate); } -template +template static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( const T *x, const T *y, const T *intermediate_out, const T *out, - const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, T *dx, - T *dy) { + const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, + DIntermediate_OP dintermediate_op, T *dx, T *dy, T *d_intermediate) { int tid = threadIdx.x; int j = blockIdx.x; - T val(0); + T val(0), inter_val(0); int ttid = tid; int64_t tmp_out_idx, x_idx, y_idx; while (true) { @@ -1259,10 +1344,12 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( } if (dx != nullptr) { - T tmp = UseIntermediateOut - ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + T tmp = + UseIntermediateOut + ? dx_op.UseIntermediateOut(x[x_idx], y[y_idx], + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]); if (BcastY) { dx[x_idx] = tmp; @@ -1271,24 +1358,38 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( } } if (dy != nullptr) { - T tmp = UseIntermediateOut - ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], - out[offset], dout[offset]) - : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + T tmp = + UseIntermediateOut + ? dy_op.UseIntermediateOut(x[x_idx], y[y_idx], + intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op.Recompute(x[x_idx], y[y_idx], out[offset], dout[offset]); if (BcastY) { val += tmp; } else { dy[y_idx] = tmp; } } - + if (d_intermediate != nullptr) { + T tmp = UseIntermediateOut + ? dintermediate_op.UseIntermediateOut( + y[y_idx], intermediate_out[tmp_out_idx], out[offset], + dout[offset]) + : dintermediate_op.Recompute(x[x_idx], y[y_idx], out[offset], + dout[offset]); + if (SameShapeOfIntermediateOutAndOut) { + d_intermediate[tmp_out_idx] = tmp; + } else { + inter_val += tmp; + } + } ttid += ELEMWISE_MAX_BLOCK_DIM; } + int h = pre * post; + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; if (BcastY) { if (dy) { - int h = pre * post; - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; val = paddle::platform::reduceSum(val, tid, h); if (threadIdx.x == 0) { dy[j] = val; @@ -1296,40 +1397,51 @@ static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( } } else { if (dx) { - int h = pre * post; - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; val = paddle::platform::reduceSum(val, tid, h); if (threadIdx.x == 0) { dx[j] = val; } } } + if (!SameShapeOfIntermediateOutAndOut) { + if (d_intermediate) { + inter_val = paddle::platform::reduceSum(inter_val, tid, h); + if (threadIdx.x == 0) { + d_intermediate[j] = inter_val; + } + } + } } -template +template static void FusedElemwiseAndActGradBroadcast2CUDA( cudaStream_t stream, const T *x, const T *y, const T *intermediate_out, const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op, - DY_OP dy_op, T *dx, T *dy) { + DY_OP dy_op, DIntermediate_OP dintermediate_op, T *dx, T *dy, + T *dintermediate) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); int gird_size = n; FusedElemwiseAndActGradBroadcast2CUDAKernel< - T, DX_OP, DY_OP, UseIntermediateOut, BcastY, + T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, BcastY, SameShapeOfIntermediateOutAndOut><<>>( - x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op, dx, dy); + x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op, + dintermediate_op, dx, dy, dintermediate); } #endif template void FusedElemwiseAndActGradComputeWithBroadcast( const framework::ExecutionContext &ctx, const framework::DDim &x_dim, const framework::DDim &y_dim_untrimed, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *intermediate_out, const framework::Tensor *out, const framework::Tensor *dout, int axis, - framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { + framework::Tensor *dx, framework::Tensor *dy, + framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op, + DIntermediate_OP dintermediate_op) { axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis); auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); axis = (y_dim.size() == 0) ? x_dim.size() : axis; @@ -1341,70 +1453,82 @@ void FusedElemwiseAndActGradComputeWithBroadcast( int w = n; if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef __NVCC__ - FusedElemwiseAndActGradBroadcast1CUDA( ctx.template device_context().stream(), x->data(), y->data(), intermediate_out == nullptr ? nullptr : intermediate_out->data(), - out->data(), dout->data(), h, w, dx_op, dy_op, + out->data(), dout->data(), h, w, dx_op, dy_op, dintermediate_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()), + dintermediate == nullptr ? nullptr : dintermediate->mutable_data( + ctx.GetPlace())); #endif } else { - FusedElemwiseAndActGradBroadcast1CPU( x->data(), y->data(), intermediate_out == nullptr ? nullptr : intermediate_out->data(), - out->data(), dout->data(), h, w, dx_op, dy_op, + out->data(), dout->data(), h, w, dx_op, dy_op, dintermediate_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()), + dintermediate == nullptr ? nullptr : dintermediate->mutable_data( + ctx.GetPlace())); } } else { if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef __NVCC__ - FusedElemwiseAndActGradBroadcast2CUDA( ctx.template device_context().stream(), x->data(), y->data(), intermediate_out == nullptr ? nullptr : intermediate_out->data(), out->data(), dout->data(), pre, n, post, dx_op, dy_op, + dintermediate_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()), + dintermediate == nullptr ? nullptr : dintermediate->mutable_data( + ctx.GetPlace())); #endif } else { - FusedElemwiseAndActGradBroadcast2CPU( x->data(), y->data(), intermediate_out == nullptr ? nullptr : intermediate_out->data(), out->data(), dout->data(), pre, n, post, dx_op, dy_op, + dintermediate_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()), + dintermediate == nullptr ? nullptr : dintermediate->mutable_data( + ctx.GetPlace())); } } } template + typename DIntermediate_OP, bool UseIntermediateOut, + bool SameShapeOfIntermediateOutAndOut> void FusedElemwiseAndActGradComputeEx( const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *out, const framework::Tensor *intermediate_out, const framework::Tensor *dout, - int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, - DY_OP dy_op) { + int axis, framework::Tensor *dx, framework::Tensor *dy, + framework::Tensor *dintermediate, DX_OP dx_op, DY_OP dy_op, + DIntermediate_OP dintermediate_op) { const framework::DDim &x_dim = x->dims(); const framework::DDim &y_dim = y->dims(); if (UseIntermediateOut) { PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr"); } if (x_dim == y_dim) { - FusedElemwiseAndActGradComputeNoBroadcast( + FusedElemwiseAndActGradComputeNoBroadcast< + DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut>( ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy, - dx_op, dy_op); + dintermediate, dx_op, dy_op, dintermediate_op); } else { // Y is a scalar bool bcast_y = x_dim.size() >= y_dim.size(); if (x_dim.size() == y_dim.size()) { @@ -1420,16 +1544,16 @@ void FusedElemwiseAndActGradComputeEx( // z = f1(f2(x, y)) if (bcast_y) { // Y should be broadcast. FusedElemwiseAndActGradComputeWithBroadcast< - DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, true /*BcastY*/, - SameShapeOfIntermediateOutAndOut>(ctx, x_dim, y_dim, x, y, - intermediate_out, out, dout, axis, - dx, dy, dx_op, dy_op); + DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, + true /*BcastY*/, SameShapeOfIntermediateOutAndOut>( + ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy, + dintermediate, dx_op, dy_op, dintermediate_op); } else { FusedElemwiseAndActGradComputeWithBroadcast< - DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, false /*BcastY*/, - SameShapeOfIntermediateOutAndOut>(ctx, y_dim, x_dim, x, y, - intermediate_out, out, dout, axis, - dx, dy, dx_op, dy_op); + DeviceContext, T, DX_OP, DY_OP, DIntermediate_OP, UseIntermediateOut, + false /*BcastY*/, SameShapeOfIntermediateOutAndOut>( + ctx, y_dim, x_dim, x, y, intermediate_out, out, dout, axis, dx, dy, + dintermediate, dx_op, dy_op, dintermediate_op); } } } @@ -1444,7 +1568,7 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx, framework::Tensor *intermediate_out) { if (KeepIntermediateOut) { PADDLE_ENFORCE(intermediate_out, - "The keep_intermediate_value is opened, " + "The save_intermediate_out is opened, " "intermediate_out should not be nullptr."); } diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused_elemwise_activation_op.cc index b54f0091b..d88ef1594 100644 --- a/paddle/fluid/operators/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused_elemwise_activation_op.cc @@ -13,18 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fused_elemwise_activation_op.h" -#include -#include namespace paddle { namespace operators { -/* - * Whether the compound function is Unary(Binary(X, Y)). - * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final - * out. - */ -static bool IsUnaryCompound(const std::vector &functor_list) { +bool IsUnaryCompound(const std::vector &functor_list) { PADDLE_ENFORCE_EQ(functor_list.size(), 2); static std::unordered_set binary_fun = { "elementwise_add", "elementwise_mul", "elementwise_add_grad", @@ -32,10 +25,17 @@ static bool IsUnaryCompound(const std::vector &functor_list) { return binary_fun.count(functor_list[1]) != 0; } -/* - * Whether the Input(X) could be absent. - */ -static bool InputXCanBeAbsent(const std::vector &functor_list) { +bool HasInPlaceUnary(const std::vector &functor_list) { + PADDLE_ENFORCE_EQ(functor_list.size(), 2); + static std::unordered_set InplaceOpSet = {"relu", "relu_grad"}; + bool is_in_place = false; + for (auto &func_name : functor_list) { + is_in_place |= (InplaceOpSet.count(func_name) == 1); + } + return is_in_place; +} + +bool InputXCanBeAbsent(const std::vector &functor_list) { PADDLE_ENFORCE_EQ(functor_list.size(), 2); static std::unordered_set binary_fun = {"elementwise_add_grad"}; return binary_fun.count(functor_list[0]) != 0 || @@ -86,20 +86,12 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { // Whether the shape of Y is a continuous subsequence of X, // For more information please refer to the op's introduction. - bool bcast_y = x_dim.size() >= y_dim.size(); - if (x_dim.size() == y_dim.size()) { - for (int i = 0; i < x_dim.size(); ++i) { - if (x_dim[i] < y_dim[i]) { - bcast_y = false; - break; - } - } - } + bool bcast_y = IsBcastY(x_dim, y_dim); auto &out_dim = bcast_y ? x_dim : y_dim; std::string out_lod = bcast_y ? "X" : "Y"; - if (ctx->Attrs().Get("keep_intermediate_value")) { + if (ctx->Attrs().Get("save_intermediate_out")) { PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"), "Output(IntermediateOut) of FusedElemwiseActivationOp " "should not be null."); @@ -123,6 +115,20 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ctx->ShareLoD(out_lod, /*->*/ "Out"); } + static bool IsBcastY(const framework::DDim &x_dim, + const framework::DDim &y_dim) { + bool bcast_y = x_dim.size() >= y_dim.size(); + if (x_dim.size() == y_dim.size()) { + for (int i = 0; i < x_dim.size(); ++i) { + if (x_dim[i] < y_dim[i]) { + bcast_y = false; + break; + } + } + } + return bcast_y; + } + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -157,17 +163,7 @@ class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker { AddAttr("scale", "scale is used by scale_op, the default value is 0.0.") .SetDefault(0.0); - AddAttr( - "recomputation", - "Whether to recompute the Out." - "The computation of fused_elemwise_activation_grad has two methods to " - "get the dx and dy, one is to use the 'Out', and the other is not. " - "The former method will save the time of recomputing the 'Out', but it " - "must occupy the memory to store the 'out'. While, the later method " - "can avoid occupying the memory, but it must recompute the 'Out'. " - "It is useful for Unary(Binary(X, Y)). The default value is true.") - .SetDefault(true); - AddAttr("keep_intermediate_value", + AddAttr("save_intermediate_out", "Whether to save the intermediate_out.") .SetDefault(false); AddAttr>("functor_list", @@ -227,30 +223,38 @@ class FusedElemwiseActivationGradMaker protected: std::unique_ptr Apply() const override { - auto *op_desc_ptr = new framework::OpDesc(); - op_desc_ptr->SetType(this->ForwardOpType() + "_grad"); + auto *grad_op = new framework::OpDesc(); + grad_op->SetType(this->ForwardOpType() + "_grad"); for (auto &input_param : this->InputNames()) { - op_desc_ptr->SetInput(input_param, this->Input(input_param)); - op_desc_ptr->SetOutput(framework::GradVarName(input_param), - this->InputGrad(input_param, true)); + grad_op->SetInput(input_param, this->Input(input_param)); + grad_op->SetOutput(framework::GradVarName(input_param), + this->InputGrad(input_param, true)); } - for (auto &output_param : this->OutputNames()) { - op_desc_ptr->SetInput(output_param, this->Output(output_param)); - op_desc_ptr->SetInput(framework::GradVarName(output_param), - this->OutputGrad(output_param)); - } + grad_op->SetInput("Out", this->Output("Out")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op_desc_ptr->SetAttrMap(this->Attrs()); + grad_op->SetAttrMap(this->Attrs()); std::vector functor_names = - boost::get>( - op_desc_ptr->GetAttr("functor_list")); + boost::get>(grad_op->GetAttr("functor_list")); + functor_names[0] += "_grad"; functor_names[1] += "_grad"; - op_desc_ptr->SetAttr("functor_list", functor_names); - return std::unique_ptr(op_desc_ptr); + grad_op->SetAttr("functor_list", functor_names); + + if (boost::get(grad_op->GetAttr("save_intermediate_out"))) { + PADDLE_ENFORCE_NE(Output("IntermediateOut").size(), 0); + grad_op->SetInput("IntermediateOut", this->Output("IntermediateOut")); + grad_op->SetOutput(framework::GradVarName("IntermediateOut"), + this->OutputGrad("IntermediateOut")); + } else { + grad_op->SetInput("IntermediateOut", {}); + grad_op->SetOutput(framework::GradVarName("IntermediateOut"), {}); + } + + return std::unique_ptr(grad_op); } }; @@ -261,56 +265,65 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@Grad) should not be null"); - if (ctx->Attrs().Get("keep_intermediate_value")) { + + auto functor_list = + ctx->Attrs().Get>("functor_list"); + + if (ctx->Attrs().Get("save_intermediate_out")) { PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"), "Input(IntermediateOut) should not be null"); } else { - PADDLE_ENFORCE_EQ(ctx->Inputs(framework::GradVarName("Out")).size(), 1); + if (!InputXCanBeAbsent(functor_list)) { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + } } - auto funtor_list = - ctx->Attrs().Get>("functor_list"); auto x_grad_name = framework::GradVarName("X"); auto y_grad_name = framework::GradVarName("Y"); + auto inter_grad_name = framework::GradVarName("IntermediateOut"); if (ctx->HasOutput(x_grad_name)) { if (ctx->HasInputs("X")) { ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); ctx->ShareLoD("X", x_grad_name); } else { - // Node: If "X" is absence, the shape of Y should be a continuous - // subsequence of X, if not, we could not infer the shape of dx. - // Currently, only when Binary is elementwise_add or elementwise_sub, // the "X" could be absent. - PADDLE_ENFORCE(InputXCanBeAbsent(funtor_list), + PADDLE_ENFORCE(InputXCanBeAbsent(functor_list), "Only when BinaryFunctor is elementwise_add, the 'X' " "could be absent."); - // For Unary(Binary(X, Y)), IntermediateOut should not be empty. - if (IsUnaryCompound(funtor_list)) { - PADDLE_ENFORCE( - ctx->HasInputs("IntermediateOut"), - "If the compound_functor is Unary(Binary(X, Y)) and Binary " - "is elementwise_add, the intermediate_out must be not absent."); - } + // Node: If "X" is absence, the shape of Y should be a continuous + // subsequence of X, otherwise, we could not infer the shape of dx. ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(framework::GradVarName("Out"))); ctx->ShareLoD(framework::GradVarName("Out"), x_grad_name); } } + if (ctx->HasOutput(y_grad_name)) { PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); ctx->ShareLoD("Y", y_grad_name); } + + if (ctx->HasOutput(inter_grad_name)) { + // For Unary(Binary(X, Y)), IntermediateOut should not be empty. + if (IsUnaryCompound(functor_list)) { + ctx->SetOutputDim(inter_grad_name, + ctx->GetInputDim(framework::GradVarName("Out"))); + ctx->ShareLoD(framework::GradVarName("Out"), inter_grad_name); + } else { + ctx->SetOutputDim(inter_grad_name, ctx->GetInputDim("Y")); + ctx->ShareLoD("Y", inter_grad_name); + } + } } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - // PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); auto input_data_type_index = ctx.Input("Y")->type(); auto input_data_type = framework::ToDataType(input_data_type_index); return framework::OpKernelType(input_data_type, ctx.GetPlace()); diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused_elemwise_activation_op.h index 6321541aa..5ae9aea95 100644 --- a/paddle/fluid/operators/fused_elemwise_activation_op.h +++ b/paddle/fluid/operators/fused_elemwise_activation_op.h @@ -26,6 +26,24 @@ limitations under the License. */ namespace paddle { namespace operators { +/** + * Whether the compound function is Unary(Binary(X, Y)). + * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final + * out. + */ +bool IsUnaryCompound(const std::vector &functor_list); + +/** + * For the in-place unary functor, the inputs of op_desc only have Out and + * Out@Grad. + */ +bool HasInPlaceUnary(const std::vector &functor_list); + +/** + * Whether the Input(X) could be absent. + */ +bool InputXCanBeAbsent(const std::vector &functor_list); + template static void RunBinaryCompoundFunctor( @@ -39,7 +57,7 @@ static void RunBinaryCompoundFunctor( paddle::operators::math::BinaryCompoundFunctor compound_func(binary_functor, unary_functor); int axis = ctx.Attr("axis"); - if (ctx.Attr("keep_intermediate_value")) { + if (ctx.Attr("save_intermediate_out")) { FusedElemwiseAndActComputeEx, @@ -71,7 +89,7 @@ static void RunUnaryCompoundFunctors( paddle::operators::math::UnaryCompoundFunctor compound_func(unary_functor, binary_functor); - if (ctx.Attr("keep_intermediate_value")) { + if (ctx.Attr("save_intermediate_out")) { FusedElemwiseAndActComputeEx, @@ -89,7 +107,7 @@ static void RunUnaryCompoundFunctors( } template + typename UnaryFunctor, typename UnaryGradFunctor, bool InPlace> static void RunBinaryCompoundGradFunctors( const framework::ExecutionContext &ctx, const BinaryGradFunctor &binary_grad_functor, @@ -98,7 +116,7 @@ static void RunBinaryCompoundGradFunctors( const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, - framework::Tensor *y_grad) { + framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) { // Z = Binary(X, Unary(Y)) int axis = ctx.Attr("axis"); @@ -107,32 +125,40 @@ static void RunBinaryCompoundGradFunctors( UnaryFunctor>; using BinaryCompoundDyFunctor = paddle::operators::math::BinaryCompoundGradDyFunctor< - T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor>; + T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>; + using BinaryCompoundDIntermedaiteOutFunctor = + paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor< + T, BinaryGradFunctor, UnaryFunctor>; if (in_intermediate_out) { FusedElemwiseAndActGradComputeEx< DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor, - true /*UseIntermediateOut*/, + BinaryCompoundDIntermedaiteOutFunctor, true /*UseIntermediateOut*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, - y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), + y_grad, d_intermediate_out, + BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, - unary_grad_functor)); + unary_grad_functor), + BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor, + unary_functor)); } else { FusedElemwiseAndActGradComputeEx< DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor, - false /*UseIntermediateOut*/, + BinaryCompoundDIntermedaiteOutFunctor, false /*UseIntermediateOut*/, false /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, - y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), + y_grad, d_intermediate_out, + BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, - unary_grad_functor)); + unary_grad_functor), + BinaryCompoundDIntermedaiteOutFunctor(binary_grad_functor, + unary_functor)); } } template + typename BinaryFunctor, typename BinaryGradFunctor, bool InPlace> static void RunUnaryCompoundGradFunctors( const framework::ExecutionContext &ctx, const UnaryGradFunctor &unary_grad_functor, @@ -141,36 +167,44 @@ static void RunUnaryCompoundGradFunctors( const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, - framework::Tensor *y_grad) { + framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) { // Z = Unary(Binary(X, Y)) int axis = ctx.Attr("axis"); using UnaryCompoundDxFunctor = paddle::operators::math::UnaryCompoundGradDxFunctor< - T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>; + T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>; using UnaryCompoundDyFunctor = paddle::operators::math::UnaryCompoundGradDyFunctor< - T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>; + T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>; + using UnaryCompoundDIntermediateFunctor = + paddle::operators::math::UnaryCompoundGradDIntermediateFunctor< + T, UnaryGradFunctor, BinaryFunctor, InPlace>; if (in_intermediate_out) { FusedElemwiseAndActGradComputeEx< DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor, - true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>( + UnaryCompoundDIntermediateFunctor, true /*UseIntermediateOut*/, + true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, - y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, - binary_grad_functor), + y_grad, d_intermediate_out, + UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, + binary_grad_functor), UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, - binary_grad_functor)); + binary_grad_functor), + UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor)); } else { - FusedElemwiseAndActGradComputeEx( + FusedElemwiseAndActGradComputeEx< + DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor, + UnaryCompoundDIntermediateFunctor, false /*UseIntermediateOut*/, + true /*SameShapeOfIntermediateOutAndOut*/>( ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, - y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, - binary_grad_functor), + y_grad, d_intermediate_out, + UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, + binary_grad_functor), UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, - binary_grad_functor)); + binary_grad_functor), + UnaryCompoundDIntermediateFunctor(unary_grad_functor, binary_functor)); } } @@ -226,72 +260,67 @@ static void RunFunctors(const framework::ExecutionContext &ctx, } } -template -static void RunGradFunctors(const framework::ExecutionContext &ctx, - const framework::Tensor *in_x, - const framework::Tensor *in_y, - const framework::Tensor *in_out, - const framework::Tensor *in_intermediate_out, - const framework::Tensor *in_out_grad, - framework::Tensor *x_grad, - framework::Tensor *y_grad) { +template +static void RunGradFunctors( + const framework::ExecutionContext &ctx, const framework::Tensor *in_x, + const framework::Tensor *in_y, const framework::Tensor *in_out, + const framework::Tensor *in_intermediate_out, + const framework::Tensor *in_out_grad, framework::Tensor *x_grad, + framework::Tensor *y_grad, framework::Tensor *d_intermediate_out) { auto &functors = ctx.Attr>("functor_list"); auto funcs_str = functors[0] + "," + functors[1]; - // TODO(zcd): The following code can be refined. for example, use registrition if (funcs_str == "elementwise_add_grad,scale_grad") { // The backward of Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); - RunBinaryCompoundGradFunctors, - paddle::operators::math::ScaleFunctor, - paddle::operators::math::ScaleGradFunctor>( + RunBinaryCompoundGradFunctors< + DeviceContext, T, paddle::operators::math::AddGradFunctor, + paddle::operators::math::ScaleFunctor, + paddle::operators::math::ScaleGradFunctor, InPlace>( ctx, paddle::operators::math::AddGradFunctor(), paddle::operators::math::ScaleFunctor(scale), paddle::operators::math::ScaleGradFunctor(scale), in_x, in_y, in_out, - in_intermediate_out, in_out_grad, x_grad, y_grad); + in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "scale_grad,elementwise_add_grad") { // The backward of Z = Unary(Binary(X, Y)) T scale = static_cast(ctx.Attr("scale")); - RunUnaryCompoundGradFunctors, - paddle::operators::math::AddFunctor, - paddle::operators::math::AddGradFunctor, - ReComputation /*Recomputation*/>( + RunUnaryCompoundGradFunctors< + DeviceContext, T, paddle::operators::math::ScaleGradFunctor, + paddle::operators::math::AddFunctor, + paddle::operators::math::AddGradFunctor, InPlace>( ctx, paddle::operators::math::ScaleGradFunctor(scale), paddle::operators::math::AddFunctor(), paddle::operators::math::AddGradFunctor(), in_x, in_y, in_out, - in_intermediate_out, in_out_grad, x_grad, y_grad); + in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "elementwise_add_grad,relu_grad") { - RunBinaryCompoundGradFunctors, - paddle::operators::math::ReluFunctor, - paddle::operators::math::ReluGradFunctor>( + RunBinaryCompoundGradFunctors< + DeviceContext, T, paddle::operators::math::AddGradFunctor, + paddle::operators::math::ReluFunctor, + paddle::operators::math::ReluGradFunctor, InPlace>( ctx, paddle::operators::math::AddGradFunctor(), paddle::operators::math::ReluFunctor(), paddle::operators::math::ReluGradFunctor(), in_x, in_y, in_out, - in_intermediate_out, in_out_grad, x_grad, y_grad); + in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "relu_grad,elementwise_add_grad") { - RunUnaryCompoundGradFunctors, - paddle::operators::math::AddFunctor, - paddle::operators::math::AddGradFunctor, - ReComputation /*Recomputation*/>( + RunUnaryCompoundGradFunctors< + DeviceContext, T, paddle::operators::math::ReluGradFunctor, + paddle::operators::math::AddFunctor, + paddle::operators::math::AddGradFunctor, InPlace>( ctx, paddle::operators::math::ReluGradFunctor(), paddle::operators::math::AddFunctor(), paddle::operators::math::AddGradFunctor(), in_x, in_y, in_out, - in_intermediate_out, in_out_grad, x_grad, y_grad); + in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else if (funcs_str == "elementwise_mul_grad,scale_grad") { // The backward of Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); - RunBinaryCompoundGradFunctors, - paddle::operators::math::ScaleFunctor, - paddle::operators::math::ScaleGradFunctor>( + RunBinaryCompoundGradFunctors< + DeviceContext, T, paddle::operators::math::MulGradFunctor, + paddle::operators::math::ScaleFunctor, + paddle::operators::math::ScaleGradFunctor, InPlace>( ctx, paddle::operators::math::MulGradFunctor(), paddle::operators::math::ScaleFunctor(scale), paddle::operators::math::ScaleGradFunctor(scale), in_x, in_y, in_out, - in_intermediate_out, in_out_grad, x_grad, y_grad); + in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); } else { PADDLE_THROW("%s has not been implemented.", funcs_str); } @@ -313,9 +342,9 @@ class FusedElemwiseActivationKernel : public framework::OpKernel { std::vector outputs; outputs.emplace_back(output); - if (ctx.Attr("keep_intermediate_value")) { + if (ctx.Attr("save_intermediate_out")) { PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"), - "The keep_intermediate_value is enable, so the " + "The save_intermediate_out is enable, so the " "IntermediateOut should not be empty."); auto intermediate_out = ctx.Output("IntermediateOut"); outputs.emplace_back(intermediate_out); @@ -331,65 +360,63 @@ template class FusedElemwiseActivationGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - auto x = ctx.Input("X"); - auto y = ctx.Input("Y"); - + auto in_y = ctx.Input("Y"); + PADDLE_ENFORCE(in_y != nullptr, "Input(Y) should not be nullptr."); auto in_out = ctx.Input("Out"); + PADDLE_ENFORCE(in_out != nullptr, "Input(Out) should not be nullptr."); auto in_out_grad = ctx.Input(framework::GradVarName("Out")); - + PADDLE_ENFORCE(in_out_grad != nullptr, + "Input(Out@Grad) should not be nullptr."); + framework::Tensor *in_x = + const_cast(ctx.Input("X")); framework::Tensor *x_grad = ctx.Output(framework::GradVarName("X")); framework::Tensor *y_grad = ctx.Output(framework::GradVarName("Y")); + framework::Tensor *d_intermediate_out = ctx.Output( + framework::GradVarName("IntermediateOut")); - PADDLE_ENFORCE(y != nullptr, "Input(Y) should not be nullptr."); - - if (ctx.Attr("recomputation")) { - PADDLE_ENFORCE( - x != nullptr, - "The recomputation is opened, so Input(X) should not be absent."); - } else { - PADDLE_ENFORCE(in_out != nullptr, - "The recomputation is disabled, so the Input('Out') " - "should not be empty."); - } - - framework::Tensor *in_x; auto functor_list = ctx.Attr>("functor_list"); - // If functor_list contains elementwise_add, the backward doesn't use - // in_x, and in_outs. - if (x == nullptr) { - PADDLE_ENFORCE(functor_list[0] == "elementwise_add_grad" || - functor_list[1] == "elementwise_add_grad", - "Only when the compoundfunctor contains " - "elementwise_add_grad, the 'X' could be absent."); - in_x = const_cast(in_out_grad); - in_out = const_cast(in_out_grad); - } else { - in_x = const_cast(x); - } - - framework::Tensor *in_intermediate_out; - if (ctx.Attr("keep_intermediate_value")) { + // Get intermediate_out + framework::Tensor *in_intermediate_out = nullptr; + if (ctx.Attr("save_intermediate_out")) { + // if save_intermediate_out is true, for Unary(Binary(x, y)) and + // Binary(x, Unary(y)), the Binary(x, y) and Unary(y) not need to + // recompute. in_intermediate_out = const_cast( ctx.Input("IntermediateOut")); PADDLE_ENFORCE(in_intermediate_out != nullptr, - "The option of 'keep_intermediate_value' is opened, " + "The option of 'save_intermediate_out' is opened, " "so the number of 'Out' should be two."); } else { - in_intermediate_out = nullptr; + if (!InputXCanBeAbsent(functor_list)) { + PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be null."); + } + } + + // Get in_x + if (ctx.HasInput("X")) { + PADDLE_ENFORCE(in_x != nullptr, "Input(X) should not be nullptr."); + } else { + // If functor_list contains elementwise_add, the backward doesn't use + // in_x, in_y and in_out. + PADDLE_ENFORCE(InputXCanBeAbsent(functor_list), + "Only when the compoundfunctor contains " + "elementwise_add_grad, the 'X' could be absent."); + in_x = const_cast(in_out_grad); } - if (ctx.Attr("recomputation")) { - RunGradFunctors( - ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad, - y_grad); + bool has_in_place = HasInPlaceUnary(functor_list); + if (has_in_place) { + RunGradFunctors( + ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, + y_grad, d_intermediate_out); } else { - RunGradFunctors( - ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad, - y_grad); + RunGradFunctors( + ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, x_grad, + y_grad, d_intermediate_out); } } }; diff --git a/paddle/fluid/operators/math/compound_functors.h b/paddle/fluid/operators/math/compound_functors.h index 1d32a9585..7aba4a917 100644 --- a/paddle/fluid/operators/math/compound_functors.h +++ b/paddle/fluid/operators/math/compound_functors.h @@ -22,11 +22,11 @@ namespace paddle { namespace operators { namespace math { +// Z = BinaryFunctor(X, UnaryFunctor(Y)) template struct BinaryCompoundFunctor { BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2) : func1_(func1), func2_(func2) {} - // Z = BinaryFunctor(X, UnaryFunctor(Y)) inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); } @@ -40,11 +40,11 @@ struct BinaryCompoundFunctor { UnaryFunctor func2_; }; +// Z = UnaryFunctor(BinaryFunctor(X, Y)) template struct UnaryCompoundFunctor { UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2) : func1_(func1), func2_(func2) {} - // Z = UnaryFunctor(BinaryFunctor(X, Y)) inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); } @@ -58,23 +58,19 @@ struct UnaryCompoundFunctor { BinaryFunctor func2_; }; -// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get -// the dx, one is to use the 'out', and the other is not to use it. -// the former method will save the time of recomputing the -// 'out', but it must occupy the memory to store the 'out'. -// While the later method can avoid occupying this memory, -// but it must recompute the 'out'. +// Z = BinaryFunctor(X, UnaryFunctor(Y)) template struct BinaryCompoundGradDxFunctor { BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun, const UnaryFun &unary_fun) : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {} - inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) { return dout * d_binary_fun_.Dx(x, unary_fun_(y)); } - inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) { + inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out, + T dout) { return dout * d_binary_fun_.Dx(x, intermediate_out); } @@ -83,8 +79,9 @@ struct BinaryCompoundGradDxFunctor { UnaryFun unary_fun_; }; +// Z = BinaryFunctor(X, UnaryFunctor(Y)) template + typename DUnaryFun, bool InPlace> struct BinaryCompoundGradDyFunctor { BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun, const UnaryFun &unary_fun, @@ -93,13 +90,19 @@ struct BinaryCompoundGradDyFunctor { unary_fun_(unary_fun), d_unary_fun_(d_unary_fun) {} - inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { - return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_(y); + inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) { + return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y); } - inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) { - return dout * d_binary_fun_.Dy(x, intermediate_out) * - d_unary_fun_(y, intermediate_out); + inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out, + T dout) { + if (InPlace) { + return dout * d_binary_fun_.Dy(x, intermediate_out) * + d_unary_fun_.UseOut(intermediate_out); + } else { + return dout * d_binary_fun_.Dy(x, intermediate_out) * + d_unary_fun_.UseXAndOut(y, intermediate_out); + } } private: @@ -108,8 +111,9 @@ struct BinaryCompoundGradDyFunctor { DUnaryFun d_unary_fun_; }; +// Z = UnaryFunctor(BinaryFunctor(X, Y)) template + typename DBinaryFun, bool InPlace> struct UnaryCompoundGradDxFunctor { UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun, const BinaryFun &binary_fun, @@ -118,22 +122,23 @@ struct UnaryCompoundGradDxFunctor { binary_fun_(binary_fun), d_binary_fun_(d_binary_fun) {} - inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) { T base; - if (Recomputation) { - base = dout * d_unary_fun_(binary_fun_(x, y)); + if (InPlace) { + base = dout * d_unary_fun_.UseOut(out); } else { - base = dout * d_unary_fun_(binary_fun_(x, y), out); + base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out); } return base * d_binary_fun_.Dx(x, y); } - inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) { + inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out, + T dout) { T base; - if (Recomputation) { - base = dout * d_unary_fun_(intermediate_out); + if (InPlace) { + base = dout * d_unary_fun_.UseOut(out); } else { - base = dout * d_unary_fun_(intermediate_out, out); + base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out); } return base * d_binary_fun_.Dx(x, y); } @@ -144,8 +149,9 @@ struct UnaryCompoundGradDxFunctor { DBinaryFun d_binary_fun_; }; +// Z = UnaryFunctor(BinaryFunctor(X, Y)) template + typename DBinaryFun, bool InPlace> struct UnaryCompoundGradDyFunctor { UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun, const BinaryFun &binary_fun, @@ -154,22 +160,23 @@ struct UnaryCompoundGradDyFunctor { binary_fun_(binary_fun), d_binary_fun_(d_binary_fun) {} - inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) { T base; - if (Recomputation) { - base = dout * d_unary_fun_(binary_fun_(x, y)); + if (InPlace) { + base = dout * d_unary_fun_.UseOut(out); } else { - base = dout * d_unary_fun_(binary_fun_(x, y), out); + base = dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out); } return base * d_binary_fun_.Dy(x, y); } - inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) { + inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out, + T dout) { T base; - if (Recomputation) { - base = dout * d_unary_fun_(intermediate_out); + if (InPlace) { + base = dout * d_unary_fun_.UseOut(out); } else { - base = dout * d_unary_fun_(intermediate_out, out); + base = dout * d_unary_fun_.UseXAndOut(intermediate_out, out); } return base * d_binary_fun_.Dy(x, y); } @@ -180,6 +187,56 @@ struct UnaryCompoundGradDyFunctor { DBinaryFun d_binary_fun_; }; +// Z = BinaryFunctor(X, UnaryFunctor(Y)) +template +struct BinaryCompoundGradDIntermedaiteOutFunctor { + BinaryCompoundGradDIntermedaiteOutFunctor(const DBinaryFun &d_binary_fun, + const UnaryFun &unary_fun) + : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {} + + inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) { + return dout * d_binary_fun_.Dy(x, unary_fun_(y)); + } + + inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out, + T dout) { + return dout * d_binary_fun_.Dy(x, intermediate_out); + } + + private: + DBinaryFun d_binary_fun_; + UnaryFun unary_fun_; +}; + +// Z = UnaryFunctor(BinaryFunctor(X, Y)) +template +struct UnaryCompoundGradDIntermediateFunctor { + UnaryCompoundGradDIntermediateFunctor(const DUnaryFun &d_unary_fun, + const BinaryFun &binary_fun) + : d_unary_fun_(d_unary_fun), binary_fun_(binary_fun) {} + + inline HOSTDEVICE T Recompute(T x, T y, T out, T dout) { + if (InPlace) { + return dout * d_unary_fun_.UseOut(out); + } else { + return dout * d_unary_fun_.UseXAndOut(binary_fun_(x, y), out); + } + } + + inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out, + T dout) { + if (InPlace) { + return dout * d_unary_fun_.UseOut(out); + } else { + return dout * d_unary_fun_.UseXAndOut(intermediate_out, out); + } + } + + private: + DUnaryFun d_unary_fun_; + BinaryFun binary_fun_; +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/functors.h b/paddle/fluid/operators/math/functors.h index ddb01cdfc..955c0b6ba 100644 --- a/paddle/fluid/operators/math/functors.h +++ b/paddle/fluid/operators/math/functors.h @@ -58,9 +58,9 @@ template struct ScaleGradFunctor { explicit ScaleGradFunctor(T coeff) : coeff_(coeff) {} - inline HOSTDEVICE T operator()(T x) { return coeff_; } - - inline HOSTDEVICE T operator()(T x, T out) { return coeff_; } + inline HOSTDEVICE T UseX(T x) { return coeff_; } + inline HOSTDEVICE T UseOut(T out) { return coeff_; } + inline HOSTDEVICE T UseXAndOut(T x, T out) { return coeff_; } private: T coeff_; @@ -73,9 +73,9 @@ struct ReluFunctor { template struct ReluGradFunctor { - inline HOSTDEVICE T operator()(T x) { return x > 0 ? 1 : 0; } - - inline HOSTDEVICE T operator()(T x, T out) { return x > 0 ? 1 : 0; } + inline HOSTDEVICE T UseX(T x) { return x > 0 ? 1 : 0; } + inline HOSTDEVICE T UseOut(T out) { return out > 0 ? 1 : 0; } + inline HOSTDEVICE T UseXAndOut(T x, T out) { return out > 0 ? 1 : 0; } }; } // namespace math diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 1d081f89c..8b62502e3 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -670,7 +670,14 @@ All parameter, weight, gradient are variables in Paddle. .def_property( "enable_data_balance", [](const BuildStrategy &self) { return self.enable_data_balance_; }, - [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }); + [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }) + .def_property("fuse_elewise_add_act_ops", + [](const BuildStrategy &self) { + return self.fuse_elewise_add_act_ops_; + }, + [](BuildStrategy &self, bool b) { + self.fuse_elewise_add_act_ops_ = b; + }); pe.def(py::init &, const std::unordered_set &, diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index cefaa3892..e97643cdd 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -47,8 +47,7 @@ def get_numeric_gradient(place, input_to_check, output_names, delta=0.005, - in_place=False, - sum_outputs=None): + in_place=False): # FIXME: change this method by compile time concepts set_input(scope, op, inputs, place) @@ -59,8 +58,6 @@ def get_numeric_gradient(place, sum = [] op.run(scope, place) for output_name in output_names: - if sum_outputs and output_name not in sum_outputs: - continue sum.append( np.array(scope.find_var(output_name).get_tensor()).mean()) return np.array(sum).sum() / len(output_names) @@ -407,14 +404,13 @@ class OpTest(unittest.TestCase): numeric_grad_delta=0.005, in_place=False, max_relative_error=0.005, - user_defined_grads=None, - sum_outputs=None): + user_defined_grads=None): places = self._get_places() for place in places: self.check_grad_with_place(place, inputs_to_check, output_names, no_grad_set, numeric_grad_delta, in_place, max_relative_error, - user_defined_grads, sum_outputs) + user_defined_grads) def check_grad_with_place(self, place, @@ -424,8 +420,7 @@ class OpTest(unittest.TestCase): numeric_grad_delta=0.005, in_place=False, max_relative_error=0.005, - user_defined_grads=None, - sum_outputs=None): + user_defined_grads=None): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict() @@ -448,8 +443,7 @@ class OpTest(unittest.TestCase): input_to_check, output_names, delta=numeric_grad_delta, - in_place=in_place, - sum_outputs=sum_outputs) for input_to_check in inputs_to_check + in_place=in_place) for input_to_check in inputs_to_check ] analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 74e9d5c5f..ee291fe74 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -38,6 +38,7 @@ class TestParallelExecutorBase(unittest.TestCase): seed=None, use_parallel_executor=True, use_reduce=False, + fuse_elewise_add_act_ops=False, optimizer=fluid.optimizer.Adam, use_fast_executor=False): def run_executor(exe, feed, fetch_list, program=None): @@ -78,6 +79,7 @@ class TestParallelExecutorBase(unittest.TestCase): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce + build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops if use_parallel_executor: exe = fluid.ParallelExecutor( diff --git a/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py b/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py new file mode 100644 index 000000000..03471a443 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fuse_elewise_add_act_pass.py @@ -0,0 +1,156 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parallel_executor_test_base import TestParallelExecutorBase +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np +import paddle +import paddle.dataset.mnist as mnist +import unittest +import os + +MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio" + + +def simple_fc_net(use_feed): + if use_feed: + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + else: + reader = fluid.layers.open_files( + filenames=[MNIST_RECORDIO_FILE], + shapes=[[-1, 784], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + reader = fluid.layers.io.double_buffer(reader) + img, label = fluid.layers.read_file(reader) + hidden = img + for _ in range(4): + hidden = fluid.layers.fc( + hidden, + size=200, + act='relu', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def fc_with_batchnorm(use_feed): + if use_feed: + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + else: + reader = fluid.layers.open_files( + filenames=[MNIST_RECORDIO_FILE], + shapes=[[-1, 784], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + reader = fluid.layers.io.double_buffer(reader) + img, label = fluid.layers.read_file(reader) + + hidden = img + for _ in range(2): + hidden = fluid.layers.fc( + hidden, + size=200, + act='relu', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + hidden = fluid.layers.batch_norm(input=hidden) + + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestMNIST(TestParallelExecutorBase): + @classmethod + def setUpClass(cls): + os.environ['CPU_NUM'] = str(4) + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch(mnist.train(), batch_size=4) + feeder = fluid.DataFeeder( + feed_list=[ # order is image and label + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + fluid.recordio_writer.convert_reader_to_recordio_file( + MNIST_RECORDIO_FILE, reader, feeder) + + def _init_data(self, random=True): + np.random.seed(5) + if random: + img = np.random.random(size=[32, 784]).astype(np.float32) + else: + img = np.ones(shape=[32, 784], dtype='float32') + label = np.ones(shape=[32, 1], dtype='int64') + return img, label + + def _compare_fuse_elewise_add_act_ops(self, + model, + use_cuda, + random_data=True): + if use_cuda and not core.is_compiled_with_cuda(): + return + img, label = self._init_data(random_data) + + def _optimizer(learning_rate=1e-6): + optimizer = fluid.optimizer.SGD( + learning_rate=learning_rate, + regularization=fluid.regularizer.L2Decay(1e-6)) + return optimizer + + not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence( + model, + feed_dict={"image": img, + "label": label}, + use_cuda=use_cuda, + fuse_elewise_add_act_ops=False, + memory_opt=False, + optimizer=_optimizer) + fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence( + model, + feed_dict={"image": img, + "label": label}, + use_cuda=use_cuda, + fuse_elewise_add_act_ops=True, + memory_opt=False, + optimizer=_optimizer) + + for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss): + self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss): + self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + + def test_simple_fc_with_fuse_op(self): + self._compare_fuse_elewise_add_act_ops(simple_fc_net, True) + self._compare_fuse_elewise_add_act_ops(simple_fc_net, False) + + def test_batchnorm_fc_with_fuse_op(self): + self._compare_fuse_elewise_add_act_ops(fc_with_batchnorm, True) + self._compare_fuse_elewise_add_act_ops(fc_with_batchnorm, False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py index 4a213c291..3cf8e7229 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py @@ -48,7 +48,7 @@ def create_test_class(test_case, callback, attrs): 'X': OpTest.np_dtype_to_fluid_dtype(self.x), 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - if self.attrs["keep_intermediate_value"]: + if self.attrs["save_intermediate_out"]: self.outputs = { 'Out': self.out, "IntermediateOut": self.intermediate_out @@ -73,22 +73,19 @@ def create_test_class(test_case, callback, attrs): def test_check_output(self): self.check_output() + # FIXME(zcd): the intermediate_out_grad is not checked. def test_check_grad_normal(self): - if self.attrs["keep_intermediate_value"]: - self.check_grad( - ['X', 'Y'], ['Out', 'IntermediateOut'], - max_relative_error=0.005, - sum_outputs=['Out']) + if self.attrs["save_intermediate_out"]: + self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005) else: self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005) def test_check_grad_ingore_x(self): - if self.attrs["keep_intermediate_value"]: + if self.attrs["save_intermediate_out"]: self.check_grad( - ['Y'], ['Out', 'IntermediateOut'], + ['Y'], ['Out'], max_relative_error=0.005, - no_grad_set=set("X"), - sum_outputs=['Out']) + no_grad_set=set("X")) else: self.check_grad( ['Y'], ['Out'], @@ -96,12 +93,11 @@ def create_test_class(test_case, callback, attrs): no_grad_set=set("X")) def test_check_grad_ingore_y(self): - if self.attrs["keep_intermediate_value"]: + if self.attrs["save_intermediate_out"]: self.check_grad( - ['X'], ['Out', 'IntermediateOut'], + ['X'], ['Out'], max_relative_error=0.005, - no_grad_set=set("Y"), - sum_outputs=['Out']) + no_grad_set=set("Y")) else: self.check_grad( ['X'], ['Out'], @@ -303,39 +299,32 @@ for mode in {0, 1}: relu_add_func = partial(relu_add_func, mode=mode) add_relu_func = partial(add_relu_func, mode=mode) - for recomputation in {True, False}: - for keep_intermediate_value in {True, False}: - suffix = ("_keep_intermediate_value" if keep_intermediate_value else "") \ - + ("_recomputation" if recomputation else "") \ - + ("_mode_"+ str(mode)) - create_test_class('scale_add' + suffix, scale_add_func, { - 'scale': scale, - 'functor_list': ["scale", "elementwise_add"], - 'keep_intermediate_value': keep_intermediate_value, - 'recomputation': recomputation - }) - create_test_class('add_scale' + suffix, add_scale_func, { - 'scale': scale, - 'functor_list': ["elementwise_add", "scale"], - 'keep_intermediate_value': keep_intermediate_value, - 'recomputation': recomputation - }) - create_test_class('add_relu' + suffix, add_relu_func, { - 'functor_list': ["elementwise_add", "relu"], - 'keep_intermediate_value': keep_intermediate_value, - 'recomputation': recomputation - }) - create_test_class('relu_add' + suffix, relu_add_func, { - 'functor_list': ["relu", "elementwise_add"], - 'keep_intermediate_value': keep_intermediate_value, - 'recomputation': recomputation - }) - create_test_class('mul_scale' + suffix, mul_scale_func, { - 'scale': scale, - 'functor_list': ["elementwise_mul", "scale"], - 'keep_intermediate_value': keep_intermediate_value, - 'recomputation': recomputation - }) + for save_intermediate_out in {True, False}: + suffix = ("_save_intermediate_out" if save_intermediate_out else "") \ + + ("_mode_"+ str(mode)) + create_test_class('scale_add' + suffix, scale_add_func, { + 'scale': scale, + 'functor_list': ["scale", "elementwise_add"], + 'save_intermediate_out': save_intermediate_out, + }) + create_test_class('add_scale' + suffix, add_scale_func, { + 'scale': scale, + 'functor_list': ["elementwise_add", "scale"], + 'save_intermediate_out': save_intermediate_out, + }) + create_test_class('add_relu' + suffix, add_relu_func, { + 'functor_list': ["elementwise_add", "relu"], + 'save_intermediate_out': save_intermediate_out, + }) + create_test_class('relu_add' + suffix, relu_add_func, { + 'functor_list': ["relu", "elementwise_add"], + 'save_intermediate_out': save_intermediate_out, + }) + create_test_class('mul_scale' + suffix, mul_scale_func, { + 'scale': scale, + 'functor_list': ["elementwise_mul", "scale"], + 'save_intermediate_out': save_intermediate_out, + }) if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 055759365..7691221a5 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -79,7 +79,7 @@ class TestReshapeOpWithInputShape(OpTest): self.check_output(no_check_set=['XShape']) def test_check_grad(self): - self.check_grad(["X"], "Out", sum_outputs=["Out"]) + self.check_grad(["X"], "Out") if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index c30da2389..bbcabb751 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -34,7 +34,7 @@ class TestTransposeOp(OpTest): self.check_output(no_check_set=['XShape']) def test_check_grad(self): - self.check_grad(['X'], 'Out', sum_outputs=['Out']) + self.check_grad(['X'], 'Out') def initTestCase(self): self.shape = (3, 4) -- GitLab