From 5ab2cec53328bf814540a693876ab517a53f8b52 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 14 Mar 2022 10:20:44 +0800 Subject: [PATCH] Move the common function to kernel funcs (#40422) * move the common function to kernel/funcs/sparse/ * add namespace * rm unused file * move func * reuse code --- .../phi/kernels/funcs/sparse/common_shape.h | 45 +++++ paddle/phi/kernels/funcs/sparse/convolution.h | 170 ++++++++++++++++++ .../phi/kernels/sparse/convolution_kernel.h | 96 +--------- paddle/phi/kernels/sparse/cpu/convolution.h | 37 ++-- .../sparse/cpu/convolution_grad_kernel.cc | 33 +--- .../kernels/sparse/cpu/convolution_kernel.cc | 14 +- .../kernels/sparse/cpu/sparse_utils_kernel.cc | 5 +- .../cpu/submanifold_convolution_kernel.cu | 30 ---- .../sparse/gpu/convolution_grad_kernel.cu | 33 +--- .../kernels/sparse/gpu/convolution_kernel.cu | 41 ++--- .../kernels/sparse/gpu/sparse_utils_kernel.cu | 78 ++++---- .../phi/kernels/sparse/sparse_utils_kernel.h | 31 ---- 12 files changed, 333 insertions(+), 280 deletions(-) create mode 100644 paddle/phi/kernels/funcs/sparse/common_shape.h create mode 100644 paddle/phi/kernels/funcs/sparse/convolution.h delete mode 100644 paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu diff --git a/paddle/phi/kernels/funcs/sparse/common_shape.h b/paddle/phi/kernels/funcs/sparse/common_shape.h new file mode 100644 index 0000000000..3617e3cd2f --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/common_shape.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/ddim.h" + +namespace phi { +namespace funcs { +namespace sparse { + +inline const DDim InferDenseDims(const DDim& x_dims, + const int64_t sparse_dim, + const int64_t non_zero_num) { + auto dense_dim = x_dims.size() - sparse_dim; + DDim values_dims; + if (dense_dim > 0) { + std::vector dense_dim_vec(dense_dim + 1); + dense_dim_vec[0] = non_zero_num; + memcpy(&dense_dim_vec[1], + x_dims.Get() + sparse_dim, + dense_dim * sizeof(x_dims[0])); + values_dims = phi::make_ddim(dense_dim_vec); + } else { + values_dims = phi::make_ddim({non_zero_num}); + } + return values_dims; +} + +} // namespace sparse +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/sparse/convolution.h b/paddle/phi/kernels/funcs/sparse/convolution.h new file mode 100644 index 0000000000..68fe8880a9 --- /dev/null +++ b/paddle/phi/kernels/funcs/sparse/convolution.h @@ -0,0 +1,170 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { +namespace funcs { +namespace sparse { + +struct Dims4D { + int dims[4]; + Dims4D(const int batch, const int x, const int y, const int z) { + dims[0] = batch; + dims[1] = z; + dims[2] = y; + dims[3] = x; + } + HOSTDEVICE const int& operator[](int i) const { return dims[i]; } +}; + +// Judge whether the current position x is in (lower, upper) +inline HOSTDEVICE bool Check(const int& x, + const int& kx, + const int& pad, + const int& stride, + const int dilation, + const int kdim, + const int xdim) { + const int lower = x - dilation * kx + pad; + const int uper = x + (kdim - kx - 1) * dilation - pad; + return (lower >= 0 && lower % stride == 0 && uper < xdim); +} + +// Check whether the current position(x, y, z) is legal: +// Judge the minimum and maximum values at each latitude +inline HOSTDEVICE bool Check(const Dims4D& dims, + const Dims4D& kernel_dims, + const Dims4D& paddings, + const Dims4D& dilations, + const Dims4D& strides, + const int x, + const int y, + const int z, + const int kx, + const int ky, + const int kz) { + bool x_valid = Check( + x, kx, paddings[3], strides[3], dilations[3], kernel_dims[3], dims[3]); + bool y_valid = Check( + y, ky, paddings[2], strides[2], dilations[2], kernel_dims[2], dims[2]); + bool z_valid = Check( + z, kz, paddings[1], strides[1], dilations[1], kernel_dims[1], dims[1]); + return (x_valid && y_valid && z_valid); +} + +template +inline HOSTDEVICE int PointToIndex(const int& batch, + const int& x, + const int& y, + const int& z, + const Dim& dims) { + return batch * dims[1] * dims[2] * dims[3] + z * dims[2] * dims[3] + + y * dims[3] + x; +} + +// TODO(zhangkaihuo): use division and multiply to optimize +// modulo operation +template +inline HOSTDEVICE void IndexToPoint( + const int index, const Dim& dims, int* batch, int* x, int* y, int* z) { + int n = index; + *x = n % dims[3]; + n /= dims[3]; + *y = n % dims[2]; + n /= dims[2]; + *z = n % dims[1]; + n /= dims[1]; + *batch = n; +} + +inline void GetOutShape(const DDim& x_dims, + const DDim& kernel_dims, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + DDim* out_dims) { + PADDLE_ENFORCE_EQ( + x_dims.size(), + 5, + phi::errors::InvalidArgument("the shape of x should be (N, D, H, W, C)")); + PADDLE_ENFORCE_EQ(kernel_dims.size(), + 5, + phi::errors::InvalidArgument( + "the shape of kernel should be (D, H, W, C, OC)")); + + // infer out shape + (*out_dims)[0] = x_dims[0]; + (*out_dims)[4] = kernel_dims[4]; + for (int i = 1; i < 4; i++) { + (*out_dims)[i] = (x_dims[i] + 2 * paddings[i - 1] - + dilations[i - 1] * (kernel_dims[i - 1] - 1) - 1) / + strides[i - 1] + + 1; + } +} + +inline void ResetSubmKernelSizeAndStrides(const DDim& kernel_dims, + std::vector* paddings, + std::vector* strides) { + for (uint64_t i = 0; i < paddings->size(); i++) { + (*paddings)[i] = kernel_dims[i] / 2; + (*strides)[i] = 1; + } +} + +template +inline void SubmPreProcess(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const SparseCooTensor& out_grad, + const int in_channels, + const int out_channels, + const int half_kernel_size, + DenseTensor* kernel_grad, + DenseTensor* x_grad) { + auto blas = phi::funcs::GetBlas(dev_ctx); + T* d_kernel_ptr = kernel_grad->data(); + blas.GEMM(CblasTrans, + CblasNoTrans, + x.non_zero_elements().dims()[1], + out_grad.non_zero_elements().dims()[1], + x.non_zero_elements().dims()[0], + static_cast(1), + x.non_zero_elements().data(), + out_grad.non_zero_elements().data(), + static_cast(0), + d_kernel_ptr + half_kernel_size * in_channels * out_channels); + + // call gemm: d_x = out_grad * transpose(kernel) + // (n, out_channels) * (out_channels, in_channels) + T* x_grad_ptr = x_grad->data(); + blas.GEMM(CblasNoTrans, + CblasTrans, + out_grad.non_zero_elements().dims()[0], + in_channels, + out_grad.non_zero_elements().dims()[1], + static_cast(1), + out_grad.non_zero_elements().data(), + kernel.data() + half_kernel_size * in_channels * out_channels, + static_cast(0), + x_grad_ptr); +} + +} // namespace sparse +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/sparse/convolution_kernel.h b/paddle/phi/kernels/sparse/convolution_kernel.h index 778600a228..ff2cf94edb 100644 --- a/paddle/phi/kernels/sparse/convolution_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_kernel.h @@ -18,105 +18,11 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/sparse/convolution.h" namespace phi { namespace sparse { -struct Dims4D { - int dims[4]; - Dims4D(const int batch, const int x, const int y, const int z) { - dims[0] = batch; - dims[1] = z; - dims[2] = y; - dims[3] = x; - } - HOSTDEVICE const int& operator[](int i) const { return dims[i]; } -}; - -// Judge whether the current position x is in (lower, upper) -inline HOSTDEVICE bool Check(const int& x, - const int& kx, - const int& pad, - const int& stride, - const int dilation, - const int kdim, - const int xdim) { - const int lower = x - dilation * kx + pad; - const int uper = x + (kdim - kx - 1) * dilation - pad; - return (lower >= 0 && lower % stride == 0 && uper < xdim); -} - -// Check whether the current position(x, y, z) is legal: -// Judge the minimum and maximum values at each latitude -inline HOSTDEVICE bool Check(const Dims4D& dims, - const Dims4D& kernel_dims, - const Dims4D& paddings, - const Dims4D& dilations, - const Dims4D& strides, - const int x, - const int y, - const int z, - const int kx, - const int ky, - const int kz) { - bool x_valid = Check( - x, kx, paddings[3], strides[3], dilations[3], kernel_dims[3], dims[3]); - bool y_valid = Check( - y, ky, paddings[2], strides[2], dilations[2], kernel_dims[2], dims[2]); - bool z_valid = Check( - z, kz, paddings[1], strides[1], dilations[1], kernel_dims[1], dims[1]); - return (x_valid && y_valid && z_valid); -} - -template -inline HOSTDEVICE int PointToIndex(const int& batch, - const int& x, - const int& y, - const int& z, - const Dim& dims) { - return batch * dims[1] * dims[2] * dims[3] + z * dims[2] * dims[3] + - y * dims[3] + x; -} - -template -inline HOSTDEVICE void IndexToPoint( - const int index, const Dim& dims, int* batch, int* x, int* y, int* z) { - int n = index; - *x = n % dims[3]; - n /= dims[3]; - *y = n % dims[2]; - n /= dims[2]; - *z = n % dims[1]; - n /= dims[1]; - *batch = n; -} - -inline void GetOutShape(const DDim& x_dims, - const DDim& kernel_dims, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - DDim* out_dims) { - PADDLE_ENFORCE_EQ( - x_dims.size(), - 5, - phi::errors::InvalidArgument("the shape of x should be (N, D, H, W, C)")); - PADDLE_ENFORCE_EQ(kernel_dims.size(), - 5, - phi::errors::InvalidArgument( - "the shape of kernel should be (D, H, W, C, OC)")); - - // infer out shape - (*out_dims)[0] = x_dims[0]; - (*out_dims)[4] = kernel_dims[4]; - for (int i = 1; i < 4; i++) { - (*out_dims)[i] = (x_dims[i] + 2 * paddings[i - 1] - - dilations[i - 1] * (kernel_dims[i - 1] - 1) - 1) / - strides[i - 1] + - 1; - } -} - template void Conv3dKernel(const Context& dev_ctx, const SparseCooTensor& x, diff --git a/paddle/phi/kernels/sparse/cpu/convolution.h b/paddle/phi/kernels/sparse/cpu/convolution.h index a5a946dce7..64c32df189 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution.h +++ b/paddle/phi/kernels/sparse/cpu/convolution.h @@ -16,8 +16,6 @@ limitations under the License. */ #include -#include "paddle/phi/api/lib/utils/allocator.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/sparse_coo_tensor.h" @@ -28,6 +26,8 @@ limitations under the License. */ namespace phi { namespace sparse { +using Dims4D = phi::funcs::sparse::Dims4D; + // such as: kernel(3, 3, 3), kernel_size = 27 // counter_per_weight: (kernel_size) // TODO(zhangkaihuo): optimize performance with multithreading @@ -67,7 +67,8 @@ void ProductRuleBook(const Context& dev_ctx, int in_z = indices_ptr[i + non_zero_num]; int in_y = indices_ptr[i + 2 * non_zero_num]; int in_x = indices_ptr[i + 3 * non_zero_num]; - int index = PointToIndex(batch, in_x, in_y, in_z, x_dims); + int index = phi::funcs::sparse::PointToIndex( + batch, in_x, in_y, in_z, x_dims); hash_in.insert(index); } } @@ -86,20 +87,20 @@ void ProductRuleBook(const Context& dev_ctx, int out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0]; int out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1]; int out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2]; - if (Check(c_x_dims, - c_kernel_dims, - c_paddings, - c_dilations, - c_strides, - in_x, - in_y, - in_z, - kx, - ky, - kz)) { + if (phi::funcs::sparse::Check(c_x_dims, + c_kernel_dims, + c_paddings, + c_dilations, + c_strides, + in_x, + in_y, + in_z, + kx, + ky, + kz)) { if (subm) { - int out_index = - PointToIndex(batch, out_x, out_y, out_z, out_dims); + int out_index = phi::funcs::sparse::PointToIndex( + batch, out_x, out_y, out_z, out_dims); if (hash_in.find(out_index) == hash_in.end()) { continue; } @@ -112,7 +113,7 @@ void ProductRuleBook(const Context& dev_ctx, rulebook_ptr[rulebook_index] = kernel_index - 1; rulebook_ptr[rulebook_index + rulebook_len] = i; // in_i rulebook_ptr[rulebook_index + rulebook_len * 2] = - PointToIndex( + phi::funcs::sparse::PointToIndex( batch, out_x, out_y, out_z, out_dims); // out_index ++rulebook_index; } @@ -161,7 +162,7 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx, for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) { const int index = *it; int batch, x, y, z; - IndexToPoint(index, out_dims, &batch, &x, &y, &z); + phi::funcs::sparse::IndexToPoint(index, out_dims, &batch, &x, &y, &z); out_indices_ptr[i] = batch; out_indices_ptr[i + out_non_zero_num] = z; out_indices_ptr[i + out_non_zero_num * 2] = y; diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index bb414faef6..5d7b381b7c 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -94,30 +94,15 @@ void Conv3dGradKernel(const Context& dev_ctx, offsets[kernel_size] = offset; if (subm) { - blas.GEMM(CblasTrans, - CblasNoTrans, - x.non_zero_elements().dims()[1], - out_grad.non_zero_elements().dims()[1], - x.non_zero_elements().dims()[0], - static_cast(1), - x.non_zero_elements().data(), - out_grad.non_zero_elements().data(), - static_cast(0), - d_kernel_ptr + half_kernel_size * in_channels * out_channels); - - // call gemm: d_x = out_grad * transpose(kernel) - // (n, out_channels) * (out_channels, in_channels) - T* x_grad_ptr = x_grad->data(); - blas.GEMM(CblasNoTrans, - CblasTrans, - out_grad.non_zero_elements().dims()[0], - in_channels, - out_grad.non_zero_elements().dims()[1], - static_cast(1), - out_grad.non_zero_elements().data(), - kernel.data() + half_kernel_size * in_channels * out_channels, - static_cast(0), - x_grad_ptr); + phi::funcs::sparse::SubmPreProcess(dev_ctx, + x, + kernel, + out_grad, + in_channels, + out_channels, + half_kernel_size, + kernel_grad, + x_grad); if (max_count == 0) { return; } diff --git a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc index f65e1cf579..746ca04a82 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/sparse/cpu/convolution.h" -#include "paddle/phi/api/lib/utils/allocator.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/blas/blas.h" @@ -46,10 +44,16 @@ void Conv3dKernel(const Context& dev_ctx, const auto& kernel_dims = kernel.dims(); int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; DDim out_dims = {1, 1, 1, 1, 1}; - GetOutShape(x_dims, kernel_dims, paddings, dilations, strides, &out_dims); + phi::funcs::sparse::GetOutShape( + x_dims, kernel_dims, paddings, dilations, strides, &out_dims); const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; + std::vector subm_paddings(paddings), subm_strides(strides); + if (subm) { + phi::funcs::sparse::ResetSubmKernelSizeAndStrides( + kernel.dims(), &subm_paddings, &subm_strides); + } // Second algorithm: // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // 1. product rulebook @@ -60,9 +64,9 @@ void Conv3dKernel(const Context& dev_ctx, ProductRuleBook(dev_ctx, x, kernel, - paddings, + subm_paddings, dilations, - strides, + subm_strides, out_dims, subm, rulebook, diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index ba89135641..50e95ee0b8 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" #include "paddle/phi/api/lib/utils/allocator.h" -#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/sparse/common_shape.h" namespace phi { namespace sparse { @@ -71,7 +71,8 @@ void DenseToSparseCooKernel(const Context& dev_ctx, int64_t non_zero_num = GetNonZeroNum(x, sparse_dim); const auto place = dev_ctx.GetPlace(); - const auto values_dims = InferDenseDims(x_dims, sparse_dim, non_zero_num); + const auto values_dims = + phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num); DenseTensorMeta indices_meta(DataType::INT64, {sparse_dim, static_cast(non_zero_num)}, DataLayout::NCHW); diff --git a/paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu b/paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu deleted file mode 100644 index 5f6d24093a..0000000000 --- a/paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/phi/api/lib/utils/allocator.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/sparse_coo_tensor.h" -#include "paddle/phi/core/tensor_meta.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/sparse/submanifold_convolution_kernel.h" - -namespace phi { -namespace sparse {} // namespace sparse -} // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index a307ab0f54..d6d992d0f4 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -110,30 +110,15 @@ void Conv3dGradKernel(const Context& dev_ctx, offsets[kernel_size] = offset; if (subm) { - blas.GEMM(CblasTrans, - CblasNoTrans, - x.non_zero_elements().dims()[1], - out_grad.non_zero_elements().dims()[1], - x.non_zero_elements().dims()[0], - static_cast(1), - x.non_zero_elements().data(), - out_grad.non_zero_elements().data(), - static_cast(0), - d_kernel_ptr + half_kernel_size * in_channels * out_channels); - - // call gemm: d_x = out_grad * transpose(kernel) - // (n, out_channels) * (out_channels, in_channels) - T* x_grad_ptr = x_grad->data(); - blas.GEMM(CblasNoTrans, - CblasTrans, - out_grad.non_zero_elements().dims()[0], - in_channels, - out_grad.non_zero_elements().dims()[1], - static_cast(1), - out_grad.non_zero_elements().data(), - kernel.data() + half_kernel_size * in_channels * out_channels, - static_cast(0), - x_grad_ptr); + phi::funcs::sparse::SubmPreProcess(dev_ctx, + x, + kernel, + out_grad, + in_channels, + out_channels, + half_kernel_size, + kernel_grad, + x_grad); if (max_count == 0) { return; } diff --git a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu index 94186600f1..1a0c7e9b97 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu @@ -33,6 +33,8 @@ limitations under the License. */ namespace phi { namespace sparse { +using Dims4D = phi::funcs::sparse::Dims4D; + __global__ void SetFlagAndUpdateCounterKernel(const int* indexs, const int n, const int rulebook_len, @@ -83,7 +85,8 @@ __global__ void UpdateIndexKernel(const int* unique_keys, for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { const int index = unique_keys[i]; int batch, x, y, z; - IndexToPoint(index, out_dims, &batch, &x, &y, &z); + phi::funcs::sparse::IndexToPoint( + index, out_dims, &batch, &x, &y, &z); // get out indices out_indices[i] = batch; out_indices[i + non_zero_num] = z; @@ -150,23 +153,23 @@ __global__ void ProductRuleBookKernel(const int* x_indices, for (int ky = 0; ky < kernel_dims[2]; ky++) { for (int kx = 0; kx < kernel_dims[3]; kx++) { int in_i = -1, out_index = -1, kernel_i = -1; - if (Check(x_dims, - kernel_dims, - paddings, - dilations, - strides, - in_x, - in_y, - in_z, - kx, - ky, - kz)) { + if (phi::funcs::sparse::Check(x_dims, + kernel_dims, + paddings, + dilations, + strides, + in_x, + in_y, + in_z, + kx, + ky, + kz)) { int out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1]; int out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2]; int out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; in_i = i; - out_index = - PointToIndex(batch, out_x, out_y, out_z, out_dims); + out_index = phi::funcs::sparse::PointToIndex( + batch, out_x, out_y, out_z, out_dims); atomicAdd(&counter_buf[kernel_index], 1); kernel_i = kernel_index; } @@ -542,7 +545,8 @@ void Conv3dKernel(const Context& dev_ctx, const auto& kernel_dims = kernel.dims(); int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; DDim out_dims = {1, 1, 1, 1, 1}; - GetOutShape(x_dims, kernel_dims, paddings, dilations, strides, &out_dims); + phi::funcs::sparse::GetOutShape( + x_dims, kernel_dims, paddings, dilations, strides, &out_dims); out->set_dims(out_dims); const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; @@ -564,11 +568,8 @@ void Conv3dKernel(const Context& dev_ctx, std::vector subm_paddings(paddings), subm_strides(strides); if (subm) { - auto kernel_dims = kernel.dims(); - for (int i = 0; i < paddings.size(); i++) { - subm_paddings[i] = kernel_dims[i] / 2; - subm_strides[i] = 1; - } + phi::funcs::sparse::ResetSubmKernelSizeAndStrides( + kernel.dims(), &subm_paddings, &subm_strides); } int n = ProductRuleBook(dev_ctx, diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index 2e741111fb..8048180e42 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -16,8 +16,10 @@ limitations under the License. */ #include #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/sparse/common_shape.h" #include "paddle/phi/kernels/sparse/sparse_utils_kernel.h" namespace phi { @@ -115,14 +117,16 @@ void DenseToSparseCooKernel(const Context& dev_ctx, PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(nums_ptr, 0, sizeof(int), dev_ctx.stream())); #endif - int grid_size = 1, block_size = 1; - GetGpuLaunchConfig1D(dev_ctx, rows, &grid_size, &block_size); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1); auto temp_indexs_meta = phi::DenseTensorMeta(DataType::INT32, {rows}, phi::DataLayout::NCHW); DenseTensor temp_indexs = phi::Empty(dev_ctx, std::move(temp_indexs_meta)); int* temp_indexs_ptr = temp_indexs.mutable_data(place); - GetNonZeroNums<<>>( + GetNonZeroNums<<>>( x_data, rows, cols, nums_ptr, temp_indexs_ptr); #ifdef PADDLE_WITH_HIP thrust::remove(thrust::hip::par.on(dev_ctx.stream()), @@ -167,7 +171,8 @@ void DenseToSparseCooKernel(const Context& dev_ctx, dev_ctx.Wait(); // wait the copy - const auto values_dims = InferDenseDims(x_dims, sparse_dim, non_zero_num); + const auto values_dims = + phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num); DenseTensorMeta indices_meta(DataType::INT64, {sparse_dim, static_cast(non_zero_num)}, DataLayout::NCHW); @@ -184,16 +189,18 @@ void DenseToSparseCooKernel(const Context& dev_ctx, T* sparse_data = values.mutable_data(place); // 3. calc indices by indexs and get values by indexs - GetGpuLaunchConfig1D(dev_ctx, non_zero_num, &grid_size, &block_size); - GetNonZeroElementsAndIndices<<>>( - x_data, - sparse_dim, - cols, - d_x_dims.data(), - non_zero_num, - temp_indexs_ptr, - indices_data, - sparse_data); + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); + GetNonZeroElementsAndIndices<<>>(x_data, + sparse_dim, + cols, + d_x_dims.data(), + non_zero_num, + temp_indexs_ptr, + indices_data, + sparse_data); out->SetMember(indices, values, x_dims, true); } @@ -263,10 +270,9 @@ void SparseCsrToCooKernel(const Context& dev_ctx, int* offsets_ptr = batchs == 1 ? nullptr : offsets.mutable_data(place); T* coo_values_data = values.mutable_data(place); - int grid_size = 1, block_size = 1; if (batchs > 1) { - GetGpuLaunchConfig1D(dev_ctx, batchs, &grid_size, &block_size); - GetBatchSizes<<>>( + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1); + GetBatchSizes<<>>( csr_crows_data, rows, batchs, offsets_ptr); #ifdef PADDLE_WITH_HIP @@ -279,9 +285,10 @@ void SparseCsrToCooKernel(const Context& dev_ctx, offsets_ptr); } - GetGpuLaunchConfig1D(dev_ctx, rows, &grid_size, &block_size); - dim3 grids(grid_size, batchs, 1); - ConvertCsrCrowsToCooRows<<>>( + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rows, 1); + config.block_per_grid.y = batchs; + ConvertCsrCrowsToCooRows<<>>( csr_crows_data, offsets_ptr, coo_rows_data, batch_ptr, rows); #ifdef PADDLE_WITH_HIP @@ -404,21 +411,29 @@ void SparseCooToCsrKernel(const Context& dev_ctx, // TODO(zhangkahuo): call coalesced() to distinct and sort the indices } - int grid_size = 1, block_size = 1; - GetGpuLaunchConfig1D(dev_ctx, batchs, &grid_size, &block_size); + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, batchs, 1); if (batchs > 1) { DenseTensorMeta batchs_meta(DataType::INT64, {batchs}, DataLayout::NCHW); phi::DenseTensor batchs_offset( phi::make_intrusive(place), std::move(batchs_meta)); int64_t* batchs_offset_ptr = batchs_offset.mutable_data(place); - GetBatchsOffset<<>>( + GetBatchsOffset<<>>( batchs_ptr, non_zero_num, batchs_offset_ptr); - dim3 grids(grid_size, batchs, 1); - ConvertCooRowsToCsrCrows<<>>( + config.block_per_grid.y = batchs; + ConvertCooRowsToCsrCrows<<>>( batchs_offset_ptr, coo_rows_data, csr_crows_data, rows, non_zero_num); } else { - ConvertCooRowsToCsrCrows<<>>( + ConvertCooRowsToCsrCrows<<>>( nullptr, coo_rows_data, csr_crows_data, rows, non_zero_num); } @@ -522,12 +537,13 @@ void SparseCooToDenseKernel(const Context& dev_ctx, PADDLE_ENFORCE_GPU_SUCCESS( cudaMemsetAsync(out_data, 0, sizeof(T) * out->numel(), dev_ctx.stream())); #endif - int grid_size = 1, block_size = 1; - GetGpuLaunchConfig1D(dev_ctx, non_zero_num, &grid_size, &block_size); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); - KernelSparseCooToDense< - T, - int64_t><<>>( + KernelSparseCooToDense<<>>( indices.data(), d_sparse_offsets.data(), x_data, diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index c83b2130ed..da05eb3d3c 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -23,37 +23,6 @@ limitations under the License. */ namespace phi { namespace sparse { -inline const DDim InferDenseDims(const DDim& x_dims, - const int64_t sparse_dim, - const int64_t non_zero_num) { - auto dense_dim = x_dims.size() - sparse_dim; - DDim values_dims; - if (dense_dim) { - std::vector dense_dim_vec(dense_dim + 1); - dense_dim_vec[0] = non_zero_num; - memcpy(&dense_dim_vec[1], - x_dims.Get() + sparse_dim, - dense_dim * sizeof(x_dims[0])); - values_dims = phi::make_ddim(dense_dim_vec); - } else { - values_dims = phi::make_ddim({non_zero_num}); - } - return values_dims; -} - -template -inline void GetGpuLaunchConfig1D(const Context& dev_ctx, - const int64_t n, - int* grid_size, - int* block_size) { - const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock(); - const int MAX_GRID_DIM = dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; - *block_size = (n >= MAX_BLOCK_DIM) ? MAX_BLOCK_DIM - : (1 << static_cast(std::log2(n))); - *grid_size = n / *block_size; - *grid_size = (*grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : *grid_size; -} - template void DenseToSparseCooKernel(const Context& dev_ctx, const DenseTensor& x, -- GitLab