nccl_op_handle.h 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/nccl_helper.h"

DECLARE_bool(sync_nccl_allreduce);

namespace paddle {
namespace framework {
namespace details {

class NCCLOpHandleBase : public OpHandleBase {
 public:
  NCCLOpHandleBase(ir::Node* node, const std::vector<platform::Place>& places,
36
                   const platform::NCCLCommunicator* nccl_ctxs)
37 38 39 40 41 42 43 44 45 46 47 48
      : OpHandleBase(node), places_(places), nccl_ctxs_(nccl_ctxs) {
    if (nccl_ctxs == nullptr) {
      return;
    }
    // init device context
    auto default_nccl_ctxs = nccl_ctxs_->DefaultFlatCtx();
    for (auto& p : places_) {
      this->SetDeviceContext(p, default_nccl_ctxs->DevCtx(p));
    }
  }
  virtual ~NCCLOpHandleBase() {
    for (auto& ev : inter_events_) {
49
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second));
50 51
    }
    for (auto& ev : exter_events_) {
52
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second));
53 54 55
    }
  }
  void SetRunEnv(int run_order, bool use_hierarchical_allreduce) {
56 57 58 59
    PADDLE_ENFORCE_GE(
        run_order, 0,
        platform::errors::InvalidArgument(
            "The argument run_order must be >= 0, but got %d.", run_order));
60 61 62 63 64
    run_order_ = run_order;
    use_hierarchical_allreduce_ = use_hierarchical_allreduce;

    VLOG(10) << "SetRunEnv "
             << " run_order:" << run_order
65 66
             << ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
             << ", nccl_ctx_:" << nccl_ctxs_;
67 68 69 70 71 72 73 74 75 76 77 78 79

    if (nccl_ctxs_ == nullptr) {
      return;
    }

    if (!use_hierarchical_allreduce_) {
      auto ctxs = nccl_ctxs_->GetFlatCtx(run_order);
      for (auto& p : places_) {
        this->SetDeviceContext(p, ctxs->DevCtx(p));
      }
      return;
    }

80 81 82 83 84
    PADDLE_ENFORCE_EQ(places_.size(), 1,
                      platform::errors::InvalidArgument(
                          "HierarchicalAllReduce can only run "
                          "one proccess with one card mode, but got %d cards.",
                          places_.size()));
85 86 87 88 89 90 91

    for (auto& p : places_) {
      auto ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order);
      this->SetDeviceContext(p, ctxs->DevCtx(p));
    }

    for (auto& p : dev_ctxes_) {
92
      int dev_id = BOOST_GET_CONST(platform::CUDAPlace, p.first).device;
93 94 95 96
      if (inter_events_.find(dev_id) != inter_events_.end()) {
        continue;
      }

97 98 99 100 101
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags(
          &inter_events_[dev_id], cudaEventDisableTiming));
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags(
          &exter_events_[dev_id], cudaEventDisableTiming));
102 103 104 105 106 107 108 109 110
      VLOG(10) << "Create events on dev_id:" << dev_id
               << ", inter_event:" << &inter_events_[dev_id]
               << ", exter_event:" << &exter_events_[dev_id];
    }
  }

  void FlatNCCLAllReduce(platform::Place place, const void* sendbuff,
                         void* recvbuff, size_t count, ncclDataType_t datatype,
                         ncclRedOp_t op) {
111 112 113 114
    PADDLE_ENFORCE_GE(
        run_order_, 0,
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
115
    auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
116
    int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
117 118 119 120 121 122 123 124
    auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce buffer:" << sendbuff << ", numel:" << count
             << ", dev_id:" << dev_id << ", dtype:" << datatype
             << ", place:" << place;

125
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
126 127 128 129 130 131
        sendbuff, recvbuff, count, datatype, op, comm, stream));
  }

  void NCCLAllReduce(platform::Place place, const void* sendbuff,
                     void* recvbuff, size_t count, ncclDataType_t datatype,
                     ncclRedOp_t op) {
132 133 134 135
    PADDLE_ENFORCE_GE(
        run_order_, 0,
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
136 137 138 139 140 141 142 143 144 145 146
    if (!use_hierarchical_allreduce_) {
      FlatNCCLAllReduce(place, sendbuff, recvbuff, count, datatype, op);
      return;
    }

    HierarchicalAllReduce(place, sendbuff, recvbuff, count, datatype, op);
  }

  void HierarchicalAllReduce(platform::Place place, const void* sendbuff,
                             void* recvbuff, size_t count,
                             ncclDataType_t datatype, ncclRedOp_t op) {
147 148 149 150
    PADDLE_ENFORCE_GE(
        run_order_, 0,
        platform::errors::InvalidArgument(
            "The argument run_order_ must be >= 0, but got %d.", run_order_));
151 152 153 154 155 156 157 158 159 160 161 162 163
    InterReduce(place, sendbuff, recvbuff, count, datatype, op);
    // When a trainer is not in exter allreduce ring
    // they need not to call this.
    if (nccl_ctxs_->NeedExterAllReduce()) {
      ExterAllReduce(place, recvbuff, recvbuff, count, datatype, op);
    }
    InterBroadCast(place, recvbuff, count, datatype, op);
  }

 protected:
  void InterReduce(platform::Place place, const void* sendbuff, void* recvbuff,
                   size_t count, ncclDataType_t datatype, ncclRedOp_t op) {
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order_);
164
    int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
165 166 167 168 169 170 171 172 173 174
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce"
             << " run_order:" << run_order_ << ", buffer:" << sendbuff
             << ", numel:" << count << ", dev_id:" << dev_id
             << ", dtype:" << datatype << ", place:" << place
             << ", stream:" << stream;

175
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclReduce(
176 177 178 179 180
        sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream));

    cudaEventRecord(inter_events_.at(dev_id), stream);

    if (FLAGS_sync_nccl_allreduce) {
181
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
182 183 184 185 186 187 188
    }
  }

  void ExterAllReduce(platform::Place place, const void* sendbuff,
                      void* recvbuff, size_t count, ncclDataType_t datatype,
                      ncclRedOp_t op) {
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalExterCtx(run_order_);
189 190 191
    PADDLE_ENFORCE_NOT_NULL(
        nccl_ctxs_, platform::errors::NotFound(
                        "Can't get exter %d nccl contexts.", run_order_));
192
    int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
193 194 195 196 197 198 199 200 201 202 203
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before all reduce run_order:" << run_order_
             << "buffer:" << sendbuff << ", numel:" << count
             << ", dev_id:" << dev_id << ", dtype:" << datatype
             << ", place:" << place << ", stream:" << stream;

    cudaStreamWaitEvent(stream, inter_events_.at(dev_id), 0);

204
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
205 206 207 208 209
        sendbuff, recvbuff, count, datatype, op, comm, stream));

    cudaEventRecord(exter_events_.at(dev_id), stream);

    if (FLAGS_sync_nccl_allreduce) {
210
      PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
211 212 213 214 215 216
    }
  }

  void InterBroadCast(platform::Place place, void* sendbuff, size_t count,
                      ncclDataType_t datatype, ncclRedOp_t op) {
    auto nccl_ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order_);
217
    int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
218 219 220 221 222 223 224 225 226 227
    auto& nccl_ctx = nccl_ctxs->at(dev_id);
    auto stream = nccl_ctx.stream();
    auto comm = nccl_ctx.comm_;

    VLOG(10) << "before InterBroadCast buffer:" << sendbuff
             << ", numel:" << count << ", dev_id:" << dev_id
             << ", dtype:" << datatype << ", place:" << place
             << ", stream:" << stream;

    cudaStreamWaitEvent(stream, exter_events_.at(dev_id), 0);
228 229
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
        sendbuff, count, datatype, 0, comm, stream));
230 231 232 233
  }

 protected:
  std::vector<platform::Place> places_;
234
  const platform::NCCLCommunicator* nccl_ctxs_{nullptr};
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
  // When multi trainer call collective function, they need run the same order.
  // Or the program will hang.So we use allreduce_deps_pass to set this
  // run_order_.
  int run_order_{0};
  // Use 2d allreduce or not.
  bool use_hierarchical_allreduce_{false};

 private:
  // hierarchical needed events
  std::unordered_map<int, cudaEvent_t> inter_events_;
  std::unordered_map<int, cudaEvent_t> exter_events_;
};

}  // namespace details
}  // namespace framework
}  // namespace paddle