未验证 提交 641038dc 编写于 作者: W wanghuancoder 提交者: GitHub

clear local scope every setp (#37569)

* clear local scope every setp, test=develop

* refine,test=develop

* refine, test=develop
上级 4201c94a
...@@ -88,6 +88,10 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -88,6 +88,10 @@ paddle::framework::FetchList InterpreterCore::Run(
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
if (create_local_scope_) {
ClearLoDTensorArrayInLocalScope();
}
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
...@@ -122,11 +126,28 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -122,11 +126,28 @@ paddle::framework::FetchList InterpreterCore::Run(
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
} }
if (create_local_scope_) {
ClearLoDTensorArrayInLocalScope();
}
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName); auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return std::move(*fetch_var->GetMutable<framework::FetchList>()); return std::move(*fetch_var->GetMutable<framework::FetchList>());
} }
// At the end of each step, the holder of Tensor in LoDTensorArray is null.
// Clear these Tensors and leave LoDTensorArray empty, otherwise an exception
// will occur in the next step
void InterpreterCore::ClearLoDTensorArrayInLocalScope() {
auto vars = local_scope_->LocalVars();
for (auto var : vars) {
if (var->IsType<LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
lod_tensor_arr->clear();
}
}
}
void InterpreterCore::BuildOperatorDependences() { void InterpreterCore::BuildOperatorDependences() {
// analysis the dependences between ops, set the dependecy_count_ and Call // analysis the dependences between ops, set the dependecy_count_ and Call
// Schedule // Schedule
...@@ -609,6 +630,10 @@ interpreter::CostInfo InterpreterCore::DryRun( ...@@ -609,6 +630,10 @@ interpreter::CostInfo InterpreterCore::DryRun(
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
} }
if (create_local_scope_) {
ClearLoDTensorArrayInLocalScope();
}
return cost_info; return cost_info;
} }
......
...@@ -86,6 +86,8 @@ class InterpreterCore { ...@@ -86,6 +86,8 @@ class InterpreterCore {
void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names); void SetFeedVarsInplaceSkip(const std::vector<std::string>& feed_names);
void ClearLoDTensorArrayInLocalScope();
bool is_build_; bool is_build_;
const platform::Place& place_; const platform::Place& place_;
......
...@@ -79,7 +79,6 @@ void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var, ...@@ -79,7 +79,6 @@ void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var,
for (auto& t : *tensor_arr) { for (auto& t : *tensor_arr) {
Add(t.MoveMemoryHolder(), event, ctx); Add(t.MoveMemoryHolder(), event, ctx);
} }
tensor_arr->clear();
} else if (var->IsType<std::vector<Scope*>>()) { } else if (var->IsType<std::vector<Scope*>>()) {
// NOTE(@xiongkun03) conditional_op / while_op will create a STEP_SCOPE // NOTE(@xiongkun03) conditional_op / while_op will create a STEP_SCOPE
// refer to executor.cc to see what old garbage collector does. // refer to executor.cc to see what old garbage collector does.
......
...@@ -411,7 +411,6 @@ void build_op_func_list(const platform::Place& place, ...@@ -411,7 +411,6 @@ void build_op_func_list(const platform::Place& place,
for (auto& t : *lod_tensor_arr) { for (auto& t : *lod_tensor_arr) {
garbages->emplace_back(t.MoveMemoryHolder()); garbages->emplace_back(t.MoveMemoryHolder());
} }
lod_tensor_arr->clear();
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Type %s of variable %s is not supported eager deletion.", "Type %s of variable %s is not supported eager deletion.",
......
...@@ -146,6 +146,18 @@ std::vector<std::string> Scope::LocalVarNames() const { ...@@ -146,6 +146,18 @@ std::vector<std::string> Scope::LocalVarNames() const {
return known_vars; return known_vars;
} }
std::vector<Variable*> Scope::LocalVars() {
std::vector<Variable*> known_vars;
{
SCOPE_VARS_READER_LOCK
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
known_vars.emplace_back(p.second.get());
}
}
return known_vars;
}
void Scope::DeleteScope(Scope* scope) const { void Scope::DeleteScope(Scope* scope) const {
{ {
SCOPE_KIDS_WRITER_LOCK SCOPE_KIDS_WRITER_LOCK
......
...@@ -134,9 +134,12 @@ class Scope : public ScopeBase { ...@@ -134,9 +134,12 @@ class Scope : public ScopeBase {
const std::list<Scope*>& kids() const { return kids_; } const std::list<Scope*>& kids() const { return kids_; }
// enumerate all the variables current contains. // enumerate all the variable names current contains.
std::vector<std::string> LocalVarNames() const; std::vector<std::string> LocalVarNames() const;
// enumerate all the variables current contains.
std::vector<Variable*> LocalVars();
// Rename variable to a new name // Rename variable to a new name
void Rename(const std::string& origin_name, void Rename(const std::string& origin_name,
const std::string& new_name) const; const std::string& new_name) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册