nccl_op_handle.h 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
24
#ifdef PADDLE_WITH_CUDA
25
#include "paddle/fluid/platform/dynload/nccl.h"
26 27 28 29
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
30
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
31 32 33 34 35 36 37 38 39

DECLARE_bool(sync_nccl_allreduce);

namespace paddle {
namespace framework {
namespace details {

class NCCLOpHandleBase : public OpHandleBase {
 public:
40 41
  NCCLOpHandleBase(ir::Node* node,
                   const std::vector<platform::Place>& places,
42
                   const platform::NCCLCommunicator* nccl_ctxs)
43 44 45 46 47 48 49 50 51 52 53 54
      : OpHandleBase(node), places_(places), nccl_ctxs_(nccl_ctxs) {
    if (nccl_ctxs == nullptr) {
      return;
    }
    // init device context
    auto default_nccl_ctxs = nccl_ctxs_->DefaultFlatCtx();
    for (auto& p : places_) {
      this->SetDeviceContext(p, default_nccl_ctxs->DevCtx(p));
    }
  }
  virtual ~NCCLOpHandleBase() {
    for (auto& ev : inter_events_) {
55
#ifdef PADDLE_WITH_HIP
56
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(ev.second));
57
#else
58
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(ev.second));
59
#endif
60 61
    }
    for (auto& ev : exter_events_) {
62
#ifdef PADDLE_WITH_HIP
63
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(ev.second));
64
#else
65
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(ev.second));
66
#endif
67 68
    }
  }
69 70 71 72 73

  const platform::NCCLCommunicator* GetNcclContext() const {
    return nccl_ctxs_;
  }

74
  ncclComm_t GetComm() const {
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    PADDLE_ENFORCE_EQ(
        places_.size(),
        1,
        platform::errors::Unimplemented(
            "Only supported for single place now, but got %d", places_.size()));
    PADDLE_ENFORCE_EQ(use_hierarchical_allreduce_,
                      0,
                      platform::errors::Unimplemented(
                          "Not supported use_hierarchical_allreduce_ now"));
    auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
    int dev_id = places_[0].device;
    auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
    auto comm = nccl_ctx.comm_;
    return comm;
  }

91
  void SetRunEnv(int run_order, bool use_hierarchical_allreduce) {
92
    PADDLE_ENFORCE_GE(
93 94
        run_order,
        0,
95 96
        platform::errors::InvalidArgument(
            "The argument run_order must be >= 0, but got %d.", run_order));
97 98 99 100 101
    run_order_ = run_order;
    use_hierarchical_allreduce_ = use_hierarchical_allreduce;

    VLOG(10) << "SetRunEnv "
             << " run_order:" << run_order
102 103
             << ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
             << ", nccl_ctx_:" << nccl_ctxs_;
104 105 106 107 108 109 110 111 112 113 114 115 116

    if (nccl_ctxs_ == nullptr) {
      return;
    }

    if (!use_hierarchical_allreduce_) {
      auto ctxs = nccl_ctxs_->GetFlatCtx(run_order);
      for (auto& p : places_) {
        this->SetDeviceContext(p, ctxs->DevCtx(p));
      }
      return;
    }

117 118
    PADDLE_ENFORCE_EQ(places_.size(),
                      1,
119 120 121 122
                      platform::errors::InvalidArgument(
                          "HierarchicalAllReduce can only run "
                          "one proccess with one card mode, but got %d cards.",
                          places_.size()));
123 124 125 126 127 128 129

    for (auto& p : places_) {
      auto ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order);
      this->SetDeviceContext(p, ctxs->DevCtx(p));
    }

    for (auto& p : dev_ctxes_) {
130
      int dev_id = p.first.device;
131 132 133 134
      if (inter_events_.find(dev_id) != inter_events_.end()) {
        continue;
      }

L
Leo Chen 已提交
135
      platform::SetDeviceId(dev_id);
136
#ifdef PADDLE_WITH_HIP
137
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(
138
          &inter_events_[dev_id], hipEventDisableTiming));
139
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(
140 141
          &exter_events_[dev_id], hipEventDisableTiming));
#else
142
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(
143
          &inter_events_[dev_id], cudaEventDisableTiming));
144
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(
145
          &exter_events_[dev_id], cudaEventDisableTiming));
146
#endif
147 148 149 150 151 152
      VLOG(10) << "Create events on dev_id:" << dev_id
               << ", inter_event:" << &inter_events_[dev_id]
               << ", exter_event:" << &exter_events_[dev_id];
    }
  }

153 154 155 156 157
  void FlatNCCLAllReduce(platform::Place place,
                         const void* sendbuff,
                         void* recvbuff,
                         size_t count,
                         ncclDataType_t datatype,
158
                         ncclRedOp_t op) {
159
    PADDLE_ENFORCE_GE(
160 161
        run_order_,
        0,
162 163
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
164
    auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
165
    int dev_id = place.device;
166 167 168 169 170 171 172 173
    auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce buffer:" << sendbuff << ", numel:" << count
             << ", dev_id:" << dev_id << ", dtype:" << datatype
             << ", place:" << place;

174
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
175 176 177
        sendbuff, recvbuff, count, datatype, op, comm, stream));
  }

178 179 180 181 182
  void NCCLAllReduce(platform::Place place,
                     const void* sendbuff,
                     void* recvbuff,
                     size_t count,
                     ncclDataType_t datatype,
183
                     ncclRedOp_t op) {
184
    PADDLE_ENFORCE_GE(
185 186
        run_order_,
        0,
187 188
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
189 190 191 192 193 194 195 196
    if (!use_hierarchical_allreduce_) {
      FlatNCCLAllReduce(place, sendbuff, recvbuff, count, datatype, op);
      return;
    }

    HierarchicalAllReduce(place, sendbuff, recvbuff, count, datatype, op);
  }

197 198 199 200 201 202
  void HierarchicalAllReduce(platform::Place place,
                             const void* sendbuff,
                             void* recvbuff,
                             size_t count,
                             ncclDataType_t datatype,
                             ncclRedOp_t op) {
203
    PADDLE_ENFORCE_GE(
204 205
        run_order_,
        0,
206 207
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
208 209 210 211 212 213 214 215 216 217
    InterReduce(place, sendbuff, recvbuff, count, datatype, op);
    // When a trainer is not in exter allreduce ring
    // they need not to call this.
    if (nccl_ctxs_->NeedExterAllReduce()) {
      ExterAllReduce(place, recvbuff, recvbuff, count, datatype, op);
    }
    InterBroadCast(place, recvbuff, count, datatype, op);
  }

 protected:
218 219 220 221 222 223
  void InterReduce(platform::Place place,
                   const void* sendbuff,
                   void* recvbuff,
                   size_t count,
                   ncclDataType_t datatype,
                   ncclRedOp_t op) {
224
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order_);
225
    int dev_id = place.device;
226 227 228 229 230 231 232 233 234 235
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce"
             << " run_order:" << run_order_ << ", buffer:" << sendbuff
             << ", numel:" << count << ", dev_id:" << dev_id
             << ", dtype:" << datatype << ", place:" << place
             << ", stream:" << stream;

236
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
237 238
        sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream));

239 240 241
#ifdef PADDLE_WITH_HIP
    hipEventRecord(inter_events_.at(dev_id), stream);
#else
242
    cudaEventRecord(inter_events_.at(dev_id), stream);
243
#endif
244 245

    if (FLAGS_sync_nccl_allreduce) {
246
      platform::GpuStreamSync(stream);
247 248 249
    }
  }

250 251 252 253 254
  void ExterAllReduce(platform::Place place,
                      const void* sendbuff,
                      void* recvbuff,
                      size_t count,
                      ncclDataType_t datatype,
255 256
                      ncclRedOp_t op) {
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalExterCtx(run_order_);
257
    PADDLE_ENFORCE_NOT_NULL(
258 259 260
        nccl_ctxs_,
        platform::errors::NotFound("Can't get exter %d nccl contexts.",
                                   run_order_));
261
    int dev_id = place.device;
262 263 264 265 266 267 268 269 270
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce run_order:" << run_order_
             << "buffer:" << sendbuff << ", numel:" << count
             << ", dev_id:" << dev_id << ", dtype:" << datatype
             << ", place:" << place << ", stream:" << stream;

271 272 273
#ifdef PADDLE_WITH_HIP
    hipStreamWaitEvent(stream, inter_events_.at(dev_id), 0);

274
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
275 276 277 278
        sendbuff, recvbuff, count, datatype, op, comm, stream));

    hipEventRecord(exter_events_.at(dev_id), stream);
#else
279 280
    cudaStreamWaitEvent(stream, inter_events_.at(dev_id), 0);

281
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
282 283 284
        sendbuff, recvbuff, count, datatype, op, comm, stream));

    cudaEventRecord(exter_events_.at(dev_id), stream);
285
#endif
286
    if (FLAGS_sync_nccl_allreduce) {
287
      platform::GpuStreamSync(stream);
288 289 290
    }
  }

291 292 293 294 295
  void InterBroadCast(platform::Place place,
                      void* sendbuff,
                      size_t count,
                      ncclDataType_t datatype,
                      ncclRedOp_t op) {
296
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order_);
297
    int dev_id = place.device;
298 299 300 301 302 303 304 305
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before InterBroadCast buffer:" << sendbuff
             << ", numel:" << count << ", dev_id:" << dev_id
             << ", dtype:" << datatype << ", place:" << place
             << ", stream:" << stream;
306 307 308
#ifdef PADDLE_WITH_HIP
    hipStreamWaitEvent(stream, exter_events_.at(dev_id), 0);
#else
309
    cudaStreamWaitEvent(stream, exter_events_.at(dev_id), 0);
310
#endif
311
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
312
        sendbuff, count, datatype, 0, comm, stream));
313 314 315 316
  }

 protected:
  std::vector<platform::Place> places_;
317
  const platform::NCCLCommunicator* nccl_ctxs_{nullptr};
318 319 320 321 322 323 324 325 326
  // When multi trainer call collective function, they need run the same order.
  // Or the program will hang.So we use allreduce_deps_pass to set this
  // run_order_.
  int run_order_{0};
  // Use 2d allreduce or not.
  bool use_hierarchical_allreduce_{false};

 private:
  // hierarchical needed events
327 328
  std::unordered_map<int, gpuEvent_t> inter_events_;
  std::unordered_map<int, gpuEvent_t> exter_events_;
329 330 331 332 333
};

}  // namespace details
}  // namespace framework
}  // namespace paddle