#pragma once #include "paddle/framework/op_registry.h" #include "paddle/operators/nccl/nccl_gpu_common.h" #include namespace paddle { namespace operators { template class NCCLTypeWrapper; template <> class NCCLTypeWrapper { static const ncclDataType_t type = ncclFloat; }; template <> class NCCLTypeWrapper { static const ncclDataType_t type = ncclDouble; }; template class NCCLAllReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); auto outs = ctx.MultiOutput("Out"); std::string reduction = ctx.Attr("reduction"); std::vector gpus = ctx.Attr>("gpus"); ncclRedOp_t op_type; if (reduction == "ncclSum") { op_type = ncclSum; } else if (reduction == "ncclProd") { op_type = ncclProd; } else if (reduction == "ncclMin") { op_type = ncclMin; } else (reduction == "ncclMax") { op_type = ncclMax; } auto dev_ctx = static_cast(ctx.device_context()); NCCLManager* m = NCCLManager::Get(); auto* comm = m->GetCommunicator(gpus); comm->wg_.Add(1); auto* stream = &dev_ctx.stream(); // device id int gid = ctx.GetPlace().GetDeviceId(); int idx = gid % gpus.size(); comm->streams_[idx] = stream; for (size_t i = 0; i < ins.size(); ++i) { NCCL_CHECK(ncclAllReduce(ins[i]->data(), outs[i]->mutable_data(), outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, op_type, &comm->comms_[idx], comm->streams_[idx])); NCCL_CHECK(cudaEventRecord(comm->events_[idx], *comms_->streams_[idx])); // wait finish NCCL_CHECK( cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0)); } comm->wg_.Done(); wg.Wait(); } }; } }