From 5d3fd4fee7df4c2dda48212d263fc7d5ac6f6260 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sat, 2 Apr 2022 13:53:41 +0800 Subject: [PATCH] Sparse conv and pool support indices as template (#41137) --- paddle/phi/kernels/empty_kernel.cc | 4 + paddle/phi/kernels/funcs/sparse/convolution.h | 37 +-- .../kernels/sparse/convolution_grad_kernel.h | 4 +- .../phi/kernels/sparse/convolution_kernel.h | 6 +- paddle/phi/kernels/sparse/cpu/convolution.h | 75 +++--- .../sparse/cpu/convolution_grad_kernel.cc | 131 ++++++---- .../kernels/sparse/cpu/convolution_kernel.cc | 96 ++++--- .../sparse/cpu/sparse_pool_grad_kernel.cc | 55 +++- .../kernels/sparse/cpu/sparse_pool_kernel.cc | 72 ++++-- .../phi/kernels/sparse/gpu/convolution.cu.h | 241 +++++++++--------- .../sparse/gpu/convolution_grad_kernel.cu | 143 +++++++---- .../kernels/sparse/gpu/convolution_kernel.cu | 117 +++++---- .../sparse/gpu/sparse_pool_grad_kernel.cu | 77 ++++-- .../kernels/sparse/gpu/sparse_pool_kernel.cu | 99 ++++--- .../kernels/sparse/sparse_pool_grad_kernel.h | 20 +- .../phi/kernels/sparse/sparse_pool_kernel.h | 6 +- .../kernels/test_sparse_conv3d_dev_api.cc | 148 +++++++---- .../tests/kernels/test_sparse_pool_dev_api.cc | 120 +++++---- 18 files changed, 862 insertions(+), 589 deletions(-) diff --git a/paddle/phi/kernels/empty_kernel.cc b/paddle/phi/kernels/empty_kernel.cc index e547e0ea13..06d258a8a4 100644 --- a/paddle/phi/kernels/empty_kernel.cc +++ b/paddle/phi/kernels/empty_kernel.cc @@ -45,6 +45,7 @@ PD_REGISTER_KERNEL(empty, phi::EmptyKernel, float, double, + int8_t, uint8_t, int16_t, int, @@ -61,6 +62,7 @@ PD_REGISTER_KERNEL(empty_like, phi::EmptyLikeKernel, float, double, + int8_t, uint8_t, int16_t, int, @@ -80,6 +82,7 @@ PD_REGISTER_KERNEL(empty, phi::EmptyKernel, float, double, + int8_t, uint8_t, int16_t, int, @@ -95,6 +98,7 @@ PD_REGISTER_KERNEL(empty_like, phi::EmptyLikeKernel, float, double, + int8_t, uint8_t, int16_t, int, diff --git a/paddle/phi/kernels/funcs/sparse/convolution.h b/paddle/phi/kernels/funcs/sparse/convolution.h index 19f1f3d3cd..f3caa2a62f 100644 --- a/paddle/phi/kernels/funcs/sparse/convolution.h +++ b/paddle/phi/kernels/funcs/sparse/convolution.h @@ -33,28 +33,30 @@ struct Dims4D { }; // Judge whether the current position x is in (lower, upper) -inline HOSTDEVICE bool Check(const int& x, +template +inline HOSTDEVICE bool Check(const IntT& x, const int& kx, const int& pad, const int& stride, const int dilation, const int kdim, const int xdim) { - const int lower = x - dilation * kx + pad; - const int uper = x + (kdim - kx - 1) * dilation - pad; + const IntT lower = x - dilation * kx + pad; + const IntT uper = x + (kdim - kx - 1) * dilation - pad; return (lower >= 0 && lower % stride == 0 && uper < xdim); } // Check whether the current position(x, y, z) is legal: // Judge the minimum and maximum values at each latitude +template inline HOSTDEVICE bool Check(const Dims4D& dims, const Dims4D& kernel_dims, const Dims4D& paddings, const Dims4D& dilations, const Dims4D& strides, - const int x, - const int y, - const int z, + const IntT x, + const IntT y, + const IntT z, const int kx, const int ky, const int kz) { @@ -67,22 +69,22 @@ inline HOSTDEVICE bool Check(const Dims4D& dims, return (x_valid && y_valid && z_valid); } -template -inline HOSTDEVICE int PointToIndex(const int& batch, - const int& x, - const int& y, - const int& z, - const Dim& dims) { +template +inline HOSTDEVICE IntT PointToIndex(const IntT& batch, + const IntT& x, + const IntT& y, + const IntT& z, + const Dim& dims) { return batch * dims[1] * dims[2] * dims[3] + z * dims[2] * dims[3] + y * dims[3] + x; } // TODO(zhangkaihuo): use division and multiply to optimize // modulo operation -template +template inline HOSTDEVICE void IndexToPoint( - const int index, const Dim& dims, int* batch, int* x, int* y, int* z) { - int n = index; + const IntT index, const Dim& dims, IntT* batch, IntT* x, IntT* y, IntT* z) { + IntT n = index; *x = n % dims[3]; n /= dims[3]; *y = n % dims[2]; @@ -176,8 +178,9 @@ inline const std::vector PoolResetKernel( return res; } -inline void PrefixSum(const int* counter, int* offsets, const int n) { - int offset = 0; +template +inline void PrefixSum(const T* counter, T* offsets, const int n) { + T offset = 0; for (int i = 0; i < n; i++) { offsets[i] = offset; offset += counter[i]; diff --git a/paddle/phi/kernels/sparse/convolution_grad_kernel.h b/paddle/phi/kernels/sparse/convolution_grad_kernel.h index 5a47575141..eebfcddfc7 100644 --- a/paddle/phi/kernels/sparse/convolution_grad_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_grad_kernel.h @@ -49,8 +49,8 @@ std::tuple Conv3dGrad( const int groups, const bool subm) { SparseCooTensor x_grad; - DenseTensor kernel_grad = phi::Empty( - dev_ctx, DenseTensorMeta(kernel.dtype(), {1}, kernel.layout())); + DenseTensor kernel_grad; + // TODO(zhangkaihuo): call InferMeta func here Conv3dGradKernel(dev_ctx, x, diff --git a/paddle/phi/kernels/sparse/convolution_kernel.h b/paddle/phi/kernels/sparse/convolution_kernel.h index ff2cf94edb..6120d6339a 100644 --- a/paddle/phi/kernels/sparse/convolution_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_kernel.h @@ -45,11 +45,7 @@ SparseCooTensor Conv3d(const Context& dev_ctx, const int groups, const bool subm, DenseTensor* rulebook) { - DenseTensor indices = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); - DenseTensor values = - phi::Empty(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout())); - SparseCooTensor coo(indices, values, x.dims()); + SparseCooTensor coo; Conv3dKernel(dev_ctx, x, kernel, diff --git a/paddle/phi/kernels/sparse/cpu/convolution.h b/paddle/phi/kernels/sparse/cpu/convolution.h index 4ea93f4ad5..b254461977 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution.h +++ b/paddle/phi/kernels/sparse/cpu/convolution.h @@ -31,7 +31,7 @@ using Dims4D = phi::funcs::sparse::Dims4D; // such as: kernel(3, 3, 3), kernel_size = 27 // counter_per_weight: (kernel_size) // TODO(zhangkaihuo): optimize performance with multithreading -template +template void ProductRuleBook(const Context& dev_ctx, const SparseCooTensor& x, const std::vector& kernel_sizes, @@ -44,7 +44,7 @@ void ProductRuleBook(const Context& dev_ctx, DenseTensor* counter_per_kernel) { const int64_t non_zero_num = x.nnz(); const auto& non_zero_indices = x.non_zero_indices(); - const int* indices_ptr = non_zero_indices.data(); + const IntT* indices_ptr = non_zero_indices.data(); int* counter_ptr = counter_per_kernel->data(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; memset(counter_ptr, 0, kernel_size * sizeof(int)); @@ -60,33 +60,33 @@ void ProductRuleBook(const Context& dev_ctx, const Dims4D c_strides(1, strides[2], strides[1], strides[0]); const Dims4D c_dilations(1, dilations[2], dilations[1], dilations[0]); - std::set hash_in; + std::set hash_in; if (subm) { for (int i = 0; i < non_zero_num; i++) { - int batch = indices_ptr[i]; - int in_z = indices_ptr[i + non_zero_num]; - int in_y = indices_ptr[i + 2 * non_zero_num]; - int in_x = indices_ptr[i + 3 * non_zero_num]; - int index = phi::funcs::sparse::PointToIndex( + IntT batch = indices_ptr[i]; + IntT in_z = indices_ptr[i + non_zero_num]; + IntT in_y = indices_ptr[i + 2 * non_zero_num]; + IntT in_x = indices_ptr[i + 3 * non_zero_num]; + IntT index = phi::funcs::sparse::PointToIndex( batch, in_x, in_y, in_z, x_dims); hash_in.insert(index); } } - auto f_calc_rulebook = [&](int* rulebook_ptr) { + auto f_calc_rulebook = [&](IntT* rulebook_ptr) { int kernel_index = 0, rulebook_index = 0; for (int kz = 0; kz < kernel_sizes[0]; kz++) { for (int ky = 0; ky < kernel_sizes[1]; ky++) { for (int kx = 0; kx < kernel_sizes[2]; kx++) { ++kernel_index; for (int64_t i = 0; i < non_zero_num; i++) { - int batch = indices_ptr[i]; - int in_z = indices_ptr[i + non_zero_num]; - int in_y = indices_ptr[i + 2 * non_zero_num]; - int in_x = indices_ptr[i + 3 * non_zero_num]; - int out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0]; - int out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1]; - int out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2]; + IntT batch = indices_ptr[i]; + IntT in_z = indices_ptr[i + non_zero_num]; + IntT in_y = indices_ptr[i + 2 * non_zero_num]; + IntT in_x = indices_ptr[i + 3 * non_zero_num]; + IntT out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0]; + IntT out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1]; + IntT out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2]; if (phi::funcs::sparse::Check(c_x_dims, c_kernel_dims, c_paddings, @@ -99,7 +99,7 @@ void ProductRuleBook(const Context& dev_ctx, ky, kz)) { if (subm) { - int out_index = phi::funcs::sparse::PointToIndex( + IntT out_index = phi::funcs::sparse::PointToIndex( batch, out_x, out_y, out_z, out_dims); if (hash_in.find(out_index) == hash_in.end()) { continue; @@ -126,15 +126,16 @@ void ProductRuleBook(const Context& dev_ctx, f_calc_rulebook(nullptr); // alloc the rulebook - DenseTensorMeta rulebook_meta( - DataType::INT32, {3, rulebook_len}, DataLayout::NCHW); - rulebook->set_meta(rulebook_meta); - dev_ctx.Alloc(rulebook, rulebook->dtype(), rulebook->numel() * sizeof(int)); - int* rulebook_ptr = rulebook->data(); + *rulebook = phi::Empty( + dev_ctx, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {3, rulebook_len}, + DataLayout::NCHW)); + IntT* rulebook_ptr = rulebook->data(); f_calc_rulebook(rulebook_ptr); } -template +template void UpdateRulebookAndOutIndex(const Context& dev_ctx, const SparseCooTensor& x, const int kernel_size, @@ -142,9 +143,9 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, const DDim& out_dims, DenseTensor* rulebook, SparseCooTensor* out) { - std::set out_indexs; + std::set out_indexs; int n = rulebook->dims()[1]; - int* rulebook_ptr = rulebook->data(); + IntT* rulebook_ptr = rulebook->data(); for (int i = 0; i < n; i++) { out_indexs.insert(rulebook_ptr[i + n * 2]); } @@ -152,17 +153,19 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, int out_non_zero_num = out_indexs.size(); const int64_t sparse_dim = 4; DenseTensorMeta indices_meta( - DataType::INT32, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); + paddle::experimental::CppTypeToDataType::Type(), + {sparse_dim, out_non_zero_num}, + DataLayout::NCHW); DenseTensorMeta values_meta(x.dtype(), {out_non_zero_num, out_channels}, x.non_zero_elements().layout()); phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); - int* out_indices_ptr = out_indices.data(); + IntT* out_indices_ptr = out_indices.data(); int i = 0; for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) { - const int index = *it; - int batch, x, y, z; + const IntT index = *it; + IntT batch, x, y, z; phi::funcs::sparse::IndexToPoint(index, out_dims, &batch, &x, &y, &z); out_indices_ptr[i] = batch; out_indices_ptr[i + out_non_zero_num] = z; @@ -170,7 +173,7 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, out_indices_ptr[i + out_non_zero_num * 3] = x; } for (i = 0; i < n; i++) { - int out_index = rulebook_ptr[i + n * 2]; + IntT out_index = rulebook_ptr[i + n * 2]; rulebook_ptr[i + n * 2] = std::distance(out_indexs.begin(), out_indexs.find(out_index)); } @@ -178,20 +181,20 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, out->SetMember(out_indices, out_values, out_dims, true); } -template +template void Gather( - const T* x, const int* indexs, const int n, const int channels, T* out) { + const T* x, const IntT* indexs, const int n, const int channels, T* out) { for (int i = 0; i < n; i++) { - int real_i = indexs[i]; + IntT real_i = indexs[i]; memcpy(out + i * channels, x + real_i * channels, channels * sizeof(T)); } } -template +template void Scatter( - const T* x, const int* indexs, const int n, const int channels, T* out) { + const T* x, const IntT* indexs, const int n, const int channels, T* out) { for (int i = 0; i < n; i++) { - int real_i = indexs[i]; + IntT real_i = indexs[i]; for (int j = 0; j < channels; j++) { out[real_i * channels + j] += x[i * channels + j]; } diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index 29079918cb..80693c90d1 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/sparse/cpu/convolution.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { @@ -29,24 +31,24 @@ namespace sparse { //] // x_grad = out_grad * transpose(kenrel) // kernel_grad = transpose(x) * out_grad -template -void Conv3dGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& kernel, - const DenseTensor& rulebook, - const SparseCooTensor& out_grad, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm, - SparseCooTensor* x_grad, - DenseTensor* kernel_grad) { +template +void Conv3dGradCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* x_grad, + DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; - const int* rulebook_ptr = rulebook.data(); + const IntT* rulebook_ptr = rulebook.data(); const int rulebook_len = rulebook.dims()[1]; @@ -66,32 +68,30 @@ void Conv3dGradKernel(const Context& dev_ctx, T* in_features_ptr = in_features.data(); T* d_x_features_ptr = d_x_features.data(); T* out_grad_features_ptr = out_grad_features.data(); - kernel_grad->Resize(kernel_dims); - dev_ctx.Alloc( - kernel_grad, kernel_grad->dtype(), kernel_grad->numel() * sizeof(T)); + *kernel_grad = phi::EmptyLike(dev_ctx, kernel); T* d_kernel_ptr = kernel_grad->data(); memset(d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel()); int half_kernel_size = kernel_size / 2; - auto blas = phi::funcs::GetBlas(dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); DenseTensor x_grad_indices = - phi::EmptyLike(dev_ctx, x.non_zero_indices()); + phi::EmptyLike(dev_ctx, x.non_zero_indices()); DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); T* x_grad_values_ptr = x_grad_values.data(); memset(x_grad_values_ptr, 0, sizeof(T) * x_grad_values.numel()); memset(d_x_features_ptr, 0, sizeof(T) * d_x_features.numel()); - phi::Copy(dev_ctx, - x.non_zero_indices(), - dev_ctx.GetPlace(), - false, - &x_grad_indices); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0); for (int i = 0; i < rulebook_len; i++) { counter[rulebook_ptr[i]] += 1; } - int offset = 0, max_count = 0; + IntT offset = 0, max_count = 0; for (int i = 0; i < kernel_size; i++) { offsets[i] = offset; offset += counter[i]; @@ -102,30 +102,31 @@ void Conv3dGradKernel(const Context& dev_ctx, offsets[kernel_size] = offset; if (subm) { - phi::funcs::sparse::SubmPreProcess(dev_ctx, - x, - kernel, - out_grad.non_zero_elements(), - in_channels, - out_channels, - half_kernel_size, - kernel_grad, - &x_grad_values); + phi::funcs::sparse::SubmPreProcess( + dev_ctx, + x, + kernel, + out_grad.non_zero_elements(), + in_channels, + out_channels, + half_kernel_size, + kernel_grad, + &x_grad_values); if (max_count == 0) { return; } } - Gather(x.non_zero_elements().data(), - rulebook_ptr + rulebook_len, - rulebook_len, - in_channels, - in_features_ptr); - Gather(out_grad.non_zero_elements().data(), - rulebook_ptr + rulebook_len * 2, - rulebook_len, - out_channels, - out_grad_features_ptr); + Gather(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + rulebook_len, + in_channels, + in_features_ptr); + Gather(out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len * 2, + rulebook_len, + out_channels, + out_grad_features_ptr); const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { @@ -170,11 +171,41 @@ void Conv3dGradKernel(const Context& dev_ctx, } // 4. scatter - Scatter(d_x_features_ptr, - rulebook.data() + rulebook_len, - rulebook_len, - in_channels, - x_grad_values_ptr); + Scatter(d_x_features_ptr, + rulebook.data() + rulebook_len, + rulebook_len, + in_channels, + x_grad_values_ptr); +} + +template +void Conv3dGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* x_grad, + DenseTensor* kernel_grad) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Conv3dGradCPUKernel", ([&] { + Conv3dGradCPUKernel(dev_ctx, + x, + kernel, + rulebook, + out_grad, + paddings, + dilations, + strides, + groups, + subm, + x_grad, + kernel_grad); + })); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc index f022e4ef4b..a1c8cf014c 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc @@ -17,6 +17,8 @@ limitations under the License. */ #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { @@ -25,17 +27,17 @@ namespace sparse { * kernel: (D, H, W, C, OC) * out: (N, D, H, W, OC) **/ -template -void Conv3dKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& kernel, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm, - SparseCooTensor* out, - DenseTensor* rulebook) { +template +void Conv3dCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* out, + DenseTensor* rulebook) { // update padding and dilation // Currently, only support x.layout is NDHWC, groups = 1 // if x.layout != NDHWC then transpose(x), transpose(weight) @@ -66,18 +68,18 @@ void Conv3dKernel(const Context& dev_ctx, DataType::INT32, {kernel_size}, DataLayout::NCHW); DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); - ProductRuleBook(dev_ctx, - x, - kernel_sizes, - subm_paddings, - dilations, - subm_strides, - out_dims, - subm, - rulebook, - &counter_per_kernel); - - UpdateRulebookAndOutIndex( + ProductRuleBook(dev_ctx, + x, + kernel_sizes, + subm_paddings, + dilations, + subm_strides, + out_dims, + subm, + rulebook, + &counter_per_kernel); + + UpdateRulebookAndOutIndex( dev_ctx, x, kernel_size, out_channels, out_dims, rulebook, out); int n = rulebook->dims()[1]; @@ -95,14 +97,14 @@ void Conv3dKernel(const Context& dev_ctx, T* in_features_ptr = in_features.data(); T* out_features_ptr = out_features.data(); - Gather(x.non_zero_elements().data(), - rulebook->data() + n, - n, - in_channels, - in_features_ptr); + Gather(x.non_zero_elements().data(), + rulebook->data() + n, + n, + in_channels, + in_features_ptr); // 3. call gemm for every werght - auto blas = phi::funcs::GetBlas(dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); std::vector offsets(kernel_size + 1); int offset = 0; for (int i = 0; i < kernel_size; i++) { @@ -139,11 +141,37 @@ void Conv3dKernel(const Context& dev_ctx, // 4. scatter T* out_values_ptr = out->mutable_non_zero_elements()->data(); memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels); - Scatter(out_features_ptr, - rulebook->data() + n * 2, - n, - out_channels, - out_values_ptr); + Scatter(out_features_ptr, + rulebook->data() + n * 2, + n, + out_channels, + out_values_ptr); +} + +template +void Conv3dKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* out, + DenseTensor* rulebook) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Conv3dCPUKernel", ([&] { + Conv3dCPUKernel(dev_ctx, + x, + kernel, + paddings, + dilations, + strides, + groups, + subm, + out, + rulebook); + })); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc index 3010d480b5..30221975e7 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc @@ -14,24 +14,28 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { -template -void MaxPoolGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const DenseTensor& out_grad, - const std::vector& kernel_sizes, - DenseTensor* x_grad) { +template +void MaxPoolGradCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const int channels = x.dims()[4]; int rulebook_len = rulebook.dims()[1]; - const int* rulebook_ptr = rulebook.data(); + const IntT* rulebook_ptr = rulebook.data(); std::vector offsets(kernel_size + 1), counter(kernel_size, 0); for (int i = 0; i < rulebook_len; i++) { counter[rulebook_ptr[i]] += 1; @@ -40,15 +44,25 @@ void MaxPoolGradKernel(const Context& dev_ctx, const T* in_features_ptr = x.non_zero_elements().data(); const T* out_features_ptr = out.non_zero_elements().data(); - const T* out_grad_ptr = out_grad.data(); - T* x_grad_ptr = x_grad->data(); + const T* out_grad_ptr = out_grad.non_zero_elements().data(); + // TODO(zhangkaihuo): call phi::sparse::EmptyLike + DenseTensor x_grad_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); + T* x_grad_ptr = x_grad_values.data(); memset(x_grad_ptr, 0, sizeof(T) * x_grad->numel()); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); phi::funcs::MaxPoolGrad grad_functor; for (int i = 0; i < kernel_size; i++) { for (int j = 0; j < counter[i]; j++) { - int in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; - int out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; + IntT in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; + IntT out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; for (int c = 0; c < channels; c++) { grad_functor.compute(in_features_ptr[in_i * channels + c], out_features_ptr[out_i * channels + c], @@ -60,6 +74,21 @@ void MaxPoolGradKernel(const Context& dev_ctx, } } +template +void MaxPoolGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "MaxPoolGradCPUKernel", ([&] { + MaxPoolGradCPUKernel( + dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc index 86971242df..ed6e020058 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/sparse/convolution.h" #include "paddle/phi/kernels/sparse/cpu/convolution.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { @@ -27,15 +29,15 @@ namespace sparse { * kernel: (D, H, W, C, OC) * out: (N, D, H, W, OC) **/ -template -void MaxPoolKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook) { +template +void MaxPoolCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { const auto& x_dims = x.dims(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const std::vector& real_kernel_sizes = @@ -51,22 +53,22 @@ void MaxPoolKernel(const Context& dev_ctx, const T* in_features_ptr = x.non_zero_elements().data(); // 1. product rule book - ProductRuleBook(dev_ctx, - x, - real_kernel_sizes, - paddings, - dilations, - strides, - out_dims, - false, - rulebook, - &counter_per_kernel); - - UpdateRulebookAndOutIndex( + ProductRuleBook(dev_ctx, + x, + real_kernel_sizes, + paddings, + dilations, + strides, + out_dims, + false, + rulebook, + &counter_per_kernel); + + UpdateRulebookAndOutIndex( dev_ctx, x, kernel_size, in_channels, out_dims, rulebook, out); int rulebook_len = rulebook->dims()[1]; - const int* rulebook_ptr = rulebook->data(); + const IntT* rulebook_ptr = rulebook->data(); const int* counter_ptr = counter_per_kernel.data(); std::vector offsets(kernel_size + 1); @@ -78,8 +80,8 @@ void MaxPoolKernel(const Context& dev_ctx, phi::funcs::MaxPool max_pool_functor; for (int i = 0; i < kernel_size; i++) { for (int j = 0; j < counter_ptr[i]; j++) { - int in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; - int out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; + IntT in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; + IntT out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; if (!out_flags[out_i]) { out_flags[out_i] = true; memcpy(&out_features_ptr[out_i * in_channels], @@ -95,6 +97,28 @@ void MaxPoolKernel(const Context& dev_ctx, } } +template +void MaxPoolKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "MaxPoolCPUKernel", ([&] { + MaxPoolCPUKernel(dev_ctx, + x, + kernel_sizes, + paddings, + dilations, + strides, + out, + rulebook); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/convolution.cu.h b/paddle/phi/kernels/sparse/gpu/convolution.cu.h index a512a60b94..5662a4fac7 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution.cu.h +++ b/paddle/phi/kernels/sparse/gpu/convolution.cu.h @@ -98,21 +98,21 @@ __global__ void ScatterKernel(const T* input, } } -template -inline int* SortedAndUniqueIndex(const Context& dev_ctx, - const int* rulebook_ptr, - const int len, - DenseTensor* out_index, - DenseTensor* unique_key, - DenseTensor* unique_value) { +template +inline IntT* SortedAndUniqueIndex(const Context& dev_ctx, + const IntT* rulebook_ptr, + const int len, + DenseTensor* out_index, + DenseTensor* unique_key, + DenseTensor* unique_value) { phi::IndexKernel>( dev_ctx, out_index, kps::IdentityFunctor()); phi::IndexKernel>( dev_ctx, unique_value, kps::IdentityFunctor()); - phi::backends::gpu::GpuMemcpyAsync(unique_key->data(), + phi::backends::gpu::GpuMemcpyAsync(unique_key->data(), rulebook_ptr, - sizeof(int) * len, + sizeof(IntT) * len, #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToDevice, #else @@ -126,19 +126,19 @@ inline int* SortedAndUniqueIndex(const Context& dev_ctx, #else thrust::sort_by_key(thrust::cuda::par.on(dev_ctx.stream()), #endif - unique_key->data(), - unique_key->data() + len, + unique_key->data(), + unique_key->data() + len, out_index->data()); // 4. unique - thrust::pair new_end = + thrust::pair new_end = #ifdef PADDLE_WITH_HIP thrust::unique_by_key(thrust::hip::par.on(dev_ctx.stream()), #else thrust::unique_by_key(thrust::cuda::par.on(dev_ctx.stream()), #endif - unique_key->data(), - unique_key->data() + len, + unique_key->data(), + unique_key->data() + len, unique_value->data()); return new_end.first; } @@ -159,7 +159,7 @@ __global__ void SetFlagAndUpdateCounterKernel(const int* indexs, for (int i = tid; i < n; i += gridDim.x * blockDim.x) { int index = indexs[i]; - int kernel_index = rulebook_ptr[index]; + T kernel_index = rulebook_ptr[index]; rulebook_ptr[index + rulebook_len] = -1; rulebook_ptr[index + 2 * rulebook_len] = -1; rulebook_ptr[index] = -1; @@ -183,18 +183,18 @@ __global__ void SetFlagAndUpdateCounterKernel(const int* indexs, * rulebook_out_indexs: the output index in rulebook **/ template -__global__ void UpdateIndexKernel(const int* unique_keys, +__global__ void UpdateIndexKernel(const T* unique_keys, const int* unique_values, const int* out_indexs, - const int non_zero_num, + const int64_t non_zero_num, const int rulebook_len, const Dims4D out_dims, T* out_indices, T* rulebook_out_indexs) { int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { - const int index = unique_keys[i]; - int batch, x, y, z; + const T index = unique_keys[i]; + T batch, x, y, z; phi::funcs::sparse::IndexToPoint( index, out_dims, &batch, &x, &y, &z); // get out indices @@ -207,7 +207,7 @@ __global__ void UpdateIndexKernel(const int* unique_keys, int start = unique_values[i]; int end = i == non_zero_num - 1 ? rulebook_len : unique_values[i + 1]; // max(end-start) = kernel_size - for (int j = start; j < end; j++) { + for (T j = start; j < end; j++) { rulebook_out_indexs[out_indexs[j]] = i; } } @@ -215,7 +215,7 @@ __global__ void UpdateIndexKernel(const int* unique_keys, // brief: calculation the distance between start and end template -__global__ void DistanceKernel(const T* start, const T* end, int* distance) { +__global__ void DistanceKernel(const T* start, const T* end, T* distance) { if (threadIdx.x == 0) { *distance = end - start; } @@ -249,7 +249,7 @@ __global__ void ProductRuleBookKernel(const T* x_indices, const bool subm, T* rulebook, int* counter, - int* in_indexs) { + T* in_indexs) { int tid = threadIdx.x + blockIdx.x * blockDim.x; extern __shared__ int counter_buf[]; // kernel_size const int kernel_size = kernel_dims[3] * kernel_dims[2] * kernel_dims[1]; @@ -261,10 +261,10 @@ __global__ void ProductRuleBookKernel(const T* x_indices, for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { int kernel_index = 0; - int batch = x_indices[i]; - int in_z = x_indices[i + non_zero_num]; - int in_y = x_indices[i + 2 * non_zero_num]; - int in_x = x_indices[i + 3 * non_zero_num]; + T batch = x_indices[i]; + T in_z = x_indices[i + non_zero_num]; + T in_y = x_indices[i + 2 * non_zero_num]; + T in_x = x_indices[i + 3 * non_zero_num]; if (subm) { in_indexs[i] = PointToIndex(batch, in_x, in_y, in_z, x_dims); } @@ -283,9 +283,9 @@ __global__ void ProductRuleBookKernel(const T* x_indices, kx, ky, kz)) { - int out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1]; - int out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2]; - int out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; + T out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1]; + T out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2]; + T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; in_i = i; out_index = phi::funcs::sparse::PointToIndex( batch, out_x, out_y, out_z, out_dims); @@ -321,7 +321,7 @@ __global__ void ProductRuleBookKernel(const T* x_indices, // 5. update the out_index by unique_key, uniqe_value and the index of // unique_value: // the new out_index: 0, 2, 3, 2, 3, 0, 1 -template +template int ProductRuleBook(const Context& dev_ctx, const SparseCooTensor& x, const std::vector& kernel_sizes, @@ -334,26 +334,26 @@ int ProductRuleBook(const Context& dev_ctx, DenseTensor* counter_per_kernel, DenseTensor* offsets_per_kernel, DenseTensor* out_index, - DenseTensor* unique_key, DenseTensor* unique_value, SparseCooTensor* out, std::vector* h_counter, std::vector* h_offsets) { + // TODO(zhangkaihuo): use PD_DISPATCH_INTEGRAL_TYPES for secondary dispatch + auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); const int64_t non_zero_num = x.nnz(); const auto& non_zero_indices = x.non_zero_indices(); - const int* indices_ptr = non_zero_indices.data(); + const IntT* indices_ptr = non_zero_indices.data(); DenseTensor in_indexs = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW)); + dev_ctx, DenseTensorMeta(indices_dtype, {x.nnz()}, DataLayout::NCHW)); int* counter_ptr = counter_per_kernel->data(); int* offsets_ptr = offsets_per_kernel->data(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const int rulebook_rows = 3; const int rulebook_cols = kernel_size * non_zero_num; DenseTensorMeta rulebook_meta( - DataType::INT32, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); - rulebook->set_meta(rulebook_meta); - dev_ctx.Alloc(rulebook, rulebook->dtype(), rulebook->numel() * sizeof(int)); - int* rulebook_ptr = rulebook->data(); + indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); + *rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); + IntT* rulebook_ptr = rulebook->data(); const auto x_dims = x.dims(); Dims4D d_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]); @@ -369,39 +369,39 @@ int ProductRuleBook(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - ProductRuleBookKernel<<>>(indices_ptr, - d_x_dims, - d_kernel_dims, - d_out_dims, - non_zero_num, - d_paddings, - d_dilations, - d_strides, - subm, - rulebook_ptr, - counter_ptr, - in_indexs.data()); + ProductRuleBookKernel<<>>(indices_ptr, + d_x_dims, + d_kernel_dims, + d_out_dims, + non_zero_num, + d_paddings, + d_dilations, + d_strides, + subm, + rulebook_ptr, + counter_ptr, + in_indexs.data()); // 2. remove -1 #ifdef PADDLE_WITH_HIP - int* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), + IntT* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), #else - int* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), + IntT* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), #endif - rulebook_ptr, - rulebook_ptr + rulebook_rows * rulebook_cols, - -1); + rulebook_ptr, + rulebook_ptr + rulebook_rows * rulebook_cols, + -1); - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( rulebook_ptr, last, rulebook_ptr + 3 * kernel_size * non_zero_num - 1); - int rulebook_len = 0; + IntT rulebook_len = 0; phi::backends::gpu::GpuMemcpyAsync( &rulebook_len, rulebook_ptr + 3 * kernel_size * non_zero_num - 1, - sizeof(int), + sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -418,11 +418,10 @@ int ProductRuleBook(const Context& dev_ctx, // and then the intermediate output index is subtracted from the input index // to obain the rulebook. // get difference - int32_t* A_key_ptr = rulebook_ptr + 2 * rulebook_len; - int32_t* B_key_ptr = in_indexs.data(); - DenseTensor A_val = phi::Empty( - dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); + IntT* A_key_ptr = rulebook_ptr + 2 * rulebook_len; + IntT* B_key_ptr = in_indexs.data(); + DenseTensorMeta val_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW); + DenseTensor A_val = phi::Empty(dev_ctx, std::move(val_meta)); DenseTensor B_val = phi::Empty( dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW)); phi::IndexKernel>( @@ -431,10 +430,8 @@ int ProductRuleBook(const Context& dev_ctx, dev_ctx, &B_val, kps::IdentityFunctor()); DenseTensor key_result = phi::Empty( dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len + 1}, DataLayout::NCHW)); - DenseTensor val_result = phi::Empty( - dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {rulebook_len + 1}, DataLayout::NCHW)); + DenseTensor val_result = phi::Empty(dev_ctx, std::move(val_meta)); #ifdef PADDLE_WITH_HIP thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), @@ -457,7 +454,7 @@ int ProductRuleBook(const Context& dev_ctx, dev_ctx.stream()); dev_ctx.Wait(); - thrust::pair end; + thrust::pair end; // Because set_diff does not support duplicate data, set_diff is performed // separately for each segment of data. // TODO(zhangkaihuo): Using hashtable here may get better performance, @@ -465,7 +462,7 @@ int ProductRuleBook(const Context& dev_ctx, for (int i = 0; i < kernel_size; i++) { int start = offsets[i]; int stop = i == kernel_size - 1 ? rulebook_len : offsets[i + 1]; - int* key_result_start = (i == 0 ? key_result.data() : end.first); + IntT* key_result_start = (i == 0 ? key_result.data() : end.first); int* val_result_start = i == 0 ? val_result.data() : end.second; end = #ifdef PADDLE_WITH_HIP @@ -483,14 +480,14 @@ int ProductRuleBook(const Context& dev_ctx, val_result_start); } - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( - key_result.data(), + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + key_result.data(), end.first, - key_result.data() + rulebook_len); - int len = 0; + key_result.data() + rulebook_len); + IntT len = 0; phi::backends::gpu::GpuMemcpyAsync(&len, - key_result.data() + rulebook_len, - sizeof(int), + key_result.data() + rulebook_len, + sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -500,10 +497,10 @@ int ProductRuleBook(const Context& dev_ctx, dev_ctx.Wait(); // set the diff value = -1, and update counter auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len, 1); - SetFlagAndUpdateCounterKernel<<>>( + SetFlagAndUpdateCounterKernel<<>>( val_result.data(), len, rulebook_len, @@ -512,18 +509,18 @@ int ProductRuleBook(const Context& dev_ctx, counter_ptr); // remove -1 #ifdef PADDLE_WITH_HIP - int* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), + IntT* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), #else - int* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), + IntT* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), #endif - rulebook_ptr, - rulebook_ptr + 3 * rulebook_len, - -1); - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( - rulebook_ptr, last, key_result.data() + rulebook_len); + rulebook_ptr, + rulebook_ptr + 3 * rulebook_len, + -1); + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + rulebook_ptr, last, key_result.data() + rulebook_len); phi::backends::gpu::GpuMemcpyAsync(&rulebook_len, - key_result.data() + rulebook_len, - sizeof(int), + key_result.data() + rulebook_len, + sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -566,42 +563,47 @@ int ProductRuleBook(const Context& dev_ctx, cudaMemcpyDeviceToHost, dev_ctx.stream()); #endif - rulebook->Resize({rulebook_rows, rulebook_len}); + rulebook->Resize({rulebook_rows, static_cast(rulebook_len)}); // 3. sorted or merge the out index - out_index->ResizeAndAllocate({rulebook_len}); - unique_value->ResizeAndAllocate({rulebook_len}); - unique_key->ResizeAndAllocate({rulebook_len}); + out_index->ResizeAndAllocate({static_cast(rulebook_len)}); + unique_value->ResizeAndAllocate({static_cast(rulebook_len)}); + DenseTensor unique_key = phi::Empty( + dev_ctx, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {static_cast(rulebook_len)}, + DataLayout::NCHW)); int* out_index_ptr = out_index->data(); int* unique_value_ptr = unique_value->data(); - int* unique_key_ptr = unique_key->data(); - - int* new_end = SortedAndUniqueIndex(dev_ctx, - rulebook_ptr + 2 * rulebook_len, - rulebook_len, - out_index, - unique_key, - unique_value); + IntT* unique_key_ptr = unique_key.data(); + + IntT* new_end = + SortedAndUniqueIndex(dev_ctx, + rulebook_ptr + 2 * rulebook_len, + rulebook_len, + out_index, + &unique_key, + unique_value); // thrust::distance doesn't support stream parameters // const int out_non_zero_num = thrust::distance(unique_key_ptr, // new_end.first); - DistanceKernel<<<1, 1>>>( + DistanceKernel<<<1, 1>>>( unique_key_ptr, new_end, rulebook_ptr + rulebook_rows * rulebook_cols - 1); - int out_non_zero_num = 0; + IntT out_non_zero_num = 0; #ifdef PADDLE_WITH_HIP phi::backends::gpu::GpuMemcpyAsync( &out_non_zero_num, rulebook_ptr + rulebook_rows * rulebook_cols - 1, - sizeof(int), + sizeof(IntT), hipMemcpyDeviceToHost, dev_ctx.stream()); #else phi::backends::gpu::GpuMemcpyAsync( &out_non_zero_num, rulebook_ptr + rulebook_rows * rulebook_cols - 1, - sizeof(int), + sizeof(IntT), cudaMemcpyDeviceToHost, dev_ctx.stream()); #endif @@ -610,28 +612,29 @@ int ProductRuleBook(const Context& dev_ctx, // 5. update out_indices and rulebook by unique_value_ptr const int64_t sparse_dim = 4; DenseTensorMeta indices_meta( - DataType::INT32, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); + indices_dtype, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); DenseTensorMeta values_meta(x.dtype(), {out_non_zero_num, kernel_sizes[4]}, x.non_zero_elements().layout()); phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); - int* out_indices_ptr = out_indices.data(); + IntT* out_indices_ptr = out_indices.data(); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_non_zero_num, 1); - UpdateIndexKernel<<>>(unique_key_ptr, - unique_value_ptr, - out_index_ptr, - out_non_zero_num, - rulebook_len, - d_out_dims, - out_indices_ptr, - rulebook_ptr + 2 * rulebook_len); + UpdateIndexKernel<<>>( + unique_key_ptr, + unique_value_ptr, + out_index_ptr, + out_non_zero_num, + rulebook_len, + d_out_dims, + out_indices_ptr, + rulebook_ptr + 2 * rulebook_len); out->SetMember(out_indices, out_values, out_dims, true); return rulebook_len; } diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index 4a6094c23b..2b61be7289 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -24,6 +24,8 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { @@ -35,24 +37,24 @@ namespace sparse { //] // x_grad = out_grad * transpose(kenrel) // kernel_grad = transpose(x) * out_grad -template -void Conv3dGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& kernel, - const DenseTensor& rulebook, - const SparseCooTensor& out_grad, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm, - SparseCooTensor* x_grad, - DenseTensor* kernel_grad) { +template +void Conv3dGradGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* x_grad, + DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; - const int* rulebook_ptr = rulebook.data(); + const IntT* rulebook_ptr = rulebook.data(); const int rulebook_len = rulebook.dims()[1]; @@ -74,29 +76,29 @@ void Conv3dGradKernel(const Context& dev_ctx, T* out_grad_features_ptr = out_grad_features.data(); *kernel_grad = phi::EmptyLike(dev_ctx, kernel); T* d_kernel_ptr = kernel_grad->data(); - phi::funcs::SetConstant set_zero; + phi::funcs::SetConstant set_zero; set_zero(dev_ctx, kernel_grad, static_cast(0.0f)); int half_kernel_size = kernel_size / 2; - auto blas = phi::funcs::GetBlas(dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); DenseTensor x_grad_indices = - phi::EmptyLike(dev_ctx, x.non_zero_indices()); + phi::EmptyLike(dev_ctx, x.non_zero_indices()); DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); T* x_grad_values_ptr = x_grad_values.data(); set_zero(dev_ctx, &x_grad_values, static_cast(0.0f)); set_zero(dev_ctx, &d_x_features, static_cast(0.0f)); - phi::Copy(dev_ctx, - x.non_zero_indices(), - dev_ctx.GetPlace(), - false, - &x_grad_indices); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0), + std::vector offsets(kernel_size + 1), counter(kernel_size, 0), h_counter(rulebook_len, 0); phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], rulebook_ptr, - rulebook_len * sizeof(int), + rulebook_len * sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -109,7 +111,7 @@ void Conv3dGradKernel(const Context& dev_ctx, for (int i = 0; i < rulebook_len; i++) { counter[h_counter[i]] += 1; } - int offset = 0, max_count = 0; + IntT offset = 0, max_count = 0; for (int i = 0; i < kernel_size; i++) { offsets[i] = offset; offset += counter[i]; @@ -120,15 +122,16 @@ void Conv3dGradKernel(const Context& dev_ctx, offsets[kernel_size] = offset; if (subm) { - phi::funcs::sparse::SubmPreProcess(dev_ctx, - x, - kernel, - out_grad.non_zero_elements(), - in_channels, - out_channels, - half_kernel_size, - kernel_grad, - &x_grad_values); + phi::funcs::sparse::SubmPreProcess( + dev_ctx, + x, + kernel, + out_grad.non_zero_elements(), + in_channels, + out_channels, + half_kernel_size, + kernel_grad, + &x_grad_values); if (max_count == 0) { return; } @@ -136,21 +139,21 @@ void Conv3dGradKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, rulebook_len * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + rulebook_len, - in_features_ptr, - rulebook_len, - in_channels); + GatherKernel<<>>(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + in_features_ptr, + rulebook_len, + in_channels); config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, rulebook_len * out_channels, 1); - GatherKernel<<>>( + GatherKernel<<>>( out_grad.non_zero_elements().data(), rulebook_ptr + rulebook_len * 2, out_grad_features_ptr, @@ -203,15 +206,19 @@ void Conv3dGradKernel(const Context& dev_ctx, // x_grad->ResizeAndAllocate(x.non_zero_elements().dims()); DenseTensorMeta index_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW); DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); - DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); + DenseTensor unique_key = phi::Empty( + dev_ctx, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {rulebook_len}, + DataLayout::NCHW)); DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); - SortedAndUniqueIndex(dev_ctx, - rulebook_ptr + rulebook_len, - rulebook_len, - &out_index, - &unique_key, - &unique_value); + SortedAndUniqueIndex(dev_ctx, + rulebook_ptr + rulebook_len, + rulebook_len, + &out_index, + &unique_key, + &unique_value); config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, rulebook_len * in_channels, 1); @@ -229,6 +236,36 @@ void Conv3dGradKernel(const Context& dev_ctx, subm); } +template +void Conv3dGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const DenseTensor& rulebook, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* x_grad, + DenseTensor* kernel_grad) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Conv3dGradGPUKernel", ([&] { + Conv3dGradGPUKernel(dev_ctx, + x, + kernel, + rulebook, + out_grad, + paddings, + dilations, + strides, + groups, + subm, + x_grad, + kernel_grad); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu index 214e689e93..2d212eadff 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu @@ -19,29 +19,25 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/convolution_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" +#include "paddle/phi/api/ext/dispatch.h" + namespace phi { namespace sparse { -/** - * x: (N, D, H, W, C) - * kernel: (D, H, W, C, OC) - * out: (N, D, H, W, OC) -**/ -template -void Conv3dKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& kernel, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const int groups, - const bool subm, - SparseCooTensor* out, - DenseTensor* rulebook) { +template +void Conv3dGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* out, + DenseTensor* rulebook) { // update padding and dilation // Currently, only support x.layout is NDHWC, groups = 1 // if x.layout != NDHWC then transpose(x), transpose(weight) - const auto& x_dims = x.dims(); const auto& kernel_dims = kernel.dims(); int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; @@ -67,7 +63,6 @@ void Conv3dKernel(const Context& dev_ctx, DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta)); DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW); DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); - DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); std::vector subm_paddings(paddings), subm_strides(strides); @@ -75,28 +70,26 @@ void Conv3dKernel(const Context& dev_ctx, phi::funcs::sparse::ResetSubmKernelSizeAndStrides( kernel.dims(), &subm_paddings, &subm_strides); } - - int n = ProductRuleBook(dev_ctx, - x, - kernel_sizes, - subm_paddings, - dilations, - subm_strides, - out_dims, - subm, - rulebook, - &counter_per_kernel, - &offsets_per_kernel, - &out_index, - &unique_key, - &unique_value, - out, - &h_counter, - &offsets); + int n = ProductRuleBook(dev_ctx, + x, + kernel_sizes, + subm_paddings, + dilations, + subm_strides, + out_dims, + subm, + rulebook, + &counter_per_kernel, + &offsets_per_kernel, + &out_index, + &unique_value, + out, + &h_counter, + &offsets); const int* counter_ptr = counter_per_kernel.data(); const int* offsets_ptr = counter_per_kernel.data(); - const int* rulebook_ptr = rulebook->data(); + const IntT* rulebook_ptr = rulebook->data(); // 2. gather DenseTensorMeta in_features_meta( @@ -109,22 +102,22 @@ void Conv3dKernel(const Context& dev_ctx, phi::Empty(dev_ctx, std::move(out_features_meta)); T* in_features_ptr = in_features.data(); T* out_features_ptr = out_features.data(); - phi::funcs::SetConstant set_zero; + phi::funcs::SetConstant set_zero; set_zero(dev_ctx, &out_features, static_cast(0.0f)); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + n, - in_features_ptr, - n, - in_channels); + GatherKernel<<>>(x.non_zero_elements().data(), + rulebook_ptr + n, + in_features_ptr, + n, + in_channels); // 3. call gemm for every werght - auto blas = phi::funcs::GetBlas(dev_ctx); + auto blas = phi::funcs::GetBlas(dev_ctx); auto* out_values = out->mutable_non_zero_elements(); T* out_values_ptr = out_values->data(); @@ -168,6 +161,36 @@ void Conv3dKernel(const Context& dev_ctx, out_channels, out_values_ptr); } +/** + * x: (N, D, H, W, C) + * kernel: (D, H, W, C, OC) + * out: (N, D, H, W, OC) +**/ +template +void Conv3dKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + const bool subm, + SparseCooTensor* out, + DenseTensor* rulebook) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "Conv3dGPUKernel", ([&] { + Conv3dGPUKernel(dev_ctx, + x, + kernel, + paddings, + dilations, + strides, + groups, + subm, + out, + rulebook); + })); +} } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu index 1048dd1be0..8657e7319d 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu @@ -12,24 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" -#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" +#include "paddle/phi/api/ext/dispatch.h" namespace phi { namespace sparse { -template +template __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr, const T* out_features_ptr, const T* out_grad_ptr, - const int* rulebook_ptr, + const IntT* rulebook_ptr, const int n, const int rulebook_len, const int channels, @@ -38,8 +42,8 @@ __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr, CUDA_KERNEL_LOOP_TYPE(i, n * channels, int64_t) { int real_i = i / channels; int c = i - real_i * channels; - int in_i = rulebook_ptr[real_i]; - int out_i = rulebook_ptr[real_i + rulebook_len]; + IntT in_i = rulebook_ptr[real_i]; + IntT out_i = rulebook_ptr[real_i + rulebook_len]; grad_functor.compute(in_features_ptr[in_i * channels + c], out_features_ptr[out_i * channels + c], out_grad_ptr[out_i * channels + c], @@ -48,23 +52,23 @@ __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr, } } -template -void MaxPoolGradKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const DenseTensor& out_grad, - const std::vector& kernel_sizes, - DenseTensor* x_grad) { +template +void MaxPoolGradGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const int in_channels = x.dims()[4]; int rulebook_len = rulebook.dims()[1]; - const int* rulebook_ptr = rulebook.data(); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0), + const IntT* rulebook_ptr = rulebook.data(); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0), h_counter(kernel_size); phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], rulebook_ptr, - rulebook_len * sizeof(int), + rulebook_len * sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, #else @@ -80,10 +84,20 @@ void MaxPoolGradKernel(const Context& dev_ctx, const T* in_features_ptr = x.non_zero_elements().data(); const T* out_features_ptr = out.non_zero_elements().data(); - const T* out_grad_ptr = out_grad.data(); - T* x_grad_ptr = x_grad->data(); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, x_grad, static_cast(0.0f)); + const T* out_grad_ptr = out_grad.non_zero_elements().data(); + // TODO(zhangkaihuo): call phi::sparse::EmptyLike + DenseTensor x_grad_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); + T* x_grad_ptr = x_grad_values.data(); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, &x_grad_values, static_cast(0.0f)); + phi::Copy(dev_ctx, + x.non_zero_indices(), + dev_ctx.GetPlace(), + false, + &x_grad_indices); for (int i = 0; i < kernel_size; i++) { if (counter[i] <= 0) { @@ -92,10 +106,10 @@ void MaxPoolGradKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, counter[i] * in_channels, 1); - MaxPoolGradCudaKernel<<>>( + MaxPoolGradCudaKernel<<>>( in_features_ptr, out_features_ptr, out_grad_ptr, @@ -107,6 +121,21 @@ void MaxPoolGradKernel(const Context& dev_ctx, } } +template +void MaxPoolGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes, + SparseCooTensor* x_grad) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "MaxPoolGradGPUKernel", ([&] { + MaxPoolGradGPUKernel( + dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu index 0f6a0d13b1..a59cd3c7a5 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu @@ -12,19 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" + #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" -#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" + +#include "paddle/phi/api/ext/dispatch.h" namespace phi { namespace sparse { -template +template __global__ void MaxPoolCudaKernel(const T* in_features_ptr, - const int* rulebook_ptr, + const IntT* rulebook_ptr, const int n, const int rulebook_len, const int channels, @@ -33,8 +36,8 @@ __global__ void MaxPoolCudaKernel(const T* in_features_ptr, CUDA_KERNEL_LOOP_TYPE(i, n * channels, int64_t) { int real_i = i / channels; int channel_i = i - real_i * channels; - int in_i = rulebook_ptr[real_i]; - int out_i = rulebook_ptr[real_i + rulebook_len]; + IntT in_i = rulebook_ptr[real_i]; + IntT out_i = rulebook_ptr[real_i + rulebook_len]; max_pool_functor.compute(in_features_ptr[in_i * channels + channel_i], &out_features_ptr[out_i * channels + channel_i]); } @@ -45,15 +48,15 @@ __global__ void MaxPoolCudaKernel(const T* in_features_ptr, * kernel: (D, H, W, C, OC) * out: (N, D, H, W, OC) **/ -template -void MaxPoolKernel(const Context& dev_ctx, - const SparseCooTensor& x, - const std::vector& kernel_sizes, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - SparseCooTensor* out, - DenseTensor* rulebook) { +template +void MaxPoolGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { const auto& x_dims = x.dims(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const std::vector& real_kernel_sizes = @@ -70,29 +73,27 @@ void MaxPoolKernel(const Context& dev_ctx, DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW); DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); - DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); // 1. product rulebook - int rulebook_len = ProductRuleBook(dev_ctx, - x, - real_kernel_sizes, - paddings, - dilations, - strides, - out_dims, - false, - rulebook, - &counter_per_kernel, - &offsets_per_kernel, - &out_index, - &unique_key, - &unique_value, - out, - &counter, - &offsets); - - const int* rulebook_ptr = rulebook->data(); + int rulebook_len = ProductRuleBook(dev_ctx, + x, + real_kernel_sizes, + paddings, + dilations, + strides, + out_dims, + false, + rulebook, + &counter_per_kernel, + &offsets_per_kernel, + &out_index, + &unique_value, + out, + &counter, + &offsets); + + const IntT* rulebook_ptr = rulebook->data(); T* out_features_ptr = out->mutable_non_zero_elements()->data(); const T* in_features_ptr = x.non_zero_elements().data(); @@ -113,10 +114,10 @@ void MaxPoolKernel(const Context& dev_ctx, auto config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, counter[i] * in_channels, 1); - MaxPoolCudaKernel<<>>( + MaxPoolCudaKernel<<>>( in_features_ptr, rulebook_ptr + offsets[i] + rulebook_len, counter[i], @@ -126,6 +127,28 @@ void MaxPoolKernel(const Context& dev_ctx, } } +template +void MaxPoolKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + SparseCooTensor* out, + DenseTensor* rulebook) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "MaxPoolGPUKernel", ([&] { + MaxPoolGPUKernel(dev_ctx, + x, + kernel_sizes, + paddings, + dilations, + strides, + out, + rulebook); + })); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h b/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h index 572ade7628..2f7366a010 100644 --- a/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h @@ -26,20 +26,18 @@ void MaxPoolGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& rulebook, const SparseCooTensor& out, - const DenseTensor& out_grad, + const SparseCooTensor& out_grad, const std::vector& kernel_sizes, - DenseTensor* x_grad); + SparseCooTensor* x_grad); template -DenseTensor MaxPoolGrad(const Context& dev_ctx, - const SparseCooTensor& x, - const DenseTensor& rulebook, - const SparseCooTensor& out, - const DenseTensor& out_grad, - const std::vector& kernel_sizes) { - DenseTensor x_grad = phi::Empty( - dev_ctx, - DenseTensorMeta(x.dtype(), x.non_zero_elements().dims(), x.layout())); +SparseCooTensor MaxPoolGrad(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const SparseCooTensor& out, + const SparseCooTensor& out_grad, + const std::vector& kernel_sizes) { + SparseCooTensor x_grad; MaxPoolGradKernel( dev_ctx, x, rulebook, out, out_grad, kernel_sizes, &x_grad); return x_grad; diff --git a/paddle/phi/kernels/sparse/sparse_pool_kernel.h b/paddle/phi/kernels/sparse/sparse_pool_kernel.h index bfadbf72e3..d5248a1ad2 100644 --- a/paddle/phi/kernels/sparse/sparse_pool_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_pool_kernel.h @@ -39,11 +39,7 @@ SparseCooTensor MaxPool(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, DenseTensor* rulebook) { - DenseTensor indices = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); - DenseTensor values = - phi::Empty(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout())); - SparseCooTensor coo(indices, values, x.dims()); + SparseCooTensor coo; MaxPoolKernel( dev_ctx, x, kernel_sizes, paddings, dilations, strides, &coo, rulebook); return coo; diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index c22464e538..9fb0e56926 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -48,13 +48,13 @@ std::vector cast(const std::vector& in) { return out; } -template -void TestConv3dBase(const std::vector& indices, +template +void TestConv3dBase(const std::vector& indices, const std::vector& features, const DDim& x_dims, const std::vector& kernel, const DDim& kernel_dims, - const std::vector& correct_out_indices, + const std::vector& correct_out_indices, const std::vector& correct_out_features, const DDim& correct_out_dims, const int non_zero_num, @@ -80,11 +80,13 @@ void TestConv3dBase(const std::vector& indices, const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; + auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); DenseTensor indices_tensor = phi::Empty( dev_ctx_cpu, - DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); - memcpy( - indices_tensor.data(), indices.data(), indices.size() * sizeof(int)); + DenseTensorMeta(indices_dtype, {4, non_zero_num}, DataLayout::NCHW)); + memcpy(indices_tensor.data(), + indices.data(), + indices.size() * sizeof(IntT)); DenseTensor features_tensor = phi::Empty( dev_ctx_cpu, DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), @@ -111,7 +113,7 @@ void TestConv3dBase(const std::vector& indices, if (!std::is_same::value) { DenseTensor rulebook = phi::Empty( - dev_ctx_cpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + dev_ctx_cpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW)); SparseCooTensor out = sparse::Conv3d(dev_ctx_cpu, x_tensor, kernel_tensor, @@ -129,8 +131,8 @@ void TestConv3dBase(const std::vector& indices, ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, out.nnz()); int cmp_indices = memcmp(correct_out_indices.data(), - out.non_zero_indices().data(), - correct_out_indices.size() * sizeof(int)); + out.non_zero_indices().data(), + correct_out_indices.size() * sizeof(IntT)); ASSERT_EQ(cmp_indices, 0); f_verify(out.non_zero_elements().data(), correct_out_features); @@ -172,7 +174,7 @@ void TestConv3dBase(const std::vector& indices, DenseTensor d_indices_tensor = phi::Empty( dev_ctx_gpu, - DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {4, non_zero_num}, DataLayout::NCHW)); phi::Copy( dev_ctx_gpu, indices_tensor, phi::GPUPlace(), true, &d_indices_tensor); @@ -195,7 +197,7 @@ void TestConv3dBase(const std::vector& indices, dev_ctx_gpu, kernel_tensor, phi::GPUPlace(), true, &d_kernel_tensor); DenseTensor d_rulebook = phi::Empty( - dev_ctx_gpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + dev_ctx_gpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW)); SparseCooTensor d_out = sparse::Conv3d(dev_ctx_gpu, d_x_tensor, d_kernel_tensor, @@ -214,7 +216,7 @@ void TestConv3dBase(const std::vector& indices, DenseTensor h_indices_tensor = phi::Empty( dev_ctx_cpu, - DenseTensorMeta(DataType::INT32, {4, d_out.nnz()}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW)); phi::Copy(dev_ctx_gpu, d_out.non_zero_indices(), phi::CPUPlace(), @@ -222,8 +224,8 @@ void TestConv3dBase(const std::vector& indices, &h_indices_tensor); int cmp_indices2 = memcmp(correct_out_indices.data(), - h_indices_tensor.data(), - correct_out_indices.size() * sizeof(int)); + h_indices_tensor.data(), + correct_out_indices.size() * sizeof(IntT)); ASSERT_EQ(cmp_indices2, 0); DenseTensor h_features_tensor = @@ -264,12 +266,13 @@ void TestConv3dBase(const std::vector& indices, #endif } -void TestConv3d(const std::vector& indices, +template +void TestConv3d(const std::vector& indices, const std::vector& features, const DDim& x_dims, const std::vector& kernel, const DDim& kernel_dims, - const std::vector& correct_out_indices, + const std::vector& correct_out_indices, const std::vector& correct_out_features, const DDim& correct_out_dims, const int non_zero_num, @@ -282,41 +285,41 @@ void TestConv3d(const std::vector& indices, const std::vector kernel_grad = {}, const bool subm = false) { // test float - TestConv3dBase(indices, - features, - x_dims, - kernel, - kernel_dims, - correct_out_indices, - correct_out_features, - correct_out_dims, - non_zero_num, - paddings, - strides, - dilations, - diff, - backward, - features_grad, - kernel_grad, - subm); + TestConv3dBase(indices, + features, + x_dims, + kernel, + kernel_dims, + correct_out_indices, + correct_out_features, + correct_out_dims, + non_zero_num, + paddings, + strides, + dilations, + diff, + backward, + features_grad, + kernel_grad, + subm); // test double - TestConv3dBase(indices, - cast(features), - x_dims, - cast(kernel), - kernel_dims, - correct_out_indices, - cast(correct_out_features), - correct_out_dims, - non_zero_num, - paddings, - strides, - dilations, - diff, - backward, - cast(features_grad), - cast(kernel_grad), - subm); + TestConv3dBase(indices, + cast(features), + x_dims, + cast(kernel), + kernel_dims, + correct_out_indices, + cast(correct_out_features), + correct_out_dims, + non_zero_num, + paddings, + strides, + dilations, + diff, + backward, + cast(features_grad), + cast(kernel_grad), + subm); } TEST(DEV_API, sparse_conv3d) { @@ -616,6 +619,51 @@ TEST(DEV_API, sparse_conv2d) { dilations); } +TEST(DEV_API, sparse_conv2d_int64) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 1, 5, 5, in_channels}; + DDim kernel_dims = {1, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 1, 3, 3, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices_flatten = {0, 0, 0, 0, 0, 0, 0, 4, 0, 3, 2, 4}; + + std::vector features = {-0.79394531, -0.3125, -0.55029297}; + // 3*3*3=27 + std::vector kernel = {0.65820312, + 0.75048828, + 0.21411133, + 0.17370605, + 0.85546875, + 0.53076172, + 0.28833008, + 0.71044922, + 0.00659943}; + + std::vector out_indices_flatten = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 2, 2, 1, 2, 0, 1, 2}; + + std::vector out_features = { + -0.17004, -0.71338, -0.00206, -0.22205, -0.09009}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + TEST(DEV_API, sparse_conv3d_backward) { const int in_channels = 1; const int out_channels = 1; diff --git a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc index 632beadf3d..8f7288d70d 100644 --- a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc @@ -36,11 +36,11 @@ std::vector cast(const std::vector& in) { } return out; } -template -void TestMaxPoolBase(const std::vector& indices, +template +void TestMaxPoolBase(const std::vector& indices, const std::vector& features, const DDim& x_dims, - const std::vector& correct_out_indices, + const std::vector& correct_out_indices, const std::vector& correct_out_features, const DDim& correct_out_dims, const int non_zero_num, @@ -65,11 +65,13 @@ void TestMaxPoolBase(const std::vector& indices, const int in_channels = x_dims[4]; const int out_channels = in_channels; + auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); DenseTensor indices_tensor = phi::Empty( dev_ctx_cpu, - DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); - memcpy( - indices_tensor.data(), indices.data(), indices.size() * sizeof(int)); + DenseTensorMeta(indices_dtype, {4, non_zero_num}, DataLayout::NCHW)); + memcpy(indices_tensor.data(), + indices.data(), + indices.size() * sizeof(IntT)); DenseTensor features_tensor = phi::Empty( dev_ctx_cpu, DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), @@ -88,8 +90,7 @@ void TestMaxPoolBase(const std::vector& indices, }; if (!std::is_same::value) { - DenseTensor rulebook = phi::Empty( - dev_ctx_cpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + DenseTensor rulebook; SparseCooTensor out = sparse::MaxPool(dev_ctx_cpu, x_tensor, kernel_sizes, @@ -105,20 +106,16 @@ void TestMaxPoolBase(const std::vector& indices, ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, out.nnz()); int cmp_indices = memcmp(correct_out_indices.data(), - out.non_zero_indices().data(), - correct_out_indices.size() * sizeof(int)); + out.non_zero_indices().data(), + correct_out_indices.size() * sizeof(IntT)); ASSERT_EQ(cmp_indices, 0); f_verify(out.non_zero_elements().data(), correct_out_features); if (backward) { - DenseTensor x_grad = sparse::MaxPoolGrad(dev_ctx_cpu, - x_tensor, - rulebook, - out, - out.non_zero_elements(), - kernel_sizes); - f_verify(x_grad.data(), features_grad); + SparseCooTensor x_grad = sparse::MaxPoolGrad( + dev_ctx_cpu, x_tensor, rulebook, out, out, kernel_sizes); + f_verify(x_grad.non_zero_elements().data(), features_grad); } } @@ -142,7 +139,7 @@ void TestMaxPoolBase(const std::vector& indices, DenseTensor d_indices_tensor = phi::Empty( dev_ctx_gpu, - DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {4, non_zero_num}, DataLayout::NCHW)); phi::Copy( dev_ctx_gpu, indices_tensor, phi::GPUPlace(), true, &d_indices_tensor); @@ -153,8 +150,7 @@ void TestMaxPoolBase(const std::vector& indices, SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims); - DenseTensor d_rulebook = phi::Empty( - dev_ctx_gpu, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + DenseTensor d_rulebook; SparseCooTensor d_out = sparse::MaxPool(dev_ctx_gpu, d_x_tensor, kernel_sizes, @@ -171,7 +167,7 @@ void TestMaxPoolBase(const std::vector& indices, DenseTensor h_indices_tensor = phi::Empty( dev_ctx_cpu, - DenseTensorMeta(DataType::INT32, {4, d_out.nnz()}, DataLayout::NCHW)); + DenseTensorMeta(indices_dtype, {4, d_out.nnz()}, DataLayout::NCHW)); phi::Copy(dev_ctx_gpu, d_out.non_zero_indices(), phi::CPUPlace(), @@ -179,8 +175,8 @@ void TestMaxPoolBase(const std::vector& indices, &h_indices_tensor); int cmp_indices2 = memcmp(correct_out_indices.data(), - h_indices_tensor.data(), - correct_out_indices.size() * sizeof(int)); + h_indices_tensor.data(), + correct_out_indices.size() * sizeof(IntT)); ASSERT_EQ(cmp_indices2, 0); DenseTensor h_features_tensor = @@ -194,23 +190,25 @@ void TestMaxPoolBase(const std::vector& indices, f_verify(h_features_tensor.data(), correct_out_features); if (backward) { - DenseTensor x_grad = sparse::MaxPoolGrad(dev_ctx_gpu, - d_x_tensor, - d_rulebook, - d_out, - d_out.non_zero_elements(), - kernel_sizes); - DenseTensor h_features_grad = phi::EmptyLike(dev_ctx_cpu, x_grad); - phi::Copy(dev_ctx_gpu, x_grad, phi::CPUPlace(), true, &h_features_grad); + SparseCooTensor x_grad = sparse::MaxPoolGrad( + dev_ctx_gpu, d_x_tensor, d_rulebook, d_out, d_out, kernel_sizes); + DenseTensor h_features_grad = + phi::EmptyLike(dev_ctx_cpu, x_grad.non_zero_elements()); + phi::Copy(dev_ctx_gpu, + x_grad.non_zero_elements(), + phi::CPUPlace(), + true, + &h_features_grad); f_verify(h_features_grad.data(), features_grad); } #endif } -void TestMaxPool(const std::vector& indices, +template +void TestMaxPool(const std::vector& indices, const std::vector& features, const DDim& x_dims, - const std::vector& correct_out_indices, + const std::vector& correct_out_indices, const std::vector& correct_out_features, const DDim& correct_out_dims, const int non_zero_num, @@ -222,35 +220,35 @@ void TestMaxPool(const std::vector& indices, const bool backward = false, const std::vector features_grad = {}) { // test float - TestMaxPoolBase(indices, - features, - x_dims, - correct_out_indices, - correct_out_features, - correct_out_dims, - non_zero_num, - kernel_sizes, - paddings, - strides, - dilations, - diff, - backward, - features_grad); + TestMaxPoolBase(indices, + features, + x_dims, + correct_out_indices, + correct_out_features, + correct_out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + diff, + backward, + features_grad); // test double - TestMaxPoolBase(indices, - cast(features), - x_dims, - correct_out_indices, - cast(correct_out_features), - correct_out_dims, - non_zero_num, - kernel_sizes, - paddings, - strides, - dilations, - diff, - backward, - cast(features_grad)); + TestMaxPoolBase(indices, + cast(features), + x_dims, + correct_out_indices, + cast(correct_out_features), + correct_out_dims, + non_zero_num, + kernel_sizes, + paddings, + strides, + dilations, + diff, + backward, + cast(features_grad)); } TEST(DEV_API, sparse_maxpool) { -- GitLab