nccl_ops.cc 3.1 KB
Newer Older
D
dzhwinter 已提交
1
#include "paddle/operators/nccl/nccl_ops.h"
D
dongzhihong 已提交
2 3 4 5 6

namespace paddle {
namespace operators {

// AllreduceOp
D
dzhwinter 已提交
7
class NCCLAllReduceOp : public framework::OperatorWithKernel {
D
dongzhihong 已提交
8 9 10 11 12
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  // allreduce do nothing in infershape
D
dzhwinter 已提交
13
  void InferShape(const framework::InferShapeContext &ctx) const override {
D
Dong Zhihong 已提交
14 15 16
    PADDLE_ENFORCE_NOT_NULL(
        ctx.InputVar("X"),
        " Input(X) of AllReduce op input should not be NULL");
D
dzhwinter 已提交
17 18
    auto ins = ctx.MultiInput<framework::Tensor>("X");
    auto outs = ctx.MultiOutput<framework::Tensor>("Out");
D
Dong Zhihong 已提交
19 20 21
    PADDLE_ENFORCE(ins.size() == outs.size(),
                   "Input(X) and Output(Out) must have same size");
    for (size_t i = 0; i < ins.size(); ++i) {
D
dzhwinter 已提交
22 23 24
      outs[i]->Resize(ins[i]->dims());
    }
    std::string reduction = ctx.Attr<std::string>("reduction");
D
Dong Zhihong 已提交
25 26 27
    PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
                    reduction == "ncclMin" || reduction == "ncclMax"),
                   "invalid reduction!");
D
dongzhihong 已提交
28 29 30 31 32
  }
};

// BcastSendOp
template <typename T>
D
dzhwinter 已提交
33
class NCCLBcastSendOp final : public framework::OperatorWithKernel {
D
dongzhihong 已提交
34 35 36 37
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
D
dzhwinter 已提交
38
  void InferShape(const framework::InferShapeContext &ctx) const override {
D
Dong Zhihong 已提交
39 40 41
    PADDLE_ENFORCE_NOT_NULL(
        ctx.InputVar("X"),
        " Input(X) of BcastSend op input should not be NULL");
D
dzhwinter 已提交
42
  }
D
dongzhihong 已提交
43 44 45 46
};

// BcastRecvOp
template <typename T>
D
dzhwinter 已提交
47
class NCCLBcastRecvOp final : public framework::OperatorWithKernel {
D
dongzhihong 已提交
48 49 50 51
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
D
dzhwinter 已提交
52
  void InferShape(const framework::InferShapeContext &ctx) const override {
D
Dong Zhihong 已提交
53 54 55
    PADDLE_ENFORCE_NOT_NULL(
        ctx.OutputVar("Out"),
        " Input(X) of BcastRecv op input should not be NULL");
D
dzhwinter 已提交
56 57 58 59
  }
};

class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
D
Dong Zhihong 已提交
60 61 62
  NCCLAllReduceOpMaker(framework::OpProto *proto,
                       framework::OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
63 64
    AddInput("X", "The input of AllReduce op");
    AddOutput("Out", "The output of AllReduce op");
D
Dong Zhihong 已提交
65 66 67
    AddAttr<std::string>("reduction",
                         "{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}.");
    AddAttr<std::vector<int>>("gpus", "gpu id lists");
D
dzhwinter 已提交
68 69 70 71
    AddComment(R"DOC(
            AllReduce the input tensors.
        )DOC");
  }
D
dongzhihong 已提交
72
};
D
dzhwinter 已提交
73 74

class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
D
Dong Zhihong 已提交
75 76 77
  NCCLAllReduceOpMaker(framework::OpProto *proto,
                       framework::OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
78 79 80 81 82 83 84 85
    AddInput("X", "The input of BcastSend op");
    AddComment(R"DOC(
            BcastSend the tensors.
        )DOC");
  }
};

class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
D
Dong Zhihong 已提交
86 87 88
  NCCLAllReduceOpMaker(framework::OpProto *proto,
                       framework::OpAttrChecker *op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
D
dzhwinter 已提交
89 90 91 92 93 94 95
    AddOutput("Out", "The output of BcastRecv op");
    AddComment(R"DOC(
            BcastRecv the tensors.
        )DOC");
  }
};

D
Dong Zhihong 已提交
96 97
}  // operators
}  // paddle