/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" #include #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" namespace phi { namespace sparse { template __global__ void MaskKernel(const T* x_ptr, const IntT* indices_ptr, const int64_t* sparse_offsets, const int64_t non_zero_num, const int cols, const int sparse_dim, T* out_values_ptr) { CUDA_KERNEL_LOOP_TYPE(i, non_zero_num * cols, int64_t) { int64_t out_i = i / cols; int64_t col_i = i - out_i * cols; int64_t index = 0; for (int j = 0; j < sparse_dim; j++) { index += indices_ptr[j * non_zero_num + out_i] * sparse_offsets[j]; } out_values_ptr[out_i * cols + col_i] = x_ptr[index * cols + col_i]; } } template void SparseMaskGPUKernel(const GPUContext& dev_ctx, const DenseTensor& x, const SparseCooTensor& mask, SparseCooTensor* out) { const DDim& dims = x.dims(); PADDLE_ENFORCE_EQ( x.dims(), mask.dims(), phi::errors::InvalidArgument("the input x and mask must have the shape")); const DenseTensor& indices = mask.non_zero_indices(); const DenseTensor& values = mask.non_zero_elements(); const int sparse_dim = mask.sparse_dim(); DenseTensor sparse_offsets = phi::Empty( dev_ctx, DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW)); std::vector h_sparse_offsets(sparse_dim); phi::funcs::sparse::CalcOffsetsPerDim( dims, sparse_dim, h_sparse_offsets.data()); phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data(), &h_sparse_offsets[0], sizeof(int64_t) * sparse_dim, #ifdef PADDLE_WITH_HIP hipMemcpyHostToDevice, #else cudaMemcpyHostToDevice, #endif dev_ctx.stream()); DenseTensor out_indices = phi::EmptyLike(dev_ctx, indices); DenseTensor out_values = phi::EmptyLike(dev_ctx, values); phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices); const IntT* indices_ptr = indices.data(); T* out_values_ptr = out_values.data(); const T* x_ptr = x.data(); const int64_t non_zero_num = mask.nnz(); auto dims_2d = flatten_to_2d(dims, sparse_dim); const int cols = dims_2d[1]; auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num * cols, 1); MaskKernel<<>>( x_ptr, indices_ptr, sparse_offsets.data(), non_zero_num, cols, sparse_dim, out_values_ptr); out->SetMember(out_indices, out_values, dims, true); } /** * @brief Filter the DenseTensor x by the * mask.non_zero_indices() and output a SparseCooTensor * x and mask must have the same shape. **/ template void SparseMaskKernel(const Context& dev_ctx, const DenseTensor& x, const SparseCooTensor& mask, SparseCooTensor* out) { PD_VISIT_INTEGRAL_TYPES( mask.non_zero_indices().dtype(), "SparseMaskGPUKernel", ([&] { SparseMaskGPUKernel(dev_ctx, x, mask, out); })); } template __global__ void SparseMaskCopyKernel(const IntT* x_indexs, const IntT* mask_indexs, const IntT* bound_out, const T* x_values, const int64_t n, const int64_t stride, T* out_values) { CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { const IntT j = bound_out[i]; if (j >= 0 && j < n && mask_indexs[i] == x_indexs[j]) { for (int k = 0; k < stride; k++) { out_values[i * stride + k] = x_values[j * stride + k]; } } } } template void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, const SparseCooTensor& x, const DenseTensor& mask_indices, DenseTensor* out) { PADDLE_ENFORCE_EQ( mask_indices.dims().size(), 2, phi::errors::InvalidArgument("the mask_indices must be 2-D tensor")); const int32_t sparse_dim = x.sparse_dim(); auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); std::vector sparse_offsets(sparse_dim); DenseTensorMeta x_indexs_meta(indices_dtype, {x.nnz()}, DataLayout::NCHW); DenseTensorMeta mask_indexs_meta( indices_dtype, {mask_indices.dims()[1]}, DataLayout::NCHW); DenseTensorMeta sparse_offset_meta( indices_dtype, {sparse_dim}, DataLayout::NCHW); DenseTensor x_indexs = phi::Empty(dev_ctx, std::move(x_indexs_meta)); DenseTensor mask_indexs = phi::Empty(dev_ctx, std::move(mask_indexs_meta)); DenseTensor bound_out = phi::Empty(dev_ctx, std::move(mask_indexs_meta)); DenseTensor d_sparse_offsets = phi::Empty(dev_ctx, std::move(sparse_offset_meta)); IntT* x_indexs_ptr = x_indexs.data(); IntT* mask_indexs_ptr = mask_indexs.data(); IntT* bound_out_ptr = bound_out.data(); // 1. calc the offsets of per dim phi::funcs::sparse::CalcOffsetsPerDim( x.dims(), sparse_dim, sparse_offsets.data()); // 2. copy sparse_offsets to device phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data(), sparse_offsets.data(), sizeof(IntT) * sparse_dim, #ifdef PADDLE_WITH_HIP hipMemcpyHostToDevice, #else cudaMemcpyHostToDevice, #endif dev_ctx.stream()); // 3. flatten x indices and mask indices auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1); phi::funcs::sparse::FlattenIndicesKernel<<>>( x.non_zero_indices().data(), d_sparse_offsets.data(), x_indexs.numel(), sparse_dim, x_indexs_ptr); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); phi::funcs::sparse::FlattenIndicesKernel<<>>( mask_indices.data(), d_sparse_offsets.data(), mask_indexs.numel(), sparse_dim, mask_indexs_ptr); // 4. call thrust::lower_bound #ifdef PADDLE_WITH_HIP thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()), #else thrust::lower_bound(thrust::cuda::par.on(dev_ctx.stream()), #endif x_indexs_ptr, x_indexs_ptr + x_indexs.numel(), mask_indexs_ptr, mask_indexs_ptr + mask_indexs.numel(), bound_out_ptr); // 5. copy value to out *out = phi::EmptyLike(dev_ctx, x.non_zero_elements()); phi::funcs::SetConstant set_zero; set_zero(dev_ctx, out, static_cast(0)); T* out_ptr = out->data(); const int64_t stride = x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; SparseMaskCopyKernel<<>>(x_indexs_ptr, mask_indexs_ptr, bound_out_ptr, x.non_zero_elements().data(), mask_indexs.numel(), stride, out_ptr); } template void SparseMaskHelperKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& mask_indices, DenseTensor* out) { PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "SparseMaskHelperGPUKernel", ([&] { SparseMaskHelperGPUKernel(dev_ctx, x, mask_indices, out); })); } } // namespace sparse } // namespace phi PD_REGISTER_KERNEL(sparse_mask, GPU, ALL_LAYOUT, phi::sparse::SparseMaskKernel, float, double, phi::dtype::float16, uint8_t, int8_t, int16_t, int, int64_t) { kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } PD_REGISTER_KERNEL(sparse_mask_helper, GPU, ALL_LAYOUT, phi::sparse::SparseMaskHelperKernel, float, double, phi::dtype::float16, uint8_t, int16_t, int, int64_t) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); }