From 8f0590e7c5924e9281a957cf0d355176c4bed301 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 16 Mar 2018 16:31:58 +0800 Subject: [PATCH] Add ncclAllReduce --- paddle/fluid/framework/parallel_executor.cc | 50 +++++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 996273c720..ec5eb57910 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -138,14 +138,6 @@ struct ScaleLossGradOpHandle : public OpHandle { } }; -struct NCCLAllReduceOpHandle : public OpHandle { - void Run() override { - if (this->inputs_.size() == 1) { - return; // No need to all reduce when GPU count = 1; - } - } -}; - class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(size_t num_threads = 12) @@ -243,6 +235,46 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { } } +struct NCCLAllReduceOpHandle : public OpHandle { + ParallelExecutorPrivate *member_; + + explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) + : member_(member) {} + + void Run() override { + if (this->inputs_.size() == 1) { + return; // No need to all reduce when GPU count = 1; + } else { + auto &var_name = static_cast(this->inputs_[0])->name_; + + int dtype = -1; + size_t numel = 0; + + for (auto &p : member_->places_) { + int dev_id = boost::get(p).device; + + Scope *s = member_->local_scopes_[p]; + auto &lod_tensor = s->FindVar(var_name)->Get(); + void *buffer = const_cast(lod_tensor.data()); + if (dtype == -1) { + dtype = ToNCCLDataType(lod_tensor.type()); + } + + if (numel == 0) { + numel = static_cast(lod_tensor.numel()); + } + + auto &nccl_ctx = member_->communication_streams_.at(dev_id); + + ncclAllReduce(buffer, buffer, numel, static_cast(dtype), + ncclSum, nccl_ctx.comm, nccl_ctx.stream()); + } + + ncclGroupEnd(); + } + } +}; + ParallelExecutor::ParallelExecutor( const std::vector &places, const std::unordered_set ¶ms, @@ -361,7 +393,7 @@ void ParallelExecutor::ConstructDependencyGraph( for (auto &og : var_names) { if (grads.count(og) != 0) { // is param grad // Insert NCCL AllReduce Op - member_->ops_.emplace_back(new NCCLAllReduceOpHandle()); + member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_)); auto *op_handle = member_->ops_.back().get(); for (auto &pair : member_->local_scopes_) { -- GitLab