ProcessGroup.h 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <chrono>
#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h"

constexpr auto kWaitTimeout = std::chrono::milliseconds(0);

namespace paddle {
namespace distributed {

L
lilong12 已提交
33
constexpr int IGNORE_ID = -1;
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
using Tensor = paddle::experimental::Tensor;

enum class CommType : std::uint8_t {
  BROADCAST = 0,
  ALLREDUCE = 1,
  ALLREDUCE_SPARSE = 2,  // TODO(shenliang03): to support sparse in allreduce
  REDUCE = 3,
  ALLGATHER = 4,
  GATHER = 5,
  SCATTER = 6,
  REDUCE_SCATTER = 7,
  ALLTOALL = 8,
  SEND = 9,
  RECV = 10,
  BARRIER = 11,
  UNKNOWN = 100,
};

class ProcessGroup {
 public:
  class Task {
   public:
56
    Task(int rank, CommType comm_type, bool sync_op);
57 58 59 60 61

    virtual ~Task();
    virtual bool IsCompleted();
    virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
    virtual void Synchronize();
62
    virtual void UpdateWaitChain(const phi::DeviceContext& ctx);
63
    bool IsSync() const { return sync_op_; }
64

65 66 67 68 69 70 71 72 73
    // TODO(sunyilun): methods below will be removed later
    Task(int rank,
         const std::vector<phi::DenseTensor>& inputs,
         CommType comm_type);
    Task(int rank,
         const std::vector<phi::DenseTensor>& inputs,
         CommType comm_type,
         bool sync_op);

74 75
   protected:
    const int rank_;
76
    CommType comm_type_{CommType::UNKNOWN};
77
    std::mutex mutex_;
78 79 80 81
    bool is_completed_{false};

   private:
    bool sync_op_{true};
82 83
  };

84
 public:
L
LiYuRio 已提交
85
  ProcessGroup(int rank, int size, int gid);
86
  virtual ~ProcessGroup() = default;
W
wuhuachaocoding 已提交
87

88 89 90 91
  int GetRank() const { return rank_; }

  int GetSize() const { return size_; }

L
LiYuRio 已提交
92
  virtual std::string GetBackendName() const = 0;
93

94
  virtual phi::DeviceContext* GetDeviceContext(const Place& place) const {
95 96
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support get device_context.",
L
LiYuRio 已提交
97 98
        GetBackendName()));
  }
99

W
Wen Sun 已提交
100 101 102 103 104 105 106 107 108 109 110
  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      bool sync_op) {
    return AllGather(out_tensor,
                     in_tensor,
                     /*offset*/ 0,
                     /*numel*/ -1,  // -1 indicates the whole tensor
                     sync_op);
  }

111 112 113
  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
114 115
      int64_t offset,
      int64_t numel,
116
      bool sync_op) {
117 118
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support all_gather with sync_op flag.",
119 120 121 122 123 124 125 126
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const AllreduceOptions& opts,
      bool sync_op) {
127 128
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support all_reduce with sync_op flag.",
129 130 131
        GetBackendName()));
  }

132 133 134 135 136 137 138 139 140 141 142
  virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const std::vector<int64_t>& out_size_each_rank,
      const std::vector<int64_t>& in_size_each_rank,
      bool sync_op) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support all_to_all with sync_op flag.",
        GetBackendName()));
  }

143 144
  virtual std::shared_ptr<ProcessGroup::Task> Barrier(
      const BarrierOptions& = BarrierOptions()) {
145 146
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support barrier.", GetBackendName()));
147 148 149 150 151 152 153
  }

  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const BroadcastOptions& opts,
      bool sync_op) {
154
    PADDLE_THROW(platform::errors::Unimplemented(
155 156 157 158
        "ProcessGroup%s does not support broadcast with sync_op flag",
        GetBackendName()));
  }

159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
  virtual std::shared_ptr<ProcessGroup::Task> Reduce(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ReduceOptions& opts,
      bool sync_op) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support reduce with sync_op flag.",
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ReduceScatterOptions& opts,
      bool sync_op) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support reduce_scatter with sync_op flag.",
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Scatter(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ScatterOptions& opts,
      bool sync_op) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support scatter with sync_op flag.",
        GetBackendName()));
  }

W
Wen Sun 已提交
189 190 191 192 193 194 195 196 197 198
  virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
                                                   int src_rank,
                                                   bool sync_op) {
    return Recv(tensor,
                src_rank,
                /*offset*/ 0,
                /*numel*/ -1,  // -1 indicates the whole tensor
                sync_op);
  }

199 200
  virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
                                                   int src_rank,
201 202
                                                   int64_t offset,
                                                   int64_t numel,
203
                                                   bool sync_op) {
204 205
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support recv with sync_op flag.",
206 207 208
        GetBackendName()));
  }

W
Wen Sun 已提交
209 210 211 212 213 214 215 216 217
  virtual std::shared_ptr<ProcessGroup::Task> Send(
      const phi::DenseTensor& tensor, int dst_rank, bool sync_op) {
    return Send(tensor,
                dst_rank,
                /*offset*/ 0,
                /*numel*/ -1,  // -1 indicates the whole tensor
                sync_op);
  }

218 219 220 221 222 223
  virtual std::shared_ptr<ProcessGroup::Task> Send(
      const phi::DenseTensor& tensor,
      int dst_rank,
      int64_t offset,
      int64_t numel,
      bool sync_op) {
224 225
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support send with sync_op flag.",
226 227 228
        GetBackendName()));
  }

229
  // TODO(liyurui): This API will be moved later
230
  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
231 232
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
233 234 235 236 237
      const AllreduceOptions& = AllreduceOptions()) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support allreduce", GetBackendName()));
  }

238 239 240 241 242 243 244 245 246 247
  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
      const AllreduceOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support allreduce with sync_op flag",
        GetBackendName()));
  }

248
  // TODO(sunyilun): methods below will be removed later
249
  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
250 251
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
252 253
      const BroadcastOptions& = BroadcastOptions()) {
    PADDLE_THROW(platform::errors::InvalidArgument(
B
Baibaifan 已提交
254 255 256
        "ProcessGroup%s does not support broadcast", GetBackendName()));
  }

257 258 259 260 261 262 263 264 265 266
  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
      const BroadcastOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support broadcast with sync_op flag",
        GetBackendName()));
  }

B
Baibaifan 已提交
267
  virtual std::shared_ptr<ProcessGroup::Task> Send(
268
      std::vector<phi::DenseTensor>&, int) {  // NOLINT
B
Baibaifan 已提交
269 270 271 272 273
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support send", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Recv(
274
      std::vector<phi::DenseTensor>&, int) {  // NOLINT
B
Baibaifan 已提交
275
    PADDLE_THROW(platform::errors::InvalidArgument(
276
        "ProcessGroup%s does not support recv", GetBackendName()));
277 278
  }

279
  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
280 281
      std::vector<phi::DenseTensor>&,    // NOLINT
      std::vector<phi::DenseTensor>&) {  // NOLINT
282
    PADDLE_THROW(platform::errors::InvalidArgument(
283 284 285 286 287 288 289 290 291 292
        "ProcessGroup%s does not support all_gather", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support all_gather with sync_op flag",
        GetBackendName()));
293 294 295
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
296 297
      std::vector<phi::DenseTensor>&,    // NOLINT
      std::vector<phi::DenseTensor>&) {  // NOLINT
298 299 300 301 302
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support AllToAll", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Reduce(
303 304 305
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      const ReduceOptions& opts) {
306
    PADDLE_THROW(platform::errors::InvalidArgument(
307 308 309
        "ProcessGroup%s does not support reduce", GetBackendName()));
  }

310
  virtual std::shared_ptr<ProcessGroup::Task> Scatter(
311 312
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
313
      const ScatterOptions&) {
314
    PADDLE_THROW(platform::errors::InvalidArgument(
315 316 317
        "ProcessGroup%s does not support scatter", GetBackendName()));
  }

318
 protected:
L
LiYuRio 已提交
319 320 321
  int rank_;
  int size_;
  int gid_;
322 323
};

L
LiYuRio 已提交
324 325 326 327 328 329 330
class ProcessGroupIdMap
    : public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
 public:
  static ProcessGroupIdMap& GetInstance();
};

// TODO(dev): The following method will be removed soon.
L
lilong12 已提交
331 332 333 334 335 336 337 338
class ProcessGroupMapFromGid {
 public:
  bool has(int gid) {
    auto it = map_.find(gid);
    return it != map_.end();
  }

  void insert(int gid, ProcessGroup* pg) {
339
    // TODO(sandyhouse): address ut and uncomment the following codes
340
    // PADDLE_ENFORCE_EQ(has(gid), false,
341 342 343
    //                   platform::errors::PreconditionNotMet(
    //                       "The process group with id %d doesnot exist.",
    //                       gid));
L
lilong12 已提交
344 345 346 347
    map_[gid] = pg;
  }

  ProcessGroup* get(int gid) {
348
    // TODO(sandyhouse): address ut and uncomment the following codes
349
    // PADDLE_ENFORCE_EQ(has(gid), true,
350 351 352
    //                   platform::errors::PreconditionNotMet(
    //                       "The process group with id %d doesnot exist.",
    //                       gid));
L
lilong12 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
    return map_.find(gid)->second;
  }

  static std::shared_ptr<ProcessGroupMapFromGid> getInstance() {
    static auto s_instance = std::make_shared<ProcessGroupMapFromGid>();
    return s_instance;
  }

  ProcessGroupMapFromGid() = default;
  ~ProcessGroupMapFromGid() = default;

 private:
  std::unordered_map<int, ProcessGroup*> map_;
};

368 369
}  //  namespace distributed
}  //  namespace paddle