process_group_gloo.h 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <future>
L
LiYuRio 已提交
18
#include <memory>
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
#include <mutex>

#include "paddle/fluid/distributed/collective/ProcessGroup.h"

#ifdef PADDLE_WITH_GLOO
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_store.h"

namespace paddle {
namespace distributed {

class ProcessGroupGloo : public ProcessGroup {
 public:
  class GlooTask : public ProcessGroup::Task,
                   public std::enable_shared_from_this<GlooTask> {
   public:
38 39
    explicit GlooTask(int rank,
                      const std::vector<phi::DenseTensor>& input_tensors,
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
                      CommType comm_type);

    ~GlooTask() = default;

    virtual void Run() = 0;
    bool Wait(std::chrono::milliseconds timeout) override { return true; }
    bool IsCompleted() override { return true; }
    void Synchronize() override {}

   protected:
    friend class ProcessGroupGloo;
  };

  class GlooStore : public ::gloo::rendezvous::Store {
   public:
55
    explicit GlooStore(const std::shared_ptr<paddle::distributed::Store>& store)
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        : _store(store) {}

    ~GlooStore() = default;

    std::vector<char> get(const std::string& key) override {
      VLOG(3) << "GlooStore::get";
      auto value = _store->get(key);
      return std::vector<char>(value.begin(), value.end());
    }

    void wait(const std::vector<std::string>& keys) override {
      VLOG(3) << "GlooStore::wait";
      for (auto& key : keys) {
        _store->wait(key);
      }
    }

    void set(const std::string& key, const std::vector<char>& value) override {
      VLOG(3) << "GlooStore::set";
      std::vector<uint8_t> tmp(value.begin(), value.end());
      _store->set(key, tmp);
    }

    void wait(const std::vector<std::string>& keys,
              const std::chrono::milliseconds& timeout) override {
      VLOG(3) << "GlooStore::wait";
      for (auto& key : keys) {
        _store->wait(key);
      }
      // wait(keys);
    }

   protected:
89
    std::shared_ptr<paddle::distributed::Store> _store;
90 91 92 93 94 95 96 97 98 99 100 101
  };

  class GlooOptions {
   public:
    GlooOptions() = default;
    ~GlooOptions() = default;
    static std::shared_ptr<GlooOptions> create() {
      return std::make_shared<GlooOptions>();
    }
    std::shared_ptr<::gloo::transport::Device> device;
  };

L
LiYuRio 已提交
102 103 104 105 106 107 108
  ProcessGroupGloo(const std::shared_ptr<paddle::distributed::Store>& store,
                   int rank,
                   int world_size,
                   int gid,
                   std::shared_ptr<GlooOptions> options);

  static std::shared_ptr<ProcessGroupGloo> CreateProcessGroupGloo(
109 110 111
      const std::shared_ptr<paddle::distributed::Store>& store,
      int rank,
      int world_size,
L
LiYuRio 已提交
112
      int gid);
113 114 115

  ~ProcessGroupGloo() = default;

L
LiYuRio 已提交
116 117 118
  std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
W
Wen Sun 已提交
119 120 121 122
      int64_t /*offset*/,  // for compatibility, no use now
      int64_t /*numel*/,   // for compatibility, no use now
      bool sync_op) override;

W
Wen Sun 已提交
123 124 125 126 127
  std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      bool sync_op) override;

W
Wen Sun 已提交
128 129 130 131
  std::shared_ptr<ProcessGroup::Task> AllReduce(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const AllreduceOptions& opts,
L
LiYuRio 已提交
132 133
      bool sync_op) override;

134 135 136 137 138 139
  std::shared_ptr<ProcessGroup::Task> Broadcast(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const BroadcastOptions& opts,
      bool sync_op) override;

140 141 142 143 144 145 146 147 148 149
  std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
                                             const phi::DenseTensor& in_tensor,
                                             const ReduceOptions& opts,
                                             bool sync_op) override;

  std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
                                              const phi::DenseTensor& in_tensor,
                                              const ScatterOptions& opts,
                                              bool sync_op) override;

150
  // TODO(sunyilun): methods below will be removed later
151
  std::shared_ptr<ProcessGroup::Task> Broadcast(
152 153
      std::vector<phi::DenseTensor>& inputs,
      std::vector<phi::DenseTensor>& outputs,
154 155
      const BroadcastOptions& = BroadcastOptions()) override;

156 157 158 159 160 161
  std::shared_ptr<ProcessGroup::Task> Broadcast(
      std::vector<phi::DenseTensor>& inputs,
      std::vector<phi::DenseTensor>& outputs,
      const BroadcastOptions& opts,
      bool sync_op) override;

162
  std::shared_ptr<ProcessGroup::Task> AllReduce(
163 164
      std::vector<phi::DenseTensor>& inputs,
      std::vector<phi::DenseTensor>& outputs,
165 166
      const AllreduceOptions& opts = AllreduceOptions()) override;

167 168 169 170 171 172
  std::shared_ptr<ProcessGroup::Task> AllReduce(
      std::vector<phi::DenseTensor>& inputs,
      std::vector<phi::DenseTensor>& outputs,
      const AllreduceOptions& opts,
      bool sync_op) override;

173 174 175 176
  std::shared_ptr<ProcessGroup::Task> Barrier(
      const BarrierOptions& = BarrierOptions()) override;

  std::shared_ptr<ProcessGroup::Task> AllGather(
177 178
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors) override;
179

L
LiYuRio 已提交
180 181 182 183 184
  std::shared_ptr<ProcessGroup::Task> AllGather(
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      bool sync_op) override;

185
  std::shared_ptr<ProcessGroup::Task> Reduce(
186 187 188 189 190 191 192 193
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      const ReduceOptions& opts) override;

  std::shared_ptr<ProcessGroup::Task> Scatter(
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      const ScatterOptions&) override;
194

195 196 197
  std::shared_ptr<::gloo::Context> get_context() { return _context; }
  uint64_t next_tag() { return _tag++; }

L
LiYuRio 已提交
198
  std::string GetBackendName() const override { return "GLOO"; }
199

200 201
  phi::DeviceContext* GetDeviceContext(const Place& place) const override {
    return platform::DeviceContextPool::Instance().Get(place);
202 203
  }

204 205 206 207 208 209 210
  // Helper functions for Gloo.
  static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
      const std::string& hostname);
  static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
      const std::string& ifname);
  static std::shared_ptr<::gloo::transport::Device> createDefaultDevice();

L
LiYuRio 已提交
211
 private:
212 213
  uint32_t _tag;
  std::shared_ptr<gloo::rendezvous::Context> _context;
214
  std::shared_ptr<::gloo::rendezvous::Store> _store;
215 216 217 218
};

}  // namespace distributed
}  // namespace paddle