nccl_ops.h 2.0 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"

D
dzhwinter 已提交
5 6
#include <string.h>

D
dongzhihong 已提交
7
namespace paddle {
D
dzhwinter 已提交
8 9
namespace operators {

D
Dong Zhihong 已提交
10 11
using framework::Tensor;

D
Dong Zhihong 已提交
12
template <typename Type>
D
dzhwinter 已提交
13 14
class NCCLTypeWrapper;

D
Dong Zhihong 已提交
15
template <>
D
dzhwinter 已提交
16 17 18 19
class NCCLTypeWrapper<float> {
  static const ncclDataType_t type = ncclFloat;
};

D
Dong Zhihong 已提交
20
template <>
D
dzhwinter 已提交
21 22 23 24
class NCCLTypeWrapper<double> {
  static const ncclDataType_t type = ncclDouble;
};

D
Dong Zhihong 已提交
25
template <typename T>
D
Dong Zhihong 已提交
26
class NCCLAllReduceKernel : public framework::OpKernel<T> {
D
Dong Zhihong 已提交
27
 public:
D
dzhwinter 已提交
28 29 30 31
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<Tensor>("X");
    auto outs = ctx.MultiOutput<Tensor>("Out");
    std::string reduction = ctx.Attr<std::string>("reduction");
D
Dong Zhihong 已提交
32
    std::vector<int> gpus = ctx.Attr<std::vector<int>>("gpus");
D
dzhwinter 已提交
33 34 35 36 37 38 39
    ncclRedOp_t op_type;
    if (reduction == "ncclSum") {
      op_type = ncclSum;
    } else if (reduction == "ncclProd") {
      op_type = ncclProd;
    } else if (reduction == "ncclMin") {
      op_type = ncclMin;
D
Dong Zhihong 已提交
40 41 42
    } else if (reduction == "ncclMax") {
      op_type = ncclMax;
    }
D
Dong Zhihong 已提交
43 44 45 46

    auto dev_ctx =
        static_cast<const platform::CUDADeviceContext>(ctx.device_context());

D
Dong Zhihong 已提交
47
    platform::NCCLManager* m = platform::NCCLManager::Get();
D
Dong Zhihong 已提交
48 49 50

    auto* comm = m->GetCommunicator(gpus);
    comm->wg_.Add(1);
D
dzhwinter 已提交
51

D
Dong Zhihong 已提交
52
    auto* stream = &dev_ctx.stream();
D
dzhwinter 已提交
53

D
Dong Zhihong 已提交
54 55 56 57 58 59
    // device id
    int gid = ctx.GetPlace().GetDeviceId();
    int idx = gid % gpus.size();
    comm->streams_[idx] = stream;

    for (size_t i = 0; i < ins.size(); ++i) {
D
Dong Zhihong 已提交
60 61 62 63 64 65
      PADDLE_ENFORCE(
          ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(),
                        outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type,
                        op_type, &comm->comms_[idx], comm->streams_[idx]));
      PADDLE_ENFORCE(
          cudaEventRecord(comm->events_[idx], *comms_->streams_[idx]));
D
Dong Zhihong 已提交
66 67

      // wait finish
D
Dong Zhihong 已提交
68
      PADDLE_ENFORCE(
D
Dong Zhihong 已提交
69
          cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
D
dzhwinter 已提交
70 71
    }

D
Dong Zhihong 已提交
72
    comm->wg_.Done();
D
dzhwinter 已提交
73

D
Dong Zhihong 已提交
74 75 76
    wg.Wait();
  }
};
D
dzhwinter 已提交
77
}
D
dongzhihong 已提交
78
}