sum_kernel.cu 17.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/sparse/unary_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"

namespace phi {
namespace sparse {

template <typename T, typename IntT>
__global__ void SumCooCudaKernel(const IntT* x_indices_data,
                                 const T* x_values_data,
                                 const int64_t x_nnz,
                                 const int64_t dense_dim,
                                 const int64_t sparse_dim,
                                 const int64_t axis,
                                 const bool keep_dim,
                                 IntT* out_indices_data,
                                 T* out_values_data) {
  CUDA_KERNEL_LOOP_TYPE(index_i, x_nnz, int64_t) {
    int64_t i = 0;
    for (int j = 0; j < dense_dim; ++j) {
      out_values_data[j + index_i * dense_dim] = 0;
    }

    int64_t _index_j_ =
        static_cast<int64_t>(blockIdx.y) * blockDim.y + threadIdx.y;
    for (auto index_j = _index_j_; index_j < x_nnz;
         index_j += static_cast<int64_t>(blockDim.y) * gridDim.y) {
      // Determine whether the index_i and index_j elements have the same
      // indices in all dimensions except for the specified axis dimension.
      bool same = true;
      for (int j = 0; j < sparse_dim + !keep_dim; ++j) {
        if (j != axis && x_indices_data[index_i + j * x_nnz] !=
                             x_indices_data[index_j + j * x_nnz]) {
          same = false;
          break;
        }
      }
      if (same) {
        for (int j = 0; j < dense_dim; ++j) {
          phi::CudaAtomicAdd(&out_values_data[j + index_i * dense_dim],
                             x_values_data[j + index_j * dense_dim]);
        }
      }
    }
    if (_index_j_ != 0) {
      return;
    }
    if (keep_dim) {
      for (int j = 0; j < sparse_dim; ++j) {
        if (j == axis) {
          out_indices_data[index_i + j * x_nnz] = 0;
        } else {
          out_indices_data[index_i + j * x_nnz] =
              x_indices_data[index_i + j * x_nnz];
        }
      }
      return;
    }
    for (int j = 0; j < sparse_dim; ++j) {
      // out_indices_data [sparse_dim, x.nnz()]
      int64_t x_indices_data_offset;
      if (j < axis) {
        x_indices_data_offset = index_i + j * x_nnz;
      } else {
        x_indices_data_offset = index_i + (j + 1) * x_nnz;
      }
      out_indices_data[index_i + j * x_nnz] =
          x_indices_data[x_indices_data_offset];
    }
  }
}

__global__ void SumAllCsrCudaKernel(int64_t* out_crows_data,
                                    int64_t* out_cols_data) {
  CUDA_KERNEL_LOOP_TYPE(index, 2, int64_t) {
    out_crows_data[index] = index;
    if (index == 0) {
      out_cols_data[0] = 0;
    }
  }
}

template <typename T>
__global__ void SumCsr2DCudaKernel(const int64_t* x_crows_data,
                                   const T* x_values_data,
                                   const int64_t x_dim0,
                                   int64_t* out_crows_data,
                                   int64_t* out_cols_data,
                                   T* out_values_data) {
  CUDA_KERNEL_LOOP_TYPE(index, x_dim0 + 1, int64_t) {
    out_crows_data[index] = index;
    if (index != x_dim0) {
      out_cols_data[index] = 0;
      T sum_value = 0;
      for (auto j = x_crows_data[index]; j < x_crows_data[index + 1]; ++j) {
        sum_value += x_values_data[j];
      }
      out_values_data[index] = sum_value;
    }
  }
}

template <typename T>
__global__ void SumCsr3DCudaKernel(const int64_t* x_crows_data,
                                   const T* x_values_data,
                                   const int64_t x_dim0,
                                   const int64_t x_dim1,
                                   const int64_t* batch_nnz_data,
                                   int64_t* out_crows_data,
                                   int64_t* out_cols_data,
                                   T* out_values_data) {
  CUDA_KERNEL_LOOP_TYPE(index, x_dim0 * (x_dim1 + 1), int64_t) {
    int64_t batch = index / (x_dim1 + 1);
    int64_t number = index % (x_dim1 + 1);
    out_crows_data[index] = number;
    out_cols_data[index] = 0;

    if (number != x_dim1) {
      T sum_value = 0;
      int64_t x_values_data_offset;
      if (batch == 0) {
        x_values_data_offset = 0;
      } else {
        x_values_data_offset = batch_nnz_data[batch - 1];
      }
      for (int64_t j = x_crows_data[index]; j < x_crows_data[index + 1]; ++j) {
        sum_value += x_values_data[j + x_values_data_offset];
      }
      out_values_data[index - batch] = sum_value;
    }
  }
}

template <typename T, typename IntT, typename Context>
void SumCooGPU0Kernel(const Context& dev_ctx,
                      const SparseCooTensor& x,
                      const IntArray& axis,
                      DataType dtype,
                      bool keep_dim,
                      SparseCooTensor* out) {
  auto sparse_dim = x.sparse_dim();
  // create out sparse tensor
  const auto& x_dims = x.dims();
  const auto& x_indices = x.indices();
  const auto& x_values = x.values();
  DDim out_dims;
  DenseTensor out_indices;
  DenseTensor out_values;
  if (keep_dim) {
    out_dims = make_ddim(std::vector<int64_t>(x_dims.size(), 1));
    out_indices = Empty<IntT, Context>(dev_ctx, {sparse_dim, 1});
  } else {
    out_dims = make_ddim({1});
    out_indices = Empty<IntT, Context>(dev_ctx, {1, 1});
  }
  phi::funcs::SetConstant<Context, IntT> set_out_indices;
  set_out_indices(dev_ctx, &out_indices, static_cast<IntT>(0));
  out_values = phi::Sum<T>(dev_ctx, x.values(), {}, dtype, keep_dim);
  out->SetMember(out_indices, out_values, out_dims, x.coalesced());
}

template <typename T, typename IntT, typename Context>
void SumCooGPU1Kernel(const Context& dev_ctx,
                      const SparseCooTensor& x,
                      const IntArray& axis,
                      DataType dtype,
                      bool keep_dim,
                      SparseCooTensor* out) {
  auto sparse_dim = x.sparse_dim();
  // create out sparse tensor
  const auto& x_dims = x.dims();
  const auto& x_indices = x.indices();
  const auto& x_values = x.values();
  DDim out_dims;
  DenseTensor out_indices;
  DenseTensor out_values;
  auto n_dim = x.dims().size();
  auto dim = axis[0] < 0 ? x_dims.size() + axis[0] : axis[0];

  std::vector<int64_t> dims;
  for (int i = 0; i < n_dim; ++i) {
    if (i != dim) {
      dims.emplace_back(x.dims()[i]);
    } else if (keep_dim || (dim < sparse_dim && sparse_dim == 1)) {
      dims.emplace_back(1);
    }
  }
  out_dims = make_ddim(dims);

  if (dim >= sparse_dim) {
    out_indices = x_indices;
    dim = dim - sparse_dim + 1;
    out_values = phi::Sum<T>(dev_ctx, x.values(), {dim}, dtype, keep_dim);
    out->SetMember(out_indices, out_values, out_dims, x.coalesced());
    return;
  }

  // Ensure the sparse_dim is not less than 1.
  if (sparse_dim == 1) {
    keep_dim = true;
  }
  // if axis in sparse_dim and keep_dim, sparse_dim will be reduced.
  if (!keep_dim) {
    sparse_dim -= 1;
  }

  std::vector<int> out_values_dims;
  out_values_dims.push_back(x.nnz());
  for (auto i = 1; i < x.values().dims().size(); ++i) {
    out_values_dims.push_back(static_cast<int>(x.values().dims()[i]));
  }
  int64_t dense_dim = std::accumulate(out_values_dims.begin() + 1,
                                      out_values_dims.end(),
                                      1,
                                      std::multiplies<int64_t>());

  out_indices = Empty<IntT, Context>(dev_ctx, {sparse_dim, x.nnz()});
  out_values = Empty<T, Context>(dev_ctx, out_values_dims);

  const auto* x_indices_data = x_indices.data<IntT>();
  const auto* x_values_data = x_values.data<T>();
  auto* out_indices_data = out_indices.data<IntT>();
  auto* out_values_data = out_values.data<T>();

  auto config =
      phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, x.nnz(), x.nnz());
  SumCooCudaKernel<T, IntT><<<config.block_per_grid.x,
                              config.thread_per_block.x,
                              0,
                              dev_ctx.stream()>>>(x_indices_data,
                                                  x_values_data,
                                                  x.nnz(),
                                                  dense_dim,
                                                  sparse_dim,
                                                  dim,
                                                  keep_dim,
                                                  out_indices_data,
                                                  out_values_data);
  if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) {
    out_values = phi::Cast<T, Context>(dev_ctx, out_values, dtype);
  }
  out->SetMember(out_indices, out_values, out_dims, x.coalesced());
}

template <typename T, typename Context>
void SumCooKernel(const Context& dev_ctx,
                  const SparseCooTensor& x,
                  const IntArray& axis,
                  DataType dtype,
                  bool keep_dim,
                  SparseCooTensor* out) {
  const size_t n_dim = axis.size();
  if (n_dim == 0) {
    PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "SumCooGPUKernel", ([&] {
                                   SumCooGPU0Kernel<T, data_t, Context>(
                                       dev_ctx, x, axis, dtype, keep_dim, out);
                                 }));
  } else {
    PD_VISIT_BASE_INTEGRAL_TYPES(x.indices().dtype(), "SumCooGPUKernel", ([&] {
                                   SumCooGPU1Kernel<T, data_t, Context>(
                                       dev_ctx, x, axis, dtype, keep_dim, out);
                                 }));
  }
}

template <typename T, typename Context>
void SumCsr0Kernel(const Context& dev_ctx,
                   const SparseCsrTensor& x,
                   const IntArray& axis,
                   DataType dtype,
                   bool keep_dim,
                   SparseCsrTensor* out) {
  auto x_dim0 = x.dims()[0];
  auto x_dim1 = x.dims()[1];
  const auto& x_crows = x.crows();
  const auto& x_values = x.values();
  const auto* x_crows_data = x_crows.data<int64_t>();
  const auto* x_values_data = x_values.data<T>();

  DenseTensor out_crows, out_cols, out_values;
  DDim out_dims;
  if (keep_dim && x.dims().size() == 3) {
    out_dims = make_ddim({1, 1, 1});
  } else {
    out_dims = make_ddim({1, 1});
  }
  out_crows = Empty<int64_t, Context>(dev_ctx, {2});  // crows = [0, 1]
  out_cols = Empty<int64_t, Context>(dev_ctx, {1});   // crows = [0]
  auto* out_crows_data = out_crows.data<int64_t>();
  auto* out_cols_data = out_cols.data<int64_t>();

  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, 2, 1);
  SumAllCsrCudaKernel<<<config.block_per_grid.x,
                        config.thread_per_block.x,
                        0,
                        dev_ctx.stream()>>>(out_crows_data, out_cols_data);

  out_values = phi::Sum<T>(dev_ctx, x.values(), {}, dtype, true);
  out->SetMember(out_crows, out_cols, out_values, out_dims);
}

template <typename T, typename Context>
void SumCsr1Kernel(const Context& dev_ctx,
                   const SparseCsrTensor& x,
                   const IntArray& axis,
                   DataType dtype,
                   bool keep_dim,
                   SparseCsrTensor* out) {
  auto x_dim0 = x.dims()[0];
  auto x_dim1 = x.dims()[1];
  const auto& x_crows = x.crows();
  const auto& x_values = x.values();
  const auto* x_crows_data = x_crows.data<int64_t>();
  const auto* x_values_data = x_values.data<T>();

  DenseTensor out_crows, out_cols, out_values;
  DDim out_dims;
  out_crows = EmptyLike<int64_t, Context>(dev_ctx, x.crows());
  auto* out_crows_data = out_crows.data<int64_t>();

  if (x.dims().size() == 2) {
    out_cols = Empty<int64_t, Context>(dev_ctx, {x_dim0});
    out_values = Empty<T, Context>(dev_ctx, {x_dim0});
    auto* out_cols_data = out_cols.data<int64_t>();
    auto* out_values_data = out_values.data<T>();
    out_dims = make_ddim({x_dim0, 1});
    auto config =
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_dim0 + 1, 1);
    SumCsr2DCudaKernel<T><<<config.block_per_grid.x,
                            config.thread_per_block.x,
                            0,
                            dev_ctx.stream()>>>(x_crows_data,
                                                x_values_data,
                                                x_dim0,
                                                out_crows_data,
                                                out_cols_data,
                                                out_values_data);

  } else {
    out_cols = Empty<int64_t, Context>(dev_ctx, {x_dim0 * x_dim1});
    out_values = Empty<T, Context>(dev_ctx, {x_dim0 * x_dim1});
    auto* out_cols_data = out_cols.data<int64_t>();
    auto* out_values_data = out_values.data<T>();
    if (keep_dim) {
      out_dims = make_ddim({x_dim0, x_dim1, 1});
    } else {
      out_dims = make_ddim({x_dim0, x_dim1});
    }

    DenseTensor x_crows_reshape =
        Reshape<int64_t, Context>(dev_ctx, x_crows, {x_dim0, x_dim1 + 1});
    DenseTensor last_indices = Empty<int64_t, Context>(dev_ctx, {1});
    phi::funcs::SetConstant<Context, int64_t> set_constant;
    set_constant(dev_ctx, &last_indices, x_dim1);

    DenseTensor x_crows_last = Empty<int64_t, Context>(dev_ctx, {x_dim0, 1});
    IndexSelectKernel<int64_t, Context>(
        dev_ctx, x_crows_reshape, last_indices, 1, &x_crows_last);

    DenseTensor batch_nnz = Empty<int64_t, Context>(dev_ctx, {x_dim0, 1});
    CumsumKernel<int64_t, Context>(
        dev_ctx, x_crows_last, Scalar(0), false, false, false, &batch_nnz);
    auto* batch_nnz_data = batch_nnz.data<int64_t>();

    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
        dev_ctx, x.dims()[0] * (x.dims()[1] + 1), 1);
    SumCsr3DCudaKernel<T><<<config.block_per_grid.x,
                            config.thread_per_block.x,
                            0,
                            dev_ctx.stream()>>>(x_crows_data,
                                                x_values_data,
                                                x_dim0,
                                                x_dim1,
                                                batch_nnz_data,
                                                out_crows_data,
                                                out_cols_data,
                                                out_values_data);
  }
  if (dtype != phi::DataType::UNDEFINED && dtype != x.dtype()) {
    out_values = phi::Cast<T, Context>(dev_ctx, out_values, dtype);
  }
  out->SetMember(out_crows, out_cols, out_values, out_dims);
}

template <typename T, typename Context>
void SumCsrKernel(const Context& dev_ctx,
                  const SparseCsrTensor& x,
                  const IntArray& axis,
                  DataType dtype,
                  bool keep_dim,
                  SparseCsrTensor* out) {
  size_t n_dim = axis.size();
  if (n_dim == 0) {
    SumCsr0Kernel<T, Context>(dev_ctx, x, axis, dtype, keep_dim, out);
  } else {
    PADDLE_ENFORCE_EQ(axis[0],
                      -1,
                      phi::errors::Unimplemented(
                          "`axis` of SumCsrKernel only support None or -1 now."
                          "More number will be supported in the future."));
    SumCsr1Kernel<T, Context>(dev_ctx, x, axis, dtype, keep_dim, out);
  }
}

}  // namespace sparse
}  // namespace phi

PD_REGISTER_KERNEL(sum_coo,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SumCooKernel,
                   float,
                   double,
                   int,
                   int64_t) {
  kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
}

PD_REGISTER_KERNEL(sum_csr,
                   GPU,
                   ALL_LAYOUT,
                   phi::sparse::SumCsrKernel,
                   float,
                   double,
                   int,
                   int64_t) {
  kernel->OutputAt(0).SetDataType(paddle::DataType::UNDEFINED);
}