process_group_utils.h 12.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/api/include/tensor.h"
18
#include "paddle/phi/backends/device_guard.h"
19
#include "paddle/phi/backends/device_manager.h"
20
#include "paddle/phi/core/device_context.h"
21
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
22 23

namespace paddle {
24
namespace pybind {
25

26 27
template <typename DeviceContext, typename T>
struct ConcatDenseTensor {
28
  void operator()(const DeviceContext &context,
29 30 31 32
                  const std::vector<phi::DenseTensor> &in,
                  phi::DenseTensor *out,
                  int axis = 0) {
    phi::funcs::ConcatFunctor<DeviceContext, T> concat_functor;
33
    concat_functor(context, in, axis, out);
34 35 36
  }
};

37 38
template <typename DeviceContext, typename T>
struct SplitDenseTensor {
39
  void operator()(const DeviceContext &context,
40 41 42 43 44 45 46 47
                  const phi::DenseTensor &in,
                  std::vector<phi::DenseTensor *> *out,
                  int axis = 0) {
    std::vector<const phi::DenseTensor *> shape_refer;
    shape_refer.reserve(out->size());
    for (auto *p_tensor : *out) {
      shape_refer.emplace_back(p_tensor);
    }
48
    phi::funcs::SplitFunctor<DeviceContext, T> split_functor;
49
    split_functor(context, in, shape_refer, axis, out);
50 51 52 53
  }
};

#ifdef PADDLE_WITH_CUSTOM_DEVICE
54 55
template <typename T>
struct ConcatDenseTensor<platform::CustomDeviceContext, T> {
56
  void operator()(const platform::CustomDeviceContext &context,
57 58 59 60
                  const std::vector<phi::DenseTensor> &in,
                  phi::DenseTensor *out,
                  int axis = 0) {
    auto *out_data = out->data<T>();
61
    auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
62 63 64 65 66 67 68 69 70 71
    size_t offset = 0;
    for (const auto &tensor : in) {
      const auto *in_data = tensor.data<T>();
      auto sz = tensor.numel() * sizeof(T);
      device->MemoryCopyD2D(out_data + offset, in_data, sz, nullptr);
      offset += sz;
    }
  }
};

72 73
template <typename T>
struct SplitDenseTensor<platform::CustomDeviceContext, T> {
74
  void operator()(const platform::CustomDeviceContext &context,
75
                  const phi::DenseTensor &in,
76 77
                  std::vector<phi::DenseTensor *> *out,
                  int axis = 0) {
78
    auto *in_data = in.data<T>();
79
    auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
80 81 82 83 84 85 86 87 88 89 90
    size_t offset = 0;
    for (auto *p_tensor : *out) {
      auto *out_data = p_tensor->data<T>();
      auto sz = p_tensor->numel() * sizeof(T);
      device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr);
      offset += sz;
    }
  }
};
#endif

91
template <typename DeviceContext>
92
void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
                               const std::vector<phi::DenseTensor> &t_list,
                               phi::DenseTensor *p_out,
                               phi::DataType type) {
  switch (type) {
    case phi::DataType::BOOL:
      ConcatDenseTensor<DeviceContext, bool>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::UINT8:
      ConcatDenseTensor<DeviceContext, uint8_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT8:
      ConcatDenseTensor<DeviceContext, int8_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT32:
      ConcatDenseTensor<DeviceContext, int32_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT64:
      ConcatDenseTensor<DeviceContext, int64_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::FLOAT16:
113
      ConcatDenseTensor<DeviceContext, phi::dtype::float16>()(
114 115
          dev_ctx, t_list, p_out);
      break;
116 117 118 119
    case phi::DataType::BFLOAT16:
      ConcatDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
          dev_ctx, t_list, p_out);
      break;
120 121 122 123 124 125 126 127 128 129 130 131
    case phi::DataType::FLOAT32:
      ConcatDenseTensor<DeviceContext, float>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::FLOAT64:
      ConcatDenseTensor<DeviceContext, double>()(dev_ctx, t_list, p_out);
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Data type (%s) is not supported when it concats tensors.", type));
  }
}

132 133 134 135 136 137 138 139 140 141 142 143 144 145
#ifdef PADDLE_WITH_XPU
template <>
void ConcatDenseTensorWithType(const phi::XPUContext &dev_ctx,
                               const std::vector<phi::DenseTensor> &t_list,
                               phi::DenseTensor *p_out,
                               phi::DataType type) {
  switch (type) {
    case phi::DataType::FLOAT16:
      ConcatDenseTensor<phi::XPUContext, phi::dtype::float16>()(
          dev_ctx, t_list, p_out);
      break;
    case phi::DataType::FLOAT32:
      ConcatDenseTensor<phi::XPUContext, float>()(dev_ctx, t_list, p_out);
      break;
146 147 148 149 150 151
    case phi::DataType::INT32:
      ConcatDenseTensor<phi::XPUContext, int32_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT64:
      ConcatDenseTensor<phi::XPUContext, int64_t>()(dev_ctx, t_list, p_out);
      break;
152 153 154
    case phi::DataType::UINT8:
      ConcatDenseTensor<phi::XPUContext, uint8_t>()(dev_ctx, t_list, p_out);
      break;
155 156 157 158 159 160 161
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Data type (%s) is not supported when it concats tensors.", type));
  }
}
#endif

162
template <typename DeviceContext>
163
void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
164
                              const phi::DenseTensor &t_in,
165 166 167 168
                              std::vector<phi::DenseTensor *> *p_list,
                              phi::DataType type) {
  switch (type) {
    case phi::DataType::BOOL:
169
      SplitDenseTensor<DeviceContext, bool>()(dev_ctx, t_in, p_list);
170 171
      break;
    case phi::DataType::UINT8:
172
      SplitDenseTensor<DeviceContext, uint8_t>()(dev_ctx, t_in, p_list);
173 174
      break;
    case phi::DataType::INT8:
175
      SplitDenseTensor<DeviceContext, int8_t>()(dev_ctx, t_in, p_list);
176 177
      break;
    case phi::DataType::INT32:
178
      SplitDenseTensor<DeviceContext, int32_t>()(dev_ctx, t_in, p_list);
179 180
      break;
    case phi::DataType::INT64:
181
      SplitDenseTensor<DeviceContext, int64_t>()(dev_ctx, t_in, p_list);
182 183
      break;
    case phi::DataType::FLOAT16:
184
      SplitDenseTensor<DeviceContext, phi::dtype::float16>()(
185
          dev_ctx, t_in, p_list);
186
      break;
187 188 189 190
    case phi::DataType::BFLOAT16:
      SplitDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
          dev_ctx, t_in, p_list);
      break;
191
    case phi::DataType::FLOAT32:
192
      SplitDenseTensor<DeviceContext, float>()(dev_ctx, t_in, p_list);
193 194
      break;
    case phi::DataType::FLOAT64:
195
      SplitDenseTensor<DeviceContext, double>()(dev_ctx, t_in, p_list);
196 197 198
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
199 200 201 202
          "Data type (%s) is not supported when it splits tensors.", type));
  }
}

203 204 205 206 207 208 209 210 211 212 213 214 215 216
#ifdef PADDLE_WITH_XPU
template <>
void SplitDenseTensorWithType(const phi::XPUContext &dev_ctx,
                              const phi::DenseTensor &t_in,
                              std::vector<phi::DenseTensor *> *p_list,
                              phi::DataType type) {
  switch (type) {
    case phi::DataType::FLOAT16:
      SplitDenseTensor<phi::XPUContext, phi::dtype::float16>()(
          dev_ctx, t_in, p_list);
      break;
    case phi::DataType::FLOAT32:
      SplitDenseTensor<phi::XPUContext, float>()(dev_ctx, t_in, p_list);
      break;
217 218 219 220 221 222
    case phi::DataType::INT32:
      SplitDenseTensor<phi::XPUContext, int32_t>()(dev_ctx, t_in, p_list);
      break;
    case phi::DataType::INT64:
      SplitDenseTensor<phi::XPUContext, int64_t>()(dev_ctx, t_in, p_list);
      break;
223 224 225
    case phi::DataType::UINT8:
      SplitDenseTensor<phi::XPUContext, uint8_t>()(dev_ctx, t_in, p_list);
      break;
226 227 228 229 230 231 232
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Data type (%s) is not supported when it splits tensors.", type));
  }
}
#endif

233
void ConcatTensor(const phi::DeviceContext &dev_ctx,
234 235 236 237 238
                  const std::vector<phi::DenseTensor> &tensor_list,
                  const experimental::Tensor *tensor) {
  auto *dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(tensor->impl()).get();

239
  const auto &place = dev_ctx.GetPlace();
240 241
  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
242
    ConcatDenseTensorWithType(static_cast<const phi::GPUContext &>(dev_ctx),
243 244 245 246 247 248 249
                              tensor_list,
                              dense_tensor,
                              tensor->dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't concat tensor since it's not support GPU, please "
        "recompile or reinstall Paddle with GPU support."));
250 251 252 253 254 255 256 257 258 259 260
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    ConcatDenseTensorWithType(static_cast<const phi::XPUContext &>(dev_ctx),
                              tensor_list,
                              dense_tensor,
                              tensor->dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't concat tensor since it's not support XPU, please "
        "recompile or reinstall Paddle with XPU support."));
261 262 263 264
#endif
  } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    ConcatDenseTensorWithType(
265
        static_cast<const platform::CustomDeviceContext &>(dev_ctx),
266 267 268 269 270 271 272 273 274 275
        tensor_list,
        dense_tensor,
        tensor->dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't concat tensor since it's not compiled with "
        "CUSTOM_DEVICE, please recompile or reinstall Paddle with "
        "CUSTOM_DEVICE support."));
#endif
  } else if (platform::is_cpu_place(place)) {
276
    ConcatDenseTensorWithType(static_cast<const phi::CPUContext &>(dev_ctx),
277 278 279 280 281 282
                              tensor_list,
                              dense_tensor,
                              tensor->dtype());
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Concat tensor not supported on place (%s)", place));
283 284 285
  }
}

286
void SplitTensor(const phi::DeviceContext &dev_ctx,
287 288 289 290
                 const phi::DenseTensor &tensor,
                 const std::vector<experimental::Tensor> *tensor_list) {
  std::vector<phi::DenseTensor *> dense_list;
  for (auto &tensor : *tensor_list) {
291
    auto *p_tensor =
292 293 294 295
        std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()).get();
    dense_list.emplace_back(p_tensor);
  }

296
  const auto &place = dev_ctx.GetPlace();
297 298
  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
299
    SplitDenseTensorWithType(static_cast<const phi::GPUContext &>(dev_ctx),
300 301 302 303 304
                             tensor,
                             &dense_list,
                             tensor.dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
305 306
        "Paddle can't split tensor since it's not support GPU, please "
        "recompile or reinstall Paddle with GPU support."));
307 308 309 310 311 312 313 314 315 316 317
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    SplitDenseTensorWithType(static_cast<const phi::XPUContext &>(dev_ctx),
                             tensor,
                             &dense_list,
                             tensor.dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't split tensor since it's not compiled with XPU, "
        "please recompile or reinstall Paddle with XPU support."));
318 319 320 321
#endif
  } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    SplitDenseTensorWithType(
322
        static_cast<const platform::CustomDeviceContext &>(dev_ctx),
323 324 325 326 327 328 329 330 331
        tensor,
        &dense_list,
        tensor.dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't split tensor since it's not compiled with CUSTOM_DEVICE, "
        "please recompile or reinstall Paddle with CUSTOM_DEVICE support."));
#endif
  } else if (platform::is_cpu_place(place)) {
332
    SplitDenseTensorWithType(static_cast<const phi::CPUContext &>(dev_ctx),
333 334 335 336 337 338 339 340 341
                             tensor,
                             &dense_list,
                             tensor.dtype());
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Split tensor not supported on place (%s)", place));
  }
}

342 343 344 345 346 347
inline std::vector<int64_t> GetDefaultSplitSizes(const phi::DenseTensor &tensor,
                                                 int world_size) {
  return std::vector<int64_t>(world_size, tensor.dims()[0] / world_size);
}

}  //  namespace pybind
348
}  //  namespace paddle