From 89966525f1d667dda3bf5c44a8071c1ab716af4e Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 28 Nov 2019 10:09:54 +0800 Subject: [PATCH] Polish reference count pass (#21324) * fix ref_cnt pass, test=develop * add cpp unittests to reference_count_pass, test=develop * follow comments, test=develop --- paddle/fluid/framework/details/CMakeLists.txt | 14 +- .../ir/memory_optimize_pass/CMakeLists.txt | 2 + .../reference_count_pass.cc | 146 +++++------- ...est_reference_count_pass_last_lived_ops.cc | 210 ++++++++++++++++++ paddle/fluid/framework/parallel_executor.cc | 35 +-- paddle/fluid/operators/CMakeLists.txt | 2 +- 6 files changed, 298 insertions(+), 111 deletions(-) create mode 100644 paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 6acd54c6027..6789c54210a 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -25,7 +25,7 @@ if(WITH_GPU) nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - dynload_cuda variable_visitor place) + dynload_cuda variable_visitor place device_memory_aligment) if(WITH_DGC) nv_library(sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope @@ -46,7 +46,7 @@ else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory - variable_visitor place) + variable_visitor place device_memory_aligment) if(WITH_DISTRIBUTE) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim selected_rows_functor sendrecvop_rpc) @@ -103,4 +103,14 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS lock_free_optimize_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass + sync_batch_norm_pass runtime_context_cache_pass + pass_builder ${NGRAPH_BS_DEPS}) + +if (WITH_MKLDNN) + target_link_libraries(build_strategy mkldnn_placement_pass) +endif() + +if (WITH_NGRAPH) + target_link_libraries(build_strategy ngraph_subgraph_pass) +endif() diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index 81921c40aa4..726a2d90fcf 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -12,3 +12,5 @@ cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handl cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass) cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass) + +cc_test(test_reference_count_pass_last_lived_ops SRCS test_reference_count_pass_last_lived_ops.cc DEPS parallel_executor elementwise_mul_op elementwise_add_op scale_op) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc index cc26f7f96b2..4584b3d4e0f 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc @@ -202,34 +202,7 @@ static bool ShrinkNoNeedBufferVarOpDependency( } } -/** - * Find the nearest downstream computation op handle. If the op is a - * computation op, just return itself. - */ -static details::ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( - details::OpHandleBase *op, size_t scope_idx) { - std::queue q; - std::unordered_set visited; - q.push(op); - while (!q.empty()) { - auto *op = q.front(); - q.pop(); - auto *compute_op = dynamic_cast(op); - if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) { - return compute_op; - } - for (auto *out_var : op->Outputs()) { - for (auto *pending_op : out_var->PendingOps()) { - if (visited.count(pending_op)) continue; - visited.insert(pending_op); - q.push(pending_op); - } - } - } - return nullptr; -} - -enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede }; +enum LastLiveOpSearchStatus { kSuccess, kFailure }; static std::unordered_set ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx, @@ -237,22 +210,7 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx, const ShrinkDepsOpFunctor &shrink_func, LastLiveOpSearchStatus *status) { // stage one. Get last op for variable. - std::unordered_set candidates; - { - if (var->PendingOps().empty() && var->GeneratedOp()) { - // No operator depends on this variable. So the last operator is the op - // who generates this variable. - candidates.emplace(var->GeneratedOp()); - } else { - candidates = var->PendingOps(); - } - - // No pending ops or generated op is nullptr - if (candidates.empty()) { - *status = LastLiveOpSearchStatus::kFailure; - return {}; - } - } + auto candidates = var->PendingOps(); // stage two. Try to cast them to computation op. // return (*status=kFailure) when failed. @@ -262,37 +220,41 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx, // some op handle may operate on many DeviceContext, however, our garbage // collector can only wait one DeviceContext for now. So currently, we wait // the nearest compute op. - std::unordered_set computation_op; + std::unordered_set computation_ops; { for (auto *op : candidates) { - auto *compute_op = - FindNextComputationOpHandleOrReturnItself(op, scope_idx); - if (compute_op == nullptr) { + auto *compute_op = dynamic_cast(op); + if (compute_op && compute_op->GetScopeIdx() == scope_idx) { + computation_ops.emplace(compute_op); + } else { *status = LastLiveOpSearchStatus::kFailure; return {}; } - computation_op.emplace(compute_op); + } + + auto *generated_op = + dynamic_cast(var->GeneratedOp()); + if (generated_op && generated_op->GetScopeIdx() == scope_idx) { + computation_ops.emplace(generated_op); } } // stage three. Try to shrink computation op if any of them does // not need the buffer of var_name. - // If all computation ops do not need the buffer of var_name, - // return empty computation op set, and mark the status as kShouldPrecede, - // which means that the last living ops of var_name should be - // found in the previous version of var_name. - if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) { - *status = LastLiveOpSearchStatus::kShouldPrecede; + if (computation_ops.empty() || + ShrinkNoNeedBufferVarOpDependency(var_name, &computation_ops)) { + *status = LastLiveOpSearchStatus::kFailure; return {}; } - PADDLE_ENFORCE(!computation_op.empty(), - "Computation ops should not be empty"); + PADDLE_ENFORCE_EQ( + computation_ops.empty(), false, + platform::errors::InvalidArgument("Computation ops should not be empty")); // stage four. Try to shrink computation op if they depend on each other. // Get the smallest set of the most ops. *status = LastLiveOpSearchStatus::kSuccess; - return shrink_func(computation_op); + return shrink_func(computation_ops); } void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { @@ -344,47 +306,45 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { PADDLE_ENFORCE_EQ(var_desc->Name(), var_name); - for (auto iter = var_handles.rbegin(); iter != var_handles.rend(); - ++iter) { - if ((*iter)->Node()->IsCtrlVar()) { - break; - } + PADDLE_ENFORCE_EQ( + var_handles.empty(), false, + platform::errors::InvalidArgument("Variable %s not found", var_name)); + auto last_ver_var = var_handles.back(); - VLOG(10) << "Try to find last living ops of " << var_name << " " - << (iter - var_handles.rbegin()) << " time"; - LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure; - auto result = ExtractComputationOpFromLastLivedVar( - *iter, i, var_name, shrink_func, &status); - - // Seldomly, some vars may have no pending or preceding computation ops - // Just break; - if (status == LastLiveOpSearchStatus::kFailure) { - VLOG(1) << "Cannot find last live ops of variable " << var_name - << " in scope " << (*iter)->scope_idx(); - break; - } + if (last_ver_var->Node()->IsCtrlVar()) { + continue; + } - if (status == LastLiveOpSearchStatus::kShouldPrecede) { - VLOG(10) << "Try to precede reference count computing at var " - << var_name; - continue; - } + LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure; + auto result = ExtractComputationOpFromLastLivedVar( + last_ver_var, i, var_name, shrink_func, &status); + + // Seldomly, some vars may have no pending or preceding computation ops + // Just break; + if (status == LastLiveOpSearchStatus::kFailure) { + VLOG(1) << "Cannot find last live ops of variable " << var_name + << " in scope " << last_ver_var->scope_idx(); + continue; + } - PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess); - PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty", - var_name); + PADDLE_ENFORCE_EQ( + status, LastLiveOpSearchStatus::kSuccess, + platform::errors::InvalidArgument("status must be success")); + PADDLE_ENFORCE_EQ(result.empty(), false, + platform::errors::NotFound( + "Last living ops of %s cannot be empty", var_name)); - VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; - var_infos[i][var_name].reset( - new MemOptVarInfo(var_name, result.size())); - auto &last_live_ops_of_var = last_live_ops_of_vars[i][var_name]; - last_live_ops_of_var.set_var(*iter); - *(last_live_ops_of_var.mutable_ops()) = std::move(result); - break; + std::string last_live_ops_log_str; + for (auto &each_ret : result) { + last_live_ops_log_str += (" " + each_ret->GetOp()->Type()); } + VLOG(10) << "Extract " << result.size() << " ops of var " << var_name + << " : " << last_live_ops_log_str; - // Seldomly, all preceding trying failed. - // Just skip this corner case + var_infos[i][var_name].reset(new MemOptVarInfo(var_name, result.size())); + auto &last_live_ops_of_var = last_live_ops_of_vars[i][var_name]; + last_live_ops_of_var.set_var(last_ver_var); + *(last_live_ops_of_var.mutable_ops()) = std::move(result); } } } diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc new file mode 100644 index 00000000000..89c97541fbc --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc @@ -0,0 +1,210 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" +#include "paddle/fluid/framework/parallel_executor.h" +#include "paddle/fluid/framework/program_desc.h" + +USE_OP(scale); +USE_OP(elementwise_mul); +USE_OP(elementwise_add); +USE_OP(elementwise_add_grad); + +DECLARE_double(eager_delete_tensor_gb); + +namespace paddle { +namespace framework { + +static std::vector CreatePlaces(size_t num, bool use_cuda) { + std::vector result; + result.reserve(num); + for (size_t i = 0; i < num; ++i) { + if (use_cuda) { + result.emplace_back(platform::CUDAPlace(i)); + } else { + result.emplace_back(platform::CPUPlace()); + } + } + return result; +} + +static void NewVar(BlockDesc *block, const std::string &name, + const std::vector &shape) { + auto *var_desc = block->Var(name); + var_desc->SetShape(shape); +} + +static void AppendOp(BlockDesc *block, const std::string &type, + VariableNameMap inputs, VariableNameMap outputs, + AttributeMap attrs) { + auto &op_info = OpInfoMap::Instance().Get(type); + if (op_info.Checker()) { + op_info.Checker()->Check(&attrs); + } + + auto *op = block->AppendOp(); + op->SetType(type); + for (auto &pair : inputs) { + op->SetInput(pair.first, pair.second); + } + + for (auto &pair : outputs) { + op->SetOutput(pair.first, pair.second); + for (auto &var_name : pair.second) { + if (!block->FindVarRecursive(var_name)) { + NewVar(block, var_name, {}); + } + } + } + + op->SetAttrMap(attrs); + op->InferVarType(block); + op->InferShape(*block); +} + +class ReferenceCountPassTestHelper { + public: + ReferenceCountPassTestHelper(const ProgramDesc &program, bool use_cuda) + : graph_(program) { + details::BuildStrategy build_strategy; + build_strategy.enable_inplace_ = false; + build_strategy.memory_optimize_ = false; + FLAGS_eager_delete_tensor_gb = -1; + + details::ExecutionStrategy exec_strategy; + exec_strategy.use_cuda_ = use_cuda; + + executor_.reset(new ParallelExecutor(CreatePlaces(1, use_cuda), {}, "", + &scope_, {}, exec_strategy, + build_strategy, &graph_)); + + auto ref_cnt_pass = + ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); + ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_); + ref_cnt_pass->Apply(&graph_); + } + + bool IsLastLivedOps(const std::string &name, + std::vector ops) const { + std::sort(ops.begin(), ops.end()); + return LastLivedOpTypes(name) == ops; + } + + std::vector LastLivedOps(const std::string &name) const { + auto &ops = last_live_ops_of_vars_[0].at(name).ops(); + std::vector ret; + for (auto *op : ops) { + ret.emplace_back(op->GetOp()); + } + return ret; + } + + private: + std::vector LastLivedOpTypes(const std::string &name) const { + auto iter = last_live_ops_of_vars_[0].find(name); + std::vector ret; + if (iter != last_live_ops_of_vars_[0].end()) { + for (auto *op : iter->second.ops()) { + ret.emplace_back(op->GetOp()->Type()); + } + } + std::sort(ret.begin(), ret.end()); + return ret; + } + + private: + ir::Graph graph_; + Scope scope_; + std::unique_ptr executor_; + + ir::MemOptVarInfoMapList mem_opt_var_infos_; + std::vector last_live_ops_of_vars_; +}; + +TEST(test_reference_count_pass, test_no_need_buffer_var_shrink) { + ProgramDesc program; + auto *block = program.MutableBlock(0); + std::vector shape{{3, 4, 5}}; + + /** + * The network is: + * + * x0 = fluid.layer.data(...) + * x1 = scale(x0, scale=1) + * x2 = scale(x1, scale=2) + * x3 = elementwise_mul(x1, x2) + * scale(x3, out=x1, scale=3) # produce a new version of x1 + * x4, x5 = elementwise_add_grad(dout=x3, x=x2, y=x1) + * x6 = elementwise_mul(x4, x5) + * x7 = elementwise_add(x5, x5) + */ + std::string x0 = "x0"; + std::string x1 = "x1"; + std::string x2 = "x2"; + std::string x3 = "x3"; + std::string x4 = "x4"; + std::string x5 = "x5"; + std::string x6 = "x6"; + std::string x7 = "x7"; + + NewVar(block, x0, shape); + AppendOp(block, "scale", {{"X", {x0}}}, {{"Out", {x1}}}, {{"scale", 1.0f}}); + AppendOp(block, "scale", {{"X", {x1}}}, {{"Out", {x2}}}, {{"scale", 2.0f}}); + AppendOp(block, "elementwise_mul", {{"X", {x1}}, {"Y", {x2}}}, + {{"Out", {x3}}}, {}); + AppendOp(block, "scale", {{"X", {x3}}}, {{"Out", {x1}}}, {{"scale", 3.0f}}); + AppendOp(block, "elementwise_add_grad", + {{GradVarName("Out"), {x3}}, {"X", {x2}}, {"Y", {x1}}}, + {{GradVarName("X"), {x4}}, {GradVarName("Y"), {x5}}}, {}); + AppendOp(block, "elementwise_mul", {{"X", {x4}}, {"Y", {x5}}}, + {{"Out", {x6}}}, {}); + AppendOp(block, "elementwise_add", {{"X", {x5}}, {"Y", {x5}}}, + {{"Out", {x7}}}, {}); + + std::vector use_cuda_list{false}; +#ifdef PADDLE_WITH_CUDA + use_cuda_list.push_back(true); +#endif + for (auto use_cuda : use_cuda_list) { + ReferenceCountPassTestHelper helper(program, use_cuda); + ASSERT_TRUE(helper.IsLastLivedOps(x0, {"scale"})); + ASSERT_EQ( + boost::get(helper.LastLivedOps(x0)[0]->Attrs().at("scale")), + 1.0f); + + ASSERT_TRUE(helper.IsLastLivedOps(x1, {"scale"})); + ASSERT_EQ( + boost::get(helper.LastLivedOps(x1)[0]->Attrs().at("scale")), + 3.0f); + + ASSERT_TRUE(helper.IsLastLivedOps(x2, {"elementwise_mul"})); + ASSERT_TRUE(helper.IsLastLivedOps(x3, {"elementwise_add_grad"})); + + ASSERT_TRUE(helper.IsLastLivedOps(x4, {"elementwise_mul"})); + ASSERT_TRUE( + helper.IsLastLivedOps(x5, {"elementwise_mul", "elementwise_add"})); + + ASSERT_TRUE(helper.IsLastLivedOps(x6, {"elementwise_mul"})); + ASSERT_TRUE(helper.IsLastLivedOps(x7, {"elementwise_add"})); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 91bc26b1fd5..b6bc41f0b58 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -267,6 +267,26 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { return graph; } + /** + * NOTE(zengjinle): If BuildStrategy.memory_optimize = None in Python, + * set BuildStrategy.memory_optimize according to whether gc is enabled. + * If gc is enabled, BuildStrategy.memory_optimize = False. + * If gc is disabled, BuildStrategy.memory_optimize = True. + * This is because gc+memory_optimize is worse than gc only. + * + * As an option, users can enable BuildStrategy.memory_optimize forcely + * by setting True, and disable it forcely by setting False. + */ + bool is_gc_enabled = (GetEagerDeletionThreshold() >= 0); + if (!build_strategy_.memory_optimize_) { + build_strategy_.memory_optimize_ = !is_gc_enabled; + } + + bool need_mem_opt = build_strategy_.enable_inplace_ || + build_strategy_.memory_optimize_.get() || is_gc_enabled; + + if (!need_mem_opt) return graph; + std::vector last_live_ops_of_vars; auto ref_cnt_pass = ir::PassRegistry::Instance().Get("reference_count_pass"); @@ -288,21 +308,6 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { "build_strategy.enable_inplace = True"; } - /** - * NOTE(zengjinle): If BuildStrategy.memory_optimize = None in Python, - * set BuildStrategy.memory_optimize according to whether gc is enabled. - * If gc is enabled, BuildStrategy.memory_optimize = False. - * If gc is disabled, BuildStrategy.memory_optimize = True. - * This is because gc+memory_optimize is worse than gc only. - * - * As an option, users can enable BuildStrategy.memory_optimize forcely - * by setting True, and disable it forcely by setting False. - */ - bool is_gc_enabled = (GetEagerDeletionThreshold() >= 0); - if (!build_strategy_.memory_optimize_) { - build_strategy_.memory_optimize_ = !is_gc_enabled; - } - if (build_strategy_.memory_optimize_.get()) { auto cross_op_memory_reuse_pass = ir::PassRegistry::Instance().Get( "buffer_shared_cross_op_memory_reuse_pass"); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 236fa84f1bb..920da2cfa66 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -84,7 +84,7 @@ if (WITH_DGC) endif() -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc) -- GitLab