未验证 提交 55ce6969 编写于 作者: C chengduo 提交者: GitHub

clean tensor array (#19930)

test=develop
上级 57606205
......@@ -74,7 +74,8 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_ ||
DropScopeOrNot()) {
DropLocalExeScopes();
}
......@@ -93,6 +94,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
}
bool ScopeBufferedSSAGraphExecutor::DropScopeOrNot() const {
for (auto &var : tensor_array_vars_) {
auto tensor_array = var->GetMutable<LoDTensorArray>();
for (LoDTensor &tensor : *tensor_array) {
if (tensor.IsInitialized()) {
return true;
}
}
tensor_array->clear();
}
return false;
}
void ScopeBufferedSSAGraphExecutor::InitVariables() {
for (auto &info : tmp_var_infos_) {
for (auto &pair : info) {
......@@ -165,6 +179,9 @@ void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
Variable *tmp_var = local_scope->Var(info.name_);
preserve_vars_[idx].emplace(tmp_var);
tmp_var_infos_[idx].emplace_back(tmp_var, info.type_);
if (info.type_ == proto::VarType::LOD_TENSOR_ARRAY) {
tensor_array_vars_.emplace_back(tmp_var);
}
}
}
}
......
......@@ -61,6 +61,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
private:
void InitVariables();
bool DropScopeOrNot() const;
size_t drop_scope_counter_{0};
ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
......@@ -71,8 +73,11 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::vector<std::pair<Variable*, proto::VarType::Type>>>
tmp_var_infos_;
std::vector<Variable*> tensor_array_vars_;
std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_;
ScopeBufferedMonitor scope_monitor_;
};
} // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册