From e52ffb704e835aae2e52dbb6f266d532d467719e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 18 Mar 2022 19:03:03 +0800 Subject: [PATCH] Add Sparse OP Maxpool (#40569) sparse maxpool; kernel_registry support sparse tensor --- paddle/phi/core/kernel_registry.h | 32 ++ paddle/phi/kernels/funcs/pooling.h | 2 +- paddle/phi/kernels/funcs/sparse/convolution.h | 20 + .../sparse/cpu/sparse_pool_grad_kernel.cc | 73 ++++ .../kernels/sparse/cpu/sparse_pool_kernel.cc | 108 +++++ .../sparse/gpu/sparse_pool_grad_kernel.cu | 120 ++++++ .../kernels/sparse/gpu/sparse_pool_kernel.cu | 140 +++++++ .../kernels/sparse/sparse_pool_grad_kernel.h | 49 +++ .../phi/kernels/sparse/sparse_pool_kernel.h | 53 +++ paddle/phi/tests/kernels/CMakeLists.txt | 1 + .../tests/kernels/test_sparse_pool_dev_api.cc | 391 ++++++++++++++++++ 11 files changed, 988 insertions(+), 1 deletion(-) create mode 100644 paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc create mode 100644 paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc create mode 100644 paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu create mode 100644 paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu create mode 100644 paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h create mode 100644 paddle/phi/kernels/sparse/sparse_pool_kernel.h create mode 100644 paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index d9ed68593cd..c3356eadcbd 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -98,6 +98,28 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) { + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); + } else if (arg_type == std::type_index(typeid( + paddle::optional))) { + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); + } else if (arg_type == std::type_index(typeid(const SparseCsrTensor&))) { + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); + } else if (arg_type == std::type_index(typeid( + paddle::optional))) { + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(DenseTensor*))) { args_def->AppendOutput(default_key.backend(), default_tensor_layout, @@ -114,6 +136,16 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(SparseCooTensor*))) { + args_def->AppendOutput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); + } else if (arg_type == std::type_index(typeid(SparseCsrTensor*))) { + args_def->AppendOutput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else { // Attribute deal with // TODO(chenweihang): now here allow any types of attribute, maybe diff --git a/paddle/phi/kernels/funcs/pooling.h b/paddle/phi/kernels/funcs/pooling.h index 19c6d52c4c9..fa285dc69d1 100644 --- a/paddle/phi/kernels/funcs/pooling.h +++ b/paddle/phi/kernels/funcs/pooling.h @@ -43,7 +43,7 @@ template class MaxPool { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } - DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; } + HOSTDEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; } DEVICE inline void finalize(const T& pool_field, T* y) {} }; diff --git a/paddle/phi/kernels/funcs/sparse/convolution.h b/paddle/phi/kernels/funcs/sparse/convolution.h index d82d793e534..19f1f3d3cd2 100644 --- a/paddle/phi/kernels/funcs/sparse/convolution.h +++ b/paddle/phi/kernels/funcs/sparse/convolution.h @@ -165,6 +165,26 @@ inline void SubmPreProcess(const Context& dev_ctx, x_grad_ptr); } +inline const std::vector PoolResetKernel( + const std::vector& kernel_sizes, + const int in_channels, + const int out_channels) { + std::vector res(kernel_sizes); + res.resize(5); + res[3] = in_channels; + res[4] = out_channels; + return res; +} + +inline void PrefixSum(const int* counter, int* offsets, const int n) { + int offset = 0; + for (int i = 0; i < n; i++) { + offsets[i] = offset; + offset += counter[i]; + } + offsets[n] = offset; +} + } // namespace sparse } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc new file mode 100644 index 00000000000..3010d480b55 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/funcs/sparse/convolution.h" + +namespace phi { +namespace sparse { + +template +void MaxPoolGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const DenseTensor& out_grad, + const std::vector& kernel_sizes, + DenseTensor* x_grad) { + int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; + const int channels = x.dims()[4]; + int rulebook_len = rulebook.dims()[1]; + const int* rulebook_ptr = rulebook.data(); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0); + for (int i = 0; i < rulebook_len; i++) { + counter[rulebook_ptr[i]] += 1; + } + phi::funcs::sparse::PrefixSum(&counter[0], &offsets[0], kernel_size); + + const T* in_features_ptr = x.non_zero_elements().data(); + const T* out_features_ptr = out.non_zero_elements().data(); + const T* out_grad_ptr = out_grad.data(); + T* x_grad_ptr = x_grad->data(); + memset(x_grad_ptr, 0, sizeof(T) * x_grad->numel()); + + phi::funcs::MaxPoolGrad grad_functor; + for (int i = 0; i < kernel_size; i++) { + for (int j = 0; j < counter[i]; j++) { + int in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; + int out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; + for (int c = 0; c < channels; c++) { + grad_functor.compute(in_features_ptr[in_i * channels + c], + out_features_ptr[out_i * channels + c], + out_grad_ptr[out_i * channels + c], + 1, + &x_grad_ptr[in_i * channels + c]); + } + } + } +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sparse_maxpool_grad, + CPU, + ALL_LAYOUT, + phi::sparse::MaxPoolGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc new file mode 100644 index 00000000000..86971242df5 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/funcs/sparse/convolution.h" +#include "paddle/phi/kernels/sparse/cpu/convolution.h" + +namespace phi { +namespace sparse { + +/** + * x: (N, D, H, W, C) + * kernel: (D, H, W, C, OC) + * out: (N, D, H, W, OC) +**/ +template +void MaxPoolKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { + const auto& x_dims = x.dims(); + int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; + const std::vector& real_kernel_sizes = + phi::funcs::sparse::PoolResetKernel(kernel_sizes, x_dims[4], x_dims[4]); + DDim out_dims = {1, 1, 1, 1, 1}; + phi::funcs::sparse::GetOutShape( + x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims); + const int in_channels = real_kernel_sizes[3]; + + DenseTensorMeta counter_meta( + DataType::INT32, {kernel_size}, DataLayout::NCHW); + DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); + + const T* in_features_ptr = x.non_zero_elements().data(); + // 1. product rule book + ProductRuleBook(dev_ctx, + x, + real_kernel_sizes, + paddings, + dilations, + strides, + out_dims, + false, + rulebook, + &counter_per_kernel); + + UpdateRulebookAndOutIndex( + dev_ctx, x, kernel_size, in_channels, out_dims, rulebook, out); + + int rulebook_len = rulebook->dims()[1]; + const int* rulebook_ptr = rulebook->data(); + const int* counter_ptr = counter_per_kernel.data(); + + std::vector offsets(kernel_size + 1); + phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size); + std::vector out_flags(out->nnz(), false); + + // 2. max pool + T* out_features_ptr = out->mutable_non_zero_elements()->data(); + phi::funcs::MaxPool max_pool_functor; + for (int i = 0; i < kernel_size; i++) { + for (int j = 0; j < counter_ptr[i]; j++) { + int in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; + int out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; + if (!out_flags[out_i]) { + out_flags[out_i] = true; + memcpy(&out_features_ptr[out_i * in_channels], + &in_features_ptr[in_i * in_channels], + in_channels * sizeof(T)); + } else { + for (int c = 0; c < in_channels; c++) { + max_pool_functor.compute(in_features_ptr[in_i * in_channels + c], + &out_features_ptr[out_i * in_channels + c]); + } + } + } + } +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sparse_maxpool, + CPU, + ALL_LAYOUT, + phi::sparse::MaxPoolKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu new file mode 100644 index 00000000000..1048dd1be0c --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu @@ -0,0 +1,120 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/funcs/sparse/convolution.h" + +#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" + +namespace phi { +namespace sparse { + +template +__global__ void MaxPoolGradCudaKernel(const T* in_features_ptr, + const T* out_features_ptr, + const T* out_grad_ptr, + const int* rulebook_ptr, + const int n, + const int rulebook_len, + const int channels, + T* x_grad_ptr) { + phi::funcs::MaxPoolGrad grad_functor; + CUDA_KERNEL_LOOP_TYPE(i, n * channels, int64_t) { + int real_i = i / channels; + int c = i - real_i * channels; + int in_i = rulebook_ptr[real_i]; + int out_i = rulebook_ptr[real_i + rulebook_len]; + grad_functor.compute(in_features_ptr[in_i * channels + c], + out_features_ptr[out_i * channels + c], + out_grad_ptr[out_i * channels + c], + 1, + &x_grad_ptr[in_i * channels + c]); + } +} + +template +void MaxPoolGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const DenseTensor& out_grad, + const std::vector& kernel_sizes, + DenseTensor* x_grad) { + int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; + const int in_channels = x.dims()[4]; + int rulebook_len = rulebook.dims()[1]; + const int* rulebook_ptr = rulebook.data(); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0), + h_counter(kernel_size); + phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], + rulebook_ptr, + rulebook_len * sizeof(int), +#ifdef PADDLE_WITH_HIP + hipMemcpyDeviceToHost, +#else + cudaMemcpyDeviceToHost, +#endif + + dev_ctx.stream()); + dev_ctx.Wait(); + for (int i = 0; i < rulebook_len; i++) { + counter[h_counter[i]] += 1; + } + phi::funcs::sparse::PrefixSum(&counter[0], &offsets[0], kernel_size); + + const T* in_features_ptr = x.non_zero_elements().data(); + const T* out_features_ptr = out.non_zero_elements().data(); + const T* out_grad_ptr = out_grad.data(); + T* x_grad_ptr = x_grad->data(); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, x_grad, static_cast(0.0f)); + + for (int i = 0; i < kernel_size; i++) { + if (counter[i] <= 0) { + continue; + } + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, counter[i] * in_channels, 1); + MaxPoolGradCudaKernel<<>>( + in_features_ptr, + out_features_ptr, + out_grad_ptr, + rulebook_ptr + offsets[i] + rulebook_len, + counter[i], + rulebook_len, + in_channels, + x_grad_ptr); + } +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sparse_maxpool_grad, + GPU, + ALL_LAYOUT, + phi::sparse::MaxPoolGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu new file mode 100644 index 00000000000..0f6a0d13b1d --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu @@ -0,0 +1,140 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/funcs/sparse/convolution.h" +#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" +#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" + +namespace phi { +namespace sparse { + +template +__global__ void MaxPoolCudaKernel(const T* in_features_ptr, + const int* rulebook_ptr, + const int n, + const int rulebook_len, + const int channels, + T* out_features_ptr) { + phi::funcs::MaxPool max_pool_functor; + CUDA_KERNEL_LOOP_TYPE(i, n * channels, int64_t) { + int real_i = i / channels; + int channel_i = i - real_i * channels; + int in_i = rulebook_ptr[real_i]; + int out_i = rulebook_ptr[real_i + rulebook_len]; + max_pool_functor.compute(in_features_ptr[in_i * channels + channel_i], + &out_features_ptr[out_i * channels + channel_i]); + } +} + +/** + * x: (N, D, H, W, C) + * kernel: (D, H, W, C, OC) + * out: (N, D, H, W, OC) +**/ +template +void MaxPoolKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { + const auto& x_dims = x.dims(); + int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; + const std::vector& real_kernel_sizes = + phi::funcs::sparse::PoolResetKernel(kernel_sizes, x_dims[4], x_dims[4]); + DDim out_dims = {1, 1, 1, 1, 1}; + phi::funcs::sparse::GetOutShape( + x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims); + const int in_channels = real_kernel_sizes[3]; + + std::vector offsets(kernel_size + 1), counter(kernel_size); + DenseTensorMeta counter_meta( + DataType::INT32, {kernel_size}, DataLayout::NCHW); + DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); + DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); + DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW); + DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); + DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); + DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); + + // 1. product rulebook + int rulebook_len = ProductRuleBook(dev_ctx, + x, + real_kernel_sizes, + paddings, + dilations, + strides, + out_dims, + false, + rulebook, + &counter_per_kernel, + &offsets_per_kernel, + &out_index, + &unique_key, + &unique_value, + out, + &counter, + &offsets); + + const int* rulebook_ptr = rulebook->data(); + + T* out_features_ptr = out->mutable_non_zero_elements()->data(); + const T* in_features_ptr = x.non_zero_elements().data(); +// 2. max pool +#ifdef PADDLE_WITH_HIP + thrust::fill(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::fill(thrust::cuda::par.on(dev_ctx.stream()), +#endif + out_features_ptr, + out_features_ptr + out->non_zero_elements().numel(), + static_cast(-FLT_MAX)); + // TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster + for (int i = 0; i < kernel_size; i++) { + if (counter[i] <= 0) { + continue; + } + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, counter[i] * in_channels, 1); + MaxPoolCudaKernel<<>>( + in_features_ptr, + rulebook_ptr + offsets[i] + rulebook_len, + counter[i], + rulebook_len, + in_channels, + out_features_ptr); + } +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sparse_maxpool, + GPU, + ALL_LAYOUT, + phi::sparse::MaxPoolKernel, + float, + double, + phi::dtype::float16) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h b/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h new file mode 100644 index 00000000000..572ade76281 --- /dev/null +++ b/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void MaxPoolGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const DenseTensor& out_grad, + const std::vector& kernel_sizes, + DenseTensor* x_grad); + +template +DenseTensor MaxPoolGrad(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const DenseTensor& out_grad, + const std::vector& kernel_sizes) { + DenseTensor x_grad = phi::Empty( + dev_ctx, + DenseTensorMeta(x.dtype(), x.non_zero_elements().dims(), x.layout())); + MaxPoolGradKernel( + dev_ctx, x, rulebook, out, out_grad, kernel_sizes, &x_grad); + return x_grad; +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/sparse_pool_kernel.h b/paddle/phi/kernels/sparse/sparse_pool_kernel.h new file mode 100644 index 00000000000..bfadbf72e30 --- /dev/null +++ b/paddle/phi/kernels/sparse/sparse_pool_kernel.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void MaxPoolKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook); + +template +SparseCooTensor MaxPool(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + DenseTensor* rulebook) { + DenseTensor indices = phi::Empty( + dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + DenseTensor values = + phi::Empty(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout())); + SparseCooTensor coo(indices, values, x.dims()); + MaxPoolKernel( + dev_ctx, x, kernel_sizes, paddings, dilations, strides, &coo, rulebook); + return coo; +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index 317dcce92c8..3897c182e48 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -14,6 +14,7 @@ cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS phi phi_api_utils) cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS phi phi_api_utils) cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS phi phi_api_utils) cc_test(test_sparse_conv3d_dev_api SRCS test_sparse_conv3d_dev_api.cc DEPS phi phi_api_utils) +cc_test(test_sparse_pool_dev_api SRCS test_sparse_pool_dev_api.cc DEPS phi phi_api_utils) cc_test(test_math_function SRCS test_math_function.cc DEPS math_function) if(WITH_GPU) diff --git a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc new file mode 100644 index 00000000000..27673704168 --- /dev/null +++ b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc @@ -0,0 +1,391 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" +#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace tests { + +template +std::vector cast(const std::vector& in) { + std::vector out(in.size()); + for (uint64_t i = 0; i < in.size(); i++) { + out[i] = static_cast(in[i]); + } + return out; +} +template +void TestMaxPoolBase(const std::vector& indices, + const std::vector& features, + const DDim& x_dims, + const std::vector& correct_out_indices, + const std::vector& correct_out_features, + const DDim& correct_out_dims, + const int non_zero_num, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const float diff = 1e-3, + const bool backward = false, + const std::vector features_grad = {}) { + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.Init(); + + const int in_channels = x_dims[4]; + const int out_channels = in_channels; + + DenseTensor indices_tensor = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); + memcpy( + indices_tensor.data(), indices.data(), indices.size() * sizeof(int)); + DenseTensor features_tensor = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {non_zero_num, in_channels}, + DataLayout::NHWC)); + memcpy( + features_tensor.data(), features.data(), features.size() * sizeof(T)); + + SparseCooTensor x_tensor(indices_tensor, features_tensor, x_dims); + + auto f_verify = [&](const T* real_data, const std::vector& correct_data) { + for (uint64_t i = 0; i < correct_data.size(); i++) { + float tmp = std::fabs(static_cast(correct_data[i] - real_data[i])); + ASSERT_LT(tmp, diff); + } + }; + + if (!std::is_same::value) { + DenseTensor rulebook = phi::Empty( + dev_ctx_cpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + SparseCooTensor out = sparse::MaxPool(dev_ctx_cpu, + x_tensor, + kernel_sizes, + paddings, + dilations, + strides, + &rulebook); + + ASSERT_EQ(correct_out_dims.size(), out.dims().size()); + for (int i = 0; i < correct_out_dims.size(); i++) { + ASSERT_EQ(correct_out_dims[i], out.dims()[i]); + } + ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, out.nnz()); + + int cmp_indices = memcmp(correct_out_indices.data(), + out.non_zero_indices().data(), + correct_out_indices.size() * sizeof(int)); + ASSERT_EQ(cmp_indices, 0); + + f_verify(out.non_zero_elements().data(), correct_out_features); + + if (backward) { + DenseTensor x_grad = sparse::MaxPoolGrad(dev_ctx_cpu, + x_tensor, + rulebook, + out, + out.non_zero_elements(), + kernel_sizes); + f_verify(x_grad.data(), features_grad); + } + } + +// test gpu +#if defined(PADDLE_WITH_CUDA) + phi::GPUContext dev_ctx_gpu; + dev_ctx_gpu.PartialInitWithoutAllocator(); + dev_ctx_gpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(dev_ctx_gpu.GetPlace(), dev_ctx_gpu.stream()) + .get()); + dev_ctx_gpu.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(phi::CPUPlace()) + .get()); + dev_ctx_gpu.PartialInitWithAllocator(); + + DenseTensor d_indices_tensor = phi::Empty( + dev_ctx_gpu, + DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); + phi::Copy( + dev_ctx_gpu, indices_tensor, phi::GPUPlace(), true, &d_indices_tensor); + + DenseTensor d_features_tensor = phi::Empty( + dev_ctx_gpu, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {non_zero_num, in_channels}, + DataLayout::NHWC)); + phi::Copy( + dev_ctx_gpu, features_tensor, phi::GPUPlace(), true, &d_features_tensor); + + SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims); + + DenseTensor d_rulebook = phi::Empty( + dev_ctx_gpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + SparseCooTensor d_out = sparse::MaxPool(dev_ctx_gpu, + d_x_tensor, + kernel_sizes, + paddings, + dilations, + strides, + &d_rulebook); + + ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); + ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, d_out.nnz()); + for (int i = 0; i < correct_out_dims.size(); i++) { + ASSERT_EQ(correct_out_dims[i], d_out.dims()[i]); + } + + DenseTensor h_indices_tensor = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(DataType::INT32, {4, d_out.nnz()}, DataLayout::NCHW)); + phi::Copy(dev_ctx_gpu, + d_out.non_zero_indices(), + phi::CPUPlace(), + true, + &h_indices_tensor); + + int cmp_indices2 = memcmp(correct_out_indices.data(), + h_indices_tensor.data(), + correct_out_indices.size() * sizeof(int)); + ASSERT_EQ(cmp_indices2, 0); + + DenseTensor h_features_tensor = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {d_out.nnz()}, + d_out.layout())); + + phi::Copy(dev_ctx_gpu, + d_out.non_zero_elements(), + phi::CPUPlace(), + true, + &h_features_tensor); + f_verify(h_features_tensor.data(), correct_out_features); + + if (backward) { + DenseTensor x_grad = sparse::MaxPoolGrad(dev_ctx_gpu, + d_x_tensor, + d_rulebook, + d_out, + d_out.non_zero_elements(), + kernel_sizes); + DenseTensor h_features_grad = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(x_grad.dtype(), x_grad.dims(), x_grad.layout())); + phi::Copy(dev_ctx_gpu, x_grad, phi::CPUPlace(), true, &h_features_grad); + f_verify(h_features_grad.data(), features_grad); + } +#endif +} + +void TestMaxPool(const std::vector& indices, + const std::vector& features, + const DDim& x_dims, + const std::vector& correct_out_indices, + const std::vector& correct_out_features, + const DDim& correct_out_dims, + const int non_zero_num, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const float diff = 1e-3, + const bool backward = false, + const std::vector features_grad = {}) { + // test float + TestMaxPoolBase(indices, + features, + x_dims, + correct_out_indices, + correct_out_features, + correct_out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + diff, + backward, + features_grad); + // test double + TestMaxPoolBase(indices, + cast(features), + x_dims, + correct_out_indices, + cast(correct_out_features), + correct_out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + diff, + backward, + cast(features_grad)); +} + +TEST(DEV_API, sparse_maxpool) { + const int channels = 1; + DDim x_dims = {1, 1, 4, 4, channels}; + DDim out_dims = {1, 1, 2, 2, channels}; + std::vector kernel_sizes = {1, 3, 3}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices = {0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 1, 2}; + std::vector features = {1, 2, 3}; + std::vector out_indices = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, + }; + std::vector out_features = {2, 2, 3, 3}; + std::vector x_grad = {0, 4, 6}; + + TestMaxPool(indices, + features, + x_dims, + out_indices, + out_features, + out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + 1e-6, + true, + x_grad); +} + +TEST(DEV_API, sparse_maxpool_stride) { + const int channels = 1; + DDim x_dims = {1, 1, 4, 4, channels}; + DDim out_dims = {1, 1, 1, 1, channels}; + std::vector kernel_sizes = {1, 3, 3}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {2, 2, 2}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices = {0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 1, 2}; + std::vector features = {1, 2, 3}; + std::vector out_indices = {0, 0, 0, 0}; + std::vector out_features = {2}; + std::vector x_grad = {0, 2, 0}; + + TestMaxPool(indices, + features, + x_dims, + out_indices, + out_features, + out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + 1e-6, + true, + x_grad); +} + +TEST(DEV_API, sparse_maxpool_channel) { + const int channels = 2; + DDim x_dims = {1, 1, 4, 4, channels}; + DDim out_dims = {1, 1, 2, 2, channels}; + std::vector kernel_sizes = {1, 3, 3}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices = {0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 1, 2}; + std::vector features = {1, 1, 2, 2, 3, 3}; + std::vector out_indices = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, + }; + std::vector out_features = {2, 2, 2, 2, 3, 3, 3, 3}; + std::vector x_grad = {0, 0, 4, 4, 6, 6}; + + TestMaxPool(indices, + features, + x_dims, + out_indices, + out_features, + out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + 1e-6, + true, + x_grad); +} + +TEST(DEV_API, sparse_maxpool3d) { + const int channels = 2; + DDim x_dims = {1, 5, 4, 4, channels}; + DDim out_dims = {1, 3, 2, 2, channels}; + std::vector kernel_sizes = {3, 3, 3}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices = {0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 1, 2}; + std::vector features = {1, 1, 2, 2, 3, 3}; + std::vector out_indices = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, + }; + std::vector out_features = {2, 2, 2, 2, 3, 3, 3, 3}; + std::vector x_grad = {0, 0, 4, 4, 6, 6}; + + TestMaxPool(indices, + features, + x_dims, + out_indices, + out_features, + out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + 1e-6, + true, + x_grad); +} + +} // namespace tests +} // namespace phi -- GitLab