process_group_utils.h 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/api/include/tensor.h"
18
#include "paddle/phi/backends/device_guard.h"
19
#include "paddle/phi/backends/device_manager.h"
20
#include "paddle/phi/core/device_context.h"
21
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
22 23

namespace paddle {
24
namespace pybind {
25

26 27
template <typename DeviceContext, typename T>
struct ConcatDenseTensor {
28
  void operator()(const DeviceContext &context,
29 30 31 32
                  const std::vector<phi::DenseTensor> &in,
                  phi::DenseTensor *out,
                  int axis = 0) {
    phi::funcs::ConcatFunctor<DeviceContext, T> concat_functor;
33
    concat_functor(context, in, axis, out);
34 35 36
  }
};

37 38
template <typename DeviceContext, typename T>
struct SplitDenseTensor {
39
  void operator()(const DeviceContext &context,
40 41 42 43 44 45 46 47
                  const phi::DenseTensor &in,
                  std::vector<phi::DenseTensor *> *out,
                  int axis = 0) {
    std::vector<const phi::DenseTensor *> shape_refer;
    shape_refer.reserve(out->size());
    for (auto *p_tensor : *out) {
      shape_refer.emplace_back(p_tensor);
    }
48
    phi::funcs::SplitFunctor<DeviceContext, T> split_functor;
49
    split_functor(context, in, shape_refer, axis, out);
50 51 52 53
  }
};

#ifdef PADDLE_WITH_CUSTOM_DEVICE
54 55
template <typename T>
struct ConcatDenseTensor<platform::CustomDeviceContext, T> {
56
  void operator()(const platform::CustomDeviceContext &context,
57 58 59 60
                  const std::vector<phi::DenseTensor> &in,
                  phi::DenseTensor *out,
                  int axis = 0) {
    auto *out_data = out->data<T>();
61
    auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
62 63 64 65 66 67 68 69 70 71
    size_t offset = 0;
    for (const auto &tensor : in) {
      const auto *in_data = tensor.data<T>();
      auto sz = tensor.numel() * sizeof(T);
      device->MemoryCopyD2D(out_data + offset, in_data, sz, nullptr);
      offset += sz;
    }
  }
};

72 73
template <typename T>
struct SplitDenseTensor<platform::CustomDeviceContext, T> {
74
  void operator()(const platform::CustomDeviceContext &context,
75
                  const phi::DenseTensor &in,
76 77
                  std::vector<phi::DenseTensor *> *out,
                  int axis = 0) {
78
    auto *in_data = in.data<T>();
79
    auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
80 81 82 83 84 85 86 87 88 89 90
    size_t offset = 0;
    for (auto *p_tensor : *out) {
      auto *out_data = p_tensor->data<T>();
      auto sz = p_tensor->numel() * sizeof(T);
      device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr);
      offset += sz;
    }
  }
};
#endif

91
template <typename DeviceContext>
92
void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
                               const std::vector<phi::DenseTensor> &t_list,
                               phi::DenseTensor *p_out,
                               phi::DataType type) {
  switch (type) {
    case phi::DataType::BOOL:
      ConcatDenseTensor<DeviceContext, bool>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::UINT8:
      ConcatDenseTensor<DeviceContext, uint8_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT8:
      ConcatDenseTensor<DeviceContext, int8_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT32:
      ConcatDenseTensor<DeviceContext, int32_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT64:
      ConcatDenseTensor<DeviceContext, int64_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::FLOAT16:
113
      ConcatDenseTensor<DeviceContext, phi::dtype::float16>()(
114 115
          dev_ctx, t_list, p_out);
      break;
116 117 118 119
    case phi::DataType::BFLOAT16:
      ConcatDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
          dev_ctx, t_list, p_out);
      break;
120 121 122 123 124 125 126 127 128 129 130 131
    case phi::DataType::FLOAT32:
      ConcatDenseTensor<DeviceContext, float>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::FLOAT64:
      ConcatDenseTensor<DeviceContext, double>()(dev_ctx, t_list, p_out);
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Data type (%s) is not supported when it concats tensors.", type));
  }
}

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
#ifdef PADDLE_WITH_XPU
template <>
void ConcatDenseTensorWithType(const phi::XPUContext &dev_ctx,
                               const std::vector<phi::DenseTensor> &t_list,
                               phi::DenseTensor *p_out,
                               phi::DataType type) {
  switch (type) {
    case phi::DataType::FLOAT16:
      ConcatDenseTensor<phi::XPUContext, phi::dtype::float16>()(
          dev_ctx, t_list, p_out);
      break;
    case phi::DataType::FLOAT32:
      ConcatDenseTensor<phi::XPUContext, float>()(dev_ctx, t_list, p_out);
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Data type (%s) is not supported when it concats tensors.", type));
  }
}
#endif

153
template <typename DeviceContext>
154
void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
155
                              const phi::DenseTensor &t_in,
156 157 158 159
                              std::vector<phi::DenseTensor *> *p_list,
                              phi::DataType type) {
  switch (type) {
    case phi::DataType::BOOL:
160
      SplitDenseTensor<DeviceContext, bool>()(dev_ctx, t_in, p_list);
161 162
      break;
    case phi::DataType::UINT8:
163
      SplitDenseTensor<DeviceContext, uint8_t>()(dev_ctx, t_in, p_list);
164 165
      break;
    case phi::DataType::INT8:
166
      SplitDenseTensor<DeviceContext, int8_t>()(dev_ctx, t_in, p_list);
167 168
      break;
    case phi::DataType::INT32:
169
      SplitDenseTensor<DeviceContext, int32_t>()(dev_ctx, t_in, p_list);
170 171
      break;
    case phi::DataType::INT64:
172
      SplitDenseTensor<DeviceContext, int64_t>()(dev_ctx, t_in, p_list);
173 174
      break;
    case phi::DataType::FLOAT16:
175
      SplitDenseTensor<DeviceContext, phi::dtype::float16>()(
176
          dev_ctx, t_in, p_list);
177
      break;
178 179 180 181
    case phi::DataType::BFLOAT16:
      SplitDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
          dev_ctx, t_in, p_list);
      break;
182
    case phi::DataType::FLOAT32:
183
      SplitDenseTensor<DeviceContext, float>()(dev_ctx, t_in, p_list);
184 185
      break;
    case phi::DataType::FLOAT64:
186
      SplitDenseTensor<DeviceContext, double>()(dev_ctx, t_in, p_list);
187 188 189
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
190 191 192 193
          "Data type (%s) is not supported when it splits tensors.", type));
  }
}

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
#ifdef PADDLE_WITH_XPU
template <>
void SplitDenseTensorWithType(const phi::XPUContext &dev_ctx,
                              const phi::DenseTensor &t_in,
                              std::vector<phi::DenseTensor *> *p_list,
                              phi::DataType type) {
  switch (type) {
    case phi::DataType::FLOAT16:
      SplitDenseTensor<phi::XPUContext, phi::dtype::float16>()(
          dev_ctx, t_in, p_list);
      break;
    case phi::DataType::FLOAT32:
      SplitDenseTensor<phi::XPUContext, float>()(dev_ctx, t_in, p_list);
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Data type (%s) is not supported when it splits tensors.", type));
  }
}
#endif

215
void ConcatTensor(const phi::DeviceContext &dev_ctx,
216 217 218 219 220
                  const std::vector<phi::DenseTensor> &tensor_list,
                  const experimental::Tensor *tensor) {
  auto *dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(tensor->impl()).get();

221
  const auto &place = dev_ctx.GetPlace();
222 223
  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
224
    ConcatDenseTensorWithType(static_cast<const phi::GPUContext &>(dev_ctx),
225 226 227 228 229 230 231
                              tensor_list,
                              dense_tensor,
                              tensor->dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't concat tensor since it's not support GPU, please "
        "recompile or reinstall Paddle with GPU support."));
232 233 234 235 236 237 238 239 240 241 242
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    ConcatDenseTensorWithType(static_cast<const phi::XPUContext &>(dev_ctx),
                              tensor_list,
                              dense_tensor,
                              tensor->dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't concat tensor since it's not support XPU, please "
        "recompile or reinstall Paddle with XPU support."));
243 244 245 246
#endif
  } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    ConcatDenseTensorWithType(
247
        static_cast<const platform::CustomDeviceContext &>(dev_ctx),
248 249 250 251 252 253 254 255 256 257
        tensor_list,
        dense_tensor,
        tensor->dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't concat tensor since it's not compiled with "
        "CUSTOM_DEVICE, please recompile or reinstall Paddle with "
        "CUSTOM_DEVICE support."));
#endif
  } else if (platform::is_cpu_place(place)) {
258
    ConcatDenseTensorWithType(static_cast<const phi::CPUContext &>(dev_ctx),
259 260 261 262 263 264
                              tensor_list,
                              dense_tensor,
                              tensor->dtype());
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Concat tensor not supported on place (%s)", place));
265 266 267
  }
}

268
void SplitTensor(const phi::DeviceContext &dev_ctx,
269 270 271 272
                 const phi::DenseTensor &tensor,
                 const std::vector<experimental::Tensor> *tensor_list) {
  std::vector<phi::DenseTensor *> dense_list;
  for (auto &tensor : *tensor_list) {
273
    auto *p_tensor =
274 275 276 277
        std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()).get();
    dense_list.emplace_back(p_tensor);
  }

278
  const auto &place = dev_ctx.GetPlace();
279 280
  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
281
    SplitDenseTensorWithType(static_cast<const phi::GPUContext &>(dev_ctx),
282 283 284 285 286
                             tensor,
                             &dense_list,
                             tensor.dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
287 288
        "Paddle can't split tensor since it's not support GPU, please "
        "recompile or reinstall Paddle with GPU support."));
289 290 291 292 293 294 295 296 297 298 299
#endif
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
    SplitDenseTensorWithType(static_cast<const phi::XPUContext &>(dev_ctx),
                             tensor,
                             &dense_list,
                             tensor.dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't split tensor since it's not compiled with XPU, "
        "please recompile or reinstall Paddle with XPU support."));
300 301 302 303
#endif
  } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    SplitDenseTensorWithType(
304
        static_cast<const platform::CustomDeviceContext &>(dev_ctx),
305 306 307 308 309 310 311 312 313
        tensor,
        &dense_list,
        tensor.dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't split tensor since it's not compiled with CUSTOM_DEVICE, "
        "please recompile or reinstall Paddle with CUSTOM_DEVICE support."));
#endif
  } else if (platform::is_cpu_place(place)) {
314
    SplitDenseTensorWithType(static_cast<const phi::CPUContext &>(dev_ctx),
315 316 317 318 319 320 321 322 323
                             tensor,
                             &dense_list,
                             tensor.dtype());
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Split tensor not supported on place (%s)", place));
  }
}

324 325 326 327 328 329
inline std::vector<int64_t> GetDefaultSplitSizes(const phi::DenseTensor &tensor,
                                                 int world_size) {
  return std::vector<int64_t>(world_size, tensor.dims()[0] / world_size);
}

}  //  namespace pybind
330
}  //  namespace paddle