nccl_ops.h 2.0 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"

D
dzhwinter 已提交
5 6
#include <string.h>

D
dongzhihong 已提交
7
namespace paddle {
D
dzhwinter 已提交
8 9
namespace operators {

D
Dong Zhihong 已提交
10
template <typename Type>
D
dzhwinter 已提交
11 12
class NCCLTypeWrapper;

D
Dong Zhihong 已提交
13
template <>
D
dzhwinter 已提交
14 15 16 17
class NCCLTypeWrapper<float> {
  static const ncclDataType_t type = ncclFloat;
};

D
Dong Zhihong 已提交
18
template <>
D
dzhwinter 已提交
19 20 21 22
class NCCLTypeWrapper<double> {
  static const ncclDataType_t type = ncclDouble;
};

D
Dong Zhihong 已提交
23
template <typename T>
D
dzhwinter 已提交
24
class NCCLAllReduceKernel : public framework::OpKernel {
D
Dong Zhihong 已提交
25
 public:
D
dzhwinter 已提交
26 27 28 29
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<Tensor>("X");
    auto outs = ctx.MultiOutput<Tensor>("Out");
    std::string reduction = ctx.Attr<std::string>("reduction");
D
Dong Zhihong 已提交
30
    std::vector<int> gpus = ctx.Attr<std::vector<int>>("gpus");
D
dzhwinter 已提交
31 32 33 34 35 36 37
    ncclRedOp_t op_type;
    if (reduction == "ncclSum") {
      op_type = ncclSum;
    } else if (reduction == "ncclProd") {
      op_type = ncclProd;
    } else if (reduction == "ncclMin") {
      op_type = ncclMin;
D
Dong Zhihong 已提交
38 39 40 41 42 43 44 45 46 47
    } else
      (reduction == "ncclMax") { op_type = ncclMax; }

    auto dev_ctx =
        static_cast<const platform::CUDADeviceContext>(ctx.device_context());

    NCCLManager* m = NCCLManager::Get();

    auto* comm = m->GetCommunicator(gpus);
    comm->wg_.Add(1);
D
dzhwinter 已提交
48

D
Dong Zhihong 已提交
49
    auto* stream = &dev_ctx.stream();
D
dzhwinter 已提交
50

D
Dong Zhihong 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    // device id
    int gid = ctx.GetPlace().GetDeviceId();
    int idx = gid % gpus.size();
    comm->streams_[idx] = stream;

    for (size_t i = 0; i < ins.size(); ++i) {
      NCCL_CHECK(ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(),
                               outs[i]->numel() * sizeof(T),
                               NCCLTypeWrapper<T>::type, op_type,
                               &comm->comms_[idx], comm->streams_[idx]));
      NCCL_CHECK(cudaEventRecord(comm->events_[idx], *comms_->streams_[idx]));

      // wait finish
      NCCL_CHECK(
          cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
D
dzhwinter 已提交
66 67
    }

D
Dong Zhihong 已提交
68
    comm->wg_.Done();
D
dzhwinter 已提交
69

D
Dong Zhihong 已提交
70 71 72
    wg.Wait();
  }
};
D
dzhwinter 已提交
73
}
D
dongzhihong 已提交
74
}