nccl_op_handle.h 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
24
#ifdef PADDLE_WITH_CUDA
25
#include "paddle/fluid/platform/dynload/nccl.h"
26 27 28 29
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
30
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
31 32 33 34 35 36 37 38 39

DECLARE_bool(sync_nccl_allreduce);

namespace paddle {
namespace framework {
namespace details {

class NCCLOpHandleBase : public OpHandleBase {
 public:
40 41
  NCCLOpHandleBase(ir::Node* node,
                   const std::vector<platform::Place>& places,
42
                   const platform::NCCLCommunicator* nccl_ctxs)
43 44 45 46 47 48 49 50 51 52 53 54
      : OpHandleBase(node), places_(places), nccl_ctxs_(nccl_ctxs) {
    if (nccl_ctxs == nullptr) {
      return;
    }
    // init device context
    auto default_nccl_ctxs = nccl_ctxs_->DefaultFlatCtx();
    for (auto& p : places_) {
      this->SetDeviceContext(p, default_nccl_ctxs->DevCtx(p));
    }
  }
  virtual ~NCCLOpHandleBase() {
    for (auto& ev : inter_events_) {
55
#ifdef PADDLE_WITH_HIP
56
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(ev.second));
57
#else
58
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(ev.second));
59
#endif
60 61
    }
    for (auto& ev : exter_events_) {
62
#ifdef PADDLE_WITH_HIP
63
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventDestroy(ev.second));
64
#else
65
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(ev.second));
66
#endif
67 68
    }
  }
69 70 71 72 73

  const platform::NCCLCommunicator* GetNcclContext() const {
    return nccl_ctxs_;
  }

74
  ncclComm_t GetComm() const {
75 76 77 78 79 80 81 82 83
    PADDLE_ENFORCE_EQ(
        places_.size(),
        1,
        platform::errors::Unimplemented(
            "Only supported for single place now, but got %d", places_.size()));
    PADDLE_ENFORCE_EQ(use_hierarchical_allreduce_,
                      0,
                      platform::errors::Unimplemented(
                          "Not supported use_hierarchical_allreduce_ now"));
84 85 86 87
    PADDLE_ENFORCE_NOT_NULL(
        nccl_ctxs_,
        platform::errors::NotFound("Can't get flat %d nccl contexts.",
                                   run_order_));
88 89 90 91 92 93 94
    auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
    int dev_id = places_[0].device;
    auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
    auto comm = nccl_ctx.comm_;
    return comm;
  }

95
  void SetRunEnv(int run_order, bool use_hierarchical_allreduce) {
96
    PADDLE_ENFORCE_GE(
97 98
        run_order,
        0,
99 100
        platform::errors::InvalidArgument(
            "The argument run_order must be >= 0, but got %d.", run_order));
101 102 103 104 105
    run_order_ = run_order;
    use_hierarchical_allreduce_ = use_hierarchical_allreduce;

    VLOG(10) << "SetRunEnv "
             << " run_order:" << run_order
106 107
             << ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
             << ", nccl_ctx_:" << nccl_ctxs_;
108 109 110 111 112 113 114 115 116 117 118 119 120

    if (nccl_ctxs_ == nullptr) {
      return;
    }

    if (!use_hierarchical_allreduce_) {
      auto ctxs = nccl_ctxs_->GetFlatCtx(run_order);
      for (auto& p : places_) {
        this->SetDeviceContext(p, ctxs->DevCtx(p));
      }
      return;
    }

121 122
    PADDLE_ENFORCE_EQ(places_.size(),
                      1,
123 124 125 126
                      platform::errors::InvalidArgument(
                          "HierarchicalAllReduce can only run "
                          "one proccess with one card mode, but got %d cards.",
                          places_.size()));
127 128 129 130 131 132 133

    for (auto& p : places_) {
      auto ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order);
      this->SetDeviceContext(p, ctxs->DevCtx(p));
    }

    for (auto& p : dev_ctxes_) {
134
      int dev_id = p.first.device;
135 136 137 138
      if (inter_events_.find(dev_id) != inter_events_.end()) {
        continue;
      }

L
Leo Chen 已提交
139
      platform::SetDeviceId(dev_id);
140
#ifdef PADDLE_WITH_HIP
141
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(
142
          &inter_events_[dev_id], hipEventDisableTiming));
143
      PADDLE_ENFORCE_GPU_SUCCESS(hipEventCreateWithFlags(
144 145
          &exter_events_[dev_id], hipEventDisableTiming));
#else
146
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(
147
          &inter_events_[dev_id], cudaEventDisableTiming));
148
      PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreateWithFlags(
149
          &exter_events_[dev_id], cudaEventDisableTiming));
150
#endif
151 152 153 154 155 156
      VLOG(10) << "Create events on dev_id:" << dev_id
               << ", inter_event:" << &inter_events_[dev_id]
               << ", exter_event:" << &exter_events_[dev_id];
    }
  }

157 158 159 160 161
  void FlatNCCLAllReduce(platform::Place place,
                         const void* sendbuff,
                         void* recvbuff,
                         size_t count,
                         ncclDataType_t datatype,
162
                         ncclRedOp_t op) {
163
    PADDLE_ENFORCE_GE(
164 165
        run_order_,
        0,
166 167
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
168
    auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
169
    int dev_id = place.device;
170 171 172 173 174 175 176 177
    auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce buffer:" << sendbuff << ", numel:" << count
             << ", dev_id:" << dev_id << ", dtype:" << datatype
             << ", place:" << place;

178
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
179 180 181
        sendbuff, recvbuff, count, datatype, op, comm, stream));
  }

182 183 184 185 186
  void NCCLAllReduce(platform::Place place,
                     const void* sendbuff,
                     void* recvbuff,
                     size_t count,
                     ncclDataType_t datatype,
187
                     ncclRedOp_t op) {
188
    PADDLE_ENFORCE_GE(
189 190
        run_order_,
        0,
191 192
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
193 194 195 196 197 198 199 200
    if (!use_hierarchical_allreduce_) {
      FlatNCCLAllReduce(place, sendbuff, recvbuff, count, datatype, op);
      return;
    }

    HierarchicalAllReduce(place, sendbuff, recvbuff, count, datatype, op);
  }

201 202 203 204 205 206
  void HierarchicalAllReduce(platform::Place place,
                             const void* sendbuff,
                             void* recvbuff,
                             size_t count,
                             ncclDataType_t datatype,
                             ncclRedOp_t op) {
207
    PADDLE_ENFORCE_GE(
208 209
        run_order_,
        0,
210 211
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
212 213 214 215 216 217 218 219 220 221
    InterReduce(place, sendbuff, recvbuff, count, datatype, op);
    // When a trainer is not in exter allreduce ring
    // they need not to call this.
    if (nccl_ctxs_->NeedExterAllReduce()) {
      ExterAllReduce(place, recvbuff, recvbuff, count, datatype, op);
    }
    InterBroadCast(place, recvbuff, count, datatype, op);
  }

 protected:
222 223 224 225 226 227
  void InterReduce(platform::Place place,
                   const void* sendbuff,
                   void* recvbuff,
                   size_t count,
                   ncclDataType_t datatype,
                   ncclRedOp_t op) {
228
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order_);
229
    int dev_id = place.device;
230 231 232 233 234 235 236 237 238 239
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce"
             << " run_order:" << run_order_ << ", buffer:" << sendbuff
             << ", numel:" << count << ", dev_id:" << dev_id
             << ", dtype:" << datatype << ", place:" << place
             << ", stream:" << stream;

240
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(
241 242
        sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream));

243 244 245
#ifdef PADDLE_WITH_HIP
    hipEventRecord(inter_events_.at(dev_id), stream);
#else
246
    cudaEventRecord(inter_events_.at(dev_id), stream);
247
#endif
248 249

    if (FLAGS_sync_nccl_allreduce) {
250
      platform::GpuStreamSync(stream);
251 252 253
    }
  }

254 255 256 257 258
  void ExterAllReduce(platform::Place place,
                      const void* sendbuff,
                      void* recvbuff,
                      size_t count,
                      ncclDataType_t datatype,
259 260
                      ncclRedOp_t op) {
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalExterCtx(run_order_);
261
    PADDLE_ENFORCE_NOT_NULL(
262 263 264
        nccl_ctxs_,
        platform::errors::NotFound("Can't get exter %d nccl contexts.",
                                   run_order_));
265
    int dev_id = place.device;
266 267 268 269 270 271 272 273 274
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce run_order:" << run_order_
             << "buffer:" << sendbuff << ", numel:" << count
             << ", dev_id:" << dev_id << ", dtype:" << datatype
             << ", place:" << place << ", stream:" << stream;

275 276 277
#ifdef PADDLE_WITH_HIP
    hipStreamWaitEvent(stream, inter_events_.at(dev_id), 0);

278
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
279 280 281 282
        sendbuff, recvbuff, count, datatype, op, comm, stream));

    hipEventRecord(exter_events_.at(dev_id), stream);
#else
283 284
    cudaStreamWaitEvent(stream, inter_events_.at(dev_id), 0);

285
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
286 287 288
        sendbuff, recvbuff, count, datatype, op, comm, stream));

    cudaEventRecord(exter_events_.at(dev_id), stream);
289
#endif
290
    if (FLAGS_sync_nccl_allreduce) {
291
      platform::GpuStreamSync(stream);
292 293 294
    }
  }

295 296 297 298 299
  void InterBroadCast(platform::Place place,
                      void* sendbuff,
                      size_t count,
                      ncclDataType_t datatype,
                      ncclRedOp_t op) {
300
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order_);
301
    int dev_id = place.device;
302 303 304 305 306 307 308 309
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before InterBroadCast buffer:" << sendbuff
             << ", numel:" << count << ", dev_id:" << dev_id
             << ", dtype:" << datatype << ", place:" << place
             << ", stream:" << stream;
310 311 312
#ifdef PADDLE_WITH_HIP
    hipStreamWaitEvent(stream, exter_events_.at(dev_id), 0);
#else
313
    cudaStreamWaitEvent(stream, exter_events_.at(dev_id), 0);
314
#endif
315
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
316
        sendbuff, count, datatype, 0, comm, stream));
317 318 319 320
  }

 protected:
  std::vector<platform::Place> places_;
321
  const platform::NCCLCommunicator* nccl_ctxs_{nullptr};
322 323 324 325 326 327 328 329 330
  // When multi trainer call collective function, they need run the same order.
  // Or the program will hang.So we use allreduce_deps_pass to set this
  // run_order_.
  int run_order_{0};
  // Use 2d allreduce or not.
  bool use_hierarchical_allreduce_{false};

 private:
  // hierarchical needed events
331 332
  std::unordered_map<int, gpuEvent_t> inter_events_;
  std::unordered_map<int, gpuEvent_t> exter_events_;
333 334 335 336 337
};

}  // namespace details
}  // namespace framework
}  // namespace paddle