nccl_ops.h 2.7 KB
Newer Older
D
Dong Zhihong 已提交
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

D
dongzhihong 已提交
12 13 14 15
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"

D
dzhwinter 已提交
16 17
#include <string.h>

D
dongzhihong 已提交
18
namespace paddle {
D
dzhwinter 已提交
19 20
namespace operators {

D
Dong Zhihong 已提交
21 22
using framework::Tensor;

D
Dong Zhihong 已提交
23
template <typename Type>
D
dzhwinter 已提交
24 25
class NCCLTypeWrapper;

D
Dong Zhihong 已提交
26
template <>
D
dzhwinter 已提交
27
class NCCLTypeWrapper<float> {
D
Dong Zhihong 已提交
28
 public:
D
dzhwinter 已提交
29 30 31
  static const ncclDataType_t type = ncclFloat;
};

D
Dong Zhihong 已提交
32
template <>
D
dzhwinter 已提交
33
class NCCLTypeWrapper<double> {
D
Dong Zhihong 已提交
34
 public:
D
dzhwinter 已提交
35 36 37
  static const ncclDataType_t type = ncclDouble;
};

D
Dong Zhihong 已提交
38
template <typename T>
D
Dong Zhihong 已提交
39
class NCCLAllReduceKernel : public framework::OpKernel<T> {
D
Dong Zhihong 已提交
40
 public:
D
dzhwinter 已提交
41 42 43 44
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto ins = ctx.MultiInput<Tensor>("X");
    auto outs = ctx.MultiOutput<Tensor>("Out");
    std::string reduction = ctx.Attr<std::string>("reduction");
D
Dong Zhihong 已提交
45
    std::vector<int> gpus = ctx.Attr<std::vector<int>>("gpus");
D
dzhwinter 已提交
46 47 48 49 50 51 52
    ncclRedOp_t op_type;
    if (reduction == "ncclSum") {
      op_type = ncclSum;
    } else if (reduction == "ncclProd") {
      op_type = ncclProd;
    } else if (reduction == "ncclMin") {
      op_type = ncclMin;
D
Dong Zhihong 已提交
53 54 55
    } else if (reduction == "ncclMax") {
      op_type = ncclMax;
    }
D
Dong Zhihong 已提交
56 57 58 59

    auto dev_ctx =
        static_cast<const platform::CUDADeviceContext>(ctx.device_context());

D
Dong Zhihong 已提交
60
    platform::NCCLManager* m = platform::NCCLManager::Get();
D
Dong Zhihong 已提交
61 62 63

    auto* comm = m->GetCommunicator(gpus);
    comm->wg_.Add(1);
D
dzhwinter 已提交
64

D
Dong Zhihong 已提交
65
    auto stream = dev_ctx.stream();
D
dzhwinter 已提交
66

D
Dong Zhihong 已提交
67
    // device id
D
Dong Zhihong 已提交
68
    int gid = static_cast<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
D
Dong Zhihong 已提交
69 70 71 72
    int idx = gid % gpus.size();
    comm->streams_[idx] = stream;

    for (size_t i = 0; i < ins.size(); ++i) {
D
Dong Zhihong 已提交
73 74 75
      PADDLE_ENFORCE(
          ncclAllReduce(ins[i]->data<T>(), outs[i]->mutable_data<T>(),
                        outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type,
D
Dong Zhihong 已提交
76 77
                        op_type, comm->comms_[idx], comm->streams_[idx]));
      PADDLE_ENFORCE(cudaEventRecord(comm->events_[idx], comm->streams_[idx]));
D
Dong Zhihong 已提交
78 79

      // wait finish
D
Dong Zhihong 已提交
80
      PADDLE_ENFORCE(
D
Dong Zhihong 已提交
81
          cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
D
dzhwinter 已提交
82 83
    }

D
Dong Zhihong 已提交
84
    comm->wg_.Done();
D
dzhwinter 已提交
85

D
Dong Zhihong 已提交
86
    comm->wg_.Wait();
D
Dong Zhihong 已提交
87 88
  }
};
D
Dong Zhihong 已提交
89 90 91

}  // namespace operators
}  // namespace paddle