From 4002f3204990f9ba71b690c0a4db2ffedec0b95e Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Thu, 16 Jun 2022 14:20:49 +0800 Subject: [PATCH] Support disable GC for some vars in interpretercore (#43546) * Support disable GC for some vars in standalone executor * Setting skip_gc_vars in interprecore construction --- .../framework/new_executor/interpretercore.cc | 35 +++++++++- .../framework/new_executor/interpretercore.h | 10 +++ .../new_executor/interpretercore_util.cc | 3 +- .../new_executor/interpretercore_util.h | 1 + .../new_executor/standalone_executor.cc | 20 +++--- .../new_executor/standalone_executor_test.cc | 68 +++++++++++++++++-- 6 files changed, 116 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 11d672e8ef0..6de4b9d5e13 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -56,12 +56,15 @@ bool IsInterpretercoreFastGCEnabled() { InterpreterCore::InterpreterCore(const platform::Place& place, const BlockDesc& block, + const std::set& skip_gc_vars, VariableScope* global_scope) : place_(place), block_(block), + skip_gc_vars_(skip_gc_vars), global_scope_(global_scope), stream_analyzer_(place) { VLOG(4) << "InterpreterCore(): " << this << " on " << place_; + is_build_ = false; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -173,7 +176,8 @@ paddle::framework::FetchList InterpreterCore::Run( create_local_scope_); std::vector op_func_nodes; paddle::framework::interpreter::build_op_func_list( - place_, block_, &op_func_nodes, global_scope_, create_local_scope_); + place_, block_, skip_gc_vars_, &op_func_nodes, global_scope_, + create_local_scope_); is_build_ = true; SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph @@ -309,6 +313,15 @@ void InterpreterCore::Convert( } } + // clear the last_live_ops list for all vars in skip_gc_vars + for (const std::string& skip_gc_var : skip_gc_vars_) { + int var_id = global_scope_->GetIdByName(skip_gc_var); + if (var_id != -1) { + last_live_ops_[var_id].clear(); + VLOG(8) << "Skip gc for var: " << skip_gc_var; + } + } + // shrink, find the downstream op that has no other op in the // downstream list happens before it // For example, @@ -934,7 +947,8 @@ void InterpreterCore::Prepare( FeedInput(); std::vector op_func_nodes; paddle::framework::interpreter::build_op_func_list( - place_, block_, &op_func_nodes, global_scope_, create_local_scope_); + place_, block_, skip_gc_vars_, &op_func_nodes, global_scope_, + create_local_scope_); is_build_ = true; SetFeedVarsInplaceSkip(feed_names); // convert vec func_list to graph @@ -986,5 +1000,22 @@ void InterpreterCore::SetFeedVarsInplaceSkip( } } +std::shared_ptr CreateInterpreterCore( + const platform::Place& place, const ProgramDesc& prog, + VariableScope* global_scope, const std::vector& fetch_names, + const std::set& skip_gc_vars) { + std::shared_ptr core = nullptr; + // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy + // a new program. + auto new_prog = std::make_shared(prog); + auto* block = new_prog->MutableBlock(0); + interpreter::add_fetch(fetch_names, block); + + core = std::make_shared(place, *block, skip_gc_vars, + global_scope); + core->SetCopyProgram(new_prog); + return core; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 3af0ddb675a..e0e9c40d364 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -39,6 +39,7 @@ using AtomicVectorSizeT = std::vector>>; class InterpreterCore { public: InterpreterCore(const platform::Place& place, const BlockDesc& block, + const std::set& skip_gc_vars, VariableScope* global_scope); ~InterpreterCore(); @@ -99,6 +100,8 @@ class InterpreterCore { const platform::Place& place_; const BlockDesc& block_; // not owned + const std::set skip_gc_vars_; + // NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will // copy a new program and block, the copy_program_ here is used to // hold the program, otherwise block_ maybe not valid after the @@ -130,5 +133,12 @@ class InterpreterCore { bool create_local_scope_{true}; Scope* local_scope_{nullptr}; // not owned }; + +std::shared_ptr CreateInterpreterCore( + const platform::Place& place, const ProgramDesc& prog, + VariableScope* global_scope, + const std::vector& fetch_names = {}, + const std::set& skip_gc_vars = {}); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index dbea438b140..be199ac74ba 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -323,6 +323,7 @@ void deal_operator_base(const platform::Place& place, void build_op_func_list(const platform::Place& place, const framework::BlockDesc& block, + const std::set& skip_gc_vars, std::vector* vec_func_list, VariableScope* var_scope, bool use_local_scope) { Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() @@ -562,7 +563,7 @@ void build_op_func_list(const platform::Place& place, for (auto& var_name : delete_vars) { auto* var = var_scope->FindVar(var_name); - if (var == nullptr) { + if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) { continue; } diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.h b/paddle/fluid/framework/new_executor/interpretercore_util.h index 3d5b067c187..4bdeb3de8fc 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.h +++ b/paddle/fluid/framework/new_executor/interpretercore_util.h @@ -114,6 +114,7 @@ void build_variable_scope(const framework::BlockDesc& block, void build_op_func_list(const platform::Place& place, const framework::BlockDesc& block, + const std::set& skip_gc_vars, std::vector* vec_func_list, VariableScope* var_scope, bool use_local_scope = true); diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 64332d7fc90..4ac9f395246 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -55,7 +55,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, // No need to use_local_scope for startup_program, its variables are // persistable paddle::framework::interpreter::build_op_func_list( - place_, startup_prog.Block(0), &vec_func_list, &global_scope_, false); + place_, startup_prog.Block(0), {}, &vec_func_list, &global_scope_, + false); } } @@ -126,19 +127,14 @@ std::shared_ptr StandaloneExecutor::GetInterpreterCore( << place_; VLOG(3) << "add fetch op: " << add_fetch_op; std::shared_ptr core = nullptr; + if (add_fetch_op) { - // NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy - // a - // new program. - auto new_prog = std::make_shared(main_prog_); - auto* block = new_prog->MutableBlock(0); - interpreter::add_fetch(fetch_names, block); - - core = std::make_shared(place_, *block, &global_scope_); - core->SetCopyProgram(new_prog); + core = CreateInterpreterCore(place_, main_prog_, &global_scope_, + fetch_names); } else { - core = std::make_shared(place_, main_prog_.Block(0), - &global_scope_); + core = std::make_shared( + place_, main_prog_.Block(0), /*skip_gc_vars=*/std::set(), + &global_scope_); } interpretercores_.emplace(oss.str(), core); return core; diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 60d59899549..84c913bf47d 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -145,9 +145,7 @@ TEST(StandaloneExecutor, run) { exec.Run({}, {}, {}); auto start = std::chrono::steady_clock::now(); - // ProfilerStart("new_executor.prof"); - - for (size_t i = 0; i < 2320; ++i) { + for (size_t i = 0; i < 10; ++i) { if (i % 200 == 0) { std::cout << i << std::endl; } @@ -155,13 +153,71 @@ TEST(StandaloneExecutor, run) { exec.Run({}, {}, {}); } - // ProfilerStop(); - auto end = std::chrono::steady_clock::now(); std::chrono::duration diff = end - start; std::cout << "time cost " << diff.count() << std::endl; - // ASSERT_LT(diff.count(), 30); +} + +TEST(StandaloneExecutor, skip_gc_vars) { + FLAGS_eager_delete_tensor_gb = 0; + + int64_t batch_size = 20; + + auto place = platform::CUDAPlace(0); + auto startup_prog = load_from_file("lm_startup_program"); + auto main_prog = load_from_file("lm_main_program"); + + auto& global_block = main_prog.Block(0); + + auto& op1 = global_block.AllOps()[1]; + auto shape1 = BOOST_GET_CONST(std::vector, op1->GetAttr("shape")); + shape1[0] = batch_size * 20; + op1->SetAttr("shape", shape1); + + auto& op2 = global_block.AllOps()[2]; + auto shape2 = BOOST_GET_CONST(std::vector, op2->GetAttr("shape")); + shape2[0] = batch_size; + op2->SetAttr("shape", shape2); + + auto& op3 = global_block.AllOps()[3]; + auto shape3 = BOOST_GET_CONST(std::vector, op3->GetAttr("shape")); + shape3[0] = batch_size; + op3->SetAttr("shape", shape3); + + Scope scope; + + VariableScope startup_scope(&scope); + std::shared_ptr startup_core = + CreateInterpreterCore(place, startup_prog, &startup_scope); + startup_core->Run({}, {}); + + std::set skip_gc_vars = {"uniform_0.tmp_0", "transpose_0.tmp_0", + "embedding_0.tmp_0", "slice_0.tmp_0", + "split_1.tmp_2"}; + std::set gc_vars = {"uniform_1.tmp_0", "matmul_0.tmp_0", + "split_0.tmp_0", "elementwise_add_0.tmp_0", + "tmp_0"}; + auto check_gc_result = [](VariableScope& scope, std::set& vars, + bool is_skip_gc) { + for (const std::string& var_name : vars) { + ASSERT_EQ( + scope.FindVar(var_name)->GetMutable()->IsInitialized(), + is_skip_gc); + } + }; + + VariableScope main_scope(&scope); + std::shared_ptr main_core = + CreateInterpreterCore(place, main_prog, &main_scope, {}, skip_gc_vars); + + main_core->Run({}, {}); + check_gc_result(main_scope, skip_gc_vars, true); + check_gc_result(main_scope, gc_vars, false); + + main_core->Run({}, {}); + check_gc_result(main_scope, skip_gc_vars, true); + check_gc_result(main_scope, gc_vars, false); } } // namespace framework -- GitLab