ProcessGroupStream.cc 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"

namespace paddle {
namespace distributed {

20 21
ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid)
    : ProcessGroup(rank, size, gid) {}
22

23
const phi::DeviceContext& ProcessGroupStream::GetDeviceContext(
24
    const Place& place, bool use_calc_stream) const {
25
  PADDLE_THROW(platform::errors::Unimplemented(
26 27 28 29
      "ProcessGroup%s does not support get device_context.", GetBackendName()));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
30 31
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
32 33
    int64_t offset,
    int64_t numel,
34
    bool sync_op) {
35 36
  return AllGather(out_tensor,
                   in_tensor,
37 38
                   offset,
                   numel,
39 40 41 42 43
                   sync_op,
                   /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
44 45
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
46 47
    int64_t offset,
    int64_t numel,
48 49
    bool sync_op,
    bool use_calc_stream) {
50 51
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support all_gather.", GetBackendName()));
52 53
}

54
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
55 56 57
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const AllreduceOptions& opts,
58
    bool sync_op) {
59 60 61
  return AllReduce(out_tensor,
                   in_tensor,
                   opts,
62 63 64 65 66
                   sync_op,
                   /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
67 68 69
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const AllreduceOptions& opts,
70 71
    bool sync_op,
    bool use_calc_stream) {
72 73
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support all_reduce.", GetBackendName()));
74 75
}

76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const std::vector<int64_t>& out_size_each_rank,
    const std::vector<int64_t>& in_size_each_rank,
    bool sync_op) {
  return AllToAll(out_tensor,
                  in_tensor,
                  out_size_each_rank,
                  in_size_each_rank,
                  sync_op,
                  /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const std::vector<int64_t>& out_size_each_rank,
    const std::vector<int64_t>& in_size_each_rank,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support all_to_all.", GetBackendName()));
}

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const BroadcastOptions& opts,
    bool sync_op) {
  return Broadcast(out_tensor,
                   in_tensor,
                   opts,
                   sync_op,
                   /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const BroadcastOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
119 120
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support broadcast.", GetBackendName()));
121 122
}

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceOptions& opts,
    bool sync_op) {
  return Reduce(out_tensor,
                in_tensor,
                opts,
                sync_op,
                /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support reduce.", GetBackendName()));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceScatterOptions& opts,
    bool sync_op) {
  return ReduceScatter(out_tensor,
                       in_tensor,
                       opts,
                       sync_op,
                       /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ReduceScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support reduce_scatter.", GetBackendName()));
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ScatterOptions& opts,
    bool sync_op) {
  return Scatter(out_tensor,
                 in_tensor,
                 opts,
                 sync_op,
                 /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
    phi::DenseTensor* out_tensor,
    const phi::DenseTensor& in_tensor,
    const ScatterOptions& opts,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support scatter.", GetBackendName()));
}

189
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
190 191 192 193 194
    phi::DenseTensor* tensor,
    int src_rank,
    int64_t offset,
    int64_t numel,
    bool sync_op) {
195 196
  return Recv(tensor,
              src_rank,
197 198
              offset,
              numel,
199 200 201 202 203 204 205
              sync_op,
              /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
    phi::DenseTensor* tensor,
    int src_rank,
206 207
    int64_t offset,
    int64_t numel,
208 209
    bool sync_op,
    bool use_calc_stream) {
210 211
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support recv.", GetBackendName()));
212 213
}

214
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
215
    phi::DenseTensor* tensor,
216
    int dst_rank,
217
    int64_t offset,
218
    int64_t numel,
219 220 221
    bool sync_op) {
  return Send(tensor,
              dst_rank,
222 223
              offset,
              numel,
224 225 226 227 228
              sync_op,
              /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
229
    phi::DenseTensor*,
230 231
    int dst_rank,
    int64_t offset,
232
    int64_t numel,
233 234
    bool sync_op,
    bool use_calc_stream) {
235 236
  PADDLE_THROW(platform::errors::Unimplemented(
      "ProcessGroup%s does not support send.", GetBackendName()));
237 238
}

239
// TODO(sunyilun): methods below will be removed later
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    bool sync_op) {
  return AllToAll(in_tensors,
                  out_tensors,
                  sync_op,
                  /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
    std::vector<phi::DenseTensor>& in_tensors,
    std::vector<phi::DenseTensor>& out_tensors,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_THROW(platform::errors::InvalidArgument(
      "ProcessGroup%s does not support do alltoall", GetBackendName()));
}

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
    std::vector<phi::DenseTensor>& tensors, int src_rank, bool sync_op) {
  return Recv(tensors,
              src_rank,
              sync_op,
              /*use_calc_stream*/ false);
}

std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
    std::vector<phi::DenseTensor>& tensors,
    int src_rank,
    bool sync_op,
    bool use_calc_stream) {
  PADDLE_THROW(platform::errors::InvalidArgument(
      "ProcessGroup%s does not support do recv", GetBackendName()));
}

276 277
}  // namespace distributed
}  // namespace paddle