提交 387bac46 编写于 作者: S sneaxiy

refine code

test=develop
上级 d0c8b9b9
...@@ -31,10 +31,11 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -31,10 +31,11 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
const auto &vars = graph->Get<GraphVars>(kGraphVars); const auto &vars = graph->Get<GraphVars>(kGraphVars);
auto &ref_cnts = auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kCurReferenceCount); Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
const auto &last_live_ops = const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
auto &gcs = Get<GarbageCollectorList>(kGarbageCollector); auto &gcs = Get<GarbageCollectorMap>(kGarbageCollector);
const auto &places = Get<std::vector<platform::Place>>(kAllPlaces);
ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size()); ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size());
...@@ -58,7 +59,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -58,7 +59,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation); graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle( auto *eager_deletion_op = new EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), eager_deletion_node, op->GetScope(), op->GetPlace(),
std::move(var_names), gcs[op->GetScopeIdx()].get(), std::move(var_names), gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()])); &(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if( auto it = std::find_if(
...@@ -90,6 +91,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -90,6 +91,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
REGISTER_PASS(eager_deletion_pass, REGISTER_PASS(eager_deletion_pass,
paddle::framework::details::EagerDeletionPass) paddle::framework::details::EagerDeletionPass)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount) .RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector); .RequirePassAttr(paddle::framework::details::kGarbageCollector);
...@@ -23,6 +23,8 @@ namespace details { ...@@ -23,6 +23,8 @@ namespace details {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); } OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) { void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
preceding_ops_.clear();
pending_ops_.clear();
for (auto &op : ops) { for (auto &op : ops) {
preceding_ops_[op]; preceding_ops_[op];
pending_ops_[op]; pending_ops_[op];
......
...@@ -29,22 +29,22 @@ namespace paddle { ...@@ -29,22 +29,22 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class OpConnectionDetector { class OpRelationDetector {
public: public:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 }; enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
explicit OpConnectionDetector(const std::vector<OpHandleBase *> &all_ops) explicit OpRelationDetector(const std::vector<OpHandleBase *> &all_ops)
: graph_(all_ops) {} : graph_(all_ops) {}
template <typename OpSet> template <typename OpSet>
OpSet MaxNoDepOps(const OpSet &op_set) { OpSet MaxNoDepOps(const OpSet &op_set) const {
if (op_set.size() <= 1) return op_set;
using KeyType = typename OpSet::key_type; using KeyType = typename OpSet::key_type;
static_assert( static_assert(
std::is_base_of<OpHandleBase, std::is_base_of<OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value, typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be or derived of OpHandleBase"); "Key type of OpSet must be OpHandleBase, or derived of OpHandleBase");
if (op_set.size() <= 1) return op_set;
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end()); std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
OpSet ret; OpSet ret;
auto rels = GetRelations(ops); auto rels = GetRelations(ops);
...@@ -59,7 +59,7 @@ class OpConnectionDetector { ...@@ -59,7 +59,7 @@ class OpConnectionDetector {
private: private:
std::vector<std::vector<RelationShip>> GetRelations( std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> ops) { const std::vector<OpHandleBase *> ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx; std::unordered_map<OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph"); PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
...@@ -144,7 +144,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -144,7 +144,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size()); last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size());
ref_cnts = std::vector<ReferenceCountMap>(vars.size()); ref_cnts = std::vector<ReferenceCountMap>(vars.size());
OpConnectionDetector detector(ir::FilterByNodeWrapper<OpHandleBase>(*graph)); OpRelationDetector detector(ir::FilterByNodeWrapper<OpHandleBase>(*graph));
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) { for (auto &name_var_pair : vars[i]) {
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <atomic> #include <atomic>
#include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -33,12 +34,13 @@ using ReferenceCountMap = std::unordered_map<std::string, size_t>; ...@@ -33,12 +34,13 @@ using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using AtomicReferenceCountMap = using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<size_t>>; std::unordered_map<std::string, std::atomic<size_t>>;
using GarbageCollectorList = using GarbageCollectorMap =
std::vector<std::unique_ptr<GarbageCollector<Tensor>>>; std::map<platform::Place, std::unique_ptr<GarbageCollector<Tensor>>>;
const char kGlobalReferenceCount[] = "reference_count"; const char kGlobalReferenceCount[] = "global_reference_count";
const char kCurReferenceCount[] = "current_reference_count"; const char kRuntimeReferenceCount[] = "runtime_reference_count";
const char kGarbageCollector[] = "garbage_collector"; const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars = using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>; std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>;
......
...@@ -32,15 +32,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ...@@ -32,15 +32,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
var_infos_(std::move(var_infos)), var_infos_(std::move(var_infos)),
places_(std::move(places)) { places_(std::move(places)) {
if (Graph().Has(details::kGarbageCollector)) { if (Graph().Has(details::kGarbageCollector)) {
gc_ = &(Graph().Get<GarbageCollectorList>(details::kGarbageCollector)); gc_ = &(Graph().Get<GarbageCollectorMap>(details::kGarbageCollector));
} }
} }
void ScopeBufferedSSAGraphExecutor::WaitAllGarbageCollectors() { void ScopeBufferedSSAGraphExecutor::WaitAllGarbageCollectors() {
if (gc_) { if (gc_) {
for (auto &gc : *gc_) { for (auto &gc_pair : *gc_) {
gc->Wait(); gc_pair.second->Wait();
gc->Reset(); gc_pair.second->Reset();
} }
} }
} }
......
...@@ -60,7 +60,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -60,7 +60,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<VariableInfo> var_infos_; std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
GarbageCollectorList* gc_{nullptr}; GarbageCollectorMap* gc_{nullptr};
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -56,13 +56,7 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts( ...@@ -56,13 +56,7 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
type != proto::VarType::LOD_TENSOR_ARRAY) { type != proto::VarType::LOD_TENSOR_ARRAY) {
continue; continue;
} }
++ref_cnts[name];
auto it = ref_cnts.find(name);
if (it != ref_cnts.end()) {
++it->second;
} else {
ref_cnts[name] = 1;
}
} }
} }
}; };
...@@ -79,8 +73,8 @@ ExecutorPrepareContext::ExecutorPrepareContext( ...@@ -79,8 +73,8 @@ ExecutorPrepareContext::ExecutorPrepareContext(
const std::vector<std::string>& skip_ref_cnt_vars) const std::vector<std::string>& skip_ref_cnt_vars)
: prog_(prog), block_id_(block_id) { : prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) { if (GetEagerDeletionThreshold() >= 0) {
ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id), global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
skip_ref_cnt_vars); skip_ref_cnt_vars);
} }
} }
...@@ -443,7 +437,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -443,7 +437,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
if (gc) { if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), gc.get(), DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
&(ctx->cur_ref_cnts_)); &(ctx->runtime_ref_cnts_));
} }
} }
......
...@@ -34,14 +34,14 @@ struct ExecutorPrepareContext { ...@@ -34,14 +34,14 @@ struct ExecutorPrepareContext {
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; } void ResetReferenceCount() { runtime_ref_cnts_ = global_ref_cnts_; }
const framework::ProgramDesc& prog_; const framework::ProgramDesc& prog_;
size_t block_id_; size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, size_t> ref_cnts_; std::unordered_map<std::string, size_t> global_ref_cnts_;
std::unordered_map<std::string, size_t> cur_ref_cnts_; std::unordered_map<std::string, size_t> runtime_ref_cnts_;
}; };
class Executor { class Executor {
......
...@@ -51,11 +51,22 @@ class ParallelExecutorPrivate { ...@@ -51,11 +51,22 @@ class ParallelExecutorPrivate {
} }
} }
void ResetRuntimeReferenceCount() { std::unique_ptr<ir::Graph> PrepareGCAndRefCnts(
for (size_t i = 0; i < rt_ref_cnts_.size(); ++i) { std::unique_ptr<ir::Graph> graph, size_t max_memory_size);
for (auto &pair : rt_ref_cnts_[i]) {
rt_cur_ref_cnts_[i][pair.first] = pair.second; inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
void ResetRuntimeReferenceCount(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) {
for (auto &pair : global_ref_cnts_[i]) {
runtime_ref_cnts_[i][pair.first] = pair.second;
}
for (auto &fetch_name : fetch_tensors) {
runtime_ref_cnts_[i].erase(fetch_name);
} }
runtime_ref_cnts_[i].erase(fetched_var_name);
} }
} }
...@@ -71,14 +82,75 @@ class ParallelExecutorPrivate { ...@@ -71,14 +82,75 @@ class ParallelExecutorPrivate {
bool use_cuda_; bool use_cuda_;
bool use_all_reduce_; bool use_all_reduce_;
// rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then // global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// keeps unchanged // then keeps unchanged
// Before each iteration, rt_cur_ref_cnts_ is reset to ref_cnts_ // Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<details::ReferenceCountMap> rt_ref_cnts_; std::vector<details::ReferenceCountMap> global_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> rt_cur_ref_cnts_; std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_;
details::GarbageCollectorList gcs_; details::GarbageCollectorMap gcs_;
}; };
std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
std::unique_ptr<ir::Graph> graph, size_t max_memory_size) {
for (size_t i = 0; i < places_.size(); ++i) {
auto &place = places_[i];
if (gcs_.count(place) > 0) {
continue;
}
#ifdef PADDLE_WITH_CUDA
GarbageCollector<Tensor> *gc = nullptr;
if (platform::is_gpu_place(place)) {
if (IsFastEagerDeletionModeEnabled()) {
gc = new UnsafeFastGPUGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size);
} else {
gc = new StreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size);
}
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
} else if (platform::is_cpu_place(place)) {
#endif
gc = new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place), max_memory_size);
VLOG(10) << "Created GarbageCollector at " << place;
#ifdef PADDLE_WITH_CUDA
}
#endif
if (gc) {
gcs_[place] = std::unique_ptr<GarbageCollector<Tensor>>(gc);
}
}
if (gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&global_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(std::move(graph));
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount,
&runtime_ref_cnts_);
eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied";
graph->SetNotOwned(details::kGarbageCollector, &gcs_);
}
return graph;
}
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
return member_->local_scopes_; return member_->local_scopes_;
} }
...@@ -153,54 +225,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -153,54 +225,8 @@ ParallelExecutor::ParallelExecutor(
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
size_t place_num = member_->places_.size(); graph = member_->PrepareGCAndRefCnts(std::move(graph),
for (size_t i = 0; i < place_num; ++i) { static_cast<size_t>(max_memory_size));
auto &place = member_->places_[i];
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
if (IsFastEagerDeletionModeEnabled()) {
member_->gcs_.emplace_back(new UnsafeFastGPUGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size));
} else {
member_->gcs_.emplace_back(new StreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size));
}
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
} else if (platform::is_cpu_place(place)) {
#endif
member_->gcs_.emplace_back(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place), max_memory_size));
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
#ifdef PADDLE_WITH_CUDA
}
#endif
}
}
if (!member_->gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&(member_->rt_ref_cnts_));
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(std::move(graph));
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kCurReferenceCount,
&(member_->rt_cur_ref_cnts_));
eager_deletion_pass->SetNotOwned(details::kGarbageCollector,
&(member_->gcs_));
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied";
graph->SetNotOwned(details::kGarbageCollector, &(member_->gcs_));
} }
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
...@@ -316,15 +342,8 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -316,15 +342,8 @@ void ParallelExecutor::BCastParamsToDevices(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
platform::RecordBlock b(0); platform::RecordBlock b(0);
if (!member_->gcs_.empty()) { if (member_->HasGarbageCollectors()) {
member_->ResetRuntimeReferenceCount(); member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name);
size_t n = member_->rt_ref_cnts_.size();
for (size_t i = 0; i < n; ++i) {
for (auto &fetch_name : fetch_tensors) {
member_->rt_cur_ref_cnts_[i].erase(fetch_name);
}
member_->rt_cur_ref_cnts_[i].erase(fetched_var_name);
}
} }
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
......
...@@ -74,9 +74,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -74,9 +74,7 @@ class WhileOp : public framework::OperatorBase {
bool is_test = Attr<bool>("is_test"); bool is_test = Attr<bool>("is_test");
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars); auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
if (framework::GetEagerDeletionThreshold() >= 0) { VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
}
auto ctx = executor.Prepare(*program, block->ID(), skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
...@@ -144,9 +142,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -144,9 +142,7 @@ class WhileGradOp : public framework::OperatorBase {
auto *program = block->Program(); auto *program = block->Program();
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars); auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
if (framework::GetEagerDeletionThreshold() >= 0) { VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
}
auto ctx = executor.Prepare(*program, block->ID(), skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
auto *step_scopes = auto *step_scopes =
...@@ -369,7 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -369,7 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed. // while operator could be renamed.
while_grad->SetAttr("original_output_grad", output_grads_list); while_grad->SetAttr("original_output_grad", output_grads_list);
/* The followi_ng codes are used in eager deletion mode */ /* The following codes are used in eager deletion mode */
std::unordered_set<std::string> bwd_skip_vars; std::unordered_set<std::string> bwd_skip_vars;
if (framework::GetEagerDeletionThreshold() >= 0) { if (framework::GetEagerDeletionThreshold() >= 0) {
std::unordered_set<std::string> fwd_skip_vars; std::unordered_set<std::string> fwd_skip_vars;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册