diff --git a/paddle/phi/kernels/sparse/convolution_grad_kernel.h b/paddle/phi/kernels/sparse/convolution_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1a6ac852448a5f4a25248d2a2b6919a301a04874 --- /dev/null +++ b/paddle/phi/kernels/sparse/convolution_grad_kernel.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" + +namespace phi { +namespace sparse { + +template +void Conv3dGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const DenseTensor& kernel, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + DenseTensor* x_grad, + DenseTensor* kernel_grad); + +template +std::vector Conv3dGrad(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const DenseTensor& kernel, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups) { + DenseTensor x_grad = phi::Empty(dev_ctx); + DenseTensor kernel_grad = phi::Empty(dev_ctx); + Conv3dGradKernel(dev_ctx, + x, + rulebook, + kernel, + out_grad, + paddings, + dilations, + strides, + groups, + &x_grad, + &kernel_grad); + std::vector out(2); + out[0] = x_grad; + out[1] = kernel_grad; + return out; +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/convolution.h b/paddle/phi/kernels/sparse/cpu/convolution.h index 5803069d927d70947d8bc7c3d6af051d7ea1b81c..ab2fef5320f716b6bc780ad14b8e2adef44427dd 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution.h +++ b/paddle/phi/kernels/sparse/cpu/convolution.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/sparse/convolution_kernel.h" namespace phi { namespace sparse { diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d4f770ce8713aa84c7f87f0e49bf8468467ffdbf --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -0,0 +1,166 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/sparse/cpu/convolution.h" + +namespace phi { +namespace sparse { + +// rulebook: +//[ +// [kernel_index], +// [in_i], +// [out_i], +//] +// x_grad = out_grad * transpose(kenrel) +// kernel_grad = transpose(x) * out_grad +template +void Conv3dGradKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& rulebook, + const DenseTensor& kernel, + const SparseCooTensor& out_grad, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const int groups, + DenseTensor* x_grad, + DenseTensor* kernel_grad) { + const auto& kernel_dims = kernel.dims(); + const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; + const int in_channels = kernel_dims[3]; + const int out_channels = kernel_dims[4]; + const int* rulebook_ptr = rulebook.data(); + + const int rulebook_len = rulebook.dims()[1]; + + DenseTensorMeta in_features_meta( + x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW); + DenseTensorMeta d_x_features_meta( + x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW); + DenseTensorMeta out_grad_features_meta( + x.dtype(), {rulebook_len, out_channels}, DataLayout::NCHW); + phi::DenseTensor in_features = + phi::Empty(dev_ctx, std::move(in_features_meta)); + phi::DenseTensor d_x_features = + phi::Empty(dev_ctx, std::move(d_x_features_meta)); + phi::DenseTensor out_grad_features = + phi::Empty(dev_ctx, std::move(out_grad_features_meta)); + + dev_ctx.Alloc( + &in_features, in_features.dtype(), sizeof(T) * in_features.numel()); + T* in_features_ptr = in_features.data(); + dev_ctx.Alloc( + &d_x_features, d_x_features.dtype(), sizeof(T) * d_x_features.numel()); + T* d_x_features_ptr = d_x_features.data(); + dev_ctx.Alloc(&out_grad_features, + out_grad_features.dtype(), + sizeof(T) * out_grad_features.numel()); + T* out_grad_features_ptr = out_grad_features.data(); + kernel_grad->Resize(kernel_dims); + dev_ctx.Alloc( + kernel_grad, kernel_grad->dtype(), kernel_grad->numel() * sizeof(T)); + T* d_kernel_ptr = kernel_grad->data(); + + Gather(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + rulebook_len, + in_channels, + in_features_ptr); + Gather(out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len * 2, + rulebook_len, + out_channels, + out_grad_features_ptr); + + auto blas = phi::funcs::GetBlas(dev_ctx); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0); + for (int i = 0; i < rulebook_len; i++) { + counter[rulebook_ptr[i]] += 1; + } + int offset = 0; + for (int i = 0; i < kernel_size; i++) { + offsets[i] = offset; + offset += counter[i]; + } + offsets[kernel_size] = offset; + + const T* kernel_ptr = kernel.data(); + for (int i = 0; i < kernel_size; i++) { + if (counter[i] <= 0) { + continue; + } + + const int M = counter[i]; + const int K = in_channels; + const int N = out_channels; + T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; + T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels; + const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels; + T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * out_channels; + T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels; + + // call gemm: d_kernel = transpose(x) * out_grad + // (in_channels, n) * (n, out_channels) + blas.GEMM(CblasTrans, + CblasNoTrans, + M, + N, + K, + static_cast(1), + tmp_in_ptr, + tmp_out_grad_ptr, + static_cast(0), + tmp_d_kernel_ptr); + + // call gemm: d_x = out_grad * transpose(kernel) + // (n, out_channels) * (out_channels, in_channels) + blas.GEMM(CblasNoTrans, + CblasTrans, + M, + K, + N, + static_cast(1), + tmp_out_grad_ptr, + tmp_kernel_ptr, + static_cast(0), + tmp_d_x_ptr); + } + + // 4. scatter + x_grad->Resize(x.non_zero_elements().dims()); + dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel()); + T* x_grad_values_ptr = x_grad->data(); + memset(x_grad_values_ptr, 0, sizeof(T) * x_grad->numel()); + Scatter(d_x_features_ptr, + rulebook.data() + rulebook_len, + rulebook_len, + in_channels, + x_grad_values_ptr); +} + +} // namespace sparse +} // namespace phi + +PD_REGISTER_KERNEL(sparse_conv_grad, + CPU, + ALL_LAYOUT, + phi::sparse::Conv3dGradKernel, + float, + double) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); + kernel->InputAt(3).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index 576015143704b86957073bcf3f06b381e4b61592..00b2a256a9504595dd8ac4ffd492564557f2d783 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/common/place.h" #include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h" #include "paddle/fluid/memory/allocation/allocator_facade.h" @@ -59,7 +60,10 @@ void TestConv3dBase(const std::vector& indices, const std::vector& paddings, const std::vector& strides, const std::vector& dilations, - const float diff = 1e-3) { + const float diff = 1e-3, + const bool backward = false, + const std::vector features_grad = {}, + const std::vector kernel_grad = {}) { phi::CPUContext dev_ctx_cpu; dev_ctx_cpu.SetAllocator( paddle::memory::allocation::AllocatorFacade::Instance() @@ -122,10 +126,29 @@ void TestConv3dBase(const std::vector& indices, correct_out_indices.size() * sizeof(int)); ASSERT_EQ(cmp_indices, 0); - for (uint64_t i = 0; i < correct_out_features.size(); i++) { - float tmp = std::fabs(static_cast( - correct_out_features[i] - out.non_zero_elements().data()[i])); - ASSERT_LT(tmp, diff); + auto f_verify = [&](const T* real_data, + const std::vector& correct_data) { + for (uint64_t i = 0; i < correct_data.size(); i++) { + float tmp = + std::fabs(static_cast(correct_data[i] - real_data[i])); + ASSERT_LT(tmp, diff); + } + }; + + f_verify(out.non_zero_elements().data(), correct_out_features); + + if (backward) { + std::vector grads = sparse::Conv3dGrad(dev_ctx_cpu, + x_tensor, + rulebook, + kernel_tensor, + out, + paddings, + dilations, + strides, + 1); + f_verify(grads[0].data(), features_grad); + f_verify(grads[1].data(), kernel_grad); } } } @@ -141,7 +164,11 @@ void TestConv3d(const std::vector& indices, const int non_zero_num, const std::vector& paddings, const std::vector& strides, - const std::vector& dilations) { + const std::vector& dilations, + const float diff = 1e-3, + const bool backward = false, + const std::vector features_grad = {}, + const std::vector kernel_grad = {}) { // test float TestConv3dBase(indices, features, @@ -154,7 +181,11 @@ void TestConv3d(const std::vector& indices, non_zero_num, paddings, strides, - dilations); + dilations, + diff, + backward, + features_grad, + kernel_grad); // test double TestConv3dBase(indices, cast(features), @@ -167,7 +198,11 @@ void TestConv3d(const std::vector& indices, non_zero_num, paddings, strides, - dilations); + dilations, + diff, + backward, + cast(features_grad), + cast(kernel_grad)); } TEST(DEV_API, sparse_conv3d) { @@ -467,5 +502,66 @@ TEST(DEV_API, sparse_conv2d) { dilations); } +TEST(DEV_API, sparse_conv3d_backward) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 4, 4, 4, in_channels}; + DDim kernel_dims = {3, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 2, 2, 2, out_channels}; + std::vector paddings = {0, 0, 0}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 2; + std::vector indices_flatten = {0, 0, 0, 2, 3, 2, 3, 2}; + + std::vector features = {-0.28833008, 0.0287323}; + // 3*3*3=27 + std::vector kernel = { + 0.64306641, 0.45043945, 0.47216797, 0.22924805, 0.97509766, 0.86181641, + 0.57861328, 0.91796875, 0.87255859, 0.16589355, 0.44555664, 0.01889038, + 0.46459961, 0.44726562, 0.19909668, 0.89697266, 0.37158203, 0.00513077, + 0.69628906, 0.26904297, 0.74707031, 0.54003906, 0.5390625, 0.07958984, + 0.47338867, 0.90966797, 0.17126465}; + + std::vector out_indices_flatten = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 1, 0, 1, 0, 1, 0, 1}; + + std::vector out_features = {4.9200e-03, + 2.6140e-02, + 2.2900e-03, + -2.3596e-01, + 1.5000e-04, + 1.0670e-02, + 5.7200e-03, + 1.2850e-02}; + + std::vector features_grad = {-0.20593, -0.09149}; + std::vector kernel_grad = { + 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, + 0.000e+00, 0.000e+00, 6.805e-02, 0.000e+00, 0.000e+00, 0.000e+00, + 0.000e+00, 3.700e-04, 1.600e-04, 0.000e+00, 3.100e-04, 0.000e+00, + 0.000e+00, 0.000e+00, 0.000e+00, 0.000e+00, -6.780e-03, 7.000e-05, + 0.000e+00, 7.500e-04, 1.400e-04}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations, + 1e-3, + true, + features_grad, + kernel_grad); +} + } // namespace tests } // namespace phi