diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h index 71db8d952f4c205b875ad254dc19c0c1f74e61b3..fc479a4c4a1e7d5c824d3c202e0cccf743dd52c9 100644 --- a/paddle/fluid/framework/details/reference_count_op_handle.h +++ b/paddle/fluid/framework/details/reference_count_op_handle.h @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/tensor.h" namespace paddle { @@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase { const std::vector &var_names, GarbageCollector *gc, AtomicReferenceCountMap *ref_cnts) - : OpHandleBase(node), - scope_(scope), - var_names_(var_names), - gc_(gc), - ref_cnts_(ref_cnts) { + : OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) { dev_ctx_ = static_cast( platform::DeviceContextPool::Instance().Get(place)); if (IsStreamGarabageCollector()) { PADDLE_ENFORCE(cudaSetDevice(place.device)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); } + + for (auto &name : var_names) AddVar(name); } ~ReferenceCountOpHandle() { @@ -69,19 +68,35 @@ class ReferenceCountOpHandle : public OpHandleBase { std::string Name() const override { return "reference_count"; } + void AddVar(const std::string &name) { + auto it = var_names_.find(name); + if (it != var_names_.end()) + ++(it->second); + else + var_names_[name] = 1; + } + protected: void RunImpl() override { auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); - std::vector tensors; - for (auto &name : var_names_) { + std::vector tensors; + for (auto &pair : var_names_) { + auto &name = pair.first; auto it = ref_cnts_->find(name); if (it == ref_cnts_->end()) continue; auto *var = exec_scope->FindVar(name); - if (var == nullptr || !var->IsType()) continue; - - if (it->second.fetch_sub(1) <= 1) { - tensors.emplace_back(var->GetMutable()); + if (var == nullptr) continue; + + if (var->IsType()) { + if (it->second.fetch_sub(pair.second) <= pair.second) { + tensors.emplace_back(var->GetMutable()); + } + } else if (var->IsType()) { + if (it->second.fetch_sub(pair.second) <= pair.second) { + tensors.emplace_back( + var->GetMutable()->mutable_value()); + } } } @@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase { } private: - void ClearTensors(const std::vector &tensors) { + void ClearTensors(const std::vector &tensors) { auto *gc = dynamic_cast *>(gc_); if (gc != nullptr) { auto compute_stream = dev_ctx_->stream(); @@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase { const Scope *scope_; platform::CUDADeviceContext *dev_ctx_; - std::vector var_names_; + std::unordered_map var_names_; GarbageCollector *gc_; // not own AtomicReferenceCountMap *ref_cnts_; // not own cudaEvent_t event_; diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 344754d5a1e119c04cae08ad50126924b5824315..b1ce551ce73de33bcede187c72feebad6e2fa1a5 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -23,6 +24,25 @@ namespace paddle { namespace framework { namespace details { +static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { + std::queue queue; + queue.push(var_in); + do { + auto *var = queue.front(); + queue.pop(); + for (auto *op : var->PendingOps()) { + auto *compute_op = dynamic_cast(op); + if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) { + return compute_op; + } + for (auto *out_var : op->Outputs()) { + queue.push(out_var); + } + } + } while (!queue.empty()); + return nullptr; +} + std::unique_ptr ReferenceCountPass::ApplyImpl( std::unique_ptr graph) const { auto &ref_cnts = Get(kGlobalReferenceCount); @@ -34,6 +54,9 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( // Step 2: Find all variables in non-computation ops which refers to variables // in computation ops std::unordered_set names; + std::unordered_map> + compute_ref_cnt_map; + auto get_ref_cnts_from_compute_op = [&]( const std::unique_ptr &op, const std::vector &vars) { @@ -54,15 +77,18 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( VarDesc *var_desc = var_handle->Node()->Var(); auto var_name = var_handle->Node()->Name(); - // This is wierd but there is really some variables without var_desc + // This is weird but there is really some variables without var_desc // in computation_op if (var_desc == nullptr) { if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr) continue; } else { - if (var_desc->Persistable() || - var_desc->Proto()->type().type() != proto::VarType::LOD_TENSOR) + if (var_desc->Persistable()) continue; + auto var_type = var_desc->Proto()->type().type(); + if (var_type != proto::VarType::LOD_TENSOR && + var_type != proto::VarType::SELECTED_ROWS) { continue; + } } // compute op only runs in one device @@ -93,12 +119,33 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( if (ref_cnts.count(place.device) && ref_cnts[place.device]->count(var_name)) { ++(*ref_cnts[place.device])[var_name]; + + auto *next_compute_op = FindNextComputationOpHandle(var_handle); + if (next_compute_op != nullptr) { + if (compute_ref_cnt_map.count(next_compute_op)) { + compute_ref_cnt_map[next_compute_op]->AddVar(var_name); + VLOG(5) << "Add reference count of " << var_name << " to Operator " + << next_compute_op->Name(); + } else { + // Create new reference_count_op_handle + ir::Node *ref_cnt_node = graph->CreateEmptyNode( + "reference_count", ir::Node::Type::kOperation); + auto *ref_cnt_handle = new ReferenceCountOpHandle( + ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, + gcs[place.device].get(), cur_ref_cnts[place.device].get()); + if (next_compute_op->Outputs().empty()) { + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + next_compute_op->AddOutput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + } + ref_cnt_handle->AddInput(next_compute_op->Outputs().front()); + compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle); + } + } } } }; - std::unordered_map - compute_ref_cnt_map; auto &all_ops = graph->Get(kGraphOps); for (auto &op : all_ops) { auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); @@ -113,11 +160,13 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( auto *ref_cnt_handle = new ReferenceCountOpHandle( ref_cnt_node, compute_op->GetScope(), place, in_var_names, gcs[place.device].get(), cur_ref_cnts[place.device].get()); - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - compute_op->AddOutput(dep_var); - ref_cnt_handle->AddInput(dep_var); - graph->Get(kGraphDepVars).emplace(dep_var); - compute_ref_cnt_map[compute_op] = ref_cnt_handle; + if (compute_op->Outputs().empty()) { + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + compute_op->AddOutput(dep_var); + graph->Get(kGraphDepVars).emplace(dep_var); + } + ref_cnt_handle->AddInput(compute_op->Outputs().front()); + compute_ref_cnt_map[compute_op].reset(ref_cnt_handle); } for (auto &op : all_ops) { @@ -131,7 +180,11 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( new_all_ops.emplace_back(std::move(op)); auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); if (it != compute_ref_cnt_map.end()) { - new_all_ops.emplace_back(it->second); + // Add LeafNode to ReferenceCountOpHandle + auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dummy_leaf); + it->second->AddOutput(dummy_leaf); + new_all_ops.emplace_back(std::move(it->second)); } } diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 5b27068c9e805146b8bce03f4f676ef0d4d16c53..fbab136dbdd4c8c5f1d39cba7aabdc86f4cb8d2c 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include // for sqrt in CPU and CUDA #include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" @@ -306,26 +307,45 @@ class AdamOpKernel : public framework::OpKernel { VLOG(3) << "grad row size is 0!!"; return; } - // merge duplicated rows if any. - // The rows of grad_merge have been sorted inside MergeAdd functor - scatter::MergeAdd merge_func; - auto& grad_merge = *(ctx.scope() - .NewScope() - .Var("sparse_adam_grad_merge") - ->GetMutable()); - merge_func(ctx.template device_context(), grad, - &grad_merge); + + std::vector cpu_rows(grad.rows().begin(), grad.rows().end()); + bool is_strict_sorted = true; + for (size_t i = 1; i < cpu_rows.size(); ++i) { + if (cpu_rows[i - 1] >= cpu_rows[i]) { + is_strict_sorted = false; + break; + } + } + + const framework::SelectedRows* grad_merge_ptr; + if (is_strict_sorted) { + grad_merge_ptr = &grad; + } else { + // merge duplicated rows if any. + // The rows of grad_merge have been sorted inside MergeAdd functor + scatter::MergeAdd merge_func; + auto* grad_merge_var = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + merge_func(ctx.template device_context(), grad, + grad_merge_var); + grad_merge_ptr = grad_merge_var; + + std::cerr << "Create new variables in adam_op" << std::endl; + } + + auto& grad_merge = *grad_merge_ptr; auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); - int64_t* rows = nullptr; -// When compiled without CUDA, the CUDAMutableData() interface should not be + const int64_t* rows = nullptr; +// When compiled without CUDA, the CUDAData() interface should not be // provided. #if defined(PADDLE_WITH_CUDA) if (platform::is_gpu_place(ctx.GetPlace())) { - rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace()); + rows = grad_merge.rows().CUDAData(ctx.GetPlace()); } else { #endif - rows = grad_merge.mutable_rows()->data(); + rows = grad_merge.rows().data(); #if defined(PADDLE_WITH_CUDA) }