提交 eb825246 编写于 作者: S sneaxiy

polish code

add unittest model containing while_op
remove unnecessary codes
test=develop
上级 8a31b2eb
...@@ -72,6 +72,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto ...@@ -72,6 +72,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader) cc_test(reader_test SRCS reader_test.cc DEPS reader)
...@@ -164,7 +166,7 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor) ...@@ -164,7 +166,7 @@ cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper garbage_collector)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
......
...@@ -33,9 +33,10 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s ...@@ -33,9 +33,10 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows op_handle_base) cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle)
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view) cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
......
...@@ -26,8 +26,8 @@ namespace details { ...@@ -26,8 +26,8 @@ namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle( EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place, ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
GarbageCollector<Tensor> *gc, AtomicReferenceCountMap *ref_cnts) AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), : OpHandleBase(node),
scope_(scope), scope_(scope),
var_names_(var_names), var_names_(var_names),
...@@ -35,9 +35,9 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( ...@@ -35,9 +35,9 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
ref_cnts_(ref_cnts) { ref_cnts_(ref_cnts) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>( dev_ctx_ = reinterpret_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
if (dynamic_cast<StreamGarbageCollector<Tensor> *>(gc_)) { if (dynamic_cast<StreamGarbageCollector *>(gc_)) {
platform::CUDADeviceGuard guard( platform::CUDADeviceGuard guard(
boost::get<platform::CUDAPlace>(place).device); boost::get<platform::CUDAPlace>(place).device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
...@@ -61,10 +61,11 @@ std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } ...@@ -61,10 +61,11 @@ std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
std::vector<Tensor *> tensors; std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) { for (auto &name : var_names_) {
auto it = ref_cnts_->find(name); auto it = ref_cnts_->find(name);
if (it == ref_cnts_->end()) { // Var not found, not reference count has not decreased to 0
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
continue; continue;
} }
...@@ -73,43 +74,44 @@ void EagerDeletionOpHandle::RunImpl() { ...@@ -73,43 +74,44 @@ void EagerDeletionOpHandle::RunImpl() {
continue; continue;
} }
VLOG(2) << "Erase variable " << name;
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
if (it->second.fetch_sub(1) == 1) { garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemory());
tensors.emplace_back(var->GetMutable<LoDTensor>());
}
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
if (it->second.fetch_sub(1) == 1) { garbages.emplace_back(
tensors.emplace_back(var->GetMutable<SelectedRows>()->mutable_value()); var->GetMutable<SelectedRows>()->mutable_value()->MoveMemory());
}
} else if (var->IsType<LoDTensorArray>()) { } else if (var->IsType<LoDTensorArray>()) {
if (it->second.fetch_sub(1) == 1) { auto *tensor_arr = var->GetMutable<LoDTensorArray>();
auto *tensor_arr = var->GetMutable<LoDTensorArray>(); for (auto &t : *tensor_arr) {
for (auto &t : *tensor_arr) { garbages.emplace_back(t.MoveMemory());
tensors.emplace_back(&t);
}
} }
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
} }
} }
if (!tensors.empty()) { if (!garbages.empty()) {
ClearTensors(tensors); ClearGarbages(&garbages);
} }
} }
void EagerDeletionOpHandle::ClearTensors(const std::vector<Tensor *> &tensors) { void EagerDeletionOpHandle::ClearGarbages(
std::deque<std::shared_ptr<memory::Allocation>> *garbages) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (event_) { if (event_) {
auto compute_stream = dev_ctx_->stream(); auto compute_stream = dev_ctx_->stream();
auto callback_stream = auto callback_stream =
static_cast<StreamGarbageCollector<Tensor> *>(gc_)->stream(); reinterpret_cast<StreamGarbageCollector *>(gc_)->stream();
auto callback_func = [=]() { auto callback_func = [=]() {
PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0));
}; };
gc_->Add(tensors, callback_func); gc_->Add(std::move(*garbages), callback_func);
} else { } else {
#endif #endif
gc_->Add(tensors); gc_->Add(std::move(*garbages));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#pragma once #pragma once
#include <deque>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/reference_count_pass_helper.h"
...@@ -30,7 +30,7 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -30,7 +30,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
EagerDeletionOpHandle(ir::Node *node, const Scope *scope, EagerDeletionOpHandle(ir::Node *node, const Scope *scope,
const platform::Place &place, const platform::Place &place,
const std::unordered_set<std::string> &var_names, const std::unordered_set<std::string> &var_names,
GarbageCollector<Tensor> *gc, GarbageCollector *gc,
AtomicReferenceCountMap *ref_cnts); AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle(); ~EagerDeletionOpHandle();
...@@ -41,11 +41,11 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -41,11 +41,11 @@ class EagerDeletionOpHandle : public OpHandleBase {
void RunImpl() override; void RunImpl() override;
private: private:
void ClearTensors(const std::vector<Tensor *> &tensors); void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages);
const Scope *scope_; const Scope *scope_;
std::unordered_set<std::string> var_names_; std::unordered_set<std::string> var_names_;
GarbageCollector<Tensor> *gc_; // not own GarbageCollector *gc_; // not own
AtomicReferenceCountMap *ref_cnts_; // not own AtomicReferenceCountMap *ref_cnts_; // not own
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr}; platform::CUDADeviceContext *dev_ctx_{nullptr};
......
...@@ -28,17 +28,21 @@ namespace details { ...@@ -28,17 +28,21 @@ namespace details {
std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
const auto &vars = graph->Get<GraphVars>(kGraphVars);
auto &ref_cnts = auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount); Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!");
const auto &vars = graph->Get<GraphVars>(kGraphVars);
ref_cnts.resize(vars.size());
const auto &last_live_ops = const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
auto &gcs = Get<GarbageCollectorMap>(kGarbageCollector); const auto &gcs = Get<GarbageCollectorMap>(kGarbageCollector);
const auto &places = Get<std::vector<platform::Place>>(kAllPlaces); const auto &places = Get<std::vector<platform::Place>>(kAllPlaces);
ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size()); // a reverse map of last_live_ops
// i.e., last op --> variable names which can be deleted.
std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>> std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
op_vars_map; op_vars_map;
...@@ -58,8 +62,8 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl( ...@@ -58,8 +62,8 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
auto *eager_deletion_node = auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation); graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
auto *eager_deletion_op = new EagerDeletionOpHandle( auto *eager_deletion_op = new EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
std::move(var_names), gcs.at(places[op->GetScopeIdx()]).get(), gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()])); &(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if( auto it = std::find_if(
......
...@@ -42,6 +42,7 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) { ...@@ -42,6 +42,7 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const { std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret; std::unordered_set<OpHandleBase *> ret;
ret.reserve(preceding_ops_.size());
for (auto &pair : preceding_ops_) { for (auto &pair : preceding_ops_) {
ret.insert(pair.first); ret.insert(pair.first);
} }
......
...@@ -29,15 +29,17 @@ namespace paddle { ...@@ -29,15 +29,17 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class OpRelationDetector { // A functor to shrink/remove operators who depend on other operators in a set
public: class ShrinkDepsOpFunctor {
private:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 }; enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
explicit OpRelationDetector(const std::vector<OpHandleBase *> &all_ops) public:
explicit ShrinkDepsOpFunctor(const std::vector<OpHandleBase *> &all_ops)
: graph_(all_ops) {} : graph_(all_ops) {}
template <typename OpSet> template <typename OpSet>
OpSet MaxNoDepOps(const OpSet &op_set) const { OpSet operator()(const OpSet &op_set) const {
using KeyType = typename OpSet::key_type; using KeyType = typename OpSet::key_type;
static_assert( static_assert(
std::is_base_of<OpHandleBase, std::is_base_of<OpHandleBase,
...@@ -51,7 +53,7 @@ class OpRelationDetector { ...@@ -51,7 +53,7 @@ class OpRelationDetector {
auto not_before = [](RelationShip r) { return r != kBefore; }; auto not_before = [](RelationShip r) { return r != kBefore; };
for (size_t i = 0; i < rels.size(); ++i) { for (size_t i = 0; i < rels.size(); ++i) {
if (std::all_of(rels[i].begin(), rels[i].end(), not_before)) { if (std::all_of(rels[i].begin(), rels[i].end(), not_before)) {
ret.insert(static_cast<KeyType>(ops[i])); ret.emplace(static_cast<KeyType>(ops[i]));
} }
} }
return ret; return ret;
...@@ -59,7 +61,7 @@ class OpRelationDetector { ...@@ -59,7 +61,7 @@ class OpRelationDetector {
private: private:
std::vector<std::vector<RelationShip>> GetRelations( std::vector<std::vector<RelationShip>> GetRelations(
const std::vector<OpHandleBase *> ops) const { const std::vector<OpHandleBase *> &ops) const {
std::unordered_map<OpHandleBase *, size_t> op_to_idx; std::unordered_map<OpHandleBase *, size_t> op_to_idx;
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph"); PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph");
...@@ -112,6 +114,10 @@ class OpRelationDetector { ...@@ -112,6 +114,10 @@ class OpRelationDetector {
const OpGraphView graph_; const OpGraphView graph_;
}; };
/**
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
*/
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
OpHandleBase *op, size_t scope_idx) { OpHandleBase *op, size_t scope_idx) {
std::queue<OpHandleBase *> q; std::queue<OpHandleBase *> q;
...@@ -134,33 +140,87 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( ...@@ -134,33 +140,87 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
return nullptr; return nullptr;
} }
static std::unordered_set<ComputationOpHandle *>
ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
const ShrinkDepsOpFunctor &shrink_func,
bool *ok) {
// stage one. Get last op for variable.
std::unordered_set<OpHandleBase *> candidates;
{
if (var->PendingOps().empty() && var->GeneratedOp()) {
// No operator depends on this variable. So the last operator is the op
// who generates this variable.
candidates.emplace(var->GeneratedOp());
} else {
candidates = var->PendingOps();
}
// No pending ops or generated op is nullptr
if (candidates.empty()) {
*ok = false;
return {};
}
}
// stage two. Try to cast them to computation op.
// return (*ok=false) when failed.
//
// The reason why we cannot make any types of op handle to be the last lived
// op is:
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
std::unordered_set<ComputationOpHandle *> computation_op;
{
for (auto *op : candidates) {
auto *compute_op =
FindNextComputationOpHandleOrReturnItself(op, scope_idx);
if (compute_op == nullptr) {
*ok = false;
return {};
}
computation_op.emplace(compute_op);
}
}
// stage three. Try to shrink computation op if they depend on each other.
// Get the smallest set of the most ops.
*ok = true;
return shrink_func(computation_op);
}
static VarDesc *TryGetLatestVarDesc(const std::vector<VarHandle *> &vars) {
VarDesc *var_desc = nullptr;
std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool {
var_desc = var_handle->Node()->Var();
return var_desc != nullptr;
});
return var_desc;
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &vars = graph->Get<GraphVars>(kGraphVars);
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &last_live_ops_of_vars = auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size()); PADDLE_ENFORCE(last_live_ops_of_vars.empty() && ref_cnts.empty(),
ref_cnts = std::vector<ReferenceCountMap>(vars.size()); "Last Live Ops and Reference Counts of vars should be "
"initialized at here.");
OpRelationDetector detector(ir::FilterByNodeWrapper<OpHandleBase>(*graph)); const auto &vars = graph->Get<GraphVars>(kGraphVars);
for (size_t i = 0; i < vars.size(); ++i) { last_live_ops_of_vars.resize(vars.size());
for (auto &name_var_pair : vars[i]) { ref_cnts.resize(vars.size());
if (name_var_pair.second.empty()) {
continue;
}
const std::string &var_name = name_var_pair.first; ShrinkDepsOpFunctor shrink_func(
auto *last_ver_var = name_var_pair.second.back(); ir::FilterByNodeWrapper<OpHandleBase>(*graph));
VarDesc *var_desc = nullptr; for (size_t i = 0; i < vars.size(); ++i) {
std::find_if(name_var_pair.second.rbegin(), name_var_pair.second.rend(), for (auto &name_var_pair : vars[i]) {
[&](VarHandle *var_handle) -> bool { // Whether this variable can be reused or deleted? If not, we do not
var_desc = var_handle->Node()->Var(); // compute reference counts and dependencies.
return var_desc != nullptr; VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second);
});
if (var_desc == nullptr || var_desc->Persistable()) { if (var_desc == nullptr || var_desc->Persistable()) {
continue; continue;
...@@ -170,50 +230,20 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -170,50 +230,20 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if (var_type != proto::VarType::LOD_TENSOR && if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS && var_type != proto::VarType::SELECTED_ROWS &&
var_type != proto::VarType::LOD_TENSOR_ARRAY) { var_type != proto::VarType::LOD_TENSOR_ARRAY) {
// Var type cannot be deleted
continue; continue;
} }
std::unordered_set<ComputationOpHandle *> last_live_op; bool ok;
auto add_last_live_op = [&](OpHandleBase *op) -> bool { auto result = ExtractComputationOpFromLastLivedVar(
auto *compute_op = FindNextComputationOpHandleOrReturnItself(op, i); name_var_pair.second.back(), i, shrink_func, &ok);
if (compute_op) {
last_live_op.insert(compute_op);
return true;
} else {
return false;
}
};
bool can_delete = false;
auto &pending_ops = last_ver_var->PendingOps();
if (pending_ops.empty()) {
auto *generated_op = last_ver_var->GeneratedOp();
if (generated_op && add_last_live_op(generated_op)) {
can_delete = true;
}
} else {
can_delete = true;
for (auto *pending_op : pending_ops) {
if (!add_last_live_op(pending_op)) {
can_delete = false;
break;
}
}
}
if (can_delete) {
size_t original_size = last_live_op.size();
last_live_op = detector.MaxNoDepOps(last_live_op);
if (last_live_op.size() != original_size) {
VLOG(10) << "Shrink last living op number of " << var_name << " from "
<< original_size << " to " << last_live_op.size();
}
PADDLE_ENFORCE(!last_live_op.empty(),
"Last living ops of %s cannot be empty", var_name);
ref_cnts[i].emplace(var_name, last_live_op.size()); if (ok) {
last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op)); auto &var_name = name_var_pair.first;
PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty",
var_name);
ref_cnts[i].emplace(var_name, result.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
} }
} }
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
namespace paddle {
namespace framework {
namespace details {} // namespace details
} // namespace framework
} // namespace paddle
...@@ -18,10 +18,10 @@ ...@@ -18,10 +18,10 @@
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -35,7 +35,7 @@ using AtomicReferenceCountMap = ...@@ -35,7 +35,7 @@ using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<size_t>>; std::unordered_map<std::string, std::atomic<size_t>>;
using GarbageCollectorMap = using GarbageCollectorMap =
std::map<platform::Place, std::unique_ptr<GarbageCollector<Tensor>>>; std::map<platform::Place, std::unique_ptr<GarbageCollector>>;
const char kGlobalReferenceCount[] = "global_reference_count"; const char kGlobalReferenceCount[] = "global_reference_count";
const char kRuntimeReferenceCount[] = "runtime_reference_count"; const char kRuntimeReferenceCount[] = "runtime_reference_count";
......
...@@ -30,20 +30,7 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ...@@ -30,20 +30,7 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
underlying_executor_(std::move(underlying_executor)), underlying_executor_(std::move(underlying_executor)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
var_infos_(std::move(var_infos)), var_infos_(std::move(var_infos)),
places_(std::move(places)) { places_(std::move(places)) {}
if (Graph().Has(details::kGarbageCollector)) {
gc_ = &(Graph().Get<GarbageCollectorMap>(details::kGarbageCollector));
}
}
void ScopeBufferedSSAGraphExecutor::WaitAllGarbageCollectors() {
if (gc_) {
for (auto &gc_pair : *gc_) {
gc_pair.second->Wait();
gc_pair.second->Reset();
}
}
}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run( FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
...@@ -83,19 +70,15 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -83,19 +70,15 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0; drop_scope_counter_ = 0;
// Wait All computational streams // Wait All computational streams
for (auto &p : places_) { for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
} }
WaitAllGarbageCollectors();
for (auto &scope : local_scopes_) { for (auto &scope : local_scopes_) {
auto &local_scope = auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>(); *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
} }
} else {
WaitAllGarbageCollectors();
} }
if (eptr) { if (eptr) {
std::rethrow_exception(eptr); std::rethrow_exception(eptr);
} else { } else {
......
...@@ -21,11 +21,9 @@ ...@@ -21,11 +21,9 @@
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -50,8 +48,6 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -50,8 +48,6 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override; FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private: private:
void WaitAllGarbageCollectors();
size_t drop_scope_counter_{0}; size_t drop_scope_counter_{0};
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
...@@ -59,8 +55,6 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -59,8 +55,6 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<VariableInfo> var_infos_; std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
GarbageCollectorMap* gc_{nullptr};
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include <deque>
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
...@@ -83,31 +84,37 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { ...@@ -83,31 +84,37 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
} }
static void DeleteUnusedTensors( static void DeleteUnusedTensors(
const Scope& scope, const OperatorBase* op, GarbageCollector<Tensor>* gc, const Scope& scope, const OperatorBase* op, GarbageCollector* gc,
std::unordered_map<std::string, size_t>* ref_cnts) { std::unordered_map<std::string, size_t>* ref_cnts) {
std::unordered_set<Tensor*> erase_tensors; std::deque<std::shared_ptr<memory::Allocation>> garbages;
auto handler = [&](const VariableNameMap& name_map) { auto handler = [&](const VariableNameMap& name_map) {
for (auto& name_pair : name_map) { for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) { for (auto& name : name_pair.second) {
auto it = ref_cnts->find(name); auto it = ref_cnts->find(name);
if (it == ref_cnts->end()) continue; if (it == ref_cnts->end()) continue;
if (--(it->second) == 0) { if (--(it->second) != 0) {
auto* var = scope.FindVar(name); continue;
if (var != nullptr) { }
VLOG(2) << "Erase tensor \'" << name << "\'"; auto* var = scope.FindVar(name);
if (var->IsType<LoDTensor>()) { if (var != nullptr) {
erase_tensors.insert(var->GetMutable<LoDTensor>()); continue;
} else if (var->IsType<SelectedRows>()) { }
erase_tensors.insert(
var->GetMutable<SelectedRows>()->mutable_value()); VLOG(2) << "Erase variable " << name;
} else if (var->IsType<LoDTensorArray>()) { if (var->IsType<LoDTensor>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>(); garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemory());
for (auto& t : *lod_tensor_arr) { } else if (var->IsType<SelectedRows>()) {
erase_tensors.insert(&t); garbages.emplace_back(
} var->GetMutable<SelectedRows>()->mutable_value()->MoveMemory());
} } else if (var->IsType<LoDTensorArray>()) {
auto* lod_tensor_arr = var->GetMutable<LoDTensorArray>();
for (auto& t : *lod_tensor_arr) {
garbages.emplace_back(t.MoveMemory());
} }
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name);
} }
} }
} }
...@@ -116,8 +123,8 @@ static void DeleteUnusedTensors( ...@@ -116,8 +123,8 @@ static void DeleteUnusedTensors(
handler(op->Inputs()); handler(op->Inputs());
handler(op->Outputs()); handler(op->Outputs());
if (!erase_tensors.empty()) { if (!garbages.empty()) {
gc->Add(erase_tensors); gc->Add(std::move(garbages));
} }
} }
...@@ -411,22 +418,22 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -411,22 +418,22 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc; std::unique_ptr<GarbageCollector> gc;
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
if (IsFastEagerDeletionModeEnabled()) { if (IsFastEagerDeletionModeEnabled()) {
gc.reset(new UnsafeFastGPUGarbageCollector<Tensor>( gc.reset(new UnsafeFastGPUGarbageCollector(
boost::get<platform::CUDAPlace>(place_), max_memory_size)); boost::get<platform::CUDAPlace>(place_), max_memory_size));
} else { } else {
gc.reset(new DefaultStreamGarbageCollector<Tensor>( gc.reset(new DefaultStreamGarbageCollector(
boost::get<platform::CUDAPlace>(place_), max_memory_size)); boost::get<platform::CUDAPlace>(place_), max_memory_size));
} }
} else if (platform::is_cpu_place(place_)) { } else if (platform::is_cpu_place(place_)) {
#endif #endif
gc.reset(new CPUGarbageCollector<Tensor>( gc.reset(new CPUGarbageCollector(boost::get<platform::CPUPlace>(place_),
boost::get<platform::CPUPlace>(place_), max_memory_size)); max_memory_size));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
...@@ -442,7 +449,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -442,7 +449,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
platform::DeviceContextPool::Instance().Get(place_)->Wait(); platform::DeviceContextPool::Instance().Get(place_)->Wait();
if (gc) gc->Wait();
if (local_scope != scope) { if (local_scope != scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/framework/garbage_collector.h"
namespace paddle {
namespace framework {
GarbageCollector::GarbageCollector(const platform::Place &place,
size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new GarbageQueue());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place,
size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void CPUGarbageCollector::ClearCallback(const std::function<void()> &callback) {
callback();
}
#ifdef PADDLE_WITH_CUDA
UnsafeFastGPUGarbageCollector::UnsafeFastGPUGarbageCollector(
const platform::CUDAPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void UnsafeFastGPUGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}
DefaultStreamGarbageCollector::DefaultStreamGarbageCollector(
const platform::CUDAPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void DefaultStreamGarbageCollector::Wait() const {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
void DefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
StreamGarbageCollector::~StreamGarbageCollector() {
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
cudaStream_t StreamGarbageCollector::stream() const { return stream_; }
void StreamGarbageCollector::Wait() const { callback_manager_->Wait(); }
void StreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback_manager_->AddCallback(callback);
}
#endif
} // namespace framework
} // namespace paddle
...@@ -14,160 +14,83 @@ ...@@ -14,160 +14,83 @@
#pragma once #pragma once
#include <algorithm>
#include <deque> #include <deque>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// T should have memory_size() and clear() method
template <typename T>
class GarbageCollector { class GarbageCollector {
public: public:
GarbageCollector(const platform::Place &place, size_t max_memory_size) using GarbageQueue = std::deque<std::shared_ptr<memory::Allocation>>;
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
garbages_.reset(new std::deque<T *>());
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
virtual ~GarbageCollector() {} GarbageCollector(const platform::Place &place, size_t max_memory_size);
size_t NumOfGarbages() const { virtual ~GarbageCollector() = default;
std::lock_guard<std::mutex> guard(mutex_);
return garbages_->size();
}
void Reset() { virtual void Wait() const {}
std::lock_guard<std::mutex> guard(mutex_);
garbages_.reset(new std::deque<T *>());
cur_memory_size_ = 0;
}
template <typename Container> template <typename Container>
void Add(const Container &objs) { void Add(Container &&objs);
Add(objs, []() {});
}
template <typename Container, typename Callback> template <typename Container, typename Callback>
void Add(const Container &objs, Callback &&callback) { void Add(Container &&objs, Callback &&callback);
std::deque<T *> *clear_deque = nullptr;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto *obj : objs) {
garbages_->push_back(obj);
cur_memory_size_ += obj->memory_size();
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
clear_deque = garbages_.release();
garbages_.reset(new std::deque<T *>());
}
}
if (clear_deque != nullptr) {
callback();
ClearCallback([clear_deque]() {
for (auto *obj : *clear_deque) obj->clear();
delete clear_deque;
});
}
}
virtual void Wait() const {}
protected: protected:
virtual void ClearCallback(const std::function<void()> &callback) = 0; virtual void ClearCallback(const std::function<void()> &callback) = 0;
platform::DeviceContext *dev_ctx_; platform::DeviceContext *dev_ctx_;
std::unique_ptr<std::deque<T *>> garbages_; std::unique_ptr<GarbageQueue> garbages_;
mutable std::mutex mutex_; mutable std::mutex mutex_;
const size_t max_memory_size_; const size_t max_memory_size_;
size_t cur_memory_size_ = 0; size_t cur_memory_size_{0};
}; };
template <typename T> class CPUGarbageCollector : public GarbageCollector {
class CPUGarbageCollector : public GarbageCollector<T> {
public: public:
CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size) CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size);
: GarbageCollector<T>(place, max_memory_size) {}
protected: protected:
void ClearCallback(const std::function<void()> &callback) override { void ClearCallback(const std::function<void()> &callback) override;
callback();
}
}; };
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
template <typename T> class UnsafeFastGPUGarbageCollector : public GarbageCollector {
class UnsafeFastGPUGarbageCollector : public GarbageCollector<T> {
public: public:
UnsafeFastGPUGarbageCollector(const platform::CUDAPlace &place, UnsafeFastGPUGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size) size_t max_memory_size);
: GarbageCollector<T>(place, max_memory_size) {}
protected: protected:
void ClearCallback(const std::function<void()> &callback) override { void ClearCallback(const std::function<void()> &callback) override;
callback();
}
}; };
template <typename T> class DefaultStreamGarbageCollector : public GarbageCollector {
class DefaultStreamGarbageCollector : public GarbageCollector<T> {
public: public:
DefaultStreamGarbageCollector(const platform::CUDAPlace &place, DefaultStreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size) size_t max_memory_size);
: GarbageCollector<T>(place, max_memory_size) {}
cudaStream_t stream() const { void Wait() const override;
return static_cast<const platform::CUDADeviceContext *>(this->dev_ctx_)
->stream();
}
void Wait() const override {
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
protected: protected:
void ClearCallback(const std::function<void()> &callback) override { void ClearCallback(const std::function<void()> &callback) override;
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
}; };
template <typename T> class StreamGarbageCollector : public GarbageCollector {
class StreamGarbageCollector : public GarbageCollector<T> {
public: public:
StreamGarbageCollector(const platform::CUDAPlace &place, StreamGarbageCollector(const platform::CUDAPlace &place,
size_t max_memory_size) size_t max_memory_size);
: GarbageCollector<T>(place, max_memory_size) {
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
}
~StreamGarbageCollector() { ~StreamGarbageCollector();
auto place = boost::get<platform::CUDAPlace>(this->dev_ctx_->GetPlace());
platform::CUDADeviceGuard guard(place.device);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
void Wait() const override { callback_manager_->Wait(); } void Wait() const override;
cudaStream_t stream() const { return stream_; } cudaStream_t stream() const;
protected: protected:
void ClearCallback(const std::function<void()> &callback) override { void ClearCallback(const std::function<void()> &callback) override;
callback_manager_->AddCallback(callback);
}
private: private:
cudaStream_t stream_; cudaStream_t stream_;
...@@ -175,5 +98,33 @@ class StreamGarbageCollector : public GarbageCollector<T> { ...@@ -175,5 +98,33 @@ class StreamGarbageCollector : public GarbageCollector<T> {
}; };
#endif #endif
template <typename Container>
void GarbageCollector::Add(Container &&objs) {
Add(std::forward<Container>(objs), []() {});
}
template <typename Container, typename Callback>
void GarbageCollector::Add(Container &&objs, Callback &&callback) {
GarbageQueue *garbage_queue = nullptr;
{
std::lock_guard<std::mutex> guard(mutex_);
for (auto &obj : objs) {
if (!obj) continue;
cur_memory_size_ += obj->size();
garbages_->push_back(std::move(obj));
}
if (cur_memory_size_ >= max_memory_size_) {
cur_memory_size_ = 0;
garbage_queue = garbages_.release();
garbages_.reset(new GarbageQueue());
}
}
if (garbage_queue) {
callback();
ClearCallback([garbage_queue]() { delete garbage_queue; });
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -97,29 +97,31 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts( ...@@ -97,29 +97,31 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
if (gcs_.count(place) > 0) { if (gcs_.count(place) > 0) {
continue; continue;
} }
GarbageCollector<Tensor> *gc = nullptr; std::unique_ptr<GarbageCollector> gc;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
if (IsFastEagerDeletionModeEnabled()) { if (IsFastEagerDeletionModeEnabled()) {
gc = new UnsafeFastGPUGarbageCollector<Tensor>( gc.reset(new UnsafeFastGPUGarbageCollector(
boost::get<platform::CUDAPlace>(place), max_memory_size); boost::get<platform::CUDAPlace>(place), max_memory_size));
} else { } else {
gc = new StreamGarbageCollector<Tensor>( gc.reset(new StreamGarbageCollector(
boost::get<platform::CUDAPlace>(place), max_memory_size); boost::get<platform::CUDAPlace>(place), max_memory_size));
} }
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
} else if (platform::is_cpu_place(place)) { } else {
#endif #endif
gc = new CPUGarbageCollector<Tensor>( if (platform::is_cpu_place(place)) {
boost::get<platform::CPUPlace>(place), max_memory_size); gc.reset(new CPUGarbageCollector(boost::get<platform::CPUPlace>(place),
VLOG(10) << "Created GarbageCollector at " << place; max_memory_size));
VLOG(10) << "Created GarbageCollector at " << place;
} else {
PADDLE_THROW("Unsupported place for garbage collection");
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
#endif #endif
if (gc) { gcs_.emplace(place, std::move(gc));
gcs_[place] = std::unique_ptr<GarbageCollector<Tensor>>(gc);
}
} }
if (!gcs_.empty()) { if (!gcs_.empty()) {
...@@ -144,8 +146,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts( ...@@ -144,8 +146,6 @@ std::unique_ptr<ir::Graph> ParallelExecutorPrivate::PrepareGCAndRefCnts(
eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_); eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_);
graph = eager_deletion_pass->Apply(std::move(graph)); graph = eager_deletion_pass->Apply(std::move(graph));
VLOG(10) << "EagerDeletionPass Applied"; VLOG(10) << "EagerDeletionPass Applied";
graph->SetNotOwned(details::kGarbageCollector, &gcs_);
} }
return graph; return graph;
......
...@@ -38,7 +38,7 @@ DEFINE_double( ...@@ -38,7 +38,7 @@ DEFINE_double(
"Memory size threshold (GB) when the garbage collector clear tensors." "Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0"); "Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true, DEFINE_bool(fast_eager_deletion_mode, false,
"Fast eager deletion mode. If enabled, memory would release " "Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends."); "immediately without waiting GPU kernel ends.");
......
...@@ -158,6 +158,10 @@ class Tensor { ...@@ -158,6 +158,10 @@ class Tensor {
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; } const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
size_t offset() const { return offset_; } size_t offset() const { return offset_; }
std::shared_ptr<memory::Allocation> MoveMemory() {
return std::move(holder_);
}
private: private:
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_; std::shared_ptr<memory::Allocation> holder_;
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_eager_deletion_lstm_net import TestBase
import paddle.fluid as fluid
def gru_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=400.0):
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
gru_max_tanh = fluid.layers.tanh(gru_max)
fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class GRUTest(TestBase):
def setUp(self):
self.net = gru_net
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
os.environ['CPU_NUM'] = '2'
import six
import unittest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2):
if use_cuda and not core.is_compiled_with_cuda():
print('Skip use_cuda=True because Paddle is not compiled with cuda')
return
word_dict = paddle.dataset.imdb.word_dict()
train_reader = paddle.batch(
paddle.dataset.imdb.train(word_dict), batch_size=batch_size)
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
cost = network(data, label, len(word_dict))
optimizer = fluid.optimizer.Adagrad(learning_rate=0.2)
optimizer.minimize(cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
reader = feeder.decorate_reader(
train_reader, multi_devices=use_parallel_executor)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if use_parallel_executor:
train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=cost.name)
fetch_list = [cost.name]
else:
train_exe = exe
fetch_list = [cost]
for pass_id in six.moves.xrange(pass_num):
batch_id = 0
for data in reader():
train_exe.run(feed=data,
fetch_list=fetch_list if batch_id % 4 == 0 else [])
batch_id += 1
if batch_id > 16:
break
def lstm_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2,
emb_lr=30.0):
emb = fluid.layers.embedding(
input=data,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(learning_rate=emb_lr))
fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
lstm_h, c = fluid.layers.dynamic_lstm(
input=fc0, size=hid_dim * 4, is_reverse=False)
lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
lstm_max_tanh = fluid.layers.tanh(lstm_max)
fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class TestBase(unittest.TestCase):
def setUp(self):
self.net = lstm_net
def test_network(self):
for use_cuda in [True, False]:
for use_parallel_executor in [False, True]:
print('network: {}, use_cuda: {}, use_parallel_executor: {}'.
format(self.net.__name__, use_cuda,
use_parallel_executor))
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(core.Scope()):
train(self.net, use_cuda, use_parallel_executor)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册