ProcessGroup.h 13.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <chrono>
#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h"

constexpr auto kWaitTimeout = std::chrono::milliseconds(0);

namespace paddle {
namespace distributed {

L
lilong12 已提交
33
constexpr int IGNORE_ID = -1;
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
using Tensor = paddle::experimental::Tensor;

enum class CommType : std::uint8_t {
  BROADCAST = 0,
  ALLREDUCE = 1,
  ALLREDUCE_SPARSE = 2,  // TODO(shenliang03): to support sparse in allreduce
  REDUCE = 3,
  ALLGATHER = 4,
  GATHER = 5,
  SCATTER = 6,
  REDUCE_SCATTER = 7,
  ALLTOALL = 8,
  SEND = 9,
  RECV = 10,
  BARRIER = 11,
49
  ALLTOALL_SINGLE = 12,
50 51 52 53 54 55 56
  UNKNOWN = 100,
};

class ProcessGroup {
 public:
  class Task {
   public:
57
    Task(int rank, CommType comm_type, bool sync_op);
58 59 60 61 62

    virtual ~Task();
    virtual bool IsCompleted();
    virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
    virtual void Synchronize();
63
    virtual void UpdateWaitChain(const phi::DeviceContext& ctx);
64
    bool IsSync() const { return sync_op_; }
65

66 67 68 69 70 71 72 73 74
    // TODO(sunyilun): methods below will be removed later
    Task(int rank,
         const std::vector<phi::DenseTensor>& inputs,
         CommType comm_type);
    Task(int rank,
         const std::vector<phi::DenseTensor>& inputs,
         CommType comm_type,
         bool sync_op);

75 76
   protected:
    const int rank_;
77
    CommType comm_type_{CommType::UNKNOWN};
78
    std::mutex mutex_;
79 80 81 82
    bool is_completed_{false};

   private:
    bool sync_op_{true};
83 84
  };

85
 public:
86 87 88
  explicit ProcessGroup(int rank,
                        int size,
                        const platform::Place& place,
89
                        int gid);
W
wuhuachaocoding 已提交
90 91 92

  explicit ProcessGroup(int rank, int size, int gid);

93 94 95 96 97 98
  virtual ~ProcessGroup() {}

  int GetRank() const { return rank_; }

  int GetSize() const { return size_; }

L
LiYuRio 已提交
99
  virtual std::string GetBackendName() const = 0;
100

101
  virtual const phi::DeviceContext& GetDeviceContext(const Place& place) const {
L
LiYuRio 已提交
102 103 104 105
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Does not support to get device_context from ProcessGroup%s.",
        GetBackendName()));
  }
106

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      bool sync_op) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support all_gather with sync_op flag",
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const AllreduceOptions& opts,
      bool sync_op) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support all_reduce with sync_op flag",
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Barrier(
      const BarrierOptions& = BarrierOptions()) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support barrier", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const BroadcastOptions& opts,
      bool sync_op) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support broadcast with sync_op flag",
        GetBackendName()));
  }

142
  // TODO(liyurui): This API will be moved later
143
  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
144 145
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
146 147 148 149 150
      const AllreduceOptions& = AllreduceOptions()) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support allreduce", GetBackendName()));
  }

151 152 153 154 155 156 157 158 159 160
  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
      const AllreduceOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support allreduce with sync_op flag",
        GetBackendName()));
  }

161
  // TODO(sunyilun): methods below will be removed later
162
  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
163 164
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
165 166
      const BroadcastOptions& = BroadcastOptions()) {
    PADDLE_THROW(platform::errors::InvalidArgument(
B
Baibaifan 已提交
167 168 169
        "ProcessGroup%s does not support broadcast", GetBackendName()));
  }

170 171 172 173 174 175 176 177 178 179
  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
      const BroadcastOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support broadcast with sync_op flag",
        GetBackendName()));
  }

B
Baibaifan 已提交
180
  virtual std::shared_ptr<ProcessGroup::Task> Send(
181
      std::vector<phi::DenseTensor>&, int) {  // NOLINT
B
Baibaifan 已提交
182 183 184 185
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support send", GetBackendName()));
  }

186 187 188 189 190 191 192
  virtual std::shared_ptr<ProcessGroup::Task> Send(
      std::vector<phi::DenseTensor>&, int, bool) {  // NOLINT
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support send with sync_op flag",
        GetBackendName()));
  }

B
Baibaifan 已提交
193
  virtual std::shared_ptr<ProcessGroup::Task> Recv(
194
      std::vector<phi::DenseTensor>&, int) {  // NOLINT
B
Baibaifan 已提交
195
    PADDLE_THROW(platform::errors::InvalidArgument(
196
        "ProcessGroup%s does not support recv", GetBackendName()));
197 198
  }

199 200
  virtual std::shared_ptr<ProcessGroup::Task> Recv(
      std::vector<phi::DenseTensor>&, int, bool) {  // NOLINT
201
    PADDLE_THROW(platform::errors::InvalidArgument(
202 203 204 205 206 207 208
        "ProcessGroup%s does not support recv with sync_op flag",
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
      phi::DenseTensor&,  // NOLINT
      int,
209 210
      int64_t,
      int64_t) {
211 212 213 214 215
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support send_partial", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(
216
      phi::DenseTensor&, int, int64_t, int64_t, bool) {  // NOLINT
217 218 219
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support send_partial with sync_op flag",
        GetBackendName()));
220 221 222
  }

  virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
223 224
      phi::DenseTensor&,  // NOLINT
      int,
225 226
      int64_t,
      int64_t) {
227
    PADDLE_THROW(platform::errors::InvalidArgument(
228 229 230 231
        "ProcessGroup%s does not support recv_partial", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
232
      phi::DenseTensor&, int, int64_t, int64_t, bool) {  // NOLINT
233 234 235
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support recv_partial with sync_op flag",
        GetBackendName()));
236 237
  }

238
  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
239 240
      std::vector<phi::DenseTensor>&,    // NOLINT
      std::vector<phi::DenseTensor>&) {  // NOLINT
241
    PADDLE_THROW(platform::errors::InvalidArgument(
242 243 244 245 246 247 248 249 250 251
        "ProcessGroup%s does not support all_gather", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support all_gather with sync_op flag",
        GetBackendName()));
252 253
  }

254 255 256
  virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
      std::vector<phi::DenseTensor>& in_tensors,   // NOLINT
      std::vector<phi::DenseTensor>& out_tensors,  // NOLINT
257 258
      int64_t offset,
      int64_t length) {
259 260 261 262 263 264 265
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
      std::vector<phi::DenseTensor>& in_tensors,   // NOLINT
      std::vector<phi::DenseTensor>& out_tensors,  // NOLINT
266 267 268
      int64_t offset,
      int64_t length,
      bool) {
269 270 271 272
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support AllGather_Partial", GetBackendName()));
  }

273
  virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
274 275
      std::vector<phi::DenseTensor>&,    // NOLINT
      std::vector<phi::DenseTensor>&) {  // NOLINT
276 277 278 279
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support AllToAll", GetBackendName()));
  }

280 281 282 283 284 285 286 287
  virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support alltoall", GetBackendName()));
  }

288 289 290 291 292 293 294 295 296
  virtual std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<int64_t>&,
      std::vector<int64_t>&) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support AllToAll_Single", GetBackendName()));
  }

297 298 299 300 301 302 303 304 305 306
  virtual std::shared_ptr<ProcessGroup::Task> AllToAllSingle(
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<int64_t>&,
      std::vector<int64_t>&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support alltoall_single", GetBackendName()));
  }

307
  virtual std::shared_ptr<ProcessGroup::Task> Reduce(
308 309 310
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      const ReduceOptions& opts) {
311
    PADDLE_THROW(platform::errors::InvalidArgument(
312 313 314 315 316 317 318 319 320 321 322
        "ProcessGroup%s does not support reduce", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Reduce(
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
      const ReduceOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support reduce with sync_op flag",
        GetBackendName()));
323 324 325
  }

  virtual std::shared_ptr<ProcessGroup::Task> Scatter(
326 327
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
328
      const ScatterOptions&) {
329
    PADDLE_THROW(platform::errors::InvalidArgument(
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
        "ProcessGroup%s does not support scatter", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Scatter(
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      const ScatterOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support scatter with sync_op flag",
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      const ReduceScatterOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support reduce_scatter with sync_op flag",
        GetBackendName()));
351 352
  }

353 354 355
 protected:
  const int rank_;
  const int size_;
356
  const platform::Place place_;
357
  const int gid_;
358 359
};

L
lilong12 已提交
360 361 362 363 364 365 366 367
class ProcessGroupMapFromGid {
 public:
  bool has(int gid) {
    auto it = map_.find(gid);
    return it != map_.end();
  }

  void insert(int gid, ProcessGroup* pg) {
368
    // TODO(sandyhouse): address ut and uncomment the following codes
369
    // PADDLE_ENFORCE_EQ(has(gid), false,
370 371 372
    //                   platform::errors::PreconditionNotMet(
    //                       "The process group with id %d doesnot exist.",
    //                       gid));
L
lilong12 已提交
373 374 375 376
    map_[gid] = pg;
  }

  ProcessGroup* get(int gid) {
377
    // TODO(sandyhouse): address ut and uncomment the following codes
378
    // PADDLE_ENFORCE_EQ(has(gid), true,
379 380 381
    //                   platform::errors::PreconditionNotMet(
    //                       "The process group with id %d doesnot exist.",
    //                       gid));
L
lilong12 已提交
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
    return map_.find(gid)->second;
  }

  static std::shared_ptr<ProcessGroupMapFromGid> getInstance() {
    static auto s_instance = std::make_shared<ProcessGroupMapFromGid>();
    return s_instance;
  }

  ProcessGroupMapFromGid() = default;
  ~ProcessGroupMapFromGid() = default;

 private:
  std::unordered_map<int, ProcessGroup*> map_;
};

397 398
}  //  namespace distributed
}  //  namespace paddle