From 8f3b252392d8bdd75888e3736ca2c948990a30e3 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 21 Jan 2019 19:49:45 +0800 Subject: [PATCH] squash commits. test=develop --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/details/CMakeLists.txt | 9 +- .../fluid/framework/details/build_strategy.cc | 20 +- .../fluid/framework/details/build_strategy.h | 2 + .../framework/details/inplace_op_pass.cc | 375 ++++++++++++++++++ .../fluid/framework/details/inplace_op_pass.h | 74 ++++ .../details/memory_early_delete_pass.cc | 2 +- ...use_types.cc => memory_optimize_helper.cc} | 52 ++- ...reuse_types.h => memory_optimize_helper.h} | 46 ++- ...test.cc => memory_optimize_helper_test.cc} | 6 +- ...is_var_pass.cc => memory_optimize_pass.cc} | 168 +++----- ...ysis_var_pass.h => memory_optimize_pass.h} | 12 +- ...s_test.cc => memory_optimize_pass_test.cc} | 2 +- paddle/fluid/framework/details/op_registry.h | 21 +- paddle/fluid/framework/inplace_op_inference.h | 135 +++++++ .../framework/inplace_op_inference_test.cc | 287 ++++++++++++++ paddle/fluid/framework/ir/node.h | 1 + paddle/fluid/framework/op_info.h | 1 + paddle/fluid/framework/type_defs.h | 3 + paddle/fluid/operators/activation_op.cc | 14 +- paddle/fluid/operators/batch_norm_op.cc | 39 +- .../elementwise/elementwise_add_op.cc | 1 + .../operators/elementwise/elementwise_op.h | 17 +- paddle/fluid/operators/flatten_op.cc | 40 +- paddle/fluid/operators/reshape_op.cc | 40 +- paddle/fluid/operators/scale_op.cc | 3 +- paddle/fluid/operators/softmax_op.cc | 15 + paddle/fluid/pybind/pybind.cc | 4 + python/paddle/fluid/__init__.py | 3 +- .../unittests/parallel_executor_test_base.py | 2 + 30 files changed, 1228 insertions(+), 167 deletions(-) create mode 100644 paddle/fluid/framework/details/inplace_op_pass.cc create mode 100644 paddle/fluid/framework/details/inplace_op_pass.h rename paddle/fluid/framework/details/{memory_reuse_types.cc => memory_optimize_helper.cc} (72%) rename paddle/fluid/framework/details/{memory_reuse_types.h => memory_optimize_helper.h} (72%) rename paddle/fluid/framework/details/{memory_reuse_types_test.cc => memory_optimize_helper_test.cc} (96%) rename paddle/fluid/framework/details/{analysis_var_pass.cc => memory_optimize_pass.cc} (80%) rename paddle/fluid/framework/details/{analysis_var_pass.h => memory_optimize_pass.h} (90%) rename paddle/fluid/framework/details/{analysis_var_pass_test.cc => memory_optimize_pass_test.cc} (99%) create mode 100644 paddle/fluid/framework/inplace_op_inference.h create mode 100644 paddle/fluid/framework/inplace_op_inference_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index a16751116..d88d9e783 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -200,6 +200,7 @@ cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry proto_desc) +cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index d5966ad5a..de81f6f67 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -50,7 +50,8 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) -cc_library(memory_optimize_pass SRCS analysis_var_pass.cc memory_reuse_types.cc DEPS graph graph_helper pass) +cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc memory_optimize_helper.cc DEPS graph graph_helper pass) +cc_library(inplace_op_pass SRCS inplace_op_pass DEPS memory_optimize_pass op_info) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) @@ -65,12 +66,12 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) -set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass) +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass memory_early_delete_pass inplace_op_pass) if (WITH_GPU) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) endif() -cc_test(memory_reuse_types_test SRCS memory_reuse_types_test.cc memory_reuse_types.cc DEPS framework_proto graph) -cc_test(analysis_var_pass_test SRCS analysis_var_pass_test.cc analysis_var_pass.cc memory_reuse_types.cc DEPS framework_proto graph graph_helper op_registry pass) +cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph) +cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 756470c5b..0831772a9 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" @@ -42,6 +42,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { public: explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) : ir::PassBuilder(), strategy_(strategy) { + if (strategy_.enable_inplace_) { + AppendPass("inplace_pass"); + } if (strategy_.enable_sequential_execution_) { AppendPass("sequential_execution_pass"); } @@ -87,7 +90,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // A side-effect of that, memory optimize cannot forsee the fetched vars // , so fetchlist should be set persistable before call the Run interface. if (strategy.memory_optimize_) { - auto analysis_var_pass = AppendPass("analysis_var_pass"); + auto memory_optimize_pass = AppendPass("memory_optimize_pass"); } AppendMultiDevPass(strategy); @@ -185,8 +188,7 @@ std::unique_ptr BuildStrategy::Apply( pass->Erase("nccl_ctxs"); pass->SetNotOwned("nccl_ctxs", nctx); #endif - - } else if (pass->Type() == "analysis_var_pass") { + } else if (pass->Type() == "memory_optimize_pass") { const std::vector *all_op_descs = new std::vector(main_program.Block(0).AllOps()); graph->Set>(kAllOpDescs, @@ -213,6 +215,13 @@ std::unique_ptr BuildStrategy::Apply( pass->Set>( kAllOpDescs, new std::vector(main_program.Block(0).AllOps())); + } else if (pass->Type() == "inplace_pass") { + if (graph->Has(kAllOpDescs)) { + graph->Erase(kAllOpDescs); + } + graph->Set>( + kAllOpDescs, + new std::vector(main_program.Block(0).AllOps())); } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") { if (!use_cuda) { LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on " @@ -238,8 +247,9 @@ USE_PASS(allreduce_mode_multi_devices_pass); USE_PASS(dist_multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); -USE_PASS(analysis_var_pass); +USE_PASS(memory_optimize_pass); USE_PASS(sequential_execution_pass); USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); +USE_PASS(inplace_pass); USE_PASS(lock_free_optimize_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 603df2e06..11a80d5f9 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -80,6 +80,8 @@ struct BuildStrategy { bool memory_early_delete_{false}; + bool enable_inplace_{false}; + bool enable_sequential_execution_{false}; bool fuse_broadcast_op_{false}; diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc new file mode 100644 index 000000000..b08935e56 --- /dev/null +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -0,0 +1,375 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/inplace_op_pass.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/details/memory_optimize_pass.h" +#include "paddle/fluid/framework/op_info.h" + +// NOTE(dzhwinter): inplace means one op output variable reuse the input space. +// By our design, one operator only can read its input(const Variable), +// write its output(non-const Variable). If one operator is inplaced, means +// user have chance to write the space before reading happens. +// Especially when some optimize code writing style is applied. +// +// +// /* wrong case in operator */ +// /*In this case, a larger allocation is allocated, input content is lost*/ +// const Tensor* in = ctx.Input("In") +// Tensor* out = ctx.Output("Out"); +// auto* out_ptr = out->mutable_data(ctx.GetPlace()); +// out_ptr[0] = 0; // input contect is overwrited. + +// For backward compacity. if enable_inplace_whitelist is turn on. +// only the ops in whitelist will be use inplace strategy. +// if not, all the op will be inplaced if it registered with InplaceClass +DEFINE_bool( + enable_inplace_whitelist, true, + "If this option turns on, only these op in whitelist can be inplaced." + "If it turns off, all of the running op can be candidate of inplaced op." + "Such as scale, elementwise_add" + "By default, it's turned on"); + +// clang-format off +const std::string kInplacedOpWhiteList[] = { // NOLINT + "sigmoid", + "exp", + "relu", + "tanh", + "sqrt", + "ceil", + "floor", + "reciprocal", + "relu6", + "soft_relu", + "hard_sigmoid", + "batch_norm", + "batch_norm_grad", + "sum", + "sum_grad", + "scale", + "reshape", + "elementwise_add", + "elementwise_add_grad", +}; +// clang-format on + +namespace paddle { +namespace framework { +namespace details { + +static inline ir::Node* GetNextInplacedOpOutput(ir::Node* var) { + // if next op is inplaced, then return the output var + // otherwise return nullptr + PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); + ir::Node* inplaced_var = nullptr; + // only has one output op can be inplaced + if (var->outputs.size() == 1 && var->outputs[0]->IsOp()) { + auto* op = var->outputs[0]; + for (auto* out_var : op->outputs) { + if (!out_var->IsVar() || out_var->IsCtrlVar() || + out_var->Var() == nullptr) + continue; + if (out_var->Name() == var->Name()) { + inplaced_var = out_var; + break; + } + } + } + return inplaced_var; +} + +static inline ir::Node* GetPrevInplacedOpInput(ir::Node* var) { + PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); + ir::Node* inplaced_var = nullptr; + if (var->inputs.size() == 1 && var->inputs[0]->IsOp()) { + auto* op = var->inputs[0]; + for (auto* in_var : op->inputs) { + if (!in_var->IsVar() || in_var->IsCtrlVar() || in_var->Var() == nullptr) + continue; + if (in_var->Name() == var->Name()) { + inplaced_var = in_var; + break; + } + } + } + return inplaced_var; +} + +template +static inline bool ConnectByCtrlVar(const Container& group1, + const Container& group2) { + bool connected = false; + std::unordered_set outputs; + for (auto* op : group1) { + for (auto* var : op->outputs) { + if (var->IsCtrlVar()) outputs.emplace(var); + } + } + for (auto* op : group2) { + for (auto* var : op->inputs) { + if (outputs.count(var)) connected = true; + } + } + return connected; +} + +InplacePass::InplacePass() : Pass() { + if (FLAGS_enable_inplace_whitelist) { + for (auto& s : kInplacedOpWhiteList) { + whitelist_.emplace(s); + } + } +} + +void InplacePass::InitSSAGraphNodes() const { + std::unordered_map> all_vars; + for (auto* op : view_.AllOps()) { + for (auto* node : op->inputs) { + if (!node->IsVar() || node->IsCtrlVar()) continue; + if (all_vars[node->Name()].count(node) == 0) { + all_vars[node->Name()].emplace(node); + var_nodes_[node->Name()].emplace_back(node); + } + } + for (auto* node : op->outputs) { + if (!node->IsVar() || node->IsCtrlVar()) continue; + if (all_vars[node->Name()].count(node) == 0) { + all_vars[node->Name()].emplace(node); + var_nodes_[node->Name()].emplace_back(node); + } + } + } +} + +std::unique_ptr InplacePass::ApplyImpl( + std::unique_ptr graph) const { + var_nodes_.clear(); + view_.Build(graph.get()); + InitSSAGraphNodes(); + + for (auto* op : view_.AllOps()) { + if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) + continue; + TryInplaceOpInputOutput(op, graph.get()); + } + graph->ResolveHazard(var_nodes_); + return graph; +} + +void InplacePass::InplaceModifyDesc(const std::string& var, + const std::string& cache_var, + const size_t& idx) const { + for (size_t i = idx; i < view_.AllOps().size(); ++i) { + auto* op = view_.AllOps()[i]; + PADDLE_ENFORCE(op->IsOp() && op->Op()); + auto* op_desc = op->Op(); + op_desc->RenameInput(var, cache_var); + op_desc->RenameOutput(var, cache_var); + if (op_desc->Block()->HasVar(var)) op_desc->Block()->RemoveVar(var); + op_desc->Flush(); + } +} + +void InplacePass::InplaceModifyVar(const std::string& var, + const std::string& cache_var, + const size_t& idx, ir::Graph* graph) const { + PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && + var_nodes_[var].at(0)->Var() != nullptr); + std::unique_ptr var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); + var_desc->SetName(cache_var); + + for (size_t i = idx; i < view_.AllOps().size(); ++i) { + auto* op = view_.AllOps()[i]; + + // redirect the input to the latest version of cache_var + for (auto* node : op->inputs) { + if (node->Name() == var) { + ir::Node* cache_node = var_nodes_[cache_var].back(); + // swap node to cache_node + cache_node->outputs.insert(cache_node->outputs.end(), + node->outputs.begin(), node->outputs.end()); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, + cache_node); + } + } + } + + // if we need to rename the output, + // always create a newer version of cache_var + for (auto* node : op->outputs) { + if (node->Name() == var) { + ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + var_nodes_[cache_var].emplace_back(cache_node); + + // swap node to cache node + cache_node->outputs.insert(cache_node->outputs.end(), + node->outputs.begin(), node->outputs.end()); + cache_node->inputs.emplace_back(op); + std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, + cache_node); + } + } + } + } + + // release node of unused var in graph + for (auto* node : var_nodes_[var]) { + graph->RemoveNode(node); + } + var_nodes_.at(var).clear(); +} + +void InplacePass::TryInplaceOpInputOutput(ir::Node* op, + ir::Graph* graph) const { + PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr, + "op_desc is nullptr"); + // 3 pre-requirments need to meet if the op want to inplaced. + // 1. infer_inplace_ is registered. + auto* op_desc = op->Op(); + auto& infer_inplace = + OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_; + if (!static_cast(infer_inplace)) return; + PADDLE_ENFORCE(static_cast(infer_inplace), + "%s's infer_inplace has not been registered", op_desc->Type()); + + auto* block = op_desc->Block(); + auto in_to_outs = infer_inplace(*op_desc, block); + + auto& all_ops = view_.AllOps(); + auto cursor = std::find(all_ops.begin(), all_ops.end(), op); + size_t idx = std::distance(all_ops.begin(), cursor); + + for (auto& pair : in_to_outs) { + auto& in_var_name = pair.first; + auto& out_var_name = pair.second; + auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); + auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); + // 2. there is no external pending op on the input node + if (view_.PendingOpsOnVar(in_node).size() > 1) { + VLOG(3) << string::Sprintf( + "!!! %s input has external dependency, can not inplaced, %s => %s " + "skiped", + op->Name(), out_var_name, in_var_name); + continue; + } + // 3. if output reuse input inplaced, the dependency group is not changed. + // For detail, check + // the function description in "OutConnectInputByCtrlVar" + if (view_.OutConnectInputByCtrlVar(in_node, out_node)) { + VLOG(3) << string::Sprintf( + "!!! %s input output connect by ctrl var, cannot inplaced, %s => %s " + "skiped", + op->Name(), out_var_name, in_var_name); + continue; + } + VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), + out_var_name, in_var_name); + InplaceModifyDesc(out_var_name, in_var_name, idx); + InplaceModifyVar(out_var_name, in_var_name, idx, graph); + } +} + +ir::Node* GraphView::GetNodeByName(const std::string& name, + const std::vector& nodes) const { + // nodes should be op->inputs/outputs + // node in same node do have different name. + std::unordered_set nodes_in_op; + bool has_dup_node = + std::all_of(nodes.begin(), nodes.end(), [&nodes_in_op](ir::Node* node) { + if (!node->IsVar() || node->IsCtrlVar() || node->Var() == nullptr) { + if (nodes_in_op.count(node->Name())) return true; + nodes_in_op.emplace(node->Name()); + } + return false; + }); + PADDLE_ENFORCE(has_dup_node == false, "nodes has same name!"); + ir::Node* node = nullptr; + for (auto* it : nodes) { + if (!it->IsVar() || it->IsCtrlVar() || it->Var() == nullptr) continue; + if (it->Name() == name) { + node = it; + break; + } + } + PADDLE_ENFORCE(node != nullptr, + string::Sprintf("Not found var %s in nodes!", name)); + return node; +} + +std::vector GraphView::PendingOpsOnVar(ir::Node* node) { + return node->outputs; +} + +void GraphView::Build(ir::Graph* g) { ops_ = SortOpLikeDescOrder(*g); } + +const std::vector GraphView::AllOps() { return ops_; } + +bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) { + // assume v_a0, v_a1 is variable. v_a0 -> v_a0 means already inplaced. + // v_a1 -> v_a1 means already inplaced. + // Currently we make decision to check if the v_a0 -> v_a1 can be inplace. + // + // v_a0 + // + + // | + // v + // v_a0 + // + + // | + // v + // v_a1 + // + + // | + // v + // v_a1 + // start from the first inplaced input v_a0(on the top one). + // Do a DFSSearch, get all its paths. If there is one path connect + // the in_var and out_var which contains control dep var. + // Means there a control path. out_var can not be inplaced use in_var. + + std::unordered_set out_var_set, in_var_set; + ir::Node* out = out_var; + // get the ops with same output name + while (out != nullptr) { + out_var_set.emplace(out); + out = GetNextInplacedOpOutput(out); + } + + // get ops with same input name + ir::Node* in = in_var; + while (in != nullptr) { + in_var_set.emplace(in); + in = GetPrevInplacedOpInput(in); + } + // find if there is path with control dep var connect the in_var_set and + // out_var_set + return ConnectByCtrlVar(in_var_set, out_var_set); +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass); diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h new file mode 100644 index 000000000..c2b565a74 --- /dev/null +++ b/paddle/fluid/framework/details/inplace_op_pass.h @@ -0,0 +1,74 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/details/memory_optimize_helper.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +class GraphView { + public: + GraphView() = default; + + void Build(ir::Graph* g); + + const std::vector AllOps(); + + ir::Node* GetNodeByName(const std::string& name, + const std::vector& nodes) const; + + std::vector PendingOpsOnVar(ir::Node* var); + + bool OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var); + + private: + std::vector ops_; +}; + +class InplacePass : public ir::Pass { + public: + InplacePass(); + + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; + + void InitSSAGraphNodes() const; + + private: + void InplaceModifyVar(const std::string& in_var, const std::string& out_var, + const size_t& idx, ir::Graph* graph) const; + + void InplaceModifyDesc(const std::string& in_var, const std::string& out_var, + const size_t& idx) const; + + void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const; + + mutable std::map> var_nodes_; + + mutable std::unordered_set whitelist_; + mutable GraphView view_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/memory_early_delete_pass.cc b/paddle/fluid/framework/details/memory_early_delete_pass.cc index 5906b7d57..69f8f7054 100644 --- a/paddle/fluid/framework/details/memory_early_delete_pass.cc +++ b/paddle/fluid/framework/details/memory_early_delete_pass.cc @@ -16,7 +16,7 @@ #include #include #include -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h" diff --git a/paddle/fluid/framework/details/memory_reuse_types.cc b/paddle/fluid/framework/details/memory_optimize_helper.cc similarity index 72% rename from paddle/fluid/framework/details/memory_reuse_types.cc rename to paddle/fluid/framework/details/memory_optimize_helper.cc index 2b9ff518b..55bac90a8 100644 --- a/paddle/fluid/framework/details/memory_reuse_types.cc +++ b/paddle/fluid/framework/details/memory_optimize_helper.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include #include #include @@ -83,7 +83,7 @@ struct NodeComparator { } }; -void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) { +void OrderedNodeList::Insert(ir::Node* var, ir::Node* op) { PADDLE_ENFORCE(var->IsVar() && !var->IsCtrlVar()); PADDLE_ENFORCE(op->IsOp()); if (mark_table_.count(var->Name()) != 0) { @@ -119,11 +119,11 @@ void OrderedNodePairPool::Insert(ir::Node* var, ir::Node* op) { mark_table_[var->Name()] = it; } -int OrderedNodePairPool::GetIndex(ir::Node* var) { +int OrderedNodeList::GetIndex(ir::Node* var) { return std::distance(nodes_.begin(), mark_table_[var->Name()]); } -ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const { +ir::Node* OrderedNodeList::NodeMatch(ir::Node* var) const { ir::Node* found_node = nullptr; NodeComparator compare_node; @@ -136,13 +136,15 @@ ir::Node* OrderedNodePairPool::NodeMatch(ir::Node* var) const { return found_node; } -void OrderedNodePairPool::Erase(ir::Node* var) { - PADDLE_ENFORCE(mark_table_.count(var->Name())); - nodes_.erase(mark_table_[var->Name()]); - mark_table_.erase(var->Name()); +void OrderedNodeList::Erase(ir::Node* var) { Erase(var->Name()); } + +void OrderedNodeList::Erase(const std::string& var) { + PADDLE_ENFORCE(mark_table_.count(var)); + nodes_.erase(mark_table_[var]); + mark_table_.erase(var); } -std::string OrderedNodePairPool::ToString() const { +std::string OrderedNodeList::ToString() const { std::stringstream ss; for (auto it = nodes_.begin(); it != nodes_.end(); ++it) { ss << DebugString(it->first) << " "; @@ -150,6 +152,38 @@ std::string OrderedNodePairPool::ToString() const { return ss.str(); } +bool NodeCanReused(ir::Node* node) { + if (node == nullptr || !node->IsVar() || node->IsCtrlVar()) return false; + auto* desc = node->Var(); + auto type = desc->GetType(); + if (desc->Persistable() || type != proto::VarType::LOD_TENSOR || + desc->GetShape().empty()) { + return false; + } + // vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad + std::string name = node->Name(); + if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@') + return false; + for (auto* op : node->inputs) { + if (op->Op()->HasAttr("force_cpu")) { + // op output force generated in cpu, can not be reused. + return framework::AttrReader(op->Op()->GetAttrMap()) + .Get("force_cpu") == 0; + } + } + return true; +} + +bool OpHasSubBlock(OpDesc* desc) { + const AttributeMap& attrs = desc->GetAttrMap(); + for (auto& attr : attrs) { + if (attr.second.type() == typeid(BlockDesc*) || // NOLINT + attr.second.type() == typeid(std::vector)) // NOLINT + return true; + } + return false; +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/memory_reuse_types.h b/paddle/fluid/framework/details/memory_optimize_helper.h similarity index 72% rename from paddle/fluid/framework/details/memory_reuse_types.h rename to paddle/fluid/framework/details/memory_optimize_helper.h index 9a9c1d948..02f896325 100644 --- a/paddle/fluid/framework/details/memory_reuse_types.h +++ b/paddle/fluid/framework/details/memory_optimize_helper.h @@ -43,7 +43,7 @@ using GraphNodePool = std::vector< // For example, // node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], .. // O(1) insert, delete -class OrderedNodePairPool { +class OrderedNodeList { public: using NodePair = std::pair>; using Iter = typename std::list::iterator; @@ -53,8 +53,12 @@ class OrderedNodePairPool { void Erase(ir::Node* var); + void Erase(const std::string& var); + bool Has(ir::Node* var) { return mark_table_.count(var->Name()); } + bool Has(const std::string& var) { return mark_table_.count(var); } + ir::Node* NodeMatch(ir::Node* var) const; // map store non-const iterator, can not promise const int GetIndex(ir::Node* var); @@ -67,6 +71,11 @@ class OrderedNodePairPool { ConstIter end() const { return nodes_.end(); } size_t size() const { return nodes_.size(); } + void Clear() { + mark_table_.clear(); + nodes_.clear(); + } + private: // for searching. std::unordered_map mark_table_; @@ -74,14 +83,47 @@ class OrderedNodePairPool { std::list nodes_; }; +// valid a tensor can be reuse or not +bool NodeCanReused(ir::Node* node); + +// check op has subblock or not +bool OpHasSubBlock(OpDesc* desc); + // node memory size in bytes size_t NodeSizeInBytes(ir::Node* n); std::string DebugString(ir::Node* var); -// std::string DebugString(VarDesc* var); VarDesc* FindVarDescInBlock(ir::Node* n); +template +class FilterVariableImpl { + public: + void operator()(const Container& nodes, Callback callback) { + for (auto* node : nodes) { + callback(node); + } + } +}; + +// filter var node for op->inputs/outputs +template +class FilterVariableImpl, Callback> { + public: + void operator()(const std::vector& nodes, Callback callback) { + for (auto* var : nodes) { + if (var->IsVar() && !var->IsCtrlVar()) { + callback(var); + } + } + } +}; + +template +void FilterVariables(const Container& nodes, Callback callback) { + FilterVariableImpl()(nodes, callback); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/memory_reuse_types_test.cc b/paddle/fluid/framework/details/memory_optimize_helper_test.cc similarity index 96% rename from paddle/fluid/framework/details/memory_reuse_types_test.cc rename to paddle/fluid/framework/details/memory_optimize_helper_test.cc index d2fabf5ce..f2b9baf14 100644 --- a/paddle/fluid/framework/details/memory_reuse_types_test.cc +++ b/paddle/fluid/framework/details/memory_optimize_helper_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include #include #include @@ -27,8 +27,8 @@ namespace paddle { namespace framework { namespace details { -TEST(OrderedNodePairPool, Normal) { - OrderedNodePairPool pool; +TEST(OrderedNodeList, Normal) { + OrderedNodeList pool; std::vector> nodes; // clang-format off diff --git a/paddle/fluid/framework/details/analysis_var_pass.cc b/paddle/fluid/framework/details/memory_optimize_pass.cc similarity index 80% rename from paddle/fluid/framework/details/analysis_var_pass.cc rename to paddle/fluid/framework/details/memory_optimize_pass.cc index 223b9da3c..33ca45668 100644 --- a/paddle/fluid/framework/details/analysis_var_pass.cc +++ b/paddle/fluid/framework/details/memory_optimize_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/analysis_var_pass.h" +#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include #include #include @@ -48,35 +48,7 @@ static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) { op1->Outputs() == op2->Outputs(); } -template -class FilterVariableImpl { - public: - void operator()(const Container& nodes, Callback callback) { - for (auto* node : nodes) { - callback(node); - } - } -}; - -// filter var node for op->inputs/outputs -template -class FilterVariableImpl, Callback> { - public: - void operator()(const std::vector& nodes, Callback callback) { - for (auto* var : nodes) { - if (var->IsVar() && !var->IsCtrlVar()) { - callback(var); - } - } - } -}; - -template -void FilterVariables(const Container& nodes, Callback callback) { - FilterVariableImpl()(nodes, callback); -} - -std::unique_ptr AnalysisVarPass::ApplyImpl( +std::unique_ptr MemoryOptimizePass::ApplyImpl( std::unique_ptr graph) const { auto nodes = graph->Nodes(); auto subblock_vars = GetSubBlockVars(nodes); @@ -103,48 +75,53 @@ std::unique_ptr AnalysisVarPass::ApplyImpl( } for (auto& var : op->outputs) { - if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) { - ir::Node* cache = pool_.NodeMatch(var); - if (var->Name() == FLAGS_memory_optimize_debug) { - VLOG(3) << "start match var " << DebugString(var) << " of op " - << op->Name(); - VLOG(3) << pool_.ToString(); - VLOG(3) << "matched in pool : " - << ((cache == nullptr) ? "False" : "True"); - } - if (cache != nullptr) { - if (var->Name() == cache->Name()) { - VLOG(3) << "The same cache variable is cascade reused." - << var->Name() << " is re-filled to the pool after" - << "the reused op is finished. Current op can not " - << "replace it again. Skip this candidate."; - continue; - } + if (!NodeCanReused(var) || cfg_->Use(op).count(var->Name()) == 0 || + skip_set_.count(var->Name())) + continue; + ir::Node* cache = pool_.NodeMatch(var); + + if (var->Name() == FLAGS_memory_optimize_debug) { + VLOG(3) << "start match var " << DebugString(var) << " of op " + << op->Name(); + VLOG(3) << pool_.ToString(); + VLOG(3) << "matched in pool : " + << ((cache == nullptr) ? "False" : "True"); + } - int node_idx_in_pool = pool_.GetIndex(cache); - VLOG(3) << string::Sprintf( - "!!! %s, %s => %s, cache idx %d, pool size %d", - std::to_string(reuse_id++), DebugString(var), DebugString(cache), - node_idx_in_pool, static_cast(pool_.size())); - // update CFG Graph on the fly. - // reused var maybe re-fill into the pool - cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx); - // NOTE(dzhwinter): we need to both update the ProgramDesc - // and IR Graph. because op_desc/var_desc is used in CreateOp, - // CreateVar when running happens. But IR Graph - // define the dependence relationship between nodes. - RenameVarInGraphDesc(var->Name(), cache->Name(), idx); - RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get()); - - pool_.Erase(cache); + if (cache == nullptr) continue; + if (var->Name() == cache->Name()) { + VLOG(3) << "The same cache variable is cascade reused." << var->Name() + << " is re-filled to the pool after" + << "the reused op is finished. Current op can not " + << "replace it again. Skip this candidate."; + continue; + + int node_idx_in_pool = pool_.GetIndex(cache); + VLOG(3) << string::Sprintf( + "!!! %s, %s => %s, cache idx %d, pool size %d", + std::to_string(reuse_id++), DebugString(var), DebugString(cache), + node_idx_in_pool, static_cast(pool_.size())); + // update CFG Graph on the fly. + // reused var maybe re-fill into the pool + cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx); + // NOTE(dzhwinter): we need to both update the ProgramDesc + // and IR Graph. because op_desc/var_desc is used in CreateOp, + // CreateVar when running happens. But IR Graph + // define the dependence relationship between nodes. + RenameVarInGraphDesc(var->Name(), cache->Name(), idx); + RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get()); + + pool_.Erase(cache); + } + // fill the pool + std::unordered_set unlived_vars; + for (auto var : cfg_->LiveIn(op)) { + if (cfg_->LiveOut(op).count(var) == 0) { + unlived_vars.emplace(var); } } - } - // fill the pool - for (auto var : cfg_->LiveIn(op)) { - if (cfg_->LiveOut(op).count(var) == 0) { + for (auto var : unlived_vars) { ir::Node* var_node = cfg_->GetNodeFromVarName(var, op); - if (var_node == nullptr) continue; if (NodeCanReused(var_node) && !pool_.Has(var_node)) { pool_.Insert(var_node, op); } @@ -177,7 +154,7 @@ std::unique_ptr AnalysisVarPass::ApplyImpl( return graph; } -void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { +void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const { // conditional block, while op and their grad op auto* sub_block_desc = AttrReader(op_desc->GetAttrMap()).Get("sub_block"); @@ -247,7 +224,7 @@ void AnalysisVarPass::SubGraphOptimize(OpDesc* op_desc) const { } } -std::unordered_set AnalysisVarPass::GetSubBlockVars( +std::unordered_set MemoryOptimizePass::GetSubBlockVars( const std::unordered_set& nodes) const { std::unordered_set vars; for (auto& op : nodes) { @@ -263,9 +240,9 @@ std::unordered_set AnalysisVarPass::GetSubBlockVars( return vars; } -void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, - const std::string& cache_var, - size_t idx) const { +void MemoryOptimizePass::RenameVarInGraphDesc(const std::string& var, + const std::string& cache_var, + size_t idx) const { for (size_t i = idx; i < cfg_->Ops().size(); ++i) { auto* op = cfg_->Ops()[i]; PADDLE_ENFORCE(op->IsOp() && op->Op()); @@ -277,7 +254,7 @@ void AnalysisVarPass::RenameVarInGraphDesc(const std::string& var, } } -void AnalysisVarPass::InitSSAGraphNodes() const { +void MemoryOptimizePass::InitSSAGraphNodes() const { std::unordered_map> all_vars; if (var_nodes_.empty()) { for (auto* op : cfg_->Ops()) { @@ -297,9 +274,10 @@ void AnalysisVarPass::InitSSAGraphNodes() const { } } -void AnalysisVarPass::RenameVarInGraphNode(const std::string& var, - const std::string& cache_var, - size_t idx, ir::Graph* graph) const { +void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, + const std::string& cache_var, + size_t idx, + ir::Graph* graph) const { // if replace happens, we need to create a newer version cache_var // but use the same dims/data_type with var. PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && @@ -358,39 +336,6 @@ void AnalysisVarPass::RenameVarInGraphNode(const std::string& var, var_nodes_.at(var).clear(); } -bool AnalysisVarPass::NodeCanReused(ir::Node* node) const { - if (!node->IsVar() || node->IsCtrlVar()) return false; - auto* desc = node->Var(); - auto type = desc->GetType(); - if (desc->Persistable() || type != proto::VarType::LOD_TENSOR || - desc->GetShape().empty()) { - return false; - } - // vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad - std::string name = node->Name(); - if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@') - return false; - if (skip_set_.count(name)) return false; - for (auto* op : node->inputs) { - if (op->Op()->HasAttr("force_cpu")) { - // op output force generated in cpu, can not be reused. - return framework::AttrReader(op->Op()->GetAttrMap()) - .Get("force_cpu") == 0; - } - } - return true; -} - -bool AnalysisVarPass::OpHasSubBlock(OpDesc* desc) const { - const AttributeMap& attrs = desc->GetAttrMap(); - for (auto& attr : attrs) { - if (attr.second.type() == typeid(BlockDesc*) || // NOLINT - attr.second.type() == typeid(std::vector)) // NOLINT - return true; - } - return false; -} - std::vector SortOpLikeDescOrder(const ir::Graph& graph) { PADDLE_ENFORCE(graph.Has(kAllOpDescs), "Graph has no attribute of kAllOpDescs."); @@ -651,6 +596,7 @@ ir::Node* ControlFlowGraph::GetNodeFromVarName(const std::string& name, } // namespace framework } // namespace paddle -REGISTER_PASS(analysis_var_pass, paddle::framework::details::AnalysisVarPass) +REGISTER_PASS(memory_optimize_pass, + paddle::framework::details::MemoryOptimizePass) .RequireGraphAttr(paddle::framework::details::kGraphNodePool) .RequireGraphAttr(paddle::framework::details::kAllOpDescs); diff --git a/paddle/fluid/framework/details/analysis_var_pass.h b/paddle/fluid/framework/details/memory_optimize_pass.h similarity index 90% rename from paddle/fluid/framework/details/analysis_var_pass.h rename to paddle/fluid/framework/details/memory_optimize_pass.h index 144204bea..b3e026e0b 100644 --- a/paddle/fluid/framework/details/analysis_var_pass.h +++ b/paddle/fluid/framework/details/memory_optimize_pass.h @@ -25,7 +25,7 @@ #include #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/details/memory_reuse_types.h" +#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h" @@ -35,12 +35,10 @@ namespace details { constexpr char kAllOpDescs[] = "all_op_descs"; std::vector SortOpLikeDescOrder(const ir::Graph& graph); -// sort op in bfs order -std::vector BFSSortGraphOps(const ir::Graph& graph); class ControlFlowGraph; -class AnalysisVarPass : public ir::Pass { +class MemoryOptimizePass : public ir::Pass { protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; @@ -57,17 +55,13 @@ class AnalysisVarPass : public ir::Pass { ir::Graph* graph) const; void SubGraphOptimize(OpDesc* op_desc) const; - // valid a tensor can be reuse or not - bool NodeCanReused(ir::Node* node) const; // scan subblock and collect the output/input variables. std::unordered_set GetSubBlockVars( const std::unordered_set&) const; - // check op has subblock or not - bool OpHasSubBlock(OpDesc* desc) const; private: // Reuse Node Pool, Owned. - mutable OrderedNodePairPool pool_; + mutable OrderedNodeList pool_; // controlflow Graph mutable std::unique_ptr cfg_; // skip set diff --git a/paddle/fluid/framework/details/analysis_var_pass_test.cc b/paddle/fluid/framework/details/memory_optimize_pass_test.cc similarity index 99% rename from paddle/fluid/framework/details/analysis_var_pass_test.cc rename to paddle/fluid/framework/details/memory_optimize_pass_test.cc index 9bc4fd33f..cde78bc3b 100644 --- a/paddle/fluid/framework/details/analysis_var_pass_test.cc +++ b/paddle/fluid/framework/details/memory_optimize_pass_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/analysis_var_pass.h" +#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include #include #include diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index eea7e712f..0901e59f9 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/grad_op_desc_maker.h" +#include "paddle/fluid/framework/inplace_op_inference.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" @@ -32,7 +33,8 @@ enum OpInfoFillType { kOpProtoAndCheckerMaker = 1, kGradOpDescMaker = 2, kVarTypeInference = 3, - kShapeInference = 4 + kShapeInference = 4, + kInplaceOpInference = 5 }; template @@ -48,8 +50,11 @@ struct OpInfoFillTypeID { ? kVarTypeInference : (std::is_base_of::value ? kShapeInference - : static_cast( - -1))))); + : (std::is_base_of< + InplaceOpInference, T>::value + ? kInplaceOpInference + : static_cast( + -1)))))); } }; @@ -139,6 +144,16 @@ struct OpInfoFiller { } }; +template +struct OpInfoFiller { + void operator()(const char* op_type, OpInfo* info) const { + info->infer_inplace_ = [](const OpDesc& op_desc, BlockDesc* block) { + T infer; + return infer(op_desc, block); + }; + } +}; + } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/inplace_op_inference.h b/paddle/fluid/framework/inplace_op_inference.h new file mode 100644 index 000000000..fe28c7ed2 --- /dev/null +++ b/paddle/fluid/framework/inplace_op_inference.h @@ -0,0 +1,135 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/type_defs.h" + +namespace paddle { +namespace framework { + +/* + Inplace Inference for create In->Out pairs for inplaced operator. + If we specify a pair of corresponding names. For example, X->Out. + then Out will inplaced use X's memory. The base class will do + legality validation for both variables. +*/ +class InplaceOpInference { + public: + virtual ~InplaceOpInference() {} + virtual std::unordered_map operator()( + const OpDesc& op_desc, BlockDesc* block) const = 0; +}; + +class InplaceInToOut : public InplaceOpInference { + public: + std::unordered_map operator()( + const OpDesc& op_desc, BlockDesc* block) const { + std::unordered_map ret; + auto in_out_var_names_pair = this->Apply(op_desc, block); + for (auto& pair : in_out_var_names_pair) { + PADDLE_ENFORCE(!op_desc.Input(pair.first).empty(), + string::Sprintf("op %s do not have input of %s!", + op_desc.Type(), pair.first)); + PADDLE_ENFORCE(!op_desc.Output(pair.second).empty(), + string::Sprintf("op %s do not have output of %s!", + op_desc.Type(), pair.second)); + auto& in_name = op_desc.Input(pair.first).at(0); + auto& out_name = op_desc.Output(pair.second).at(0); + + auto in = block->FindRecursiveOrCreateVar(in_name); + auto out = block->FindRecursiveOrCreateVar(out_name); + if (TryInplaceInputOutput(in, out)) ret.insert({in_name, out_name}); + } + return ret; + } + + protected: + virtual std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const = 0; + + bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const { + auto var_can_reused = [&](const VarDesc& node) -> bool { + auto type = node.GetType(); + if (node.Persistable() || type != proto::VarType::LOD_TENSOR || + node.GetShape().empty()) { + return false; + } + // vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad + std::string name = node.Name(); + if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@') + return false; + return true; + }; + + auto var_size_in_bytes = [&](const VarDesc& node) -> size_t { + auto shape = node.GetShape(); + int size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + size_t type_size = SizeOfType(node.GetDataType()); + return type_size * std::abs(size); + }; + + return in.Name() != out.Name() && var_can_reused(in) && + var_can_reused(out) && + var_size_in_bytes(out) <= var_size_in_bytes(in); + } +}; + +/* + Inplace In and Out for operator only have an Input and an Output. + For example, activation op. + */ +class SingleOpInplaceInToOut : public InplaceInToOut { + protected: + std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const override { + PADDLE_ENFORCE(!op_desc.InputNames().empty(), + "Op inputs must not be empty"); + PADDLE_ENFORCE(!op_desc.OutputNames().empty(), + "Op outputs must not be empty"); + auto x_name = op_desc.InputNames().at(0); + auto out_name = op_desc.OutputNames().at(0); + return std::unordered_map{{x_name, out_name}}; + } +}; + +/* + Gradient op. Inplace output use it's Input. + For example, Input@Grad->Input reuse strategy. + */ +class GradOpInplaceInToOut : public InplaceInToOut { + protected: + std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const override { + std::unordered_map ret; + std::unordered_set output_names(op_desc.OutputNames().begin(), + op_desc.OutputNames().end()); + for (auto& input_name : op_desc.InputNames()) { + if (output_names.count(GradVarName(input_name))) { + ret.insert({input_name, GradVarName(input_name)}); + } + } + return ret; + } +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/inplace_op_inference_test.cc b/paddle/fluid/framework/inplace_op_inference_test.cc new file mode 100644 index 000000000..121f648a5 --- /dev/null +++ b/paddle/fluid/framework/inplace_op_inference_test.cc @@ -0,0 +1,287 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_type_inference.h" + +namespace paddle { +namespace framework { + +class NOP : public OperatorBase { + public: + NOP(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const Scope& scope, + const platform::Place& place) const override {} +}; + +class SingleOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class SingleGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("single_op_grad"); + op->SetInput("Out", OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return std::unique_ptr(op); + } +}; + +class SingleOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + ctx->HasInput("X"); + ctx->HasOutput("Out"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } +}; + +class SingleGradOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + ctx->HasInput(framework::GradVarName("Out")); + ctx->HasOutput(framework::GradVarName("X")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); + } +}; + +class MultiOutOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddInput("Y", "").AsDuplicable(); + AddInput("Z", "").AsDuplicable(); + AddOutput("Out", ""); + AddOutput("YOut", ""); + AddOutput("ZOut", ""); + AddOutput("NotReuseOut", ""); + AddComment(""); + } +}; + +class MultiOutShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + ctx->ShareDim("X", "Out"); + ctx->ShareDim("Y", "YOut"); + ctx->ShareDim("Z", "ZOut"); + } +}; + +class MultiGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("multi_out_grad"); + op->SetInput("X", Input("X")); + op->SetOutput(framework::GradVarName("Y"), OutputGrad("YOut")); + op->SetOutput(framework::GradVarName("X"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("Z"), OutputGrad("ZOut")); + return std::unique_ptr(op); + } +}; + +class MultiOutGradShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("Y"), + ctx->GetInputDim(framework::GradVarName("YOut"))); + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + ctx->SetOutputDim(framework::GradVarName("Z"), + ctx->GetInputDim(framework::GradVarName("ZOut"))); + } +}; + +class MultiOutInplaceInToOut : public framework::InplaceInToOut { + public: + using framework::InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const override { + return std::unordered_map{ + {"X", "Out"}, {"Y", "YOut"}, {"Z", "ZOut"}, + }; + } +}; + +class MultiOutGradInplaceInToOut : public framework::InplaceInToOut { + public: + using framework::InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const OpDesc& op_desc, BlockDesc* block) const override { + return std::unordered_map{ + {framework::GradVarName("YOut"), framework::GradVarName("Y")}, + {framework::GradVarName("Out"), framework::GradVarName("X")}, + {framework::GradVarName("ZOut"), framework::GradVarName("Z")}, + }; + } +}; + +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; +REGISTER_OPERATOR(single_op, f::NOP, f::SingleOpMaker, f::SingleGradOpMaker, + f::SingleOpInplaceInToOut, f::SingleOpShapeInference); +REGISTER_OPERATOR(single_op_grad, f::NOP, f::SingleOpInplaceInToOut, + f::SingleGradOpShapeInference); +REGISTER_OPERATOR(multi_out_op, f::NOP, f::MultiOutOpMaker, f::MultiGradOpMaker, + f::MultiOutInplaceInToOut, f::MultiOutShapeInference); +REGISTER_OPERATOR(multi_out_grad, f::NOP, f::MultiOutGradInplaceInToOut, + f::MultiOutGradShapeInference); + +namespace paddle { +namespace framework { + +TEST(InferInplace, SingleOpInplaceInToOut) { + ProgramDesc prog; + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("single_op"); + op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); + op->SetOutput("Out", {"test2_out"}); + + prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 64}); + prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_out"); + prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16}); + + auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; + auto in_to_outs = infer_inplace(*op, op->Block()); + EXPECT_EQ(in_to_outs.size(), 1ul); + auto it = in_to_outs.begin(); + EXPECT_EQ(it->first, "test2_a"); + EXPECT_EQ(it->second, "test2_out"); +} + +TEST(InferInplace, SingleGradOpInplaceInToOut) { + ProgramDesc prog; + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("single_op_grad"); + op->SetInput(GradVarName("Out"), {"test2_out"}); + op->SetOutput(GradVarName("X"), {"test2_a", "test2_b", "test2_c"}); + + prog.MutableBlock(0)->Var("test2_a")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_a")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("test2_b")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_c")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("test2_out"); + prog.MutableBlock(0)->Var("test2_out")->SetShape({32, 16}); + + auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; + auto in_to_outs = infer_inplace(*op, op->Block()); + EXPECT_EQ(in_to_outs.size(), 1ul); + auto it = in_to_outs.begin(); + EXPECT_EQ(it->first, "test2_out"); + EXPECT_EQ(it->second, "test2_a"); +} + +TEST(InferInplace, MultiOutInplaceInToOut) { + ProgramDesc prog; + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("multi_out_op"); + op->SetInput("X", {"a0", "a1"}); + op->SetInput("Y", {"b0"}); + op->SetInput("Z", {"c0", "c1"}); + op->SetOutput("Out", {"o0"}); + op->SetOutput("YOut", {"y0"}); + op->SetOutput("ZOut", {"z0"}); + + prog.MutableBlock(0)->Var("a0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c1")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("o0"); + prog.MutableBlock(0)->Var("y0"); + prog.MutableBlock(0)->Var("z0"); + prog.MutableBlock(0)->Var("a0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("b0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("c0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("o0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("y0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("z0")->SetShape({32, 16}); + + auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; + auto in_to_outs = infer_inplace(*op, op->Block()); + EXPECT_EQ(in_to_outs.size(), 3ul); + std::unordered_map expects = { + {"a0", "o0"}, {"b0", "y0"}, {"c0", "z0"}, + }; + EXPECT_TRUE(expects == in_to_outs); +} + +TEST(InferInplace, MultiGradInplaceInToOut) { + ProgramDesc prog; + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("multi_out_grad"); + op->SetInput(GradVarName("Out"), {"o0"}); + op->SetInput(GradVarName("YOut"), {"y0"}); + op->SetInput(GradVarName("ZOut"), {"z0"}); + op->SetOutput(GradVarName("X"), {"a0", "a1"}); + op->SetOutput(GradVarName("Y"), {"b0"}); + op->SetOutput(GradVarName("Z"), {"c0", "c1"}); + + prog.MutableBlock(0)->Var("a0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("b0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c0")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("c1")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(0)->Var("o0"); + prog.MutableBlock(0)->Var("y0"); + prog.MutableBlock(0)->Var("z0"); + prog.MutableBlock(0)->Var("a0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("b0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("c0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("o0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("y0")->SetShape({32, 16}); + prog.MutableBlock(0)->Var("z0")->SetShape({32, 16}); + + auto& infer_inplace = OpInfoMap::Instance().Get(op->Type()).infer_inplace_; + auto in_to_outs = infer_inplace(*op, op->Block()); + EXPECT_EQ(in_to_outs.size(), 3ul); + std::unordered_map expects = { + {"o0", "a0"}, {"y0", "b0"}, {"z0", "c0"}, + }; + EXPECT_TRUE(expects == in_to_outs); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 9eade9eaa..fb4fa54d3 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index 19e5c2c73..4b55bd070 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -38,6 +38,7 @@ struct OpInfo { OpAttrChecker* checker_{nullptr}; InferVarTypeFN infer_var_type_; InferShapeFN infer_shape_; + InferInplaceOpFN infer_inplace_; bool HasOpProtoAndChecker() const { return proto_ != nullptr && checker_ != nullptr; diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 938e2024c..d02c699b9 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -57,5 +57,8 @@ using InferVarTypeFN = using InferShapeFN = std::function; +using InplacePair = std::unordered_map; +using InferInplaceOpFN = std::function; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 9c5b8604f..7c29eac46 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -547,12 +547,14 @@ namespace ops = paddle::operators; __macro(Swish, swish); \ __macro(ThresholdedRelu, thresholded_relu); -#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ - REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ - ::paddle::operators::OP_NAME##OpMaker, \ - ::paddle::operators::ActivationOpInferVarType, \ - ::paddle::operators::OP_NAME##GradMaker); \ - REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) +#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ + REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ + ::paddle::operators::OP_NAME##OpMaker, \ + ::paddle::operators::ActivationOpInferVarType, \ + ::paddle::operators::OP_NAME##GradMaker, \ + ::paddle::framework::SingleOpInplaceInToOut); \ + REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \ + ::paddle::framework::SingleOpInplaceInToOut) #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 8b672e09b..facfc8a91 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -602,13 +602,48 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker { } }; +class BatchNormInplaceInToOut : public framework::InplaceInToOut { + public: + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {"Mean", "MeanOut"}, {"Variance", "VarianceOut"}, {"X", "Y"}, + }; + return inplace_in_to_out; + } +}; + +class BatchNormGradInplaceInToOut : public framework::InplaceInToOut { + public: + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + // Scale, Bias, SavedMean, SavedVariance shape is [batch_size, C] + {framework::GradVarName("Y"), framework::GradVarName("X")}, + {"SavedMean", framework::GradVarName("Scale")}, + {"SavedVariance", framework::GradVarName("Bias")}, + }; + return inplace_in_to_out; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, - ops::BatchNormOpInferVarType, ops::BatchNormGradMaker); -REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); + ops::BatchNormOpInferVarType, ops::BatchNormGradMaker, + ops::BatchNormInplaceInToOut); +REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp, + ops::BatchNormGradInplaceInToOut); REGISTER_OP_CPU_KERNEL( batch_norm, ops::BatchNormKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index 7e789cd8d..c6c658236 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -18,6 +18,7 @@ namespace ops = paddle::operators; REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add); REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out", "X"); + REGISTER_OP_CPU_KERNEL( elementwise_add, ops::ElementwiseAddKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index fd2a98cb4..d04bb8f33 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -250,6 +250,20 @@ class ElemwiseGradKernel : public framework::OpKernel { } }; +class ElementwiseOpInplace : public framework::InplaceInToOut { + public: + using framework::InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + return std::unordered_map{ + {"X", "Out"}, + }; + } +}; + } // namespace operators } // namespace paddle @@ -299,6 +313,7 @@ class ElemwiseGradKernel : public framework::OpKernel { REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \ __ElemwiseOp##op_type##Maker__, \ ::paddle::operators::ElementwiseOpInferVarType, \ - op_type##GradMaker); \ + op_type##GradMaker, \ + ::paddle::operators::ElementwiseOpInplace); \ REGISTER_OPERATOR(op_type##_grad, \ ::paddle::operators::ElementwiseOpExplicitGrad) diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 8e80dc0e6..bb904166c 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -267,6 +267,35 @@ class Flatten2GradOp : public framework::OperatorBase { } }; +class FlattenOpInplaceInToOut : public framework::InplaceInToOut { + public: + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {"X", "Out"}, + }; + return inplace_in_to_out; + } +}; + +class FlattenGradInplaceinToOut : public framework::InplaceInToOut { + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {framework::GradVarName("Out"), framework::GradVarName("X")}, + }; + return inplace_in_to_out; + } +}; + } // namespace operators } // namespace paddle @@ -275,10 +304,13 @@ USE_OP(reshape); namespace ops = paddle::operators; REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker, ops::FlattenOpInferShape, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape); + paddle::framework::DefaultGradOpDescMaker, + ops::FlattenOpInplaceInToOut); +REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape, + ops::FlattenGradInplaceinToOut); REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker, - ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker); + ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker, + ops::FlattenOpInplaceInToOut); REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp, - ops::Flatten2GradInferShape); + ops::Flatten2GradInferShape, ops::FlattenGradInplaceinToOut); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 8eab3a6f8..91fdd4309 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -327,13 +327,44 @@ class Reshape2GradOp : public framework::OperatorWithKernel { } }; +class ReshapeOpInplaceInToOut : public framework::InplaceInToOut { + public: + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {"X", "Out"}, + }; + return inplace_in_to_out; + } +}; + +class ReshapeGradInplaceInToOut : public framework::InplaceInToOut { + using InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + std::unordered_map inplace_in_to_out = { + {framework::GradVarName("Out"), framework::GradVarName("X")}, + }; + return inplace_in_to_out; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, - paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); + paddle::framework::DefaultGradOpDescMaker, + ops::ReshapeOpInplaceInToOut); +REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp, + ops::ReshapeGradInplaceInToOut); REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel); @@ -343,8 +374,9 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, ops::ReshapeGradKernel); REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, - ops::Reshape2GradMaker); -REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp); + ops::Reshape2GradMaker, ops::ReshapeOpInplaceInToOut); +REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp, + ops::ReshapeGradInplaceInToOut); REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, ops::ReshapeKernel); diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 981969d2a..4ea77ed30 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -100,13 +100,14 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { } }; +using ScaleOpInplace = framework::SingleOpInplaceInToOut; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker, - ops::ScaleOpVarTypeInference); + ops::ScaleOpVarTypeInference, ops::ScaleOpInplace); REGISTER_OP_CPU_KERNEL( scale, ops::ScaleKernel, ops::ScaleKernel, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index bc889a5a0..8fbf299a7 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -198,6 +198,21 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker { return std::unique_ptr(op); } }; + +class SoftmaxInplaceInToOut : public framework::InplaceInToOut { + public: + using framework::InplaceInToOut::InplaceInToOut; + + protected: + std::unordered_map Apply( + const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + return std::unordered_map{ + {"X", "Out"}, + }; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 96d0d16bf..86b19e907 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1049,6 +1049,10 @@ All parameter, weight, gradient are variables in Paddle. "memory_early_delete", [](const BuildStrategy &self) { return self.memory_early_delete_; }, [](BuildStrategy &self, bool b) { self.memory_early_delete_ = b; }) + .def_property( + "enable_inplace", + [](const BuildStrategy &self) { return self.enable_inplace_; }, + [](BuildStrategy &self, bool b) { self.enable_inplace_ = b; }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 564882bd2..396f36e18 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -158,7 +158,8 @@ def __bootstrap__(): 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', 'sync_nccl_allreduce', 'limit_of_tmp_allocation', - 'times_excess_than_required_tmp_allocation' + 'times_excess_than_required_tmp_allocation', + 'enable_inplace_whitelist' ] core.init_gflags([sys.argv[0]] + diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index fdacd241f..5ef1d2cfa 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -41,6 +41,7 @@ class TestParallelExecutorBase(unittest.TestCase): use_parallel_executor=True, use_reduce=False, use_ir_memory_optimize=False, + enable_inplace=True, fuse_elewise_add_act_ops=False, fuse_relu_depthwise_conv=False, optimizer=fluid.optimizer.Adam, @@ -80,6 +81,7 @@ class TestParallelExecutorBase(unittest.TestCase): build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv build_strategy.memory_optimize = use_ir_memory_optimize + build_strategy.enable_inplace = enable_inplace build_strategy.enable_sequential_execution = enable_sequential_execution if use_cuda and core.is_compiled_with_cuda(): build_strategy.remove_unnecessary_lock = True -- GitLab