提交 387bac46 编写于 作者: S sneaxiy

refine code

test=develop
上级 d0c8b9b9
......@@ -31,10 +31,11 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
const auto &vars = graph->Get<GraphVars>(kGraphVars);
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kCurReferenceCount);
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
auto &gcs = Get<GarbageCollectorList>(kGarbageCollector);
auto &gcs = Get<GarbageCollectorMap>(kGarbageCollector);
const auto &places = Get<std::vector<platform::Place>>(kAllPlaces);
ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size());
......@@ -58,7 +59,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(),
std::move(var_names), gcs[op->GetScopeIdx()].get(),
std::move(var_names), gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if(
......@@ -90,6 +91,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
REGISTER_PASS(eager_deletion_pass,
paddle::framework::details::EagerDeletionPass)
.RequirePassAttr(paddle::framework::details::kCurReferenceCount)
.RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::details::kAllPlaces)
.RequirePassAttr(paddle::framework::details::kGarbageCollector);
......@@ -23,6 +23,8 @@ namespace details {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
preceding_ops_.clear();
pending_ops_.clear();
for (auto &op : ops) {
preceding_ops_[op];
pending_ops_[op];
......
......@@ -29,22 +29,22 @@ namespace paddle {
namespace framework {
namespace details {
class OpConnectionDetector {
class OpRelationDetector {
public:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
explicit OpConnectionDetector(const std::vector<OpHandleBase *> &all_ops)
explicit OpRelationDetector(const std::vector<OpHandleBase *> &all_ops)
: graph_(all_ops) {}
template <typename OpSet>
OpSet MaxNoDepOps(const OpSet &op_set) {
if (op_set.size() <= 1) return op_set;
OpSet MaxNoDepOps(const OpSet &op_set) const {
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be or derived of OpHandleBase");
"Key type of OpSet must be OpHandleBase, or derived of OpHandleBase");
if (op_set.size() <= 1) return op_set;
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
OpSet ret;
auto rels = GetRelations(ops);
......@@ -59,7 +59,7 @@ class OpConnectionDetector {
private:
std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> ops) {
const std::vector<OpHandleBase *> ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
......@@ -144,7 +144,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size());
ref_cnts = std::vector<ReferenceCountMap>(vars.size());
OpConnectionDetector detector(ir::FilterByNodeWrapper<OpHandleBase>(*graph));
OpRelationDetector detector(ir::FilterByNodeWrapper<OpHandleBase>(*graph));
for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) {
......
......@@ -15,6 +15,7 @@
#pragma once
#include <atomic>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -33,12 +34,13 @@ using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<size_t>>;
using GarbageCollectorList =
std::vector<std::unique_ptr<GarbageCollector<Tensor>>>;
using GarbageCollectorMap =
std::map<platform::Place, std::unique_ptr<GarbageCollector<Tensor>>>;
const char kGlobalReferenceCount[] = "reference_count";
const char kCurReferenceCount[] = "current_reference_count";
const char kGlobalReferenceCount[] = "global_reference_count";
const char kRuntimeReferenceCount[] = "runtime_reference_count";
const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places";
using LastLiveOpsOfVars =
std::unordered_map<std::string, std::unordered_set<ComputationOpHandle*>>;
......
......@@ -32,15 +32,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
var_infos_(std::move(var_infos)),
places_(std::move(places)) {
if (Graph().Has(details::kGarbageCollector)) {
gc_ = &(Graph().Get<GarbageCollectorList>(details::kGarbageCollector));
gc_ = &(Graph().Get<GarbageCollectorMap>(details::kGarbageCollector));
}
}
void ScopeBufferedSSAGraphExecutor::WaitAllGarbageCollectors() {
if (gc_) {
for (auto &gc : *gc_) {
gc->Wait();
gc->Reset();
for (auto &gc_pair : *gc_) {
gc_pair.second->Wait();
gc_pair.second->Reset();
}
}
}
......
......@@ -60,7 +60,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_;
GarbageCollectorList* gc_{nullptr};
GarbageCollectorMap* gc_{nullptr};
};
} // namespace details
} // namespace framework
......
......@@ -56,13 +56,7 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
type != proto::VarType::LOD_TENSOR_ARRAY) {
continue;
}
auto it = ref_cnts.find(name);
if (it != ref_cnts.end()) {
++it->second;
} else {
ref_cnts[name] = 1;
}
++ref_cnts[name];
}
}
};
......@@ -79,8 +73,8 @@ ExecutorPrepareContext::ExecutorPrepareContext(
const std::vector<std::string>& skip_ref_cnt_vars)
: prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) {
ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
skip_ref_cnt_vars);
global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
skip_ref_cnt_vars);
}
}
......@@ -443,7 +437,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
&(ctx->cur_ref_cnts_));
&(ctx->runtime_ref_cnts_));
}
}
......
......@@ -34,14 +34,14 @@ struct ExecutorPrepareContext {
~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
void ResetReferenceCount() { runtime_ref_cnts_ = global_ref_cnts_; }
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, size_t> ref_cnts_;
std::unordered_map<std::string, size_t> cur_ref_cnts_;
std::unordered_map<std::string, size_t> global_ref_cnts_;
std::unordered_map<std::string, size_t> runtime_ref_cnts_;
};
class Executor {
......
......@@ -51,11 +51,22 @@ class ParallelExecutorPrivate {
}
}
void ResetRuntimeReferenceCount() {
for (size_t i = 0; i < rt_ref_cnts_.size(); ++i) {
for (auto &pair : rt_ref_cnts_[i]) {
rt_cur_ref_cnts_[i][pair.first] = pair.second;
std::unique_ptr<ir::Graph> PrepareGCAndRefCnts(
std::unique_ptr<ir::Graph> graph, size_t max_memory_size);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
void ResetRuntimeReferenceCount(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) {
for (auto &pair : global_ref_cnts_[i]) {
runtime_ref_cnts_[i][pair.first] = pair.second;
}
for (auto &fetch_name : fetch_tensors) {
runtime_ref_cnts_[i].erase(fetch_name);
}
runtime_ref_cnts_[i].erase(fetched_var_name);
}
}
......@@ -71,14 +82,75 @@ class ParallelExecutorPrivate {
bool use_cuda_;
bool use_all_reduce_;
// rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, rt_cur_ref_cnts_ is reset to ref_cnts_
std::vector<details::ReferenceCountMap> rt_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> rt_cur_ref_cnts_;
details::GarbageCollectorList gcs_;
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<details::ReferenceCountMap> global_ref_cnts_;
std::vector<details::AtomicReferenceCountMap> runtime_ref_cnts_;
details::GarbageCollectorMap gcs_;
};
std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
std::unique_ptr<ir::Graph> graph, size_t max_memory_size) {
for (size_t i = 0; i < places_.size(); ++i) {
auto &place = places_[i];
if (gcs_.count(place) > 0) {
continue;
}
#ifdef PADDLE_WITH_CUDA
GarbageCollector<Tensor> *gc = nullptr;
if (platform::is_gpu_place(place)) {
if (IsFastEagerDeletionModeEnabled()) {
gc = new UnsafeFastGPUGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size);
} else {
gc = new StreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size);
}
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
} else if (platform::is_cpu_place(place)) {
#endif
gc = new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place), max_memory_size);
VLOG(10) << "Created GarbageCollector at " << place;
#ifdef PADDLE_WITH_CUDA
}
#endif
if (gc) {
gcs_[place] = std::unique_ptr<GarbageCollector<Tensor>>(gc);
}
}
if (gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&global_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(std::move(graph));
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount,
&runtime_ref_cnts_);
eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied";
graph->SetNotOwned(details::kGarbageCollector, &gcs_);
}
return graph;
}
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
return member_->local_scopes_;
}
......@@ -153,54 +225,8 @@ ParallelExecutor::ParallelExecutor(
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
size_t place_num = member_->places_.size();
for (size_t i = 0; i < place_num; ++i) {
auto &place = member_->places_[i];
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
if (IsFastEagerDeletionModeEnabled()) {
member_->gcs_.emplace_back(new UnsafeFastGPUGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size));
} else {
member_->gcs_.emplace_back(new StreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size));
}
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
} else if (platform::is_cpu_place(place)) {
#endif
member_->gcs_.emplace_back(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place), max_memory_size));
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
#ifdef PADDLE_WITH_CUDA
}
#endif
}
}
if (!member_->gcs_.empty()) {
std::vector<details::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount,
&(member_->rt_ref_cnts_));
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(std::move(graph));
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(details::kCurReferenceCount,
&(member_->rt_cur_ref_cnts_));
eager_deletion_pass->SetNotOwned(details::kGarbageCollector,
&(member_->gcs_));
eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied";
graph->SetNotOwned(details::kGarbageCollector, &(member_->gcs_));
graph = member_->PrepareGCAndRefCnts(std::move(graph),
static_cast<size_t>(max_memory_size));
}
// Step 3. Create vars in each scope. Passes may also create new vars.
......@@ -316,15 +342,8 @@ void ParallelExecutor::BCastParamsToDevices(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
platform::RecordBlock b(0);
if (!member_->gcs_.empty()) {
member_->ResetRuntimeReferenceCount();
size_t n = member_->rt_ref_cnts_.size();
for (size_t i = 0; i < n; ++i) {
for (auto &fetch_name : fetch_tensors) {
member_->rt_cur_ref_cnts_[i].erase(fetch_name);
}
member_->rt_cur_ref_cnts_[i].erase(fetched_var_name);
}
if (member_->HasGarbageCollectors()) {
member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name);
}
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
......
......@@ -74,9 +74,7 @@ class WhileOp : public framework::OperatorBase {
bool is_test = Attr<bool>("is_test");
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
if (framework::GetEagerDeletionThreshold() >= 0) {
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
}
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
while (cond.data<bool>()[0]) {
......@@ -144,9 +142,7 @@ class WhileGradOp : public framework::OperatorBase {
auto *program = block->Program();
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
if (framework::GetEagerDeletionThreshold() >= 0) {
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
}
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
auto *step_scopes =
......@@ -369,7 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed.
while_grad->SetAttr("original_output_grad", output_grads_list);
/* The followi_ng codes are used in eager deletion mode */
/* The following codes are used in eager deletion mode */
std::unordered_set<std::string> bwd_skip_vars;
if (framework::GetEagerDeletionThreshold() >= 0) {
std::unordered_set<std::string> fwd_skip_vars;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册