#pragma once #include "paddle/framework/op_registry.h" #include "paddle/operators/nccl/nccl_gpu_common.h" #include namespace paddle { namespace operators { template class NCCLTypeWrapper; template<> class NCCLTypeWrapper { static const ncclDataType_t type = ncclFloat; }; template<> class NCCLTypeWrapper { static const ncclDataType_t type = ncclDouble; }; template class NCCLAllReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); auto outs = ctx.MultiOutput("Out"); std::string reduction = ctx.Attr("reduction"); ncclRedOp_t op_type; if (reduction == "ncclSum") { op_type = ncclSum; } else if (reduction == "ncclProd") { op_type = ncclProd; } else if (reduction == "ncclMin") { op_type = ncclMin; } else (reduction == "ncclMax") { op_type = ncclMax; } auto dev_ctx = ctx.device_context(); for( size_t i=0; i < ins.size(); ++i) { ncclAllReduce(ins[i]->data(), outs[i]->mutable_data(), outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, op_type, comm, stream); } } }; } }