diff --git a/paddle/phi/core/sparse_coo_tensor.cc b/paddle/phi/core/sparse_coo_tensor.cc index f2987e36d3db0163c275562562bf5d6bf7aa91af..ceaebe4e35b7120af160e27fca4347add941d458 100644 --- a/paddle/phi/core/sparse_coo_tensor.cc +++ b/paddle/phi/core/sparse_coo_tensor.cc @@ -106,7 +106,7 @@ void SparseCooTensor::SetMember(const DenseTensor& non_zero_indices, const bool coalesced) { this->non_zero_indices_ = non_zero_indices; this->non_zero_elements_ = non_zero_elements; - this->dims_ = dims_; + this->dims_ = dims; this->coalesced_ = coalesced; } diff --git a/paddle/phi/kernels/sparse/convolution_kernel.h b/paddle/phi/kernels/sparse/convolution_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..71160a6365dc778e40476af960f21443cac698e5 --- /dev/null +++ b/paddle/phi/kernels/sparse/convolution_kernel.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { +namespace sparse { + +struct Dims4D { + int dims[4]; + Dims4D(const int batch, const int x, const int y, const int z) { + dims[0] = batch; + dims[1] = z; + dims[2] = y; + dims[3] = x; + } + HOSTDEVICE const int& operator[](int i) const { return dims[i]; } +}; + +// Judge whether the current position x is in (lower, upper) +inline HOSTDEVICE bool Check(const int& x, + const int& kx, + const int& pad, + const int& stride, + const int dilation, + const int kdim, + const int xdim) { + const int lower = x - dilation * kx + pad; + const int uper = x + (kdim - kx - 1) * dilation - pad; + return (lower >= 0 && lower % stride == 0 && uper < xdim); +} + +// Check whether the current position(x, y, z) is legal: +// Judge the minimum and maximum values at each latitude +inline HOSTDEVICE bool Check(const Dims4D& dims, + const Dims4D& kernel_dims, + const Dims4D& paddings, + const Dims4D& dilations, + const Dims4D& strides, + const int x, + const int y, + const int z, + const int kx, + const int ky, + const int kz) { + bool x_valid = Check( + x, kx, paddings[3], strides[3], dilations[3], kernel_dims[3], dims[3]); + bool y_valid = Check( + y, ky, paddings[2], strides[2], dilations[2], kernel_dims[2], dims[2]); + bool z_valid = Check( + z, kz, paddings[1], strides[1], dilations[1], kernel_dims[1], dims[1]); + return (x_valid && y_valid && z_valid); +} + +template +inline HOSTDEVICE int PointToIndex(const int& batch, + const int& x, + const int& y, + const int& z, + const Dim& dims) { + return batch * dims[1] * dims[2] * dims[3] + z * dims[2] * dims[3] + + y * dims[3] + x; +} + +template +inline HOSTDEVICE void IndexToPoint( + const int index, const Dim& dims, int* batch, int* x, int* y, int* z) { + int n = index; + *x = n % dims[3]; + n /= dims[3]; + *y = n % dims[2]; + n /= dims[2]; + *z = n % dims[1]; + n /= dims[1]; + *batch = n; +} + +inline void GetOutShape(const DDim& x_dims, + const DDim& kernel_dims, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + DDim* out_dims) { + PADDLE_ENFORCE_EQ( + x_dims.size(), + 5, + phi::errors::InvalidArgument("the shape of x should be (N, D, H, W, C)")); + PADDLE_ENFORCE_EQ(kernel_dims.size(), + 5, + phi::errors::InvalidArgument( + "the shape of kernel should be (D, H, W, C, OC)")); + + // infer out shape + (*out_dims)[0] = x_dims[0]; + (*out_dims)[4] = kernel_dims[4]; + for (int i = 1; i < 4; i++) { + (*out_dims)[i] = (x_dims[i] + 2 * paddings[i - 1] - + dilations[i - 1] * (kernel_dims[i - 1] - 1) - 1) / + strides[i - 1] + + 1; + } +} + +template +void Conv3dKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + SparseCooTensor* out, + DenseTensor* rulebook); + +template +SparseCooTensor Conv3d(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + DenseTensor* rulebook) { + DenseTensor indices = phi::Empty(dev_ctx); + DenseTensor values = phi::Empty(dev_ctx); + SparseCooTensor coo(indices, values, x.dims()); + Conv3dKernel( + dev_ctx, x, kernel, paddings, dilations, strides, groups, &coo, rulebook); + return coo; +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/convolution.h b/paddle/phi/kernels/sparse/cpu/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..5803069d927d70947d8bc7c3d6af051d7ea1b81c --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/convolution.h @@ -0,0 +1,181 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { +namespace sparse { + +// such as: kernel(3, 3, 3), kernel_size = 27 +// counter_per_weight: (kernel_size) +// TODO(zhangkaihuo): optimize performance with multithreading +template +void ProductRuleBook(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const DDim& out_dims, + DenseTensor* rulebook, + DenseTensor* counter_per_kernel) { + const auto& kernel_dims = kernel.dims(); + const int64_t non_zero_num = x.nnz(); + const auto& non_zero_indices = x.non_zero_indices(); + const int* indices_ptr = non_zero_indices.data(); + dev_ctx.Alloc(counter_per_kernel, + counter_per_kernel->dtype(), + sizeof(int) * counter_per_kernel->numel()); + int* counter_ptr = counter_per_kernel->data(); + int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; + memset(counter_ptr, 0, kernel_size * sizeof(int)); + + int rulebook_len = 0; + // calc the rulebook_len + const auto& x_dims = x.dims(); + const Dims4D c_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]); + const Dims4D c_kernel_dims(1, kernel_dims[2], kernel_dims[1], kernel_dims[0]); + const Dims4D c_out_dims(out_dims[0], out_dims[3], out_dims[2], out_dims[1]); + const Dims4D c_paddings(1, paddings[2], paddings[1], paddings[0]); + const Dims4D c_strides(1, strides[2], strides[1], strides[0]); + const Dims4D c_dilations(1, dilations[2], dilations[1], dilations[0]); + + auto f_calc_rulebook = [&](int* rulebook_ptr) { + int kernel_index = 0, rulebook_index = 0; + for (int kz = 0; kz < kernel_dims[0]; kz++) { + for (int ky = 0; ky < kernel_dims[1]; ky++) { + for (int kx = 0; kx < kernel_dims[2]; kx++) { + for (int64_t i = 0; i < non_zero_num; i++) { + int batch = indices_ptr[i]; + int in_z = indices_ptr[i + non_zero_num]; + int in_y = indices_ptr[i + 2 * non_zero_num]; + int in_x = indices_ptr[i + 3 * non_zero_num]; + int out_z = (in_z + paddings[0] - kz * dilations[0]) / strides[0]; + int out_y = (in_y + paddings[1] - ky * dilations[1]) / strides[1]; + int out_x = (in_x + paddings[2] - kx * dilations[2]) / strides[2]; + if (Check(c_x_dims, + c_kernel_dims, + c_paddings, + c_dilations, + c_strides, + in_x, + in_y, + in_z, + kx, + ky, + kz)) { + if (rulebook_ptr == nullptr) { + counter_ptr[kernel_index] += 1; + ++rulebook_len; + } else { + rulebook_ptr[rulebook_index] = kernel_index; + rulebook_ptr[rulebook_index + rulebook_len] = i; // in_i + rulebook_ptr[rulebook_index + rulebook_len * 2] = + PointToIndex( + batch, out_x, out_y, out_z, out_dims); // out_index + ++rulebook_index; + } + } + } + ++kernel_index; + } + } + } + }; + + f_calc_rulebook(nullptr); + // alloc the rulebook + rulebook->ResizeAndAllocate({3, rulebook_len}); + dev_ctx.Alloc(rulebook, rulebook->dtype(), rulebook->numel() * sizeof(int)); + int* rulebook_ptr = rulebook->data(); + f_calc_rulebook(rulebook_ptr); +} + +template +void UpdateRulebookAndOutIndex(const Context& dev_ctx, + const SparseCooTensor& x, + const int kernel_size, + const int out_channels, + const DDim& out_dims, + DenseTensor* rulebook, + SparseCooTensor* out) { + std::set out_indexs; + int n = rulebook->dims()[1]; + int* rulebook_ptr = rulebook->data(); + for (int i = 0; i < n; i++) { + out_indexs.insert(rulebook_ptr[i + n * 2]); + } + + int out_non_zero_num = out_indexs.size(); + const int64_t sparse_dim = 4; + DenseTensorMeta indices_meta( + DataType::INT32, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); + DenseTensorMeta values_meta( + x.dtype(), {out_non_zero_num, out_channels}, x.layout()); + phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); + phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); + dev_ctx.Alloc( + &out_indices, out_indices.dtype(), out_indices.numel() * sizeof(int)); + int* out_indices_ptr = out_indices.data(); + int i = 0; + for (auto it = out_indexs.begin(); it != out_indexs.end(); it++, i++) { + const int index = *it; + int batch, x, y, z; + IndexToPoint(index, out_dims, &batch, &x, &y, &z); + out_indices_ptr[i] = batch; + out_indices_ptr[i + out_non_zero_num] = z; + out_indices_ptr[i + out_non_zero_num * 2] = y; + out_indices_ptr[i + out_non_zero_num * 3] = x; + } + for (i = 0; i < n; i++) { + int out_index = rulebook_ptr[i + n * 2]; + rulebook_ptr[i + n * 2] = + std::distance(out_indexs.begin(), out_indexs.find(out_index)); + } + + out->SetMember(out_indices, out_values, out_dims, true); +} + +template +void Gather( + const T* x, const int* indexs, const int n, const int channels, T* out) { + for (int i = 0; i < n; i++) { + int real_i = indexs[i]; + memcpy(out + i * channels, x + real_i * channels, channels * sizeof(T)); + } +} + +template +void Scatter( + const T* x, const int* indexs, const int n, const int channels, T* out) { + for (int i = 0; i < n; i++) { + int real_i = indexs[i]; + for (int j = 0; j < channels; j++) { + out[real_i * channels + j] += x[i * channels + j]; + } + } +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fdf255bd542e66245b44b2ec906dc207ee51a422 --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc @@ -0,0 +1,151 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/convolution_kernel.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/sparse/cpu/convolution.h" + +namespace phi { +namespace sparse { + +/** + * x: (N, D, H, W, C) + * kernel: (D, H, W, C, OC) + * out: (N, D, H, W, OC) +**/ +template +void Conv3dKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& kernel, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + SparseCooTensor* out, + DenseTensor* rulebook) { + // update padding and dilation + // Currently, only support x.layout is NDHWC, groups = 1 + // if x.layout != NDHWC then transpose(x), transpose(weight) + + const auto& x_dims = x.dims(); + const auto& kernel_dims = kernel.dims(); + int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; + DDim out_dims = {1, 1, 1, 1, 1}; + GetOutShape(x_dims, kernel_dims, paddings, dilations, strides, &out_dims); + const int in_channels = kernel_dims[3]; + const int out_channels = kernel_dims[4]; + + // Second algorithm: + // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf + // 1. product rulebook + DenseTensorMeta counter_meta( + DataType::INT32, {kernel_size}, DataLayout::NCHW); + // DenseTensor rulebook = phi::Empty(dev_ctx); + DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); + + ProductRuleBook(dev_ctx, + x, + kernel, + paddings, + dilations, + strides, + out_dims, + rulebook, + &counter_per_kernel); + + UpdateRulebookAndOutIndex( + dev_ctx, x, kernel_size, out_channels, out_dims, rulebook, out); + + int n = rulebook->dims()[1]; + const int* counter_ptr = counter_per_kernel.data(); + + // 2. gather + DenseTensorMeta in_features_meta( + x.dtype(), {n, in_channels}, DataLayout::NHWC); + DenseTensorMeta out_features_meta( + x.dtype(), {n, out_channels}, DataLayout::NHWC); + phi::DenseTensor in_features = + phi::Empty(dev_ctx, std::move(in_features_meta)); + phi::DenseTensor out_features = + phi::Empty(dev_ctx, std::move(out_features_meta)); + dev_ctx.Alloc(&in_features, x.dtype(), sizeof(T) * in_features.numel()); + dev_ctx.Alloc(&out_features, x.dtype(), sizeof(T) * out_features.numel()); + T* in_features_ptr = in_features.data(); + T* out_features_ptr = out_features.data(); + + Gather(x.non_zero_elements().data(), + rulebook->data() + n, + n, + in_channels, + in_features_ptr); + + // 3. call gemm for every werght + auto blas = phi::funcs::GetBlas(dev_ctx); + std::vector offsets(kernel_size + 1); + int offset = 0; + for (int i = 0; i < kernel_size; i++) { + offsets[i] = offset; + offset += counter_ptr[i]; + } + offsets[kernel_size] = offset; + + const T* kernel_ptr = kernel.data(); + for (int i = 0; i < kernel_size; i++) { + if (counter_ptr[i] <= 0) { + continue; + } + + // call gemm: (n, in_channels) * (in_channels, out_channels) + const int M = counter_ptr[i]; + const int K = in_channels; // in_channels + const int N = out_channels; // out_channels + T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; + const T* tmp_kernel_ptr = kernel_ptr + i * K * N; + T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels; + blas.GEMM(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + static_cast(1), + tmp_in_ptr, + tmp_kernel_ptr, + static_cast(0), + tmp_out_ptr); + } + + // 4. scatter + dev_ctx.Alloc(out->mutable_non_zero_elements(), + out->mutable_non_zero_elements()->dtype(), + sizeof(T) * in_features.numel()); + T* out_values_ptr = out->mutable_non_zero_elements()->data(); + memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels); + Scatter(out_features_ptr, + rulebook->data() + n * 2, + n, + out_channels, + out_values_ptr); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL( + sparse_conv3d, CPU, ALL_LAYOUT, phi::sparse::Conv3dKernel, float, double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index 1e2c70a9cf39bf0df738a74b301afcc0fcbd8699..2e741111fb1489aef5bdc51de637b77eec9d28a7 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -86,19 +86,6 @@ __global__ void GetNonZeroElementsAndIndices(const T* dense_data, } } -template -void GetGpuLaunchConfig1D(const Context& dev_ctx, - const int64_t n, - int* grid_size, - int* block_size) { - const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock(); - const int MAX_GRID_DIM = dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; - *block_size = (n >= MAX_BLOCK_DIM) ? MAX_BLOCK_DIM - : (1 << static_cast(std::log2(n))); - *grid_size = n / *block_size; - *grid_size = (*grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : *grid_size; -} - template void DenseToSparseCooKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index b5201e16f548d594af47aa9a4611d35f9cf2ad4f..d96d134a26b08a0208122a7ea9a62ce07c033d51 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -40,6 +40,19 @@ inline const DDim InferDenseDims(const DDim& x_dims, return values_dims; } +template +inline void GetGpuLaunchConfig1D(const Context& dev_ctx, + const int64_t n, + int* grid_size, + int* block_size) { + const int MAX_BLOCK_DIM = dev_ctx.GetMaxThreadsPerBlock(); + const int MAX_GRID_DIM = dev_ctx.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM; + *block_size = (n >= MAX_BLOCK_DIM) ? MAX_BLOCK_DIM + : (1 << static_cast(std::log2(n))); + *grid_size = n / *block_size; + *grid_size = (*grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : *grid_size; +} + template void DenseToSparseCooKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index 35137aa474e9380fcfe9a98e95d8261f40a2eae2..c92e10f8dd74af072bb8836d65898e2fc9a79bcc 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -13,6 +13,7 @@ cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS phi phi_api_utils) cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS phi phi_api_utils) cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS phi phi_api_utils) cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS phi phi_api_utils) +cc_test(test_sparse_conv3d_dev_api SRCS test_sparse_conv3d_dev_api.cc DEPS phi phi_api_utils) cc_test(test_math_function SRCS test_math_function.cc DEPS math_function) if(WITH_GPU) diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..576015143704b86957073bcf3f06b381e4b61592 --- /dev/null +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -0,0 +1,471 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/common/place.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/sparse/convolution_kernel.h" + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace tests { + +std::vector flatten(const std::vector>& in) { + std::vector out; + if (in.size() == 0) return out; + const int cols = in[0].size(); + out.resize(in.size() * cols); + for (uint64_t i = 0; i < in.size(); i++) { + memcpy(&out[i * cols], in[i].data(), cols * sizeof(int)); + } + return out; +} + +template +std::vector cast(const std::vector& in) { + std::vector out(in.size()); + for (uint64_t i = 0; i < in.size(); i++) { + out[i] = static_cast(in[i]); + } + return out; +} + +template +void TestConv3dBase(const std::vector& indices, + const std::vector& features, + const DDim& x_dims, + const std::vector& kernel, + const DDim& kernel_dims, + const std::vector& correct_out_indices, + const std::vector& correct_out_features, + const DDim& correct_out_dims, + const int non_zero_num, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const float diff = 1e-3) { + phi::CPUContext dev_ctx_cpu; + dev_ctx_cpu.SetAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx_cpu.Init(); + + const int in_channels = kernel_dims[3]; + const int out_channels = kernel_dims[4]; + + DenseTensor indices_tensor = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(DataType::INT32, {4, non_zero_num}, DataLayout::NCHW)); + dev_ctx_cpu.Alloc(&indices_tensor, + indices_tensor.dtype(), + sizeof(int) * indices_tensor.numel()); + memcpy( + indices_tensor.data(), indices.data(), indices.size() * sizeof(int)); + DenseTensor features_tensor = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + {non_zero_num, in_channels}, + DataLayout::NHWC)); + dev_ctx_cpu.Alloc(&features_tensor, + features_tensor.dtype(), + features_tensor.numel() * sizeof(T)); + memcpy( + features_tensor.data(), features.data(), features.size() * sizeof(T)); + + SparseCooTensor x_tensor(indices_tensor, features_tensor, x_dims); + + DenseTensor kernel_tensor = phi::Empty( + dev_ctx_cpu, + DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), + kernel_dims, + DataLayout::NHWC)); + dev_ctx_cpu.Alloc( + &kernel_tensor, kernel_tensor.dtype(), kernel_tensor.numel() * sizeof(T)); + memcpy(kernel_tensor.data(), kernel.data(), kernel.size() * sizeof(T)); + + if (!std::is_same::value) { + DenseTensor rulebook = phi::Empty(dev_ctx_cpu); + SparseCooTensor out = sparse::Conv3d(dev_ctx_cpu, + x_tensor, + kernel_tensor, + paddings, + dilations, + strides, + 1, + &rulebook); + + ASSERT_EQ(correct_out_dims.size(), out.dims().size()); + for (int i = 0; i < correct_out_dims.size(); i++) { + ASSERT_EQ(correct_out_dims[i], out.dims()[i]); + } + ASSERT_EQ((int64_t)correct_out_features.size() / out_channels, out.nnz()); + + int cmp_indices = memcmp(correct_out_indices.data(), + out.non_zero_indices().data(), + correct_out_indices.size() * sizeof(int)); + ASSERT_EQ(cmp_indices, 0); + + for (uint64_t i = 0; i < correct_out_features.size(); i++) { + float tmp = std::fabs(static_cast( + correct_out_features[i] - out.non_zero_elements().data()[i])); + ASSERT_LT(tmp, diff); + } + } +} + +void TestConv3d(const std::vector& indices, + const std::vector& features, + const DDim& x_dims, + const std::vector& kernel, + const DDim& kernel_dims, + const std::vector& correct_out_indices, + const std::vector& correct_out_features, + const DDim& correct_out_dims, + const int non_zero_num, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations) { + // test float + TestConv3dBase(indices, + features, + x_dims, + kernel, + kernel_dims, + correct_out_indices, + correct_out_features, + correct_out_dims, + non_zero_num, + paddings, + strides, + dilations); + // test double + TestConv3dBase(indices, + cast(features), + x_dims, + cast(kernel), + kernel_dims, + correct_out_indices, + cast(correct_out_features), + correct_out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + +TEST(DEV_API, sparse_conv3d) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 4, 4, 4, in_channels}; + DDim kernel_dims = {3, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 2, 2, 2, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 4; + std::vector> indices = { + {0, 0, 0, 0}, {0, 2, 0, 2}, {3, 2, 2, 3}, {3, 2, 3, 2}}; + std::vector indices_flatten = flatten(indices); + + std::vector features = {-0.2883, 0.0287, 0.2864, -0.0992}; + // 3*3*3=27 + std::vector kernel = { + 0.4721, 0.2292, 0.9751, 0.8616, 0.5784, 0.9178, 0.8727, 0.1659, 0.4455, + + 0.0189, 0.4646, 0.4472, 0.1991, 0.8968, 0.3717, 0.0051, 0.6963, 0.2690, + + 0.7473, 0.5403, 0.5391, 0.0796, 0.4734, 0.9097, 0.1712, 0.6237, 0.8837}; + + std::vector> out_indices = {{0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 1, 1, 1, 1}, + {0, 0, 1, 1, 0, 0, 1, 1}, + {0, 1, 0, 1, 0, 1, 0, 1}}; + std::vector out_indices_flatten = flatten(out_indices); + + std::vector out_features = { + 0.0254, 0.1455, -0.0615, 0.0862, 0.0077, 0.0200, -0.0160, -0.0433}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + +TEST(DEV_API, sparse_conv3d_batch) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {2, 4, 4, 4, in_channels}; + DDim kernel_dims = {3, 3, 3, in_channels, out_channels}; + DDim out_dims = {2, 2, 2, 2, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 8; + std::vector> indices = {{0, 0, 0, 0, 1, 1, 1, 1}, + {0, 2, 0, 2, 0, 2, 0, 2}, + {3, 2, 2, 3, 3, 2, 2, 3}, + {3, 2, 3, 2, 3, 2, 3, 2}}; + std::vector indices_flatten = flatten(indices); + + std::vector features = { + -0.2883, 0.0287, 0.2864, -0.0992, -0.2883, 0.0287, 0.2864, -0.0992}; + // 3*3*3=27 + std::vector kernel = { + 0.4721, 0.2292, 0.9751, 0.8616, 0.5784, 0.9178, 0.8727, 0.1659, 0.4455, + + 0.0189, 0.4646, 0.4472, 0.1991, 0.8968, 0.3717, 0.0051, 0.6963, 0.2690, + + 0.7473, 0.5403, 0.5391, 0.0796, 0.4734, 0.9097, 0.1712, 0.6237, 0.8837}; + + std::vector> out_indices = { + {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}, + {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}, + {0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1}, + {0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}}; + std::vector out_indices_flatten = flatten(out_indices); + + std::vector out_features = {0.0254, + 0.1455, + -0.0615, + 0.0862, + 0.0077, + 0.0200, + -0.0160, + -0.0433, + 0.0254, + 0.1455, + -0.0615, + 0.0862, + 0.0077, + 0.0200, + -0.0160, + -0.0433}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + +TEST(DEV_API, sparse_conv3d_stride) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 4, 4, 4, in_channels}; + DDim kernel_dims = {3, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 1, 1, 1, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {2, 2, 2}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector> indices = { + {0, 0, 0}, {0, 2, 0}, {3, 2, 2}, {3, 2, 3}}; + std::vector indices_flatten = flatten(indices); + + std::vector features = {-0.28833008, 0.02873230, 0.28637695}; + // 3*3*3=27 + std::vector kernel = { + 0.45043945, 0.47216797, 0.22924805, 0.97509766, 0.86181641, 0.57861328, + 0.91796875, 0.87255859, 0.16589355, 0.44555664, 0.01889038, 0.46459961, + 0.44726562, 0.19909668, 0.89697266, 0.37158203, 0.00513077, 0.69628906, + 0.26904297, 0.74707031, 0.54003906, 0.5390625, 0.07958984, 0.47338867, + 0.90966797, 0.17126465, 0.62353516}; + + std::vector> out_indices = {{0, 0, 0, 0}}; + std::vector out_indices_flatten = flatten(out_indices); + + std::vector out_features = {0.01791}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + +TEST(DEV_API, sparse_conv3d_dilation) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 6, 6, 6, in_channels}; + DDim kernel_dims = {3, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 2, 2, 2, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {2, 2, 2}; + + const int non_zero_num = 3; + std::vector> indices = { + {0, 0, 0}, {2, 3, 3}, {2, 3, 3}, {5, 2, 0}}; + std::vector indices_flatten = flatten(indices); + + std::vector features = {-0.78710938, -0.64746094, 0.98828125}; + // 3*3*3=27 + std::vector kernel = { + 0.20617676, 0.99365234, 0.16760254, 0.30639648, 0.41479492, 0.75732422, + 0.65625, 0.48535156, 0.72167969, 0.56005859, 0.5, 0.3581543, + 0.20324707, 0.88769531, 0.81298828, 0.58398438, 0.30810547, 0.12634277, + 0.70507812, 0.38720703, 0.34814453, 0.02690125, 0.80273438, 0.90625, + 0.2277832, 0.4362793, 0.44482422}; + + std::vector> out_indices = {{0, 0, 0, 1, 0, 1, 1, 0}}; + std::vector out_indices_flatten = flatten(out_indices); + + std::vector out_features = {-0.64014, -0.37402}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + +TEST(DEV_API, sparse_conv3d_padding) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 3, 3, 3, in_channels}; + DDim kernel_dims = {3, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 3, 3, 3, out_channels}; + std::vector paddings = {1, 1, 1}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 1; + std::vector> indices = {{0, 1, 0, 0}}; + std::vector indices_flatten = flatten(indices); + + std::vector features = {-0.79394531}; + // 3*3*3=27 + std::vector kernel = { + 0.34375, 0.22485352, 0.65820312, 0.75048828, 0.21411133, 0.17370605, + 0.85546875, 0.53076172, 0.28833008, 0.71044922, 0.00659943, 0.45922852, + 0.19372559, 0.64599609, 0.78808594, 0.49316406, 0.62646484, 0.40649414, + 0.62744141, 0.5703125, 0.23144531, 0.50048828, 0.31835938, 0.90869141, + 0.38208008, 0.60449219, 0.09075928}; + + std::vector out_indices_flatten = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, + 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + + std::vector out_features = {-0.25269, + -0.39746, + -0.45288, + -0.49805, + -0.5127, + -0.15381, + -0.00524, + -0.56396, + -0.17004, + -0.5957, + -0.17847, + -0.27295}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + +TEST(DEV_API, sparse_conv2d) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 1, 5, 5, in_channels}; + DDim kernel_dims = {1, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 1, 3, 3, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices_flatten = {0, 0, 0, 0, 0, 0, 0, 4, 0, 3, 2, 4}; + + std::vector features = {-0.79394531, -0.3125, -0.55029297}; + // 3*3*3=27 + std::vector kernel = {0.65820312, + 0.75048828, + 0.21411133, + 0.17370605, + 0.85546875, + 0.53076172, + 0.28833008, + 0.71044922, + 0.00659943}; + + std::vector out_indices_flatten = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 2, 2, 1, 2, 0, 1, 2}; + + std::vector out_features = { + -0.17004, -0.71338, -0.00206, -0.22205, -0.09009}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations); +} + +} // namespace tests +} // namespace phi diff --git a/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc index 15c00d385eda9347da8862a8eb88c8f6930000a1..3e2ad0495f3ba85836dc08afa3f4fa4ed0b10afd 100644 --- a/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc @@ -8,7 +8,7 @@ You may obtain a copy of the License at Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF NCHW KIND, either express or implied. +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */