未验证 提交 55ce6969 编写于 作者: C chengduo 提交者: GitHub

clean tensor array (#19930)

test=develop
上级 57606205
...@@ -74,7 +74,8 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -74,7 +74,8 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
} }
++drop_scope_counter_; ++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_ ||
DropScopeOrNot()) {
DropLocalExeScopes(); DropLocalExeScopes();
} }
...@@ -93,6 +94,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -93,6 +94,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
} }
} }
bool ScopeBufferedSSAGraphExecutor::DropScopeOrNot() const {
for (auto &var : tensor_array_vars_) {
auto tensor_array = var->GetMutable<LoDTensorArray>();
for (LoDTensor &tensor : *tensor_array) {
if (tensor.IsInitialized()) {
return true;
}
}
tensor_array->clear();
}
return false;
}
void ScopeBufferedSSAGraphExecutor::InitVariables() { void ScopeBufferedSSAGraphExecutor::InitVariables() {
for (auto &info : tmp_var_infos_) { for (auto &info : tmp_var_infos_) {
for (auto &pair : info) { for (auto &pair : info) {
...@@ -165,6 +179,9 @@ void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() { ...@@ -165,6 +179,9 @@ void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
Variable *tmp_var = local_scope->Var(info.name_); Variable *tmp_var = local_scope->Var(info.name_);
preserve_vars_[idx].emplace(tmp_var); preserve_vars_[idx].emplace(tmp_var);
tmp_var_infos_[idx].emplace_back(tmp_var, info.type_); tmp_var_infos_[idx].emplace_back(tmp_var, info.type_);
if (info.type_ == proto::VarType::LOD_TENSOR_ARRAY) {
tensor_array_vars_.emplace_back(tmp_var);
}
} }
} }
} }
......
...@@ -61,6 +61,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -61,6 +61,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
private: private:
void InitVariables(); void InitVariables();
bool DropScopeOrNot() const;
size_t drop_scope_counter_{0}; size_t drop_scope_counter_{0};
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_; std::unique_ptr<SSAGraphExecutor> underlying_executor_;
...@@ -71,8 +73,11 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -71,8 +73,11 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::vector<std::pair<Variable*, proto::VarType::Type>>> std::vector<std::vector<std::pair<Variable*, proto::VarType::Type>>>
tmp_var_infos_; tmp_var_infos_;
std::vector<Variable*> tensor_array_vars_;
std::vector<VariableInfo> var_infos_; std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
ScopeBufferedMonitor scope_monitor_; ScopeBufferedMonitor scope_monitor_;
}; };
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册