process_group_nccl.h 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <chrono>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

W
Wen Sun 已提交
23
#include "paddle/fluid/distributed/collective/process_group_stream.h"
24
#include "paddle/fluid/distributed/store/store.h"
25
#include "paddle/fluid/platform/cuda_device_guard.h"
26
#include "paddle/fluid/platform/device_event.h"
27
#include "paddle/fluid/platform/enforce.h"
28 29
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
30

S
ShenLiang 已提交
31
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
W
Wen Sun 已提交
32
#include "paddle/fluid/distributed/collective/nccl_tools.h"
S
ShenLiang 已提交
33 34 35
#endif

#ifdef PADDLE_WITH_RCCL
36
#include "paddle/phi/backends/dynload/rccl.h"
37
#elif PADDLE_WITH_NCCL
38
#include "paddle/phi/backends/dynload/nccl.h"
39 40 41 42 43 44 45
#endif

namespace paddle {
namespace distributed {

using Place = paddle::platform::Place;

46
class ProcessGroupNCCL final : public ProcessGroupStream {
47
 public:
48 49
  class NCCLTask final : public ProcessGroupStream::TaskStream,
                         public std::enable_shared_from_this<NCCLTask> {
50
   public:
51 52 53 54 55 56 57 58 59 60 61 62
    NCCLTask(const Place& place,
             int rank,
             CommType comm_type,
             bool sync_op,
             bool use_calc_stream);
    virtual ~NCCLTask();

    bool IsCompleted() override;
    bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
    void Synchronize() override;
    void UpdateWaitChain(const phi::DeviceContext& ctx) override;

W
Wen Sun 已提交
63 64 65
    bool IsBlockCPUInWait() const { return block_cpu_in_wait_; }
    void SetBlockCPUInWait() { block_cpu_in_wait_ = true; }

66
    // TODO(sunyilun): methods below will be removed later
67 68 69
    NCCLTask(const std::vector<Place>& places,
             int rank,
             CommType CommType,
70
             const std::vector<phi::DenseTensor>& inputs);
71

72
   private:
W
Wen Sun 已提交
73 74 75
    bool block_cpu_in_wait_{false};
    platform::DeviceEvent comm_event_;  // event on comm stream
    Place task_place_;
76 77
  };

78
 public:
L
LiYuRio 已提交
79 80 81
  static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
      const std::shared_ptr<Store>& store, int rank, int size, int gid);

82 83 84 85
  ProcessGroupNCCL(const std::shared_ptr<Store>& store,
                   int rank,
                   int size,
                   int gid);
86

L
LiYuRio 已提交
87
  std::string GetBackendName() const override { return "NCCL"; }
88

89 90
  phi::DeviceContext* GetDeviceContext(const Place& place,
                                       bool use_calc_stream) const override;
L
LiYuRio 已提交
91

92
  phi::DeviceContext* GetDeviceContext(const Place& place) const override;
93

94 95 96
  std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
97 98
      int64_t offset,
      int64_t numel,
99 100 101
      bool sync_op,
      bool use_calc_stream) override;

102
  std::shared_ptr<ProcessGroup::Task> AllReduce(
103 104 105 106 107 108
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const AllreduceOptions& opts,
      bool sync_op,
      bool use_calc_stream) override;

109 110 111 112 113 114 115 116
  std::shared_ptr<ProcessGroup::Task> AllToAll(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const std::vector<int64_t>& out_size_each_rank,
      const std::vector<int64_t>& in_size_each_rank,
      bool sync_op,
      bool use_calc_stream) override;

117 118 119 120 121 122 123
  std::shared_ptr<ProcessGroup::Task> Barrier(
      const BarrierOptions& = BarrierOptions()) override;

  std::shared_ptr<ProcessGroup::Task> Broadcast(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const BroadcastOptions& opts,
124 125 126
      bool sync_op,
      bool use_calc_stream) override;

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
  std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
                                             const phi::DenseTensor& in_tensor,
                                             const ReduceOptions& opts,
                                             bool sync_op,
                                             bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> ReduceScatter(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ReduceScatterOptions& opts,
      bool sync_op,
      bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
                                              const phi::DenseTensor& in_tensor,
                                              const ScatterOptions& opts,
                                              bool sync_op,
                                              bool use_calc_stream) override;

146 147
  std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
                                           int src_rank,
148 149
                                           int64_t offset,
                                           int64_t numel,
150 151 152
                                           bool sync_op,
                                           bool use_calc_stream) override;

153
  std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
154
                                           int dst_rank,
155 156
                                           int64_t offset,
                                           int64_t numel,
157 158 159
                                           bool sync_op,
                                           bool use_calc_stream) override;

160 161 162 163 164 165
  static void GroupStart();

  static void GroupEnd();

  ncclComm_t NCCLComm(const Place& place) const;

166
  // TODO(liyurui): This API will be moved later
167
  std::shared_ptr<ProcessGroup::Task> AllReduce(
168 169
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
170 171
      const AllreduceOptions& = AllreduceOptions()) override;

172
  // TODO(sunyilun): methods below will be removed later
173
  std::shared_ptr<ProcessGroup::Task> Broadcast(
174 175
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
176 177
      const BroadcastOptions& = BroadcastOptions()) override;

178 179
  std::shared_ptr<ProcessGroup::Task> Send(
      std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
B
Baibaifan 已提交
180

181 182
  std::shared_ptr<ProcessGroup::Task> Recv(
      std::vector<phi::DenseTensor>& tensors, int src_rank) override;
B
Baibaifan 已提交
183

184
  std::shared_ptr<ProcessGroup::Task> AllGather(
185 186
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors) override;
187 188

  std::shared_ptr<ProcessGroup::Task> AllToAll(
189 190 191
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors) override;

192
  std::shared_ptr<ProcessGroup::Task> Reduce(
193 194 195
      std::vector<phi::DenseTensor>& tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      const ReduceOptions& opts) override;
196

197 198 199
  std::shared_ptr<ProcessGroup::Task> Scatter(
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
200 201
      const ScatterOptions& opts) override;

202 203 204 205 206 207
 private:
  std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
                                                         int rank,
                                                         CommType op_type,
                                                         bool sync_op,
                                                         bool use_calc_stream);
208

209
  void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id);
210

211 212
  void CreateNCCLEnvCache(const Place& place, const std::string& place_key);

213 214
  void SyncCalcStream(const Place& place);

215 216 217
  std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
      std::function<void(ncclComm_t, gpuStream_t)> fn,
      const phi::DenseTensor& tensor,
218 219 220
      CommType comm_type,
      bool sync_op,
      bool use_calc_stream);
L
LiYuRio 已提交
221

222 223
  // TODO(sunyilun): methods below will be removed later
  std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
224 225
      std::vector<Place> places,
      int rank,
226
      CommType op_type,
227
      const std::vector<phi::DenseTensor>& inputs);
228 229 230

  template <typename Fn>
  std::shared_ptr<ProcessGroup::Task> Collective(
231 232
      std::vector<phi::DenseTensor>& inputs,   // NOLINT
      std::vector<phi::DenseTensor>& outputs,  // NOLINT
233 234
      Fn fn,
      CommType op_type);
235

B
Baibaifan 已提交
236 237
  template <typename Fn>
  std::shared_ptr<ProcessGroup::Task> PointToPoint(
238
      std::vector<phi::DenseTensor>& tensors,  // NOLINT
239 240 241
      Fn fn,
      int dst_rank,
      CommType op_type);
B
Baibaifan 已提交
242

243 244
  void CreateNCCLManagerCache(const std::string& places_key,
                              const std::vector<Place>& places);
245

246 247
 private:
  std::shared_ptr<Store> store_;
248

W
Wen Sun 已提交
249 250
  std::unordered_map<std::string, platform::DeviceEvent>
      place_to_calc_event_;  // event on calc stream
251 252 253 254 255 256 257
  std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
  std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
      place_to_comm_ctx_;

  // TODO(sunyilun): attrs below will be removed later
  std::mutex mutex_;
  std::unordered_map<std::string, std::vector<phi::GPUContext*>> places_to_ctx_;
258 259 260 261
};

}  //  namespace distributed
}  //  namespace paddle