diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 996273c720a2e821ac37842d066a09f84ad95311..ec5eb579105a46c0c30ce72e3fecd77f711b9f69 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -138,14 +138,6 @@ struct ScaleLossGradOpHandle : public OpHandle { } }; -struct NCCLAllReduceOpHandle : public OpHandle { - void Run() override { - if (this->inputs_.size() == 1) { - return; // No need to all reduce when GPU count = 1; - } - } -}; - class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(size_t num_threads = 12) @@ -243,6 +235,46 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { } } +struct NCCLAllReduceOpHandle : public OpHandle { + ParallelExecutorPrivate *member_; + + explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) + : member_(member) {} + + void Run() override { + if (this->inputs_.size() == 1) { + return; // No need to all reduce when GPU count = 1; + } else { + auto &var_name = static_cast(this->inputs_[0])->name_; + + int dtype = -1; + size_t numel = 0; + + for (auto &p : member_->places_) { + int dev_id = boost::get(p).device; + + Scope *s = member_->local_scopes_[p]; + auto &lod_tensor = s->FindVar(var_name)->Get(); + void *buffer = const_cast(lod_tensor.data()); + if (dtype == -1) { + dtype = ToNCCLDataType(lod_tensor.type()); + } + + if (numel == 0) { + numel = static_cast(lod_tensor.numel()); + } + + auto &nccl_ctx = member_->communication_streams_.at(dev_id); + + ncclAllReduce(buffer, buffer, numel, static_cast(dtype), + ncclSum, nccl_ctx.comm, nccl_ctx.stream()); + } + + ncclGroupEnd(); + } + } +}; + ParallelExecutor::ParallelExecutor( const std::vector &places, const std::unordered_set ¶ms, @@ -361,7 +393,7 @@ void ParallelExecutor::ConstructDependencyGraph( for (auto &og : var_names) { if (grads.count(og) != 0) { // is param grad // Insert NCCL AllReduce Op - member_->ops_.emplace_back(new NCCLAllReduceOpHandle()); + member_->ops_.emplace_back(new NCCLAllReduceOpHandle(member_)); auto *op_handle = member_->ops_.back().get(); for (auto &pair : member_->local_scopes_) {