#include "paddle/operators/nccl/nccl_ops.h" namespace paddle { namespace operators { // AllreduceOp class NCCLAllReduceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: // allreduce do nothing in infershape void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL( ctx.InputVar("X"), " Input(X) of AllReduce op input should not be NULL"); auto ins = ctx.MultiInput("X"); auto outs = ctx.MultiOutput("Out"); PADDLE_ENFORCE(ins.size() == outs.size(), "Input(X) and Output(Out) must have same size"); for (size_t i = 0; i < ins.size(); ++i) { outs[i]->Resize(ins[i]->dims()); } std::string reduction = ctx.Attr("reduction"); PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || reduction == "ncclMin" || reduction == "ncclMax"), "invalid reduction!"); } }; // BcastSendOp template class NCCLBcastSendOp final : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL( ctx.InputVar("X"), " Input(X) of BcastSend op input should not be NULL"); } }; // BcastRecvOp template class NCCLBcastRecvOp final : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL( ctx.OutputVar("Out"), " Input(X) of BcastRecv op input should not be NULL"); } }; class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of AllReduce op"); AddOutput("Out", "The output of AllReduce op"); AddAttr("reduction", "{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}."); AddAttr>("gpus", "gpu id lists"); AddComment(R"DOC( AllReduce the input tensors. )DOC"); } }; class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of BcastSend op"); AddComment(R"DOC( BcastSend the tensors. )DOC"); } }; class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "The output of BcastRecv op"); AddComment(R"DOC( BcastRecv the tensors. )DOC"); } }; } // operators } // paddle