nccl_ops.h 1.4 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"

D
dzhwinter 已提交
5 6
#include <string.h>

D
dongzhihong 已提交
7
namespace paddle {
D
dzhwinter 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
namespace operators {


template<typename Type>
class NCCLTypeWrapper;

template<>
class NCCLTypeWrapper<float> {
  static const ncclDataType_t type = ncclFloat;
};

template<>
class NCCLTypeWrapper<double> {
  static const ncclDataType_t type = ncclDouble;
};



template<typename T>
class NCCLAllReduceKernel : public framework::OpKernel {
public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<Tensor>("X");
    auto outs = ctx.MultiOutput<Tensor>("Out");
    std::string reduction = ctx.Attr<std::string>("reduction");
    ncclRedOp_t op_type;
    if (reduction == "ncclSum") {
      op_type = ncclSum;
    } else if (reduction == "ncclProd") {
      op_type = ncclProd;
    } else if (reduction == "ncclMin") {
      op_type = ncclMin;
    } else (reduction == "ncclMax") {
      op_type = ncclMax;
    }

    auto dev_ctx = ctx.device_context();

    for( size_t i=0; i < ins.size(); ++i) {
      ncclAllReduce(ins[i]->data<T>(),
                    outs[i]->mutable_data<T>(),
                    outs[i]->numel() * sizeof(T),
                    NCCLTypeWrapper<T>::type,
                    op_type,
                    comm,
                    stream);
    }
  }
};


}
D
dongzhihong 已提交
60
}