未验证 提交 4002f320 编写于 作者: R Ruibiao Chen 提交者: GitHub

Support disable GC for some vars in interpretercore (#43546)

* Support disable GC for some vars in standalone executor

* Setting skip_gc_vars in interprecore construction
上级 767efaca
...@@ -56,12 +56,15 @@ bool IsInterpretercoreFastGCEnabled() { ...@@ -56,12 +56,15 @@ bool IsInterpretercoreFastGCEnabled() {
InterpreterCore::InterpreterCore(const platform::Place& place, InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block, const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
VariableScope* global_scope) VariableScope* global_scope)
: place_(place), : place_(place),
block_(block), block_(block),
skip_gc_vars_(skip_gc_vars),
global_scope_(global_scope), global_scope_(global_scope),
stream_analyzer_(place) { stream_analyzer_(place) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_; VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
is_build_ = false; is_build_ = false;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -173,7 +176,8 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -173,7 +176,8 @@ paddle::framework::FetchList InterpreterCore::Run(
create_local_scope_); create_local_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_, create_local_scope_); place_, block_, skip_gc_vars_, &op_func_nodes, global_scope_,
create_local_scope_);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
...@@ -309,6 +313,15 @@ void InterpreterCore::Convert( ...@@ -309,6 +313,15 @@ void InterpreterCore::Convert(
} }
} }
// clear the last_live_ops list for all vars in skip_gc_vars
for (const std::string& skip_gc_var : skip_gc_vars_) {
int var_id = global_scope_->GetIdByName(skip_gc_var);
if (var_id != -1) {
last_live_ops_[var_id].clear();
VLOG(8) << "Skip gc for var: " << skip_gc_var;
}
}
// shrink, find the downstream op that has no other op in the // shrink, find the downstream op that has no other op in the
// downstream list happens before it // downstream list happens before it
// For example, // For example,
...@@ -934,7 +947,8 @@ void InterpreterCore::Prepare( ...@@ -934,7 +947,8 @@ void InterpreterCore::Prepare(
FeedInput(); FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_, create_local_scope_); place_, block_, skip_gc_vars_, &op_func_nodes, global_scope_,
create_local_scope_);
is_build_ = true; is_build_ = true;
SetFeedVarsInplaceSkip(feed_names); SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph // convert vec func_list to graph
...@@ -986,5 +1000,22 @@ void InterpreterCore::SetFeedVarsInplaceSkip( ...@@ -986,5 +1000,22 @@ void InterpreterCore::SetFeedVarsInplaceSkip(
} }
} }
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place, const ProgramDesc& prog,
VariableScope* global_scope, const std::vector<std::string>& fetch_names,
const std::set<std::string>& skip_gc_vars) {
std::shared_ptr<InterpreterCore> core = nullptr;
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy
// a new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(prog);
auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place, *block, skip_gc_vars,
global_scope);
core->SetCopyProgram(new_prog);
return core;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -39,6 +39,7 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>; ...@@ -39,6 +39,7 @@ using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
class InterpreterCore { class InterpreterCore {
public: public:
InterpreterCore(const platform::Place& place, const BlockDesc& block, InterpreterCore(const platform::Place& place, const BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
VariableScope* global_scope); VariableScope* global_scope);
~InterpreterCore(); ~InterpreterCore();
...@@ -99,6 +100,8 @@ class InterpreterCore { ...@@ -99,6 +100,8 @@ class InterpreterCore {
const platform::Place& place_; const platform::Place& place_;
const BlockDesc& block_; // not owned const BlockDesc& block_; // not owned
const std::set<std::string> skip_gc_vars_;
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will // NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to // copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the // hold the program, otherwise block_ maybe not valid after the
...@@ -130,5 +133,12 @@ class InterpreterCore { ...@@ -130,5 +133,12 @@ class InterpreterCore {
bool create_local_scope_{true}; bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned Scope* local_scope_{nullptr}; // not owned
}; };
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
const platform::Place& place, const ProgramDesc& prog,
VariableScope* global_scope,
const std::vector<std::string>& fetch_names = {},
const std::set<std::string>& skip_gc_vars = {});
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -323,6 +323,7 @@ void deal_operator_base(const platform::Place& place, ...@@ -323,6 +323,7 @@ void deal_operator_base(const platform::Place& place,
void build_op_func_list(const platform::Place& place, void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope, bool use_local_scope) { VariableScope* var_scope, bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
...@@ -562,7 +563,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -562,7 +563,7 @@ void build_op_func_list(const platform::Place& place,
for (auto& var_name : delete_vars) { for (auto& var_name : delete_vars) {
auto* var = var_scope->FindVar(var_name); auto* var = var_scope->FindVar(var_name);
if (var == nullptr) { if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) {
continue; continue;
} }
......
...@@ -114,6 +114,7 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -114,6 +114,7 @@ void build_variable_scope(const framework::BlockDesc& block,
void build_op_func_list(const platform::Place& place, void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block, const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope, bool use_local_scope = true); VariableScope* var_scope, bool use_local_scope = true);
......
...@@ -55,7 +55,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -55,7 +55,8 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
// No need to use_local_scope for startup_program, its variables are // No need to use_local_scope for startup_program, its variables are
// persistable // persistable
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, startup_prog.Block(0), &vec_func_list, &global_scope_, false); place_, startup_prog.Block(0), {}, &vec_func_list, &global_scope_,
false);
} }
} }
...@@ -126,18 +127,13 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore( ...@@ -126,18 +127,13 @@ std::shared_ptr<InterpreterCore> StandaloneExecutor::GetInterpreterCore(
<< place_; << place_;
VLOG(3) << "add fetch op: " << add_fetch_op; VLOG(3) << "add fetch op: " << add_fetch_op;
std::shared_ptr<InterpreterCore> core = nullptr; std::shared_ptr<InterpreterCore> core = nullptr;
if (add_fetch_op) { if (add_fetch_op) {
// NOTE(Aurelius84): `add_fetch` will modify BlockDesc, so we should copy core = CreateInterpreterCore(place_, main_prog_, &global_scope_,
// a fetch_names);
// new program.
auto new_prog = std::make_shared<framework::ProgramDesc>(main_prog_);
auto* block = new_prog->MutableBlock(0);
interpreter::add_fetch(fetch_names, block);
core = std::make_shared<InterpreterCore>(place_, *block, &global_scope_);
core->SetCopyProgram(new_prog);
} else { } else {
core = std::make_shared<InterpreterCore>(place_, main_prog_.Block(0), core = std::make_shared<InterpreterCore>(
place_, main_prog_.Block(0), /*skip_gc_vars=*/std::set<std::string>(),
&global_scope_); &global_scope_);
} }
interpretercores_.emplace(oss.str(), core); interpretercores_.emplace(oss.str(), core);
......
...@@ -145,9 +145,7 @@ TEST(StandaloneExecutor, run) { ...@@ -145,9 +145,7 @@ TEST(StandaloneExecutor, run) {
exec.Run({}, {}, {}); exec.Run({}, {}, {});
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
// ProfilerStart("new_executor.prof"); for (size_t i = 0; i < 10; ++i) {
for (size_t i = 0; i < 2320; ++i) {
if (i % 200 == 0) { if (i % 200 == 0) {
std::cout << i << std::endl; std::cout << i << std::endl;
} }
...@@ -155,13 +153,71 @@ TEST(StandaloneExecutor, run) { ...@@ -155,13 +153,71 @@ TEST(StandaloneExecutor, run) {
exec.Run({}, {}, {}); exec.Run({}, {}, {});
} }
// ProfilerStop();
auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> diff = end - start; std::chrono::duration<double> diff = end - start;
std::cout << "time cost " << diff.count() << std::endl; std::cout << "time cost " << diff.count() << std::endl;
// ASSERT_LT(diff.count(), 30); }
TEST(StandaloneExecutor, skip_gc_vars) {
FLAGS_eager_delete_tensor_gb = 0;
int64_t batch_size = 20;
auto place = platform::CUDAPlace(0);
auto startup_prog = load_from_file("lm_startup_program");
auto main_prog = load_from_file("lm_main_program");
auto& global_block = main_prog.Block(0);
auto& op1 = global_block.AllOps()[1];
auto shape1 = BOOST_GET_CONST(std::vector<int64_t>, op1->GetAttr("shape"));
shape1[0] = batch_size * 20;
op1->SetAttr("shape", shape1);
auto& op2 = global_block.AllOps()[2];
auto shape2 = BOOST_GET_CONST(std::vector<int64_t>, op2->GetAttr("shape"));
shape2[0] = batch_size;
op2->SetAttr("shape", shape2);
auto& op3 = global_block.AllOps()[3];
auto shape3 = BOOST_GET_CONST(std::vector<int64_t>, op3->GetAttr("shape"));
shape3[0] = batch_size;
op3->SetAttr("shape", shape3);
Scope scope;
VariableScope startup_scope(&scope);
std::shared_ptr<InterpreterCore> startup_core =
CreateInterpreterCore(place, startup_prog, &startup_scope);
startup_core->Run({}, {});
std::set<std::string> skip_gc_vars = {"uniform_0.tmp_0", "transpose_0.tmp_0",
"embedding_0.tmp_0", "slice_0.tmp_0",
"split_1.tmp_2"};
std::set<std::string> gc_vars = {"uniform_1.tmp_0", "matmul_0.tmp_0",
"split_0.tmp_0", "elementwise_add_0.tmp_0",
"tmp_0"};
auto check_gc_result = [](VariableScope& scope, std::set<std::string>& vars,
bool is_skip_gc) {
for (const std::string& var_name : vars) {
ASSERT_EQ(
scope.FindVar(var_name)->GetMutable<LoDTensor>()->IsInitialized(),
is_skip_gc);
}
};
VariableScope main_scope(&scope);
std::shared_ptr<InterpreterCore> main_core =
CreateInterpreterCore(place, main_prog, &main_scope, {}, skip_gc_vars);
main_core->Run({}, {});
check_gc_result(main_scope, skip_gc_vars, true);
check_gc_result(main_scope, gc_vars, false);
main_core->Run({}, {});
check_gc_result(main_scope, skip_gc_vars, true);
check_gc_result(main_scope, gc_vars, false);
} }
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册