process_group_custom.h 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

W
Wen Sun 已提交
24 25
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/process_group.h"
26
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
27 28 29 30
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
31
#include "paddle/phi/core/distributed/store/store.h"
32 33 34 35 36

namespace paddle {
namespace distributed {
using Place = paddle::platform::Place;
using CustomDeviceContext = paddle::platform::CustomDeviceContext;
37

38
class ProcessGroupCustom : public ProcessGroupWithStream {
39 40 41 42 43 44 45 46 47
 public:
  class CustomTask : public ProcessGroup::Task,
                     public std::enable_shared_from_this<CustomTask> {
   public:
    CustomTask(const std::vector<Place>& places,
               int rank,
               CommType CommType,
               const std::vector<phi::DenseTensor>& inputs);

48
    bool IsCompleted() override;
49
    void SynchronizeStreams();
50 51 52
    bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override;
    void Synchronize() override;
    void UpdateWaitChain(const phi::DeviceContext& ctx) override;
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    void SetOutputs(std::vector<phi::DenseTensor>& outputs);  // NOLINT
    virtual ~CustomTask();

    std::vector<CustomEventManager> control_events_;
    std::vector<phi::DenseTensor> barrierTensors_;

   protected:
    std::vector<Place> places_;
    std::vector<std::shared_ptr<CustomCCLCommManager>> cclComms_;
    std::shared_ptr<std::vector<phi::DenseTensor>> outputs_;

   private:
    const std::string device_type_;
  };

68
  ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store,
69
                     const std::string& device_type,
70 71 72 73
                     int rank,
                     int size,
                     int gid);

L
LiYuRio 已提交
74
  static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
75
      const std::shared_ptr<phi::distributed::Store>& store,
L
LiYuRio 已提交
76 77 78 79 80
      const std::string& device_type,
      int rank,
      int size,
      int gid);

L
LiYuRio 已提交
81
  std::string GetBackendName() const override { return "XCCL_" + device_type_; }
82

83 84 85
  std::shared_ptr<ProcessGroup::Task> Barrier(
      const BarrierOptions& = BarrierOptions()) override;

86
  phi::DeviceContext* GetDeviceContext(const Place& place) const override;
87 88 89 90

  phi::ccl::CCLComm CustomCCLComm(const Place& place) const;

  // TODO(sunyilun): methods below will be removed later
91
  std::shared_ptr<ProcessGroup::Task> AllGather(
92
      std::vector<phi::DenseTensor>& in_tensors,
93
      std::vector<phi::DenseTensor>& out_tensors) override;
94

95 96 97 98 99 100 101 102
  std::shared_ptr<ProcessGroup::Task> AllGather(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      int64_t offset,
      int64_t numel,
      bool sync_op,
      bool use_calc_stream) override;

103 104 105 106 107
  std::shared_ptr<ProcessGroup::Task> AllReduce(
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      const AllreduceOptions& = AllreduceOptions()) override;

108 109 110 111 112 113 114
  std::shared_ptr<ProcessGroup::Task> AllReduce(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const AllreduceOptions& opts,
      bool sync_op,
      bool use_calc_stream) override;

115 116 117 118 119
  std::shared_ptr<ProcessGroup::Task> Broadcast(
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      const BroadcastOptions& = BroadcastOptions()) override;

120 121 122 123 124 125 126
  std::shared_ptr<ProcessGroup::Task> Broadcast(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const BroadcastOptions& opts,
      bool sync_op,
      bool use_calc_stream) override;

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
                                           int dst_rank,
                                           int64_t offset,
                                           int64_t numel,
                                           bool sync_op,
                                           bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> Send(
      std::vector<phi::DenseTensor>& tensors, int dst_rank) override;

  std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
                                           int src_rank,
                                           int64_t offset,
                                           int64_t numel,
                                           bool sync_op,
                                           bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> Recv(
      std::vector<phi::DenseTensor>& tensors, int src_rank) override;

147 148 149 150 151 152
  std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
                                             const phi::DenseTensor& in_tensor,
                                             const ReduceOptions& opts,
                                             bool sync_op,
                                             bool use_calc_stream) override;

153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
  std::shared_ptr<ProcessGroup::Task> Reduce(
      std::vector<phi::DenseTensor>& tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      const ReduceOptions& opts) override;

  std::shared_ptr<ProcessGroup::Task> AllToAll(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const std::vector<int64_t>& out_size_each_rank,
      const std::vector<int64_t>& in_size_each_rank,
      bool sync_op,
      bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> AllToAll(
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors) override;

  std::shared_ptr<ProcessGroup::Task> ReduceScatter(
      phi::DenseTensor* out_tensor,
      const phi::DenseTensor& in_tensor,
      const ReduceScatterOptions& opts,
      bool sync_op,
      bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
                                              const phi::DenseTensor& in_tensor,
                                              const ScatterOptions& opts,
                                              bool sync_op,
                                              bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> Scatter(
      std::vector<phi::DenseTensor>& in_tensors,
      std::vector<phi::DenseTensor>& out_tensors,
      const ScatterOptions& opts) override;

  std::shared_ptr<ProcessGroup::Task> Gather(phi::DenseTensor* out_tensor,
                                             const phi::DenseTensor& in_tensor,
                                             const GatherOptions& opts,
                                             bool sync_op,
                                             bool use_calc_stream) override;

  std::shared_ptr<ProcessGroup::Task> Gather(
      std::vector<phi::DenseTensor>* gather_tensors_ptr,
      const phi::DenseTensor& in_tensor,
      const GatherOptions& opts,
      bool sync_op,
      bool use_calc_stream) override;

201 202 203 204 205 206 207
 protected:
  virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
      std::vector<Place> places,
      int rank,
      CommType opType,
      const std::vector<phi::DenseTensor>& inputs);

208
  std::shared_ptr<phi::distributed::Store> store_;
209 210 211 212 213 214 215 216 217 218 219 220 221
  std::shared_ptr<CustomCCLCommManager> custom_comm_;
  std::mutex mutex_;
  std::unordered_map<std::string,
                     std::vector<std::shared_ptr<CustomCCLCommManager>>>
      places_to_customcomm_;
  std::unordered_map<std::string, std::vector<CustomEventManager>>
      places_to_events_;
  std::unordered_map<std::string,
                     std::vector<std::unique_ptr<CustomDeviceContext>>>
      places_to_ctx_;
  std::set<int> used_place_ids_;

 private:
222 223
  void BcastCustomId(std::vector<phi::ccl::CCLRootId>& ccl_ids,  // NOLINT
                     int root,
224 225 226 227 228 229 230 231 232 233
                     int server_fd);

  void BroadcastUniqueCustomID(
      std::vector<phi::ccl::CCLRootId>& custom_ccl_ids);  // NOLINT

  template <typename Fn>
  std::shared_ptr<ProcessGroup::Task> Collective(
      std::vector<phi::DenseTensor>& inputs,   // NOLINT
      std::vector<phi::DenseTensor>& outputs,  // NOLINT
      Fn fn,
234 235 236
      CommType op_type,
      bool sync_op,
      bool use_calc_stream);
237

238 239 240 241 242 243
  template <typename Fn>
  std::shared_ptr<ProcessGroup::Task> Collective(Fn fn,
                                                 CommType op_type,
                                                 bool sync_op,
                                                 bool use_calc_stream);

244 245 246 247 248 249
  void CreateCustomManagerCache(const std::string& places_key,
                                const std::vector<Place>& places);
  const std::string device_type_;
};
}  //  namespace distributed
}  //  namespace paddle