nccl_ops.cc 3.2 KB
Newer Older
D
dzhwinter 已提交
1
#include "paddle/operators/nccl/nccl_ops.h"
D
dongzhihong 已提交
2 3 4 5 6

namespace paddle {
namespace operators {

// AllreduceOp
D
dzhwinter 已提交
7
class NCCLAllReduceOp : public framework::OperatorWithKernel {
D
dongzhihong 已提交
8 9 10 11 12
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  // allreduce do nothing in infershape
D
dzhwinter 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25
  void InferShape(const framework::InferShapeContext &ctx) const override {
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
                            " Input(X) of AllReduce op input should not be NULL");
    auto ins = ctx.MultiInput<framework::Tensor>("X");
    auto outs = ctx.MultiOutput<framework::Tensor>("Out");
    PADDLE_ENFORCE(ins.size() == outs.size(), "Input(X) and Output(Out) must have same size");
    for(size_t i=0; i < ins.size(); ++i) {
      outs[i]->Resize(ins[i]->dims());
    }
    std::string reduction = ctx.Attr<std::string>("reduction");
    PADDLE_ENFORCE( (reduction == "ncclSum" || reduction == "ncclProd" ||
                     reduction == "ncclMin" || reduction == "ncclMax"), "invalid reduction!");
  }
D
dongzhihong 已提交
26 27 28 29 30 31 32 33 34 35 36 37
};

template <typename T>
class NCCLAllreduceOp : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *ctx = static_cast<NCCLContext *>(context.device_context());
  }
};

// BcastSendOp
template <typename T>
D
dzhwinter 已提交
38
class NCCLBcastSendOp final : public framework::OperatorWithKernel {
D
dongzhihong 已提交
39 40 41 42
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
D
dzhwinter 已提交
43 44 45 46
  void InferShape(const framework::InferShapeContext &ctx) const override {
    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
                            " Input(X) of BcastSend op input should not be NULL");
  }
D
dongzhihong 已提交
47 48 49 50
};

// BcastRecvOp
template <typename T>
D
dzhwinter 已提交
51
class NCCLBcastRecvOp final : public framework::OperatorWithKernel {
D
dongzhihong 已提交
52 53 54 55
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
D
dzhwinter 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
  void InferShape(const framework::InferShapeContext &ctx) const override {
    PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
                            " Input(X) of BcastRecv op input should not be NULL");
  }
};


class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
  NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
    : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "The input of AllReduce op");
    AddOutput("Out", "The output of AllReduce op");
    AddAttr<std::string>("reduction: {'min', 'max', 'prod', 'sum'}.");
    AddComment(R"DOC(
            AllReduce the input tensors.
        )DOC");
  }
D
dongzhihong 已提交
73
};
D
dzhwinter 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94

class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
  NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
    : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "The input of BcastSend op");
    AddComment(R"DOC(
            BcastSend the tensors.
        )DOC");
  }
};

class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
  NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
    : OpProtoAndCheckerMaker(proto, op_checker) {
    AddOutput("Out", "The output of BcastRecv op");
    AddComment(R"DOC(
            BcastRecv the tensors.
        )DOC");
  }
};

D
dongzhihong 已提交
95 96
}
}