diff --git a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/mask_kernel.cc similarity index 99% rename from paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc rename to paddle/phi/kernels/sparse/cpu/mask_kernel.cc index cf2acd85573331adc954acb852c6d31890e80238..92c015101264c0fc259434a06ea8a59355ac381f 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/mask_kernel.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" +#include "paddle/phi/kernels/sparse/mask_kernel.h" #include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/core/ddim.h" diff --git a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu similarity index 72% rename from paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu rename to paddle/phi/kernels/sparse/gpu/mask_kernel.cu index 21d6850bdc4aa8b41a3e9e289777e5d260300cf2..39fa89c0379b7881d1d69a0f3a55883432f64f2f 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" - -#include +#include "paddle/phi/kernels/sparse/mask_kernel.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" @@ -24,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h" @@ -72,11 +71,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data(), &h_sparse_offsets[0], sizeof(int64_t) * sparse_dim, -#ifdef PADDLE_WITH_HIP - hipMemcpyHostToDevice, -#else - cudaMemcpyHostToDevice, -#endif + gpuMemcpyHostToDevice, dev_ctx.stream()); DenseTensor out_indices = phi::EmptyLike(dev_ctx, indices); @@ -93,14 +88,15 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num * cols, 1); - MaskKernel<<>>( - x_ptr, - indices_ptr, - sparse_offsets.data(), - non_zero_num, - cols, - sparse_dim, - out_values_ptr); + MaskKernel + <<>>( + x_ptr, + indices_ptr, + sparse_offsets.data(), + non_zero_num, + cols, + sparse_dim, + out_values_ptr); out->SetMember(out_indices, out_values, dims, true); } @@ -121,19 +117,31 @@ void SparseMaskKernel(const Context& dev_ctx, })); } -template -__global__ void SparseMaskCopyKernel(const IntT* x_indexs, - const IntT* mask_indexs, - const IntT* bound_out, - const T* x_values, - const int64_t n, - const int64_t stride, - T* out_values) { +template +__global__ void MaskTable(const IntT* x_indexs, const int n, int* table) { + CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { + int index = x_indexs[i]; + table[index] = i == 0 ? -1 : i; + } +} + +template +__global__ void MaskCopy(const IntT* mask_indexs, + const int* table, + const int n, + const int stride, + const T* x_values, + T* out_values) { + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { - const IntT j = bound_out[i]; - if (j >= 0 && j < n && mask_indexs[i] == x_indexs[j]) { - for (int k = 0; k < stride; k++) { - out_values[i * stride + k] = x_values[j * stride + k]; + int j = table[mask_indexs[i]]; + if (j != 0) { + if (j == -1) j = 0; + for (int k = 0; k < stride; k += VecSize) { + LoadT vec_x; + phi::Load(x_values + j * stride + k, &vec_x); + phi::Store(vec_x, out_values + i * stride + k); } } } @@ -179,11 +187,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data(), sparse_offsets.data(), sizeof(IntT) * sparse_dim, -#ifdef PADDLE_WITH_HIP - hipMemcpyHostToDevice, -#else - cudaMemcpyHostToDevice, -#endif + gpuMemcpyHostToDevice, dev_ctx.stream()); // 3. flatten x indices and mask indices @@ -210,37 +214,54 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, mask_indexs.numel(), sparse_dim, mask_indexs_ptr); -// 4. call thrust::lower_bound -#ifdef PADDLE_WITH_HIP - thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()), -#else - thrust::lower_bound(thrust::cuda::par.on(dev_ctx.stream()), -#endif - x_indexs_ptr, - x_indexs_ptr + x_indexs.numel(), - mask_indexs_ptr, - mask_indexs_ptr + mask_indexs.numel(), - bound_out_ptr); - // 5. copy value to out + int table_size = 1; + auto x_dims = x.dims(); + for (int i = 0; i < x_dims.size() - 1; i++) { + table_size *= x_dims[i]; + } + DenseTensor table = phi::Empty(dev_ctx, {table_size}); + phi::backends::gpu::GpuMemsetAsync( + table.data(), 0, table_size * sizeof(int), dev_ctx.stream()); + const int64_t stride = + x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; *out = phi::EmptyLike(dev_ctx, x.non_zero_elements()); phi::funcs::SetConstant set_zero; set_zero(dev_ctx, out, static_cast(0)); T* out_ptr = out->data(); - - const int64_t stride = - x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1]; - - SparseMaskCopyKernel<<>>(x_indexs_ptr, - mask_indexs_ptr, - bound_out_ptr, - x.non_zero_elements().data(), - mask_indexs.numel(), - stride, - out_ptr); + config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1); + MaskTable<<>>( + x_indexs_ptr, x_indexs.numel(), table.data()); + config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); + const int VecBytes = 16; + const int VecSize = VecBytes / sizeof(T); + if (stride % VecSize == 0) { + MaskCopy + <<>>(mask_indexs_ptr, + table.data(), + mask_indexs.numel(), + stride, + x.non_zero_elements().data(), + out_ptr); + } else { + MaskCopy<<>>(mask_indexs_ptr, + table.data(), + mask_indexs.numel(), + stride, + x.non_zero_elements().data(), + out_ptr); + } } template @@ -257,7 +278,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx, } // namespace sparse } // namespace phi -PD_REGISTER_KERNEL(sparse_mask, +PD_REGISTER_KERNEL(mask, GPU, ALL_LAYOUT, phi::sparse::SparseMaskKernel, @@ -272,7 +293,7 @@ PD_REGISTER_KERNEL(sparse_mask, kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } -PD_REGISTER_KERNEL(sparse_mask_helper, +PD_REGISTER_KERNEL(mask_helper, GPU, ALL_LAYOUT, phi::sparse::SparseMaskHelperKernel, diff --git a/paddle/phi/kernels/sparse/sparse_mask_kernel.h b/paddle/phi/kernels/sparse/mask_kernel.h similarity index 100% rename from paddle/phi/kernels/sparse/sparse_mask_kernel.h rename to paddle/phi/kernels/sparse/mask_kernel.h diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc index 69677be34b23194a2daa6181990cd347b3f0c6b7..9425c14b79b36d1d6db4867bd3b5f2abda0c24dc 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" namespace phi { namespace sparse { diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h index a00b9c275c2923a4f4d36cc0ed877d76c9bf54a1..7cf97c3f48ece8f03bb94bd7a2e447eea28e9bea 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h @@ -16,7 +16,7 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" -#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" +#include "paddle/phi/kernels/sparse/mask_kernel.h" namespace phi { namespace sparse {