add_n_kernel.cu 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/add_n_kernel.h"

17
#include "paddle/phi/common/amp_type_traits.h"
18
#include "paddle/phi/common/memory_utils.h"
19
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
20 21 22 23 24 25 26
namespace phi {

#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))

template <class T>
__global__ void SumArrayCUDAKernel(
    T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
27
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
28 29
  CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
    MPType total(read_dst ? static_cast<MPType>(out[idx])
30
                          : static_cast<MPType>(0));
31 32 33
    for (int i = 0; i < in_size; ++i) {
      const T *tmp = in[i];
      if (tmp) {
34
        total += static_cast<MPType>(tmp[idx]);
35 36
      }
    }
37
    out[idx] = static_cast<T>(total);
38 39 40
  }
}

Y
YuanRisheng 已提交
41 42 43 44
template <class T>
__global__ void SumSelectedRowsCUDAKernel(T **sr_in_out,
                                          int64_t N,
                                          size_t rows) {
45
  CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
Y
YuanRisheng 已提交
46 47 48 49
    for (int i = 0; i < 2 * rows; i += 2) {
      const T *tmp = sr_in_out[i];
      T *tmp_out = sr_in_out[i + 1];
      if (tmp && tmp_out) {
50
        tmp_out[idx] += tmp[idx];
Y
YuanRisheng 已提交
51 52 53 54 55
      }
    }
  }
}

56 57
template <typename T, typename Context>
void AddNKernel(const Context &dev_ctx,
Y
YuanRisheng 已提交
58
                const std::vector<const TensorBase *> &x,
59 60
                DenseTensor *out) {
  const size_t in_num = x.size();
61 62 63 64 65 66 67
  for (int i = 0; i < in_num; ++i) {
    PADDLE_ENFORCE_EQ(
        x[i]->initialized(),
        true,
        phi::errors::InvalidArgument(
            "This argument is invalid, %d-th tensor is uninitialized.", i));
  }
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

  constexpr size_t theory_sm_threads = 1024;
  auto stream = dev_ctx.stream();

  auto max_threads = dev_ctx.GetMaxPhysicalThreadCount();
  auto sm_count = max_threads / theory_sm_threads;
  size_t tile_size = 0;
  dim3 grids;
  dim3 blocks;

  auto ComputeKernelParameter = [&](size_t length) {
    if (length >= max_threads)
      tile_size = 1024;
    else if (length < max_threads && length > sm_count * 128)
      tile_size = 512;
    else if (length <= sm_count * 128)
      tile_size = 256;
    grids = dim3(CEIL_DIV(length, tile_size), 1, 1);
    blocks = dim3(tile_size, 1, 1);
  };
Y
YuanRisheng 已提交
88 89 90 91 92 93 94
  auto *out_ptr = dev_ctx.template Alloc<T>(out);
  bool in_place = false;
  if (x.size() > 0 && x[0]->initialized() && DenseTensor::classof(x[0])) {
    if ((static_cast<const DenseTensor *>(x[0]))->data() == out->data()) {
      in_place = true;
    }
  }
95

Y
YuanRisheng 已提交
96 97 98 99
  if (!in_place && in_num >= 1 && DenseTensor::classof(x[0])) {
    auto &in_0_tensor = *(static_cast<const DenseTensor *>(x[0]));
    if (in_0_tensor.numel() > 0) {
      in_place = (in_0_tensor.data<T>() == out_ptr);
100 101 102 103
    }
  }

  // Sum of two tensors
Y
YuanRisheng 已提交
104 105 106
  if (in_num == 2 && DenseTensor::classof(x[0]) && DenseTensor::classof(x[1])) {
    auto &in_0 = *(static_cast<const DenseTensor *>(x[0]));
    auto &in_1 = *(static_cast<const DenseTensor *>(x[1]));
107 108
    int64_t length_0 = in_0.numel();
    int64_t length_1 = in_1.numel();
Y
YuanRisheng 已提交
109
    if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) {
110
      using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
111 112
      auto result = EigenVector<T>::Flatten(*out);
      auto &place = *dev_ctx.eigen_device();
113 114 115
      auto in_0_e = EigenVector<T>::Flatten(in_0).template cast<MPType>();
      auto in_1_e = EigenVector<T>::Flatten(in_1).template cast<MPType>();
      result.device(place) = (in_0_e + in_1_e).template cast<T>();
Y
YuanRisheng 已提交
116
    } else if (length_0 && in_0.IsInitialized()) {
117 118 119
      auto result = EigenVector<T>::Flatten(*out);
      auto &place = *dev_ctx.eigen_device();
      result.device(place) = EigenVector<T>::Flatten(in_0);
Y
YuanRisheng 已提交
120
    } else if (length_1 && in_1.IsInitialized()) {
121 122 123 124 125 126 127 128 129
      auto result = EigenVector<T>::Flatten(*out);
      auto &place = *dev_ctx.eigen_device();
      result.device(place) = EigenVector<T>::Flatten(in_1);
    }
    return;
  }

  int start = in_place ? 1 : 0;
  if (!in_place) {
Y
YuanRisheng 已提交
130
    phi::funcs::SetConstant<phi::GPUContext, T> constant_functor;
131 132 133 134
    constant_functor(dev_ctx, out, static_cast<T>(0));
  }

  std::vector<const T *> in_data;
Y
YuanRisheng 已提交
135
  std::vector<int> selectrow_index;
136 137 138
  int64_t lod_length = 0;
  bool dst_write = false;
  for (int i = start; i < in_num; ++i) {
Y
YuanRisheng 已提交
139 140 141 142 143 144 145 146
    if (DenseTensor::classof(x[i])) {
      auto &in_i = *(static_cast<const DenseTensor *>(x[i]));
      lod_length = in_i.numel();
      if (lod_length && in_i.IsInitialized()) {
        in_data.emplace_back(in_i.data<T>());
      }
    } else if (SelectedRows::classof(x[i])) {
      selectrow_index.push_back(i);
147 148 149
    }
  }

Y
YuanRisheng 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
  // compute select rows separately.
  if (!selectrow_index.empty()) {
    std::vector<const T *> sr_in_out_data;
    size_t rows = 0;
    int64_t length = 0;
    for (auto index : selectrow_index) {
      auto &sr = *(static_cast<const SelectedRows *>(x[index]));
      auto &sr_value = sr.value();
      auto &sr_rows = sr.rows();

      auto row_numel = sr_value.numel() / sr_rows.size();
      auto out_dims = out->dims();

      PADDLE_ENFORCE_EQ(sr.height(),
                        out_dims[0],
                        errors::InvalidArgument(
                            "The table height of input must be same as output, "
                            "but received input height is %d"
                            ", output height is %d",
                            sr.height(),
                            out_dims[0]));
      PADDLE_ENFORCE_EQ(row_numel,
                        out->numel() / sr.height(),
                        errors::InvalidArgument(
                            "The table width of input must be same as output, "
                            "but received input width is %d"
                            ", output width is %d",
                            row_numel,
                            out->numel() / sr.height()));

      auto *sr_data = sr_value.data<T>();
      auto *sr_out_data = out->data<T>();
      rows += sr_rows.size();
      length = row_numel;

      for (size_t i = 0; i < sr_rows.size(); ++i) {
        sr_in_out_data.emplace_back(&sr_data[i * row_numel]);
        sr_in_out_data.emplace_back(&sr_out_data[sr_rows[i] * row_numel]);
      }
    }
    if (!sr_in_out_data.empty()) {
191
      auto tmp_sr_in_out_array = phi::memory_utils::Alloc(
Y
YuanRisheng 已提交
192 193
          dev_ctx.GetPlace(), sr_in_out_data.size() * sizeof(T *));

194 195 196 197 198 199
      memory_utils::Copy(dev_ctx.GetPlace(),
                         tmp_sr_in_out_array->ptr(),
                         phi::CPUPlace(),
                         reinterpret_cast<void *>(sr_in_out_data.data()),
                         sr_in_out_data.size() * sizeof(T *),
                         dev_ctx.stream());
Y
YuanRisheng 已提交
200 201 202 203 204 205 206 207 208 209

      T **sr_in_out_array_data =
          reinterpret_cast<T **>(tmp_sr_in_out_array->ptr());

      ComputeKernelParameter(length);
      SumSelectedRowsCUDAKernel<T>
          <<<grids, blocks, 0, stream>>>(sr_in_out_array_data, length, rows);
      dst_write = true;
    }
  }
210 211
  // if indata not null, merge into one kernel call.
  if (!in_data.empty()) {
212 213
    auto tmp_in_array = phi::memory_utils::Alloc(dev_ctx.GetPlace(),
                                                 in_data.size() * sizeof(T *));
214

215 216 217 218 219 220
    memory_utils::Copy(dev_ctx.GetPlace(),
                       tmp_in_array->ptr(),
                       phi::CPUPlace(),
                       reinterpret_cast<void *>(in_data.data()),
                       in_data.size() * sizeof(T *),
                       dev_ctx.stream());
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241

    T **in_array_data = reinterpret_cast<T **>(tmp_in_array->ptr());
    ComputeKernelParameter(lod_length);
    SumArrayCUDAKernel<T><<<grids, blocks, 0, stream>>>(in_array_data,
                                                        out->data<T>(),
                                                        lod_length,
                                                        in_data.size(),
                                                        dst_write | in_place);
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(add_n,
                   GPU,
                   ALL_LAYOUT,
                   phi::AddNKernel,
                   float,
                   double,
                   int,
                   phi::dtype::bfloat16,
Y
YuanRisheng 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254
                   phi::dtype::float16,
                   int64_t) {}

PD_REGISTER_KERNEL(add_n_array,
                   GPU,
                   ALL_LAYOUT,
                   phi::AddNArrayKernel,
                   float,
                   double,
                   int,
                   phi::dtype::bfloat16,
                   phi::dtype::float16,
                   int64_t) {}