process_group_nccl.h 9.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <chrono>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

23 24
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
25
#include "paddle/fluid/platform/device_event.h"
26
#include "paddle/phi/backends/gpu/forwards.h"
27 28
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
29
#include "paddle/phi/core/distributed/store/store.h"
30 31 32 33

namespace paddle {
namespace distributed {

34
using Place = phi::Place;
35

36
class ProcessGroupNCCL final : public ProcessGroupWithStream {
37
 public:
38
  class NCCLTask final : public ProcessGroupWithStream::TaskStream,
39
                         public std::enable_shared_from_this<NCCLTask> {
40
   public:
41 42 43 44 45 46 47 48 49 50 51 52
    NCCLTask(const Place& place,
             int rank,
             CommType comm_type,
             bool sync_op,
             bool use_calc_stream);
    virtual ~NCCLTask();

    bool IsCompleted() override;
    bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
    void Synchronize() override;
    void UpdateWaitChain(const phi::DeviceContext& ctx) override;

W
Wen Sun 已提交
53 54 55
    bool IsBlockCPUInWait() const { return block_cpu_in_wait_; }
    void SetBlockCPUInWait() { block_cpu_in_wait_ = true; }

56
    // TODO(sunyilun): methods below will be removed later
57 58 59
    NCCLTask(const std::vector<Place>& places,
             int rank,
             CommType CommType,
60
             const std::vector<phi::DenseTensor>& inputs);
61

62
   private:
W
Wen Sun 已提交
63 64 65
    bool block_cpu_in_wait_{false};
    platform::DeviceEvent comm_event_;  // event on comm stream
    Place task_place_;
66 67
  };

68
 public:
L
LiYuRio 已提交
69
  static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
70 71 72 73
      const std::shared_ptr<phi::distributed::Store>& store,
      int rank,
      int size,
      int gid);
L
LiYuRio 已提交
74

75
  ProcessGroupNCCL(const std::shared_ptr<phi::distributed::Store>& store,
76 77 78
                   int rank,
                   int size,
                   int gid);
79

L
LiYuRio 已提交
80
  std::string GetBackendName() const override { return "NCCL"; }
81

82 83
  phi::DeviceContext* GetDeviceContext(const Place& place) const override;

84 85
  phi::DeviceContext* GetDeviceContext(const Place& place,
                                       bool use_calc_stream) const override;
L
LiYuRio 已提交
86

87 88 89
  std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
90 91
      int64_t offset,
      int64_t numel,
92 93 94
      bool sync_op,
      bool use_calc_stream) override;

95
  std::shared_ptr<ProcessGroup::Task> AllReduce(
96 97 98 99 100 101
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const AllreduceOptions& opts,
      bool sync_op,
      bool use_calc_stream) override;

102 103 104 105 106 107 108 109
  std::shared_ptr<ProcessGroup::Task> AllToAll(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const std::vector<int64_t>& out_size_each_rank,
      const std::vector<int64_t>& in_size_each_rank,
      bool sync_op,
      bool use_calc_stream) override;

110 111 112 113 114 115 116
  std::shared_ptr<ProcessGroup::Task> Barrier(
      const BarrierOptions& = BarrierOptions()) override;

  std::shared_ptr<ProcessGroup::Task> Broadcast(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const BroadcastOptions& opts,
117 118 119
      bool sync_op,
      bool use_calc_stream) override;

120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
  std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
                                             const phi::DenseTensor& in_tensor,
                                             const ReduceOptions& opts,
                                             bool sync_op,
                                             bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> ReduceScatter(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ReduceScatterOptions& opts,
      bool sync_op,
      bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
                                              const phi::DenseTensor& in_tensor,
                                              const ScatterOptions& opts,
                                              bool sync_op,
                                              bool use_calc_stream) override;

139 140
  std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
                                           int src_rank,
141 142
                                           int64_t offset,
                                           int64_t numel,
143 144 145
                                           bool sync_op,
                                           bool use_calc_stream) override;

146
  std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
147
                                           int dst_rank,
148 149
                                           int64_t offset,
                                           int64_t numel,
150 151 152
                                           bool sync_op,
                                           bool use_calc_stream) override;

153 154 155 156 157 158
  static void GroupStart();

  static void GroupEnd();

  ncclComm_t NCCLComm(const Place& place) const;

159
  // TODO(liyurui): This API will be moved later
160
  std::shared_ptr<ProcessGroup::Task> AllReduce(
161 162
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
163 164
      const AllreduceOptions& = AllreduceOptions()) override;

165
  // TODO(sunyilun): methods below will be removed later
166
  std::shared_ptr<ProcessGroup::Task> Broadcast(
167 168
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
169 170
      const BroadcastOptions& = BroadcastOptions()) override;

171 172
  std::shared_ptr<ProcessGroup::Task> Send(
      std::vector<phi::DenseTensor>& tensors, int dst_rank) override;
B
Baibaifan 已提交
173

174 175
  std::shared_ptr<ProcessGroup::Task> Recv(
      std::vector<phi::DenseTensor>& tensors, int src_rank) override;
B
Baibaifan 已提交
176

177
  std::shared_ptr<ProcessGroup::Task> AllGather(
178 179
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors) override;
180 181

  std::shared_ptr<ProcessGroup::Task> AllToAll(
182 183 184
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors) override;

185
  std::shared_ptr<ProcessGroup::Task> Reduce(
186 187 188
      std::vector<phi::DenseTensor>& tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      const ReduceOptions& opts) override;
189

190 191 192
  std::shared_ptr<ProcessGroup::Task> Scatter(
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
193 194
      const ScatterOptions& opts) override;

195 196 197 198 199 200
 private:
  std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
                                                         int rank,
                                                         CommType op_type,
                                                         bool sync_op,
                                                         bool use_calc_stream);
201

202
  void BroadcastUniqueNCCLID(ncclUniqueId* nccl_id);
203

204 205
  void CreateNCCLEnvCache(const Place& place, const std::string& place_key);

206 207
  void SyncCalcStream(const Place& place);

208 209 210
  std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
      std::function<void(ncclComm_t, gpuStream_t)> fn,
      const phi::DenseTensor& tensor,
211 212 213
      CommType comm_type,
      bool sync_op,
      bool use_calc_stream);
L
LiYuRio 已提交
214

215 216
  // TODO(sunyilun): methods below will be removed later
  std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
217 218
      std::vector<Place> places,
      int rank,
219
      CommType op_type,
220
      const std::vector<phi::DenseTensor>& inputs);
221 222 223

  template <typename Fn>
  std::shared_ptr<ProcessGroup::Task> Collective(
224 225
      std::vector<phi::DenseTensor>& inputs,   // NOLINT
      std::vector<phi::DenseTensor>& outputs,  // NOLINT
226 227
      Fn fn,
      CommType op_type);
228

B
Baibaifan 已提交
229 230
  template <typename Fn>
  std::shared_ptr<ProcessGroup::Task> PointToPoint(
231
      std::vector<phi::DenseTensor>& tensors,  // NOLINT
232 233 234
      Fn fn,
      int dst_rank,
      CommType op_type);
B
Baibaifan 已提交
235

236 237
  void CreateNCCLManagerCache(const std::string& places_key,
                              const std::vector<Place>& places);
238

239
 private:
240
  std::shared_ptr<phi::distributed::Store> store_;
241

W
Wen Sun 已提交
242 243
  std::unordered_map<std::string, platform::DeviceEvent>
      place_to_calc_event_;  // event on calc stream
244 245 246 247 248 249 250
  std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
  std::unordered_map<std::string, std::unique_ptr<phi::GPUContext>>
      place_to_comm_ctx_;

  // TODO(sunyilun): attrs below will be removed later
  std::mutex mutex_;
  std::unordered_map<std::string, std::vector<phi::GPUContext*>> places_to_ctx_;
251 252 253 254
};

}  //  namespace distributed
}  //  namespace paddle