// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/sparse/unary_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" namespace phi { namespace sparse { __global__ void TransposeCooCudaKernel(const int64_t *x_indices_data, const int *perm, const std::size_t n_dim, const int64_t x_nnz, int64_t *out_indices_data) { CUDA_KERNEL_LOOP_TYPE(index, x_nnz * n_dim, int64_t) { int64_t i = index / x_nnz; int64_t j = index % x_nnz; out_indices_data[index] = x_indices_data[j + perm[i] * x_nnz]; } } template __global__ void TransposeCsr2DCudaKernel(const int64_t *x_crows_data, const int64_t *x_cols_data, const T *x_values_data, const int *perm, const int64_t *x_dims, const int64_t *out_dims, const int64_t x_nnz, int64_t *out_crows_data, int64_t *out_cols_data, T *out_values_data) { int64_t __index__ = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; // compute out_crows_data by x_cols_data for (int64_t i = __index__; i <= out_dims[0]; i += blockDim.x * gridDim.x) { out_crows_data[i] = 0; } __syncthreads(); if (__index__ == 0) { for (int64_t i = 0; i < x_nnz; ++i) { int j = x_cols_data[i]; out_crows_data[j + 2]++; } for (int64_t i = 0; i < out_dims[0]; i += 1) { out_crows_data[i + 1] += out_crows_data[i]; } // compute out_cols_data and out_values_data by out_crows_data and x for (int i = 0; i < x_dims[0]; ++i) { int64_t start = x_crows_data[i]; int64_t end = x_crows_data[i + 1]; for (int64_t j = start; j < end; ++j) { int64_t x_cols_j = x_cols_data[j] + 1; int64_t jjj = out_crows_data[x_cols_j]; out_cols_data[jjj] = i; out_values_data[jjj] = x_values_data[j]; out_crows_data[x_cols_j]++; } } } } template __global__ void TransposeCsr3DCudaKernel(const int64_t *x_crows_data, const int64_t *x_cols_data, const T *x_values_data, const int *perm, const int64_t *x_dims, const int64_t *out_dims, const std::size_t n_dim, const int64_t x_nnz, int64_t *out_crows_data, int64_t *out_cols_data, T *out_values_data) { int64_t __index__ = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (__index__ == 0) { int out_n_rows = out_dims[1]; int x_n_rows = x_dims[1]; for (int k = 0; k < out_dims[0]; ++k) { if (perm[0] == 0) { // dims == {0, 2, 1} // compute out_crows_data by x_cols_data for (int i = 0; i <= out_n_rows; ++i) { out_crows_data[i] = 0; } for (int i = 0; i < x_crows_data[x_n_rows]; ++i) { int j = x_cols_data[i]; out_crows_data[j + 2]++; } for (int i = 0; i < out_n_rows; ++i) { out_crows_data[i + 1] += out_crows_data[i]; } // compute out_cols_data and out_values_data by out_crows_data and x for (int i = 0; i < x_n_rows; ++i) { int64_t start = x_crows_data[i]; int64_t end = x_crows_data[i + 1]; for (int64_t j = start; j < end; ++j) { int64_t x_cols_j = x_cols_data[j] + 1; int64_t jjj = out_crows_data[x_cols_j]; out_cols_data[jjj] = i; out_values_data[jjj] = x_values_data[j]; out_crows_data[x_cols_j]++; } } // x offset x_cols_data += x_crows_data[x_n_rows]; x_values_data += x_crows_data[x_n_rows]; x_crows_data += x_n_rows + 1; } else if (perm[0] == 1 && perm[1] == 0) { // perm == {1, 0, 2} for (int i = 0; i < out_n_rows; ++i) { out_crows_data[i] = 0; } int x_cols_offset = 0; int out_cols_index = 0; for (int i = 0; i < x_dims[0]; ++i) { int x_crows_index = i * (x_n_rows + 1); int start = x_crows_data[x_crows_index + k]; int end = x_crows_data[x_crows_index + 1 + k]; out_crows_data[i + 1] = end - start; for (int j = start; j < end; ++j) { out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j]; out_values_data[out_cols_index] = x_values_data[x_cols_offset + j]; out_cols_index++; } x_cols_offset += x_crows_data[x_crows_index + x_n_rows]; } for (int i = 1; i <= out_n_rows; ++i) { out_crows_data[i] += out_crows_data[i - 1]; } } // out offset out_cols_data += out_crows_data[out_n_rows]; out_values_data += out_crows_data[out_n_rows]; out_crows_data += out_n_rows + 1; } } } template void TransposeCooKernel(const Context &dev_ctx, const SparseCooTensor &x, const std::vector &perm, SparseCooTensor *out) { // create out sparse tensor int64_t x_nnz = x.nnz(); std::size_t n_dim = perm.size(); DDim out_dims = x.dims().transpose(perm); DenseTensor out_indices = EmptyLike(dev_ctx, x.indices()); DenseTensor out_values(x.values()); out->SetMember(out_indices, out_values, out_dims, x.coalesced()); // compute values of indices const DenseTensor &x_indices = x.indices(); const auto *x_indices_data = x_indices.data(); auto *out_indices_data = out_indices.data(); int *d_perm; #ifdef PADDLE_WITH_HIP hipMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); hipMemcpy( d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice); #else cudaMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); cudaMemcpy( d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice); #endif auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz * n_dim, 1); TransposeCooCudaKernel<<>>( x_indices_data, d_perm, n_dim, x_nnz, out_indices_data); } template void TransposeCsrKernel(const Context &dev_ctx, const SparseCsrTensor &x, const std::vector &perm, SparseCsrTensor *out) { std::size_t n_dim = perm.size(); const DenseTensor &x_crows = x.crows(); const DenseTensor &x_cols = x.cols(); const DenseTensor &x_values = x.non_zero_elements(); DenseTensor out_crows, out_cols, out_values; // return a copy of x if (perm[0] == 0 && perm[1] == 1 && (n_dim == 2 || perm[2] == 2)) { out_crows = x_crows; out_cols = x_cols; out_values = x_values; out->SetMember(out_crows, out_cols, out_values, x.dims()); return; } // create out sparse tensor DDim out_dims = x.dims().transpose(perm); if (n_dim == 2) { out_crows = Empty(dev_ctx, {out_dims[0] + 1}); } else { out_crows = Empty(dev_ctx, {out_dims[0] * (out_dims[1] + 1)}); } out_cols = EmptyLike(dev_ctx, x.cols()); out_values = EmptyLike(dev_ctx, x.values()); out->SetMember(out_crows, out_cols, out_values, out_dims); // transpose by two stages if (perm[0] == 1 && perm[1] == 2) { // perm == {1, 2, 0} SparseCsrTensor temp; TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); TransposeCsrKernel(dev_ctx, temp, {0, 2, 1}, out); return; } else if (perm[0] == 2 && perm[1] == 0) { // perm == {2, 0, 1} SparseCsrTensor temp; TransposeCsrKernel(dev_ctx, x, {0, 2, 1}, &temp); TransposeCsrKernel(dev_ctx, temp, {1, 0, 2}, out); return; } else if (perm[0] == 2 && perm[1] == 1) { // perm == {2, 1, 0} SparseCsrTensor temp; TransposeCsrKernel(dev_ctx, x, {1, 0, 2}, &temp); TransposeCsrKernel(dev_ctx, temp, {2, 0, 1}, out); return; } int64_t *out_crows_data = out_crows.data(); int64_t *out_cols_data = out_cols.data(); T *out_values_data = out_values.data(); const int64_t *x_crows_data = x_crows.data(); const int64_t *x_cols_data = x_cols.data(); const T *x_values_data = x_values.data(); int *d_perm; int64_t *d_x_dims, *d_out_dims; #ifdef PADDLE_WITH_HIP hipMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); hipMemcpy( d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice); hipMalloc(reinterpret_cast(&d_x_dims), sizeof(int64_t) * x.dims().size()); hipMemcpy(d_x_dims, x.dims().Get(), sizeof(int64_t) * x.dims().size(), hipMemcpyHostToDevice); hipMalloc(reinterpret_cast(&d_out_dims), sizeof(int64_t) * out_dims.size()); hipMemcpy(d_out_dims, out_dims.Get(), sizeof(int64_t) * out_dims.size(), hipMemcpyHostToDevice); #else cudaMalloc(reinterpret_cast(&d_perm), sizeof(int) * perm.size()); cudaMemcpy( d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice); cudaMalloc(reinterpret_cast(&d_x_dims), sizeof(int64_t) * x.dims().size()); cudaMemcpy(d_x_dims, x.dims().Get(), sizeof(int64_t) * x.dims().size(), cudaMemcpyHostToDevice); cudaMalloc(reinterpret_cast(&d_out_dims), sizeof(int64_t) * out_dims.size()); cudaMemcpy(d_out_dims, out_dims.Get(), sizeof(int64_t) * out_dims.size(), cudaMemcpyHostToDevice); #endif int64_t x_nnz = x.nnz(); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_dims[0], 1); if (perm.size() == 2) { TransposeCsr2DCudaKernel<<>>(x_crows_data, x_cols_data, x_values_data, d_perm, d_x_dims, d_out_dims, x_nnz, out_crows_data, out_cols_data, out_values_data); } else { TransposeCsr3DCudaKernel<<<1, 1, 0, dev_ctx.stream()>>>(x_crows_data, x_cols_data, x_values_data, d_perm, d_x_dims, d_out_dims, perm.size(), x_nnz, out_crows_data, out_cols_data, out_values_data); } } } // namespace sparse } // namespace phi PD_REGISTER_KERNEL(transpose_coo, GPU, ALL_LAYOUT, phi::sparse::TransposeCooKernel, phi::dtype::float16, float, double, int8_t, uint8_t, int16_t, int, int64_t, bool) {} PD_REGISTER_KERNEL(transpose_csr, GPU, ALL_LAYOUT, phi::sparse::TransposeCsrKernel, phi::dtype::float16, float, double, int8_t, uint8_t, int16_t, int, int64_t, bool) {}