提交 0a36ef3c 编写于 作者: S sneaxiy

enhance eager deletion

上级 d775087d
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
namespace paddle { namespace paddle {
...@@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase {
const std::vector<std::string> &var_names, const std::vector<std::string> &var_names,
GarbageCollector<Tensor> *gc, GarbageCollector<Tensor> *gc,
AtomicReferenceCountMap *ref_cnts) AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), : OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) {
scope_(scope),
var_names_(var_names),
gc_(gc),
ref_cnts_(ref_cnts) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>( dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device)); PADDLE_ENFORCE(cudaSetDevice(place.device));
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
} }
for (auto &name : var_names) AddVar(name);
} }
~ReferenceCountOpHandle() { ~ReferenceCountOpHandle() {
...@@ -69,20 +68,36 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -69,20 +68,36 @@ class ReferenceCountOpHandle : public OpHandleBase {
std::string Name() const override { return "reference_count"; } std::string Name() const override { return "reference_count"; }
void AddVar(const std::string &name) {
auto it = var_names_.find(name);
if (it != var_names_.end())
++(it->second);
else
var_names_[name] = 1;
}
protected: protected:
void RunImpl() override { void RunImpl() override {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<LoDTensor *> tensors; std::vector<Tensor *> tensors;
for (auto &name : var_names_) { for (auto &pair : var_names_) {
auto &name = pair.first;
auto it = ref_cnts_->find(name); auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) continue; if (it == ref_cnts_->end()) continue;
auto *var = exec_scope->FindVar(name); auto *var = exec_scope->FindVar(name);
if (var == nullptr || !var->IsType<LoDTensor>()) continue; if (var == nullptr) continue;
if (it->second.fetch_sub(1) <= 1) { if (var->IsType<LoDTensor>()) {
if (it->second.fetch_sub(pair.second) <= pair.second) {
tensors.emplace_back(var->GetMutable<LoDTensor>()); tensors.emplace_back(var->GetMutable<LoDTensor>());
} }
} else if (var->IsType<SelectedRows>()) {
if (it->second.fetch_sub(pair.second) <= pair.second) {
tensors.emplace_back(
var->GetMutable<SelectedRows>()->mutable_value());
}
}
} }
if (!tensors.empty()) { if (!tensors.empty()) {
...@@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
} }
private: private:
void ClearTensors(const std::vector<LoDTensor *> &tensors) { void ClearTensors(const std::vector<Tensor *> &tensors) {
auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_); auto *gc = dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_);
if (gc != nullptr) { if (gc != nullptr) {
auto compute_stream = dev_ctx_->stream(); auto compute_stream = dev_ctx_->stream();
...@@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
const Scope *scope_; const Scope *scope_;
platform::CUDADeviceContext *dev_ctx_; platform::CUDADeviceContext *dev_ctx_;
std::vector<std::string> var_names_; std::unordered_map<std::string, int> var_names_;
GarbageCollector<Tensor> *gc_; // not own GarbageCollector<Tensor> *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own AtomicReferenceCountMap *ref_cnts_; // not own
cudaEvent_t event_; cudaEvent_t event_;
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <queue>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -23,6 +24,25 @@ namespace paddle { ...@@ -23,6 +24,25 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
std::queue<VarHandleBase *> queue;
queue.push(var_in);
do {
auto *var = queue.front();
queue.pop();
for (auto *op : var->PendingOps()) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) {
return compute_op;
}
for (auto *out_var : op->Outputs()) {
queue.push(out_var);
}
}
} while (!queue.empty());
return nullptr;
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount); auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
...@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables // Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops // in computation ops
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>>
compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&]( auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, const std::unique_ptr<OpHandleBase> &op,
const std::vector<VarHandleBase *> &vars) { const std::vector<VarHandleBase *> &vars) {
...@@ -54,16 +77,19 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -54,16 +77,19 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
VarDesc *var_desc = var_handle->Node()->Var(); VarDesc *var_desc = var_handle->Node()->Var();
auto var_name = var_handle->Node()->Name(); auto var_name = var_handle->Node()->Name();
// This is wierd but there is really some variables without var_desc // This is weird but there is really some variables without var_desc
// in computation_op // in computation_op
if (var_desc == nullptr) { if (var_desc == nullptr) {
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr) if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
continue; continue;
} else { } else {
if (var_desc->Persistable() || if (var_desc->Persistable()) continue;
var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR) auto var_type = var_desc->Proto()->type().type();
if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS) {
continue; continue;
} }
}
// compute op only runs in one device // compute op only runs in one device
if (ref_cnts[place.device]->count(var_name)) if (ref_cnts[place.device]->count(var_name))
...@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if (ref_cnts.count(place.device) && if (ref_cnts.count(place.device) &&
ref_cnts[place.device]->count(var_name)) { ref_cnts[place.device]->count(var_name)) {
++(*ref_cnts[place.device])[var_name]; ++(*ref_cnts[place.device])[var_name];
auto *next_compute_op = FindNextComputationOpHandle(var_handle);
if (next_compute_op != nullptr) {
if (compute_ref_cnt_map.count(next_compute_op)) {
compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
VLOG(5) << "Add reference count of " << var_name << " to Operator "
<< next_compute_op->Name();
} else {
// Create new reference_count_op_handle
ir::Node *ref_cnt_node = graph->CreateEmptyNode(
"reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (next_compute_op->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
next_compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
}
}
} }
} }
}; };
std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map;
auto &all_ops = graph->Get<GraphOps>(kGraphOps); auto &all_ops = graph->Get<GraphOps>(kGraphOps);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
...@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names, ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (compute_op->Outputs().empty()) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
compute_op->AddOutput(dep_var); compute_op->AddOutput(dep_var);
ref_cnt_handle->AddInput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var); graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
compute_ref_cnt_map[compute_op] = ref_cnt_handle; }
ref_cnt_handle->AddInput(compute_op->Outputs().front());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
} }
for (auto &op : all_ops) { for (auto &op : all_ops) {
...@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
new_all_ops.emplace_back(std::move(op)); new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); auto it = compute_ref_cnt_map.find(new_all_ops.back().get());
if (it != compute_ref_cnt_map.end()) { if (it != compute_ref_cnt_map.end()) {
new_all_ops.emplace_back(it->second); // Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
it->second->AddOutput(dummy_leaf);
new_all_ops.emplace_back(std::move(it->second));
} }
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <math.h> // for sqrt in CPU and CUDA #include <math.h> // for sqrt in CPU and CUDA
#include <Eigen/Dense> #include <Eigen/Dense>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -306,26 +307,45 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -306,26 +307,45 @@ class AdamOpKernel : public framework::OpKernel<T> {
VLOG(3) << "grad row size is 0!!"; VLOG(3) << "grad row size is 0!!";
return; return;
} }
std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
is_strict_sorted = false;
break;
}
}
const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = &grad;
} else {
// merge duplicated rows if any. // merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor // The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func; scatter::MergeAdd<DeviceContext, T> merge_func;
auto& grad_merge = *(ctx.scope() auto* grad_merge_var = const_cast<framework::Scope&>(ctx.scope())
.NewScope() .Var()
.Var("sparse_adam_grad_merge") ->GetMutable<framework::SelectedRows>();
->GetMutable<framework::SelectedRows>());
merge_func(ctx.template device_context<DeviceContext>(), grad, merge_func(ctx.template device_context<DeviceContext>(), grad,
&grad_merge); grad_merge_var);
grad_merge_ptr = grad_merge_var;
std::cerr << "Create new variables in adam_op" << std::endl;
}
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value(); auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
int64_t* rows = nullptr; const int64_t* rows = nullptr;
// When compiled without CUDA, the CUDAMutableData() interface should not be // When compiled without CUDA, the CUDAData() interface should not be
// provided. // provided.
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace()); rows = grad_merge.rows().CUDAData(ctx.GetPlace());
} else { } else {
#endif #endif
rows = grad_merge.mutable_rows()->data(); rows = grad_merge.rows().data();
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册