nccl_ops.cc 1.2 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"

namespace paddle {
namespace operators {

// AllreduceOp
class NCCLAllreduceOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  // allreduce do nothing in infershape
  void InferShape(const framework::InferShapeContext &ctx) const override {}
};

template <typename T>
class NCCLAllreduceOp : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *ctx = static_cast<NCCLContext *>(context.device_context());
    // auto *comm = ;
    // auto *src = ;
    // ncclAllReduce(src, dest, )
  }
};

// BcastSendOp
template <typename T>
class NCCLBroadcastSendOp final : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(const framework::InferShapeContext &ctx) const override {}
};

// BcastRecvOp
template <typename T>
class NCCLBroadcastRecvOp final : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(const framework::InferShapeContext &ctx) const override {}
};
}
}