提交 c47c451a 编写于 作者: S sneaxiy

fix bug

上级 096673f6
...@@ -35,7 +35,7 @@ cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_e ...@@ -35,7 +35,7 @@ cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_e
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
......
...@@ -31,6 +31,8 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -31,6 +31,8 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
VLOG(10) << "Run Op" << Name();
auto run_func = [this]() { auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}; };
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -23,28 +24,32 @@ namespace details { ...@@ -23,28 +24,32 @@ namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle( EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place, ir::Node *node, const Scope *scope, const platform::Place &place,
const std::vector<std::string> &var_names, GarbageCollector<Tensor> *gc, const std::unordered_set<std::string> &var_names,
AtomicReferenceCountMap *ref_cnts) GarbageCollector<Tensor> *gc, AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) { : OpHandleBase(node),
scope_(scope),
var_names_(var_names),
gc_(gc),
ref_cnts_(ref_cnts) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>( dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
if (dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_)) { if (dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_)) {
platform::SetDeviceId(boost::get<platform::CUDAPlace>(place).device); platform::CUDADeviceGuard guard(
boost::get<platform::CUDAPlace>(place).device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
PADDLE_ENFORCE_NOT_NULL(event_);
} }
} }
#endif #endif
for (auto &name : var_names) AddVar(name);
} }
EagerDeletionOpHandle::~EagerDeletionOpHandle() { EagerDeletionOpHandle::~EagerDeletionOpHandle() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (event_) { if (event_) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
platform::SetDeviceId(gpu_place.device); platform::CUDADeviceGuard guard(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_)); PADDLE_ENFORCE(cudaEventDestroy(event_));
} }
#endif #endif
...@@ -52,10 +57,6 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -52,10 +57,6 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::AddVar(const std::string &name) {
var_names_.insert(name);
}
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<Tensor *> tensors; std::vector<Tensor *> tensors;
......
...@@ -25,13 +25,11 @@ class Scope; ...@@ -25,13 +25,11 @@ class Scope;
namespace details { namespace details {
class EagerDeletionPass;
class EagerDeletionOpHandle : public OpHandleBase { class EagerDeletionOpHandle : public OpHandleBase {
public: public:
EagerDeletionOpHandle(ir::Node *node, const Scope *scope, EagerDeletionOpHandle(ir::Node *node, const Scope *scope,
const platform::Place &place, const platform::Place &place,
const std::vector<std::string> &var_names, const std::unordered_set<std::string> &var_names,
GarbageCollector<Tensor> *gc, GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts); AtomicReferenceCountMap *ref_cnts);
...@@ -45,8 +43,6 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -45,8 +43,6 @@ class EagerDeletionOpHandle : public OpHandleBase {
private: private:
void ClearTensors(const std::vector<Tensor *> &tensors); void ClearTensors(const std::vector<Tensor *> &tensors);
void AddVar(const std::string &name);
const Scope *scope_; const Scope *scope_;
std::unordered_set<std::string> var_names_; std::unordered_set<std::string> var_names_;
GarbageCollector<Tensor> *gc_; // not own GarbageCollector<Tensor> *gc_; // not own
...@@ -55,8 +51,6 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -55,8 +51,6 @@ class EagerDeletionOpHandle : public OpHandleBase {
platform::CUDADeviceContext *dev_ctx_{nullptr}; platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr}; cudaEvent_t event_{nullptr};
#endif #endif
friend class EagerDeletionPass;
}; };
} // namespace details } // namespace details
......
...@@ -26,62 +26,61 @@ namespace paddle { ...@@ -26,62 +26,61 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
// Add leaf node to eager_deletion_node
if (out->Outputs().empty()) {
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
out->AddOutput(dummy_leaf);
}
}
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &vars = graph->Get<GraphVars>(kGraphVars); const auto &vars = graph->Get<GraphVars>(kGraphVars);
auto &ref_cnts = auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kCurReferenceCount); Get<std::vector<AtomicReferenceCountMap>>(kCurReferenceCount);
auto &last_live_ops = Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
auto &gcs = Get<GarbageCollectorList>(kGarbageCollector); auto &gcs = Get<GarbageCollectorList>(kGarbageCollector);
ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size()); ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size());
std::unordered_map<ComputationOpHandle *, EagerDeletionOpHandle *> op_map; std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
op_vars_map;
for (auto &var_ops_map : last_live_ops) { for (auto &var_ops_map : last_live_ops) {
for (auto &var_ops_pair : var_ops_map) { for (auto &var_ops_pair : var_ops_map) {
const std::string &var_name = var_ops_pair.first; const std::string &var_name = var_ops_pair.first;
for (ComputationOpHandle *op : var_ops_pair.second) { for (auto *op : var_ops_pair.second) {
auto it = op_map.find(op); op_vars_map[op].insert(var_name);
if (it != op_map.end()) {
it->second->AddVar(var_name);
} else {
auto *eager_deletion_node = graph->CreateEmptyNode(
"eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), {var_name},
gcs[op->GetScopeIdx()].get(), &(ref_cnts[op->GetScopeIdx()]));
AddDependencyBetween(op, eager_deletion_op, graph.get());
op_map[op] = eager_deletion_op;
} }
} }
} }
for (auto &pair : op_vars_map) {
auto *op = pair.first;
auto &var_names = pair.second;
auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(),
std::move(var_names), gcs[op->GetScopeIdx()].get(),
&(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != op->Outputs().end()) {
eager_deletion_op->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
op->AddOutput(dep_var);
eager_deletion_op->AddInput(dep_var);
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf);
} }
VLOG(10) << "Create " << op_map.size() << " EagerDeletionOpHandle(s)";
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
return graph; return graph;
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include <memory> #include <queue>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -34,6 +34,11 @@ class OpGraphView { ...@@ -34,6 +34,11 @@ class OpGraphView {
bool HasOp(OpHandleBase *op) const; bool HasOp(OpHandleBase *op) const;
// Use a visitor to visit all pending ops of op
// Stop when callback returns false
template <typename Callback>
bool VisitAllPendingOps(OpHandleBase *op, Callback &&callback) const;
private: private:
void Build(const std::vector<OpHandleBase *> &ops); void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const; void EnforceHasOp(OpHandleBase *op) const;
...@@ -44,6 +49,28 @@ class OpGraphView { ...@@ -44,6 +49,28 @@ class OpGraphView {
pending_ops_; pending_ops_;
}; };
template <typename Callback>
bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
Callback &&callback) const {
EnforceHasOp(op);
std::unordered_set<OpHandleBase *> visited;
std::queue<OpHandleBase *> q;
q.push(op);
do {
op = q.front();
q.pop();
for (auto &pending_op : pending_ops_.at(op)) {
if (visited.count(pending_op) == 0) {
visited.insert(pending_op);
if (!callback(pending_op)) {
return false;
}
}
}
} while (!q.empty());
return true;
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
#include <queue> #include <queue>
#include <string> #include <string>
#include <type_traits>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/details/reference_count_pass.h" #include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -27,6 +29,89 @@ namespace paddle { ...@@ -27,6 +29,89 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct OpConnectionDetector {
public:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
explicit OpConnectionDetector(const std::vector<OpHandleBase *> &all_ops)
: graph_(all_ops) {}
template <typename OpSet>
std::unordered_set<typename OpSet::key_type> MaxNoDepOps(
const OpSet &op_set) {
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
typename std::remove_pointer<KeyType>::type>::value,
"Key type of OpSet must be or derived of OpHandleBase");
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
std::unordered_set<KeyType> ret;
auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; };
for (size_t i = 0; i < rels.size(); ++i) {
if (std::all_of(rels[i].begin(), rels[i].end(), not_before)) {
ret.insert(static_cast<KeyType>(ops[i]));
}
}
return ret;
}
private:
std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> ops) {
std::unordered_map<OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
op_to_idx[ops[i]] = i;
}
PADDLE_ENFORCE(op_to_idx.size() == ops.size(), "Duplicate ops");
std::vector<std::vector<RelationShip>> ret(ops.size());
for (auto &e : ret) {
e.assign(ops.size(), kSame);
}
size_t found_num = ops.size();
size_t total_num = ops.size() * ops.size();
auto visitor = [&](OpHandleBase *op, size_t i) {
auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) {
size_t j = it->second;
if (ret[i][j] != kSame) {
ret[i][j] = kBefore;
ret[j][i] = kAfter;
found_num += 2;
if (found_num == total_num) {
return false;
}
}
}
return true;
};
for (size_t i = 0; i < ops.size(); ++i) {
auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); };
if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) {
break;
}
}
for (size_t i = 0; i < ops.size(); ++i) {
for (size_t j = i + 1; j < ops.size(); ++j) {
if (ret[i][j] != kSame) continue;
ret[i][j] = kNoDeps;
ret[j][i] = kNoDeps;
}
}
return ret;
}
const OpGraphView graph_;
};
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) { OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q; std::queue<OpHandleBase *> q;
...@@ -59,9 +144,15 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -59,9 +144,15 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size()); last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size());
ref_cnts = std::vector<ReferenceCountMap>(vars.size()); ref_cnts = std::vector<ReferenceCountMap>(vars.size());
OpConnectionDetector detector(ir::FilterByNodeWrapper<OpHandleBase>(*graph));
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
for (auto &name_var_pair : vars[i]) { for (auto &name_var_pair : vars[i]) {
if (name_var_pair.second.empty()) continue; if (name_var_pair.second.empty()) {
continue;
}
const std::string &var_name = name_var_pair.first;
auto *last_ver_var = name_var_pair.second.back(); auto *last_ver_var = name_var_pair.second.back();
VarDesc *var_desc = nullptr; VarDesc *var_desc = nullptr;
...@@ -83,30 +174,46 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -83,30 +174,46 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
} }
std::unordered_set<ComputationOpHandle *> last_live_op; std::unordered_set<ComputationOpHandle *> last_live_op;
auto add_last_live_op = [&](OpHandleBase *op) { auto add_last_live_op = [&](OpHandleBase *op) -> bool {
auto *compute_op = FindNextComputationOpHandleOrReturnItself(op, i); auto *compute_op = FindNextComputationOpHandleOrReturnItself(op, i);
if (compute_op) { if (compute_op) {
last_live_op.insert(compute_op); last_live_op.insert(compute_op);
return true;
} else {
return false;
} }
}; };
const std::string &var_name = name_var_pair.first;
bool can_delete = false;
auto &pending_ops = last_ver_var->PendingOps(); auto &pending_ops = last_ver_var->PendingOps();
if (pending_ops.empty()) { if (pending_ops.empty()) {
auto *generated_op = last_ver_var->GeneratedOp(); auto *generated_op = last_ver_var->GeneratedOp();
if (generated_op) { if (generated_op && add_last_live_op(generated_op)) {
ref_cnts[i].emplace(var_name, 1); can_delete = true;
add_last_live_op(generated_op);
} }
} else { } else {
ref_cnts[i].emplace(var_name, pending_ops.size()); can_delete = true;
for (auto *pending_op : pending_ops) { for (auto *pending_op : pending_ops) {
add_last_live_op(pending_op); if (!add_last_live_op(pending_op)) {
can_delete = false;
break;
}
} }
} }
if (can_delete) {
size_t original_size = last_live_op.size();
last_live_op = detector.MaxNoDepOps(last_live_op);
if (last_live_op.size() != original_size) {
VLOG(10) << "Shrink last living op number of " << var_name << " from "
<< original_size << " to " << last_live_op.size();
}
ref_cnts[i].emplace(var_name, last_live_op.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op)); last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op));
} }
} }
}
return graph; return graph;
} }
......
...@@ -36,6 +36,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ...@@ -36,6 +36,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
} }
} }
void ScopeBufferedSSAGraphExecutor::WaitAllGarbageCollectors() {
if (gc_) {
for (auto &gc : *gc_) {
gc->Wait();
gc->Reset();
}
}
}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run( FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) { if (drop_scope_counter_ == 0) {
...@@ -74,19 +83,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -74,19 +83,19 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0; drop_scope_counter_ = 0;
// Wait All computational streams // Wait All computational streams
for (size_t i = 0; i < places_.size(); ++i) { for (auto &p : places_) {
platform::DeviceContextPool::Instance().Get(places_[i])->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
if (gc_) {
(*gc_)[i]->Wait();
(*gc_)[i]->Reset();
}
} }
WaitAllGarbageCollectors();
for (auto &scope : local_scopes_) { for (auto &scope : local_scopes_) {
auto &local_scope = auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>(); *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} }
} else {
WaitAllGarbageCollectors();
} }
if (eptr) { if (eptr) {
std::rethrow_exception(eptr); std::rethrow_exception(eptr);
} else { } else {
......
...@@ -50,6 +50,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -50,6 +50,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override; FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private: private:
void WaitAllGarbageCollectors();
size_t drop_scope_counter_{0}; size_t drop_scope_counter_{0};
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
......
...@@ -37,11 +37,49 @@ namespace { ...@@ -37,11 +37,49 @@ namespace {
int kProgramId = -1; int kProgramId = -1;
} // namespace } // namespace
static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
const BlockDesc& block, const std::vector<std::string>& skip_var_list) {
std::unordered_map<std::string, size_t> ref_cnts;
std::unordered_set<std::string> skip_vars(skip_var_list.begin(),
skip_var_list.end());
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
if (skip_vars.count(name)) continue;
auto* var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) continue;
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
type != proto::VarType::SELECTED_ROWS &&
type != proto::VarType::LOD_TENSOR_ARRAY) {
continue;
}
auto it = ref_cnts.find(name);
if (it != ref_cnts.end()) {
++it->second;
} else {
ref_cnts[name] = 1;
}
}
}
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
}
return ref_cnts;
}
ExecutorPrepareContext::ExecutorPrepareContext( ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id) const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars)
: prog_(prog), block_id_(block_id) { : prog_(prog), block_id_(block_id) {
if (GetEagerDeletionThreshold() >= 0) { if (GetEagerDeletionThreshold() >= 0) {
ref_cnts_ = GetNonPersistableReferenceCount<int>(prog_, block_id_); ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id),
skip_ref_cnt_vars);
} }
} }
...@@ -49,10 +87,9 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -49,10 +87,9 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext"; VLOG(5) << "destroy ExecutorPrepareContext";
} }
template <typename RefCntMap> static void DeleteUnusedTensors(
static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, const Scope& scope, const OperatorBase* op, GarbageCollector<Tensor>* gc,
GarbageCollector<Tensor>* gc, std::unordered_map<std::string, size_t>* ref_cnts) {
RefCntMap* ref_cnts) {
std::unordered_set<Tensor*> erase_tensors; std::unordered_set<Tensor*> erase_tensors;
auto handler = [&](const VariableNameMap& name_map) { auto handler = [&](const VariableNameMap& name_map) {
...@@ -60,7 +97,7 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, ...@@ -60,7 +97,7 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
for (auto& name : name_pair.second) { for (auto& name : name_pair.second) {
auto it = ref_cnts->find(name); auto it = ref_cnts->find(name);
if (it == ref_cnts->end()) continue; if (it == ref_cnts->end()) continue;
if ((it->second)-- == 1) { if (--(it->second) == 0) {
auto* var = scope.FindVar(name); auto* var = scope.FindVar(name);
if (var != nullptr) { if (var != nullptr) {
VLOG(10) << "Erase tensor \'" << name << "\'"; VLOG(10) << "Erase tensor \'" << name << "\'";
...@@ -69,6 +106,11 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, ...@@ -69,6 +106,11 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
erase_tensors.insert( erase_tensors.insert(
var->GetMutable<SelectedRows>()->mutable_value()); var->GetMutable<SelectedRows>()->mutable_value());
} else if (var->IsType<LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *lod_tensor_arr) {
erase_tensors.insert(&t);
}
} }
} }
} }
...@@ -351,9 +393,10 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -351,9 +393,10 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) { const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars) {
std::unique_ptr<ExecutorPrepareContext> ctx( std::unique_ptr<ExecutorPrepareContext> ctx(
new ExecutorPrepareContext(program, block_id)); new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars));
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
...@@ -364,16 +407,28 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( ...@@ -364,16 +407,28 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
} }
std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids) { const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars) {
PADDLE_ENFORCE(
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
"skip_ref_cnt_vars should be either empty or equals to block number %d",
block_ids.size());
std::vector<std::shared_ptr<ExecutorPrepareContext>> result; std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
size_t idx = 0;
for (auto& bid : block_ids) { for (auto& bid : block_ids) {
auto* ctx = new ExecutorPrepareContext(program, bid); ExecutorPrepareContext* ctx;
if (skip_ref_cnt_vars.empty()) {
ctx = new ExecutorPrepareContext(program, bid);
} else {
ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]);
}
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
auto& block = program.Block(bid); auto& block = program.Block(bid);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx)); result.push_back(std::shared_ptr<ExecutorPrepareContext>(ctx));
++idx;
} }
return result; return result;
} }
...@@ -392,18 +447,18 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -392,18 +447,18 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc; std::unique_ptr<GarbageCollector<Tensor>> gc;
// WhileOp would set keep_kids to true, if (max_memory_size >= 0) {
// because WhileGradOp needs the scopes created in WhileOp.
// Perhaps, we should not perform eager deletion in WhileOp
// The scopes and variables created by WhileOp would be deleted
// in WhileGradOp.
if (max_memory_size >= 0 && !keep_kids) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
gc.reset(new DefaultStreamGarbageCollector<Tensor>( if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place_), max_memory_size)); boost::get<platform::CUDAPlace>(place_), max_memory_size));
} else { } else {
gc.reset(new DefaultStreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place_), max_memory_size));
}
} else if (platform::is_cpu_place(place_)) {
#endif #endif
gc.reset(new CPUGarbageCollector<Tensor>( gc.reset(new CPUGarbageCollector<Tensor>(
boost::get<platform::CPUPlace>(place_), max_memory_size)); boost::get<platform::CPUPlace>(place_), max_memory_size));
...@@ -415,17 +470,14 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -415,17 +470,14 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc != nullptr) { if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), gc.get(), DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
&(ctx->cur_ref_cnts_)); &(ctx->cur_ref_cnts_));
} }
} }
if (gc != nullptr) {
gc->Wait();
} else {
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
} if (gc) gc->Wait();
if (local_scope != scope) { if (local_scope != scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
......
...@@ -28,42 +28,11 @@ namespace paddle { ...@@ -28,42 +28,11 @@ namespace paddle {
namespace framework { namespace framework {
extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); extern void InitializeVariable(Variable* var, proto::VarType::Type var_type);
template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id);
std::unordered_map<std::string, T> ref_cnts;
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto* var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) continue;
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
type != proto::VarType::SELECTED_ROWS) {
continue;
}
auto it = ref_cnts.find(name);
if (it != ref_cnts.end()) {
++it->second;
} else {
ref_cnts[name] = 1;
}
}
}
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
}
return ref_cnts;
}
struct ExecutorPrepareContext { struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>());
~ExecutorPrepareContext(); ~ExecutorPrepareContext();
void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; } void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; }
...@@ -72,8 +41,8 @@ struct ExecutorPrepareContext { ...@@ -72,8 +41,8 @@ struct ExecutorPrepareContext {
size_t block_id_; size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
std::unordered_map<std::string, int> ref_cnts_; std::unordered_map<std::string, size_t> ref_cnts_;
std::unordered_map<std::string, int> cur_ref_cnts_; std::unordered_map<std::string, size_t> cur_ref_cnts_;
}; };
class Executor { class Executor {
...@@ -109,10 +78,14 @@ class Executor { ...@@ -109,10 +78,14 @@ class Executor {
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id); const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>());
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare( static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids); const ProgramDesc& program, const std::vector<int>& block_ids,
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars =
std::vector<std::vector<std::string>>());
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
......
...@@ -19,6 +19,9 @@ ...@@ -19,6 +19,9 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -36,6 +39,11 @@ class GarbageCollector { ...@@ -36,6 +39,11 @@ class GarbageCollector {
virtual ~GarbageCollector() {} virtual ~GarbageCollector() {}
size_t NumOfGarbages() const {
std::lock_guard<std::mutex> guard(mutex_);
return garbages_->size();
}
void Reset() { void Reset() {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(mutex_);
garbages_.reset(new std::deque<T *>()); garbages_.reset(new std::deque<T *>());
...@@ -49,7 +57,7 @@ class GarbageCollector { ...@@ -49,7 +57,7 @@ class GarbageCollector {
template <typename Container, typename Callback> template <typename Container, typename Callback>
void Add(const Container &objs, Callback &&callback) { void Add(const Container &objs, Callback &&callback) {
std::shared_ptr<std::deque<T *>> clear_deque; std::deque<T *> *clear_deque = nullptr;
{ {
std::lock_guard<std::mutex> guard(mutex_); std::lock_guard<std::mutex> guard(mutex_);
for (auto *obj : objs) { for (auto *obj : objs) {
...@@ -58,7 +66,7 @@ class GarbageCollector { ...@@ -58,7 +66,7 @@ class GarbageCollector {
} }
if (cur_memory_size_ >= max_memory_size_) { if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0; cur_memory_size_ = 0;
clear_deque = garbages_; clear_deque = garbages_.release();
garbages_.reset(new std::deque<T *>()); garbages_.reset(new std::deque<T *>());
} }
} }
...@@ -67,6 +75,7 @@ class GarbageCollector { ...@@ -67,6 +75,7 @@ class GarbageCollector {
callback(); callback();
ClearCallback([clear_deque]() { ClearCallback([clear_deque]() {
for (auto *obj : *clear_deque) obj->clear(); for (auto *obj : *clear_deque) obj->clear();
delete clear_deque;
}); });
} }
} }
...@@ -77,7 +86,7 @@ class GarbageCollector { ...@@ -77,7 +86,7 @@ class GarbageCollector {
virtual void ClearCallback(const std::function<void()> &callback) = 0; virtual void ClearCallback(const std::function<void()> &callback) = 0;
platform::DeviceContext *dev_ctx_; platform::DeviceContext *dev_ctx_;
std::shared_ptr<std::deque<T *>> garbages_; std::unique_ptr<std::deque<T *>> garbages_;
mutable std::mutex mutex_; mutable std::mutex mutex_;
const size_t max_memory_size_; const size_t max_memory_size_;
size_t cur_memory_size_ = 0; size_t cur_memory_size_ = 0;
...@@ -96,6 +105,19 @@ class CPUGarbageCollector : public GarbageCollector<T> { ...@@ -96,6 +105,19 @@ class CPUGarbageCollector : public GarbageCollector<T> {
}; };
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
template <typename T>
class UnsafeFastGPUGarbageCollector : public GarbageCollector<T> {
public:
UnsafeFastGPUGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) {}
protected:
void ClearCallback(const std::function<void()> &callback) override {
callback();
}
};
template <typename T> template <typename T>
class DefaultStreamGarbageCollector : public GarbageCollector<T> { class DefaultStreamGarbageCollector : public GarbageCollector<T> {
public: public:
...@@ -109,7 +131,7 @@ class DefaultStreamGarbageCollector : public GarbageCollector<T> { ...@@ -109,7 +131,7 @@ class DefaultStreamGarbageCollector : public GarbageCollector<T> {
} }
void Wait() const override { void Wait() const override {
static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_) static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback(); ->WaitStreamCallback();
} }
...@@ -126,31 +148,23 @@ class StreamGarbageCollector : public GarbageCollector<T> { ...@@ -126,31 +148,23 @@ class StreamGarbageCollector : public GarbageCollector<T> {
StreamGarbageCollector(const platform::CUDAPlace &place, StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size) size_t max_memory_size)
: GarbageCollector<T>(place, max_memory_size) { : GarbageCollector<T>(place, max_memory_size) {
platform::SetDeviceId(place.device); platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_)); PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_)); callback_manager_.reset(new platform::StreamCallbackManager(stream_));
} }
~StreamGarbageCollector() { ~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace()); auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
platform::SetDeviceId(place.device); platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
} }
void Wait() const override { void Wait() const override { callback_manager_->Wait(); }
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
std::lock_guard<std::mutex> guard(this->mutex_);
callback_manager_->Wait();
}
cudaStream_t stream() const { return stream_; } cudaStream_t stream() const { return stream_; }
protected: protected:
// ClearCallback and Wait()/Reset() cannot be call in multiple threads
// But it is not important, because they would not be called in multiple
// threads
// either in Executor or ParallelExecutor
void ClearCallback(const std::function<void()> &callback) override { void ClearCallback(const std::function<void()> &callback) override {
callback_manager_->AddCallback(callback); callback_manager_->AddCallback(callback);
} }
......
...@@ -873,6 +873,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -873,6 +873,8 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input %s is not initialized: %s",
ipt_name, DebugString());
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
......
...@@ -158,8 +158,13 @@ ParallelExecutor::ParallelExecutor( ...@@ -158,8 +158,13 @@ ParallelExecutor::ParallelExecutor(
auto &place = member_->places_[i]; auto &place = member_->places_[i];
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
if (IsFastEagerDeletionModeEnabled()) {
member_->gcs_.emplace_back(new UnsafeFastGPUGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size));
} else {
member_->gcs_.emplace_back(new StreamGarbageCollector<Tensor>( member_->gcs_.emplace_back(new StreamGarbageCollector<Tensor>(
boost::get<platform::CUDAPlace>(place), max_memory_size)); boost::get<platform::CUDAPlace>(place), max_memory_size));
}
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
} else if (platform::is_cpu_place(place)) { } else if (platform::is_cpu_place(place)) {
#endif #endif
...@@ -181,8 +186,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -181,8 +186,8 @@ ParallelExecutor::ParallelExecutor(
&(member_->rt_ref_cnts_)); &(member_->rt_ref_cnts_));
ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars, ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars,
&last_live_ops_of_vars); &last_live_ops_of_vars);
VLOG(10) << "ReferenceCountPass Applied";
graph = ref_cnt_pass->Apply(std::move(graph)); graph = ref_cnt_pass->Apply(std::move(graph));
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass = auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass"); ir::PassRegistry::Instance().Get("eager_deletion_pass");
...@@ -194,6 +199,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -194,6 +199,8 @@ ParallelExecutor::ParallelExecutor(
&last_live_ops_of_vars); &last_live_ops_of_vars);
graph = eager_deletion_pass->Apply(std::move(graph)); graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied"; VLOG(10) << "EagerDeletionPass Applied";
graph->SetNotOwned(details::kGarbageCollector, &(member_->gcs_));
} }
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
......
...@@ -38,6 +38,10 @@ DEFINE_double( ...@@ -38,6 +38,10 @@ DEFINE_double(
"Memory size threshold (GB) when the garbage collector clear tensors." "Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0"); "Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
// When in inference scenario, the scopes will not be written by two threads in // When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and // a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue. // the mutex will cause serious performance issue.
...@@ -58,6 +62,8 @@ int64_t GetEagerDeletionThreshold() { ...@@ -58,6 +62,8 @@ int64_t GetEagerDeletionThreshold() {
(static_cast<int64_t>(1) << 30)); (static_cast<int64_t>(1) << 30));
} }
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
Scope::~Scope() { DropKids(); } Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
......
...@@ -27,6 +27,7 @@ namespace paddle { ...@@ -27,6 +27,7 @@ namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold(); int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();
class Scope; class Scope;
......
...@@ -153,7 +153,7 @@ class Tensor { ...@@ -153,7 +153,7 @@ class Tensor {
void set_layout(const DataLayout layout) { layout_ = layout; } void set_layout(const DataLayout layout) { layout_ = layout; }
void clear() { holder_ = nullptr; } void clear() { holder_.reset(); }
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; } const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
size_t offset() const { return offset_; } size_t offset() const { return offset_; }
......
...@@ -59,7 +59,21 @@ class WhileOp : public framework::OperatorBase { ...@@ -59,7 +59,21 @@ class WhileOp : public framework::OperatorBase {
"Condition of while op must in CPU memory."); "Condition of while op must in CPU memory.");
bool is_test = Attr<bool>("is_test"); bool is_test = Attr<bool>("is_test");
auto ctx = executor.Prepare(*program, block->ID()); auto &skip_eager_deletion_vars =
Attr<std::vector<std::string>>("skip_eager_deletion_vars");
if (framework::GetEagerDeletionThreshold() >= 0 && VLOG_IS_ON(10)) {
std::string debug_string =
"Skip " + std::to_string(skip_eager_deletion_vars.size()) +
" vars in eager deletion mode: ";
for (auto &var : skip_eager_deletion_vars) {
debug_string.append(var);
debug_string.push_back(' ');
}
VLOG(10) << debug_string;
}
auto ctx =
executor.Prepare(*program, block->ID(), skip_eager_deletion_vars);
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
...@@ -96,6 +110,10 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -96,6 +110,10 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<std::string>>("skip_eager_deletion_vars",
"Vars that would skip eager deletion."
"Users should not set this manually.")
.SetDefault(std::vector<std::string>());
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
} }
...@@ -341,6 +359,30 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -341,6 +359,30 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed. // while operator could be renamed.
while_grad->SetAttr("original_output_grad", output_grads_list); while_grad->SetAttr("original_output_grad", output_grads_list);
/* The following codes are used in eager deletion mode */
if (framework::GetEagerDeletionThreshold() >= 0) {
std::unordered_set<std::string> skip_vars;
for (auto *op_desc : grad_block->AllOps()) {
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
// If input var of ops inside grad_block is not from grad_block,
// it cannot be deleted when forward while_op runs
if (in_arg_name != framework::kEmptyVarName &&
!grad_block->HasVar(in_arg_name)) {
skip_vars.insert(in_arg_name);
}
}
}
if (!skip_vars.empty()) {
// FIXME(zjl): ugly const_cast here, maybe we should find a better way
// to modify forward while_op
auto &fwd_while_op = const_cast<framework::OpDesc &>(ForwardOp());
fwd_while_op.SetAttr(
"skip_eager_deletion_vars",
std::vector<std::string>(skip_vars.begin(), skip_vars.end()));
}
}
return std::unique_ptr<framework::OpDesc>(while_grad); return std::unique_ptr<framework::OpDesc>(while_grad);
} }
}; };
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sys/time.h> #include <sys/time.h>
#include <algorithm>
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
...@@ -55,8 +56,7 @@ class CTRReader : public framework::FileReader { ...@@ -55,8 +56,7 @@ class CTRReader : public framework::FileReader {
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty");
thread_num_ = thread_num_ = std::min<size_t>(file_list_.size(), thread_num);
file_list_.size() > thread_num ? thread_num : file_list_.size();
queue_ = queue; queue_ = queue;
SplitFiles(); SplitFiles();
for (size_t i = 0; i < thread_num_; ++i) { for (size_t i = 0; i < thread_num_; ++i) {
...@@ -95,10 +95,10 @@ class CTRReader : public framework::FileReader { ...@@ -95,10 +95,10 @@ class CTRReader : public framework::FileReader {
queue_->ReOpen(); queue_->ReOpen();
VLOG(3) << "reopen success"; VLOG(3) << "reopen success";
VLOG(3) << "thread_num " << thread_num_; VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) { for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread( read_threads_.emplace_back(new std::thread(std::bind(
std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, &ReadThread, file_groups_[thread_id], slots_, batch_size_,
thread_id, &read_thread_status_, queue_))); static_cast<int>(thread_id), &read_thread_status_, queue_)));
} }
monitor_thread_.reset(new std::thread( monitor_thread_.reset(new std::thread(
std::bind(&MonitorThread, &read_thread_status_, queue_))); std::bind(&MonitorThread, &read_thread_status_, queue_)));
......
...@@ -223,14 +223,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -223,14 +223,10 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback> template <typename Callback>
void AddStreamCallback(Callback&& callback) const { void AddStreamCallback(Callback&& callback) const {
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->AddCallback(callback); callback_manager_->AddCallback(callback);
} }
void WaitStreamCallback() const { void WaitStreamCallback() const { callback_manager_->Wait(); }
std::lock_guard<std::mutex> guard(callback_mtx_);
callback_manager_->Wait();
}
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config, /*! \brief CublasCall may need to change cublas's config,
...@@ -261,9 +257,7 @@ class CUDADeviceContext : public DeviceContext { ...@@ -261,9 +257,7 @@ class CUDADeviceContext : public DeviceContext {
mutable std::mutex mtx_; mutable std::mutex mtx_;
// This lock is only used by callback // StreamCallbackManager is thread-safe
// If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes
mutable std::mutex callback_mtx_;
std::unique_ptr<StreamCallbackManager> callback_manager_; std::unique_ptr<StreamCallbackManager> callback_manager_;
mutable std::mutex cublas_mtx_; mutable std::mutex cublas_mtx_;
......
...@@ -18,52 +18,47 @@ ...@@ -18,52 +18,47 @@
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct StreamCallbackContext { #if CUDA_VERSION >= 10000
inline StreamCallbackContext(const StreamCallbackManager *manager, static void CUDART_CB StreamCallbackFunc(void *user_data);
std::function<void()> callback) #else
: manager_(manager), callback_(std::move(callback)) {} static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status, void *user_data)
const StreamCallbackManager *manager_; // do not own #endif
std::function<void()> callback_; {
}; std::unique_ptr<std::function<void()>> func(
reinterpret_cast<std::function<void()> *>(user_data));
(*func)();
}
StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream) StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream)
: stream_(stream), thread_pool_(new ::ThreadPool(1)) {} : stream_(stream), thread_pool_(1) {}
void StreamCallbackManager::AddCallback(std::function<void()> callback) const { void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
auto *stream_callback_context = auto *callback_func = new std::function<void()>(std::move(callback));
new StreamCallbackContext(this, std::move(callback)); auto *func = new std::function<void()>([this, callback_func] {
std::lock_guard<std::mutex> lock(mtx_);
last_future_ = thread_pool_.enqueue([callback_func] {
std::unique_ptr<std::function<void()>> releaser(callback_func);
(*callback_func)();
});
});
#if CUDA_VERSION >= 10000 #if CUDA_VERSION >= 10000
PADDLE_ENFORCE(cudaLaunchHostFunc(stream_, PADDLE_ENFORCE(cudaLaunchHostFunc(stream_, StreamCallbackFunc, func));
StreamCallbackManager::StreamCallbackFunc,
stream_callback_context));
#else #else
PADDLE_ENFORCE( PADDLE_ENFORCE(cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0));
cudaStreamAddCallback(stream_, StreamCallbackManager::StreamCallbackFunc,
stream_callback_context, 0));
#endif #endif
} }
void StreamCallbackManager::Wait() const { StreamCallbackManager::~StreamCallbackManager() { Wait(); }
thread_pool_.reset(new ::ThreadPool(1));
}
#if CUDA_VERSION >= 10000 void StreamCallbackManager::Wait() const {
void CUDART_CB StreamCallbackManager::StreamCallbackFunc(void *user_data) PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
#else {
void CUDART_CB StreamCallbackManager::StreamCallbackFunc(cudaStream_t stream, std::lock_guard<std::mutex> lock(mtx_);
cudaError_t status, if (last_future_.valid()) {
void *user_data) last_future_.wait();
#endif }
{ }
auto *callback_context_ptr =
reinterpret_cast<StreamCallbackContext *>(user_data);
callback_context_ptr->manager_->thread_pool_->enqueue(
[callback_context_ptr]() {
std::unique_ptr<StreamCallbackContext> callback_context(
callback_context_ptr);
callback_context->callback_();
});
} }
} // namespace platform } // namespace platform
......
...@@ -18,30 +18,32 @@ ...@@ -18,30 +18,32 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <functional> #include <functional>
#include <future> // NOLINT
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
// NOTE(zjl): clean StreamCallback to make compilation faster // NOTE(zjl): clean StreamCallbackManager to make compilation faster
// Make StreamCallbackManager thread-safe
class StreamCallbackManager { class StreamCallbackManager {
public: public:
explicit StreamCallbackManager(const cudaStream_t stream); explicit StreamCallbackManager(const cudaStream_t stream);
~StreamCallbackManager();
void AddCallback(std::function<void()> callback) const; void AddCallback(std::function<void()> callback) const;
void Wait() const; void Wait() const;
private: private:
const cudaStream_t stream_; const cudaStream_t stream_;
mutable std::unique_ptr<::ThreadPool> thread_pool_; mutable ::ThreadPool thread_pool_;
mutable std::mutex mtx_;
#if CUDA_VERSION >= 10000 mutable std::future<void> last_future_;
static void CUDART_CB StreamCallbackFunc(void *user_data);
#else
static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
cudaError_t status, void *user_data);
#endif
}; };
} // namespace platform } // namespace platform
......
...@@ -162,7 +162,7 @@ void PyCPUTensorSetFromArray( ...@@ -162,7 +162,7 @@ void PyCPUTensorSetFromArray(
paddle::platform::CPUPlace place) { paddle::platform::CPUPlace place) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i])); dims.push_back(static_cast<int>(array.shape()[i]));
} }
...@@ -182,7 +182,7 @@ inline void PyCPUTensorSetFromArray( ...@@ -182,7 +182,7 @@ inline void PyCPUTensorSetFromArray(
paddle::platform::CPUPlace place) { paddle::platform::CPUPlace place) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i])); dims.push_back(static_cast<int>(array.shape()[i]));
} }
...@@ -200,7 +200,7 @@ void PyCUDATensorSetFromArray( ...@@ -200,7 +200,7 @@ void PyCUDATensorSetFromArray(
paddle::platform::CUDAPlace place) { paddle::platform::CUDAPlace place) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i])); dims.push_back(static_cast<int>(array.shape()[i]));
} }
...@@ -221,7 +221,7 @@ inline void PyCUDATensorSetFromArray( ...@@ -221,7 +221,7 @@ inline void PyCUDATensorSetFromArray(
paddle::platform::CUDAPlace place) { paddle::platform::CUDAPlace place) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i])); dims.push_back(static_cast<int>(array.shape()[i]));
} }
...@@ -240,7 +240,7 @@ void PyCUDAPinnedTensorSetFromArray( ...@@ -240,7 +240,7 @@ void PyCUDAPinnedTensorSetFromArray(
const paddle::platform::CUDAPinnedPlace &place) { const paddle::platform::CUDAPinnedPlace &place) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i])); dims.push_back(static_cast<int>(array.shape()[i]));
} }
...@@ -260,7 +260,7 @@ inline void PyCUDAPinnedTensorSetFromArray( ...@@ -260,7 +260,7 @@ inline void PyCUDAPinnedTensorSetFromArray(
const paddle::platform::CUDAPinnedPlace &place) { const paddle::platform::CUDAPinnedPlace &place) {
std::vector<int64_t> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) {
dims.push_back(static_cast<int>(array.shape()[i])); dims.push_back(static_cast<int>(array.shape()[i]));
} }
......
...@@ -116,8 +116,9 @@ def __bootstrap__(): ...@@ -116,8 +116,9 @@ def __bootstrap__():
'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_mkldnn', 'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_mkldnn',
'use_ngraph', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'use_ngraph', 'initial_cpu_memory_in_mb', 'init_allocated_mem',
'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size", 'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size",
'eager_delete_tensor_gb', 'allocator_strategy', 'eager_delete_tensor_gb', 'fast_eager_deletion_mode',
'reader_queue_speed_test_mode', 'print_sub_graph_dir' 'allocator_strategy', 'reader_queue_speed_test_mode',
'print_sub_graph_dir'
] ]
if 'Darwin' not in sysstr: if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory') read_env_flags.append('use_pinned_memory')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册