process_group_utils.h 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/api/include/tensor.h"
18
#include "paddle/phi/backends/device_guard.h"
19
#include "paddle/phi/backends/device_manager.h"
20
#include "paddle/phi/core/device_context.h"
21
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
22 23

namespace paddle {
24
namespace pybind {
25

26 27
template <typename DeviceContext, typename T>
struct ConcatDenseTensor {
28
  void operator()(const DeviceContext &context,
29 30 31 32
                  const std::vector<phi::DenseTensor> &in,
                  phi::DenseTensor *out,
                  int axis = 0) {
    phi::funcs::ConcatFunctor<DeviceContext, T> concat_functor;
33
    concat_functor(context, in, axis, out);
34 35 36
  }
};

37 38
template <typename DeviceContext, typename T>
struct SplitDenseTensor {
39
  void operator()(const DeviceContext &context,
40 41 42 43 44 45 46 47
                  const phi::DenseTensor &in,
                  std::vector<phi::DenseTensor *> *out,
                  int axis = 0) {
    std::vector<const phi::DenseTensor *> shape_refer;
    shape_refer.reserve(out->size());
    for (auto *p_tensor : *out) {
      shape_refer.emplace_back(p_tensor);
    }
48
    phi::funcs::SplitFunctor<DeviceContext, T> split_functor;
49
    split_functor(context, in, shape_refer, axis, out);
50 51 52 53
  }
};

#ifdef PADDLE_WITH_CUSTOM_DEVICE
54 55
template <typename T>
struct ConcatDenseTensor<platform::CustomDeviceContext, T> {
56
  void operator()(const platform::CustomDeviceContext &context,
57 58 59 60
                  const std::vector<phi::DenseTensor> &in,
                  phi::DenseTensor *out,
                  int axis = 0) {
    auto *out_data = out->data<T>();
61
    auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
62 63 64 65 66 67 68 69 70 71
    size_t offset = 0;
    for (const auto &tensor : in) {
      const auto *in_data = tensor.data<T>();
      auto sz = tensor.numel() * sizeof(T);
      device->MemoryCopyD2D(out_data + offset, in_data, sz, nullptr);
      offset += sz;
    }
  }
};

72 73
template <typename T>
struct SplitDenseTensor<platform::CustomDeviceContext, T> {
74
  void operator()(const platform::CustomDeviceContext &context,
75
                  const phi::DenseTensor &in,
76 77
                  std::vector<phi::DenseTensor *> *out,
                  int axis = 0) {
78
    auto *in_data = in.data<T>();
79
    auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
80 81 82 83 84 85 86 87 88 89 90
    size_t offset = 0;
    for (auto *p_tensor : *out) {
      auto *out_data = p_tensor->data<T>();
      auto sz = p_tensor->numel() * sizeof(T);
      device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr);
      offset += sz;
    }
  }
};
#endif

91
template <typename DeviceContext>
92
void ConcatDenseTensorWithType(const DeviceContext &dev_ctx,
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
                               const std::vector<phi::DenseTensor> &t_list,
                               phi::DenseTensor *p_out,
                               phi::DataType type) {
  switch (type) {
    case phi::DataType::BOOL:
      ConcatDenseTensor<DeviceContext, bool>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::UINT8:
      ConcatDenseTensor<DeviceContext, uint8_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT8:
      ConcatDenseTensor<DeviceContext, int8_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT32:
      ConcatDenseTensor<DeviceContext, int32_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::INT64:
      ConcatDenseTensor<DeviceContext, int64_t>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::FLOAT16:
113
      ConcatDenseTensor<DeviceContext, phi::dtype::float16>()(
114 115
          dev_ctx, t_list, p_out);
      break;
116 117 118 119
    case phi::DataType::BFLOAT16:
      ConcatDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
          dev_ctx, t_list, p_out);
      break;
120 121 122 123 124 125 126 127 128 129 130 131
    case phi::DataType::FLOAT32:
      ConcatDenseTensor<DeviceContext, float>()(dev_ctx, t_list, p_out);
      break;
    case phi::DataType::FLOAT64:
      ConcatDenseTensor<DeviceContext, double>()(dev_ctx, t_list, p_out);
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Data type (%s) is not supported when it concats tensors.", type));
  }
}

132
template <typename DeviceContext>
133
void SplitDenseTensorWithType(const DeviceContext &dev_ctx,
134
                              const phi::DenseTensor &t_in,
135 136 137 138
                              std::vector<phi::DenseTensor *> *p_list,
                              phi::DataType type) {
  switch (type) {
    case phi::DataType::BOOL:
139
      SplitDenseTensor<DeviceContext, bool>()(dev_ctx, t_in, p_list);
140 141
      break;
    case phi::DataType::UINT8:
142
      SplitDenseTensor<DeviceContext, uint8_t>()(dev_ctx, t_in, p_list);
143 144
      break;
    case phi::DataType::INT8:
145
      SplitDenseTensor<DeviceContext, int8_t>()(dev_ctx, t_in, p_list);
146 147
      break;
    case phi::DataType::INT32:
148
      SplitDenseTensor<DeviceContext, int32_t>()(dev_ctx, t_in, p_list);
149 150
      break;
    case phi::DataType::INT64:
151
      SplitDenseTensor<DeviceContext, int64_t>()(dev_ctx, t_in, p_list);
152 153
      break;
    case phi::DataType::FLOAT16:
154
      SplitDenseTensor<DeviceContext, phi::dtype::float16>()(
155
          dev_ctx, t_in, p_list);
156
      break;
157 158 159 160
    case phi::DataType::BFLOAT16:
      SplitDenseTensor<DeviceContext, phi::dtype::bfloat16>()(
          dev_ctx, t_in, p_list);
      break;
161
    case phi::DataType::FLOAT32:
162
      SplitDenseTensor<DeviceContext, float>()(dev_ctx, t_in, p_list);
163 164
      break;
    case phi::DataType::FLOAT64:
165
      SplitDenseTensor<DeviceContext, double>()(dev_ctx, t_in, p_list);
166 167 168
      break;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
169 170 171 172
          "Data type (%s) is not supported when it splits tensors.", type));
  }
}

173
void ConcatTensor(const phi::DeviceContext &dev_ctx,
174 175 176 177 178
                  const std::vector<phi::DenseTensor> &tensor_list,
                  const experimental::Tensor *tensor) {
  auto *dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(tensor->impl()).get();

179
  const auto &place = dev_ctx.GetPlace();
180 181
  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
182
    ConcatDenseTensorWithType(static_cast<const phi::GPUContext &>(dev_ctx),
183 184 185 186 187 188 189 190 191 192 193
                              tensor_list,
                              dense_tensor,
                              tensor->dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't concat tensor since it's not support GPU, please "
        "recompile or reinstall Paddle with GPU support."));
#endif
  } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    ConcatDenseTensorWithType(
194
        static_cast<const platform::CustomDeviceContext &>(dev_ctx),
195 196 197 198 199 200 201 202 203 204
        tensor_list,
        dense_tensor,
        tensor->dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't concat tensor since it's not compiled with "
        "CUSTOM_DEVICE, please recompile or reinstall Paddle with "
        "CUSTOM_DEVICE support."));
#endif
  } else if (platform::is_cpu_place(place)) {
205
    ConcatDenseTensorWithType(static_cast<const phi::CPUContext &>(dev_ctx),
206 207 208 209 210 211
                              tensor_list,
                              dense_tensor,
                              tensor->dtype());
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Concat tensor not supported on place (%s)", place));
212 213 214
  }
}

215
void SplitTensor(const phi::DeviceContext &dev_ctx,
216 217 218 219
                 const phi::DenseTensor &tensor,
                 const std::vector<experimental::Tensor> *tensor_list) {
  std::vector<phi::DenseTensor *> dense_list;
  for (auto &tensor : *tensor_list) {
220
    auto *p_tensor =
221 222 223 224
        std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()).get();
    dense_list.emplace_back(p_tensor);
  }

225
  const auto &place = dev_ctx.GetPlace();
226 227
  if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
228
    SplitDenseTensorWithType(static_cast<const phi::GPUContext &>(dev_ctx),
229 230 231 232 233
                             tensor,
                             &dense_list,
                             tensor.dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
234 235
        "Paddle can't split tensor since it's not support GPU, please "
        "recompile or reinstall Paddle with GPU support."));
236 237 238 239
#endif
  } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    SplitDenseTensorWithType(
240
        static_cast<const platform::CustomDeviceContext &>(dev_ctx),
241 242 243 244 245 246 247 248 249
        tensor,
        &dense_list,
        tensor.dtype());
#else
    PADDLE_THROW(platform::errors::PermissionDenied(
        "Paddle can't split tensor since it's not compiled with CUSTOM_DEVICE, "
        "please recompile or reinstall Paddle with CUSTOM_DEVICE support."));
#endif
  } else if (platform::is_cpu_place(place)) {
250
    SplitDenseTensorWithType(static_cast<const phi::CPUContext &>(dev_ctx),
251 252 253 254 255 256 257 258 259
                             tensor,
                             &dense_list,
                             tensor.dtype());
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Split tensor not supported on place (%s)", place));
  }
}

260 261 262 263 264 265
inline std::vector<int64_t> GetDefaultSplitSizes(const phi::DenseTensor &tensor,
                                                 int world_size) {
  return std::vector<int64_t>(world_size, tensor.dims()[0] / world_size);
}

}  //  namespace pybind
266
}  //  namespace paddle