ProcessGroup.h 11.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <chrono>
#include <memory>
#include <string>
#include <vector>

#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h"

constexpr auto kWaitTimeout = std::chrono::milliseconds(0);

namespace paddle {
namespace distributed {

L
lilong12 已提交
33
constexpr int IGNORE_ID = -1;
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
using Tensor = paddle::experimental::Tensor;

enum class CommType : std::uint8_t {
  BROADCAST = 0,
  ALLREDUCE = 1,
  ALLREDUCE_SPARSE = 2,  // TODO(shenliang03): to support sparse in allreduce
  REDUCE = 3,
  ALLGATHER = 4,
  GATHER = 5,
  SCATTER = 6,
  REDUCE_SCATTER = 7,
  ALLTOALL = 8,
  SEND = 9,
  RECV = 10,
  BARRIER = 11,
  UNKNOWN = 100,
};

class ProcessGroup {
 public:
  class Task {
   public:
56
    Task(int rank, CommType comm_type, bool sync_op);
57 58 59 60 61

    virtual ~Task();
    virtual bool IsCompleted();
    virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout);
    virtual void Synchronize();
62
    virtual void UpdateWaitChain(const phi::DeviceContext& ctx);
63
    bool IsSync() const { return sync_op_; }
64

65 66 67 68 69 70 71 72 73
    // TODO(sunyilun): methods below will be removed later
    Task(int rank,
         const std::vector<phi::DenseTensor>& inputs,
         CommType comm_type);
    Task(int rank,
         const std::vector<phi::DenseTensor>& inputs,
         CommType comm_type,
         bool sync_op);

74 75
   protected:
    const int rank_;
76
    CommType comm_type_{CommType::UNKNOWN};
77
    std::mutex mutex_;
78 79 80 81
    bool is_completed_{false};

   private:
    bool sync_op_{true};
82 83
  };

84
 public:
L
LiYuRio 已提交
85
  ProcessGroup(int rank, int size, int gid);
86
  virtual ~ProcessGroup() = default;
W
wuhuachaocoding 已提交
87

88 89 90 91
  int GetRank() const { return rank_; }

  int GetSize() const { return size_; }

L
LiYuRio 已提交
92
  virtual std::string GetBackendName() const = 0;
93

94
  virtual phi::DeviceContext* GetDeviceContext(const Place& place) const {
95 96
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support get device_context.",
L
LiYuRio 已提交
97 98
        GetBackendName()));
  }
99

100 101 102
  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
103 104
      int64_t offset,
      int64_t numel,
105
      bool sync_op) {
106 107
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support all_gather with sync_op flag.",
108 109 110 111 112 113 114 115
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const AllreduceOptions& opts,
      bool sync_op) {
116 117
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support all_reduce with sync_op flag.",
118 119 120
        GetBackendName()));
  }

121 122 123 124 125 126 127 128 129 130 131
  virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const std::vector<int64_t>& out_size_each_rank,
      const std::vector<int64_t>& in_size_each_rank,
      bool sync_op) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support all_to_all with sync_op flag.",
        GetBackendName()));
  }

132 133
  virtual std::shared_ptr<ProcessGroup::Task> Barrier(
      const BarrierOptions& = BarrierOptions()) {
134 135
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support barrier.", GetBackendName()));
136 137 138 139 140 141 142
  }

  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const BroadcastOptions& opts,
      bool sync_op) {
143
    PADDLE_THROW(platform::errors::Unimplemented(
144 145 146 147
        "ProcessGroup%s does not support broadcast with sync_op flag",
        GetBackendName()));
  }

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
  virtual std::shared_ptr<ProcessGroup::Task> Reduce(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ReduceOptions& opts,
      bool sync_op) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support reduce with sync_op flag.",
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ReduceScatterOptions& opts,
      bool sync_op) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support reduce_scatter with sync_op flag.",
        GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Scatter(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ScatterOptions& opts,
      bool sync_op) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support scatter with sync_op flag.",
        GetBackendName()));
  }

178 179
  virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
                                                   int src_rank,
180 181
                                                   int64_t offset,
                                                   int64_t numel,
182
                                                   bool sync_op) {
183 184
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support recv with sync_op flag.",
185 186 187
        GetBackendName()));
  }

188 189 190 191 192 193
  virtual std::shared_ptr<ProcessGroup::Task> Send(
      const phi::DenseTensor& tensor,
      int dst_rank,
      int64_t offset,
      int64_t numel,
      bool sync_op) {
194 195
    PADDLE_THROW(platform::errors::Unimplemented(
        "ProcessGroup%s does not support send with sync_op flag.",
196 197 198
        GetBackendName()));
  }

199
  // TODO(liyurui): This API will be moved later
200
  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
201 202
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
203 204 205 206 207
      const AllreduceOptions& = AllreduceOptions()) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support allreduce", GetBackendName()));
  }

208 209 210 211 212 213 214 215 216 217
  virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
      const AllreduceOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support allreduce with sync_op flag",
        GetBackendName()));
  }

218
  // TODO(sunyilun): methods below will be removed later
219
  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
220 221
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
222 223
      const BroadcastOptions& = BroadcastOptions()) {
    PADDLE_THROW(platform::errors::InvalidArgument(
B
Baibaifan 已提交
224 225 226
        "ProcessGroup%s does not support broadcast", GetBackendName()));
  }

227 228 229 230 231 232 233 234 235 236
  virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
      std::vector<phi::DenseTensor>& /* input tensors */,   // NOLINT
      std::vector<phi::DenseTensor>& /* output tensors */,  // NOLINT
      const BroadcastOptions&,
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support broadcast with sync_op flag",
        GetBackendName()));
  }

B
Baibaifan 已提交
237
  virtual std::shared_ptr<ProcessGroup::Task> Send(
238
      std::vector<phi::DenseTensor>&, int) {  // NOLINT
B
Baibaifan 已提交
239 240 241 242 243
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support send", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Recv(
244
      std::vector<phi::DenseTensor>&, int) {  // NOLINT
B
Baibaifan 已提交
245
    PADDLE_THROW(platform::errors::InvalidArgument(
246
        "ProcessGroup%s does not support recv", GetBackendName()));
247 248
  }

249
  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
250 251
      std::vector<phi::DenseTensor>&,    // NOLINT
      std::vector<phi::DenseTensor>&) {  // NOLINT
252
    PADDLE_THROW(platform::errors::InvalidArgument(
253 254 255 256 257 258 259 260 261 262
        "ProcessGroup%s does not support all_gather", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllGather(
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      bool) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support all_gather with sync_op flag",
        GetBackendName()));
263 264 265
  }

  virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
266 267
      std::vector<phi::DenseTensor>&,    // NOLINT
      std::vector<phi::DenseTensor>&) {  // NOLINT
268 269 270 271 272
    PADDLE_THROW(platform::errors::InvalidArgument(
        "ProcessGroup%s does not support AllToAll", GetBackendName()));
  }

  virtual std::shared_ptr<ProcessGroup::Task> Reduce(
273 274 275
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
      const ReduceOptions& opts) {
276
    PADDLE_THROW(platform::errors::InvalidArgument(
277 278 279
        "ProcessGroup%s does not support reduce", GetBackendName()));
  }

280
  virtual std::shared_ptr<ProcessGroup::Task> Scatter(
281 282
      std::vector<phi::DenseTensor>&,  // NOLINT
      std::vector<phi::DenseTensor>&,  // NOLINT
283
      const ScatterOptions&) {
284
    PADDLE_THROW(platform::errors::InvalidArgument(
285 286 287
        "ProcessGroup%s does not support scatter", GetBackendName()));
  }

288
 protected:
L
LiYuRio 已提交
289 290 291
  int rank_;
  int size_;
  int gid_;
292 293
};

L
LiYuRio 已提交
294 295 296 297 298 299 300
class ProcessGroupIdMap
    : public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
 public:
  static ProcessGroupIdMap& GetInstance();
};

// TODO(dev): The following method will be removed soon.
L
lilong12 已提交
301 302 303 304 305 306 307 308
class ProcessGroupMapFromGid {
 public:
  bool has(int gid) {
    auto it = map_.find(gid);
    return it != map_.end();
  }

  void insert(int gid, ProcessGroup* pg) {
309
    // TODO(sandyhouse): address ut and uncomment the following codes
310
    // PADDLE_ENFORCE_EQ(has(gid), false,
311 312 313
    //                   platform::errors::PreconditionNotMet(
    //                       "The process group with id %d doesnot exist.",
    //                       gid));
L
lilong12 已提交
314 315 316 317
    map_[gid] = pg;
  }

  ProcessGroup* get(int gid) {
318
    // TODO(sandyhouse): address ut and uncomment the following codes
319
    // PADDLE_ENFORCE_EQ(has(gid), true,
320 321 322
    //                   platform::errors::PreconditionNotMet(
    //                       "The process group with id %d doesnot exist.",
    //                       gid));
L
lilong12 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    return map_.find(gid)->second;
  }

  static std::shared_ptr<ProcessGroupMapFromGid> getInstance() {
    static auto s_instance = std::make_shared<ProcessGroupMapFromGid>();
    return s_instance;
  }

  ProcessGroupMapFromGid() = default;
  ~ProcessGroupMapFromGid() = default;

 private:
  std::unordered_map<int, ProcessGroup*> map_;
};

338 339
}  //  namespace distributed
}  //  namespace paddle