diff --git a/paddle/phi/kernels/sparse/convolution_grad_kernel.h b/paddle/phi/kernels/sparse/convolution_grad_kernel.h index f4265d303d730708f7e6db684accff538f604174..42bde442e1e063a355d2eabb2963865a2ff45bcb 100644 --- a/paddle/phi/kernels/sparse/convolution_grad_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_grad_kernel.h @@ -32,6 +32,7 @@ void Conv3dGradKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const int groups, + const bool subm, DenseTensor* x_grad, DenseTensor* kernel_grad); @@ -44,7 +45,8 @@ std::vector Conv3dGrad(const Context& dev_ctx, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, - const int groups) { + const int groups, + const bool subm) { DenseTensor x_grad = phi::Empty(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout())); DenseTensor kernel_grad = phi::Empty( @@ -59,6 +61,7 @@ std::vector Conv3dGrad(const Context& dev_ctx, dilations, strides, groups, + subm, &x_grad, &kernel_grad); std::vector out(2); diff --git a/paddle/phi/kernels/sparse/convolution_kernel.h b/paddle/phi/kernels/sparse/convolution_kernel.h index cfb451afdcbcb007bbce468b3e582db5057796d5..778600a2285de63a481ccd0094cb07a3206b48d9 100644 --- a/paddle/phi/kernels/sparse/convolution_kernel.h +++ b/paddle/phi/kernels/sparse/convolution_kernel.h @@ -125,6 +125,7 @@ void Conv3dKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const int groups, + const bool subm, SparseCooTensor* out, DenseTensor* rulebook); @@ -136,14 +137,23 @@ SparseCooTensor Conv3d(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const int groups, + const bool subm, DenseTensor* rulebook) { DenseTensor indices = phi::Empty( dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); DenseTensor values = phi::Empty(dev_ctx, DenseTensorMeta(x.dtype(), {1}, x.layout())); SparseCooTensor coo(indices, values, x.dims()); - Conv3dKernel( - dev_ctx, x, kernel, paddings, dilations, strides, groups, &coo, rulebook); + Conv3dKernel(dev_ctx, + x, + kernel, + paddings, + dilations, + strides, + groups, + subm, + &coo, + rulebook); return coo; } diff --git a/paddle/phi/kernels/sparse/cpu/convolution.h b/paddle/phi/kernels/sparse/cpu/convolution.h index bcb6db407883ff5a0192699d72c360d1b41200ed..a5a946dce7912f706b1b4c149c89331ce9a3744f 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution.h +++ b/paddle/phi/kernels/sparse/cpu/convolution.h @@ -39,6 +39,7 @@ void ProductRuleBook(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const DDim& out_dims, + const bool subm, DenseTensor* rulebook, DenseTensor* counter_per_kernel) { const auto& kernel_dims = kernel.dims(); @@ -59,11 +60,24 @@ void ProductRuleBook(const Context& dev_ctx, const Dims4D c_strides(1, strides[2], strides[1], strides[0]); const Dims4D c_dilations(1, dilations[2], dilations[1], dilations[0]); + std::set hash_in; + if (subm) { + for (int i = 0; i < non_zero_num; i++) { + int batch = indices_ptr[i]; + int in_z = indices_ptr[i + non_zero_num]; + int in_y = indices_ptr[i + 2 * non_zero_num]; + int in_x = indices_ptr[i + 3 * non_zero_num]; + int index = PointToIndex(batch, in_x, in_y, in_z, x_dims); + hash_in.insert(index); + } + } + auto f_calc_rulebook = [&](int* rulebook_ptr) { int kernel_index = 0, rulebook_index = 0; for (int kz = 0; kz < kernel_dims[0]; kz++) { for (int ky = 0; ky < kernel_dims[1]; ky++) { for (int kx = 0; kx < kernel_dims[2]; kx++) { + ++kernel_index; for (int64_t i = 0; i < non_zero_num; i++) { int batch = indices_ptr[i]; int in_z = indices_ptr[i + non_zero_num]; @@ -83,11 +97,19 @@ void ProductRuleBook(const Context& dev_ctx, kx, ky, kz)) { + if (subm) { + int out_index = + PointToIndex(batch, out_x, out_y, out_z, out_dims); + if (hash_in.find(out_index) == hash_in.end()) { + continue; + } + } + if (rulebook_ptr == nullptr) { - counter_ptr[kernel_index] += 1; + counter_ptr[kernel_index - 1] += 1; ++rulebook_len; } else { - rulebook_ptr[rulebook_index] = kernel_index; + rulebook_ptr[rulebook_index] = kernel_index - 1; rulebook_ptr[rulebook_index + rulebook_len] = i; // in_i rulebook_ptr[rulebook_index + rulebook_len * 2] = PointToIndex( @@ -96,7 +118,6 @@ void ProductRuleBook(const Context& dev_ctx, } } } - ++kernel_index; } } } diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index 6ee265a329673ad456e0dd491a9544143016aff5..bb414faef6743126cf2e25b49ae17689f0a6048f 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -38,6 +38,7 @@ void Conv3dGradKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const int groups, + const bool subm, DenseTensor* x_grad, DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); @@ -70,32 +71,72 @@ void Conv3dGradKernel(const Context& dev_ctx, T* d_kernel_ptr = kernel_grad->data(); memset(d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel()); - Gather(x.non_zero_elements().data(), - rulebook_ptr + rulebook_len, - rulebook_len, - in_channels, - in_features_ptr); - Gather(out_grad.non_zero_elements().data(), - rulebook_ptr + rulebook_len * 2, - rulebook_len, - out_channels, - out_grad_features_ptr); - + int half_kernel_size = kernel_size / 2; auto blas = phi::funcs::GetBlas(dev_ctx); + x_grad->Resize(x.non_zero_elements().dims()); + dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel()); + T* x_grad_values_ptr = x_grad->data(); + memset(x_grad_values_ptr, 0, sizeof(T) * x_grad->numel()); + memset(d_x_features_ptr, 0, sizeof(T) * d_x_features.numel()); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0); for (int i = 0; i < rulebook_len; i++) { counter[rulebook_ptr[i]] += 1; } - int offset = 0; + int offset = 0, max_count = 0; for (int i = 0; i < kernel_size; i++) { offsets[i] = offset; offset += counter[i]; + if (i < half_kernel_size) { + max_count = std::max(max_count, counter[i]); + } } offsets[kernel_size] = offset; + if (subm) { + blas.GEMM(CblasTrans, + CblasNoTrans, + x.non_zero_elements().dims()[1], + out_grad.non_zero_elements().dims()[1], + x.non_zero_elements().dims()[0], + static_cast(1), + x.non_zero_elements().data(), + out_grad.non_zero_elements().data(), + static_cast(0), + d_kernel_ptr + half_kernel_size * in_channels * out_channels); + + // call gemm: d_x = out_grad * transpose(kernel) + // (n, out_channels) * (out_channels, in_channels) + T* x_grad_ptr = x_grad->data(); + blas.GEMM(CblasNoTrans, + CblasTrans, + out_grad.non_zero_elements().dims()[0], + in_channels, + out_grad.non_zero_elements().dims()[1], + static_cast(1), + out_grad.non_zero_elements().data(), + kernel.data() + half_kernel_size * in_channels * out_channels, + static_cast(0), + x_grad_ptr); + if (max_count == 0) { + return; + } + } + + Gather(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + rulebook_len, + in_channels, + in_features_ptr); + Gather(out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len * 2, + rulebook_len, + out_channels, + out_grad_features_ptr); + const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { - if (counter[i] <= 0) { + if (counter[i] <= 0 || (subm && i == half_kernel_size)) { continue; } @@ -136,10 +177,6 @@ void Conv3dGradKernel(const Context& dev_ctx, } // 4. scatter - x_grad->Resize(x.non_zero_elements().dims()); - dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel()); - T* x_grad_values_ptr = x_grad->data(); - memset(x_grad_values_ptr, 0, sizeof(T) * x_grad->numel()); Scatter(d_x_features_ptr, rulebook.data() + rulebook_len, rulebook_len, diff --git a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc index 64ef068e03ab53e4338a0f5ba3d5f160a4e66dd5..f65e1cf579a9344b6e46ff693b1ce05600adc6a0 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc @@ -35,6 +35,7 @@ void Conv3dKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const int groups, + const bool subm, SparseCooTensor* out, DenseTensor* rulebook) { // update padding and dilation @@ -63,6 +64,7 @@ void Conv3dKernel(const Context& dev_ctx, dilations, strides, out_dims, + subm, rulebook, &counter_per_kernel); diff --git a/paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu b/paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..5f6d24093a4d703d86550ab1847a082823f8af6b --- /dev/null +++ b/paddle/phi/kernels/sparse/cpu/submanifold_convolution_kernel.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/sparse/submanifold_convolution_kernel.h" + +namespace phi { +namespace sparse {} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/convolution.cu.h b/paddle/phi/kernels/sparse/gpu/convolution.cu.h index 03a6aaa68943d7ea8d0ab7c02561a407166e43d5..8826fd7cf87e0a7a4a8251b4da823f18190f4a38 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution.cu.h +++ b/paddle/phi/kernels/sparse/gpu/convolution.cu.h @@ -71,7 +71,8 @@ __global__ void ScatterKernel(const T* input, const int non_zero_num, const int rulebook_len, const int channels, - T* out) { + T* out, + const bool subm = false) { int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) { int indices_i = i / channels; @@ -82,6 +83,9 @@ __global__ void ScatterKernel(const T* input, : unique_value[indices_i + 1]; // max(end-start) = kernel_size T sum = static_cast(0); + if (subm) { + sum = out[indices_i * channels + channels_i]; + } for (int j = start; j < end; j++) { const int out_feature_i = out_index[j]; sum += input[out_feature_i * channels + channels_i]; diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index 861f18f36e632c88147b38e8a9203384050293bb..a307ab0f54613a91deee6215b5c389ca0a44d6e8 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -43,6 +43,7 @@ void Conv3dGradKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const int groups, + const bool subm, DenseTensor* x_grad, DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); @@ -69,37 +70,18 @@ void Conv3dGradKernel(const Context& dev_ctx, T* in_features_ptr = in_features.data(); T* d_x_features_ptr = d_x_features.data(); T* out_grad_features_ptr = out_grad_features.data(); - kernel_grad->Resize(kernel_dims); - dev_ctx.Alloc( - kernel_grad, kernel_grad->dtype(), kernel_grad->numel() * sizeof(T)); + kernel_grad->ResizeAndAllocate(kernel_dims); T* d_kernel_ptr = kernel_grad->data(); phi::funcs::SetConstant set_zero; set_zero(dev_ctx, kernel_grad, static_cast(0.0f)); - auto config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, rulebook_len * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + rulebook_len, - in_features_ptr, - rulebook_len, - in_channels); - - config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, rulebook_len * out_channels, 1); - GatherKernel<<>>( - out_grad.non_zero_elements().data(), - rulebook_ptr + rulebook_len * 2, - out_grad_features_ptr, - rulebook_len, - out_channels); - + int half_kernel_size = kernel_size / 2; auto blas = phi::funcs::GetBlas(dev_ctx); + x_grad->ResizeAndAllocate(x.non_zero_elements().dims()); + T* x_grad_values_ptr = x_grad->data(); + set_zero(dev_ctx, x_grad, static_cast(0.0f)); + set_zero(dev_ctx, &d_x_features, static_cast(0.0f)); + std::vector offsets(kernel_size + 1), counter(kernel_size, 0), h_counter(rulebook_len, 0); phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], @@ -117,16 +99,72 @@ void Conv3dGradKernel(const Context& dev_ctx, for (int i = 0; i < rulebook_len; i++) { counter[h_counter[i]] += 1; } - int offset = 0; + int offset = 0, max_count = 0; for (int i = 0; i < kernel_size; i++) { offsets[i] = offset; offset += counter[i]; + if (i < half_kernel_size) { + max_count = std::max(max_count, counter[i]); + } } offsets[kernel_size] = offset; + if (subm) { + blas.GEMM(CblasTrans, + CblasNoTrans, + x.non_zero_elements().dims()[1], + out_grad.non_zero_elements().dims()[1], + x.non_zero_elements().dims()[0], + static_cast(1), + x.non_zero_elements().data(), + out_grad.non_zero_elements().data(), + static_cast(0), + d_kernel_ptr + half_kernel_size * in_channels * out_channels); + + // call gemm: d_x = out_grad * transpose(kernel) + // (n, out_channels) * (out_channels, in_channels) + T* x_grad_ptr = x_grad->data(); + blas.GEMM(CblasNoTrans, + CblasTrans, + out_grad.non_zero_elements().dims()[0], + in_channels, + out_grad.non_zero_elements().dims()[1], + static_cast(1), + out_grad.non_zero_elements().data(), + kernel.data() + half_kernel_size * in_channels * out_channels, + static_cast(0), + x_grad_ptr); + if (max_count == 0) { + return; + } + } + + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, rulebook_len * in_channels, 1); + GatherKernel<<>>(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + in_features_ptr, + rulebook_len, + in_channels); + + config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, rulebook_len * out_channels, 1); + GatherKernel<<>>( + out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len * 2, + out_grad_features_ptr, + rulebook_len, + out_channels); + const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { - if (counter[i] <= 0) { + if (counter[i] <= 0 || (subm && i == half_kernel_size)) { continue; } @@ -167,19 +205,11 @@ void Conv3dGradKernel(const Context& dev_ctx, } // 4. scatter - x_grad->Resize(x.non_zero_elements().dims()); - dev_ctx.Alloc(x_grad, x_grad->dtype(), sizeof(T) * x_grad->numel()); - T* x_grad_values_ptr = x_grad->data(); - - DenseTensor out_index = phi::Empty( - dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); - DenseTensor unique_key = phi::Empty( - dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); - DenseTensor unique_value = phi::Empty( - dev_ctx, - DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); + x_grad->ResizeAndAllocate(x.non_zero_elements().dims()); + DenseTensorMeta index_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW); + DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); + DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); + DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); SortedAndUniqueIndex(dev_ctx, rulebook_ptr + rulebook_len, @@ -200,7 +230,8 @@ void Conv3dGradKernel(const Context& dev_ctx, x.nnz(), rulebook_len, in_channels, - x_grad_values_ptr); + x_grad_values_ptr, + subm); } } // namespace sparse diff --git a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu index 4a533d9d1d5e8f6090976ab63a86b01c1d518c8d..94186600f1e2994f9b464bb8d81e9dbf891a4ae9 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/primitive/compute_primitives.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h" @@ -32,6 +33,34 @@ limitations under the License. */ namespace phi { namespace sparse { +__global__ void SetFlagAndUpdateCounterKernel(const int* indexs, + const int n, + const int rulebook_len, + const int kernel_size, + int* rulebook_ptr, + int* counter_ptr) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ int cache_count[]; // kernel_size + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + cache_count[i] = 0; + } + __syncthreads(); + + for (int i = tid; i < n; i += gridDim.x * blockDim.x) { + int index = indexs[i]; + int kernel_index = rulebook_ptr[index]; + rulebook_ptr[index + rulebook_len] = -1; + rulebook_ptr[index + 2 * rulebook_len] = -1; + rulebook_ptr[index] = -1; + atomicAdd(&cache_count[kernel_index], 1); + } + __syncthreads(); + + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + atomicSub(&counter_ptr[i], cache_count[i]); + } +} + /** * @brief: update the out index and indices * unique_keys: save the index of the output feature list @@ -95,8 +124,10 @@ __global__ void ProductRuleBookKernel(const int* x_indices, const Dims4D paddings, const Dims4D dilations, const Dims4D strides, + const bool subm, int* rulebook, - int* counter) { + int* counter, + int* in_indexs) { int tid = threadIdx.x + blockIdx.x * blockDim.x; extern __shared__ int counter_buf[]; // kernel_size const int kernel_size = kernel_dims[3] * kernel_dims[2] * kernel_dims[1]; @@ -108,13 +139,16 @@ __global__ void ProductRuleBookKernel(const int* x_indices, for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { int kernel_index = 0; + int batch = x_indices[i]; + int in_z = x_indices[i + non_zero_num]; + int in_y = x_indices[i + 2 * non_zero_num]; + int in_x = x_indices[i + 3 * non_zero_num]; + if (subm) { + in_indexs[i] = PointToIndex(batch, in_x, in_y, in_z, x_dims); + } for (int kz = 0; kz < kernel_dims[1]; kz++) { for (int ky = 0; ky < kernel_dims[2]; ky++) { for (int kx = 0; kx < kernel_dims[3]; kx++) { - int batch = x_indices[i]; - int in_z = x_indices[i + non_zero_num]; - int in_y = x_indices[i + 2 * non_zero_num]; - int in_x = x_indices[i + 3 * non_zero_num]; int in_i = -1, out_index = -1, kernel_i = -1; if (Check(x_dims, kernel_dims, @@ -182,6 +216,7 @@ int ProductRuleBook(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const DDim& out_dims, + const bool subm, DenseTensor* rulebook, DenseTensor* counter_per_kernel, DenseTensor* offsets_per_kernel, @@ -195,13 +230,14 @@ int ProductRuleBook(const Context& dev_ctx, const int64_t non_zero_num = x.nnz(); const auto& non_zero_indices = x.non_zero_indices(); const int* indices_ptr = non_zero_indices.data(); + DenseTensor in_indexs = phi::Empty( + dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW)); int* counter_ptr = counter_per_kernel->data(); int* offsets_ptr = offsets_per_kernel->data(); int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int rulebook_rows = 3; const int rulebook_cols = kernel_size * non_zero_num; rulebook->ResizeAndAllocate({rulebook_rows, rulebook_cols}); - dev_ctx.Alloc(rulebook, rulebook->dtype(), sizeof(int) * rulebook->numel()); int* rulebook_ptr = rulebook->data(); const auto x_dims = x.dims(); @@ -229,8 +265,10 @@ int ProductRuleBook(const Context& dev_ctx, d_paddings, d_dilations, d_strides, + subm, rulebook_ptr, - counter_ptr); + counter_ptr, + in_indexs.data()); // 2. remove -1 #ifdef PADDLE_WITH_HIP @@ -242,6 +280,144 @@ int ProductRuleBook(const Context& dev_ctx, rulebook_ptr + rulebook_rows * rulebook_cols, -1); + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + rulebook_ptr, last, rulebook_ptr + 3 * kernel_size * non_zero_num - 1); + int rulebook_len = 0; + phi::backends::gpu::GpuMemcpyAsync( + &rulebook_len, + rulebook_ptr + 3 * kernel_size * non_zero_num - 1, + sizeof(int), +#ifdef PADDLE_WITH_HIP + hipMemcpyDeviceToHost, +#else + cudaMemcpyDeviceToHost, +#endif + dev_ctx.stream()); + rulebook_len /= 3; + dev_ctx.Wait(); + + if (subm) { + // At present, hashtable is not used to map the input and output indexes. + // At present, the intermediate output index is generated by normal + // convolution, + // and then the intermediate output index is subtracted from the input index + // to obain the rulebook. + // get difference + int32_t* A_key_ptr = rulebook_ptr + 2 * rulebook_len; + int32_t* B_key_ptr = in_indexs.data(); + DenseTensor A_val = phi::Empty( + dev_ctx, + DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); + DenseTensor B_val = phi::Empty( + dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW)); + phi::IndexKernel>( + dev_ctx, &A_val, kps::IdentityFunctor()); + phi::IndexKernel>( + dev_ctx, &B_val, kps::IdentityFunctor()); + DenseTensor key_result = phi::Empty( + dev_ctx, + DenseTensorMeta(DataType::INT32, {rulebook_len + 1}, DataLayout::NCHW)); + DenseTensor val_result = phi::Empty( + dev_ctx, + DenseTensorMeta(DataType::INT32, {rulebook_len}, DataLayout::NCHW)); + +#ifdef PADDLE_WITH_HIP + thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), +#endif + counter_ptr, + counter_ptr + kernel_size, + offsets_ptr); + std::vector offsets(kernel_size, 0); + // TODO(zhangkaihuo): used unified memcpy interface + phi::backends::gpu::GpuMemcpyAsync(offsets.data(), + offsets_ptr, + kernel_size * sizeof(int), +#ifdef PADDLE_WITH_HIP + hipMemcpyDeviceToHost, +#else + cudaMemcpyDeviceToHost, +#endif + dev_ctx.stream()); + dev_ctx.Wait(); + + thrust::pair end; + // Because set_diff does not support duplicate data, set_diff is performed + // separately for each segment of data. + // TODO(zhangkaihuo): Using hashtable here may get better performance, + // further tests ared needed. + for (int i = 0; i < kernel_size; i++) { + int start = offsets[i]; + int stop = i == kernel_size - 1 ? rulebook_len : offsets[i + 1]; + int* key_result_start = (i == 0 ? key_result.data() : end.first); + int* val_result_start = i == 0 ? val_result.data() : end.second; + end = +#ifdef PADDLE_WITH_HIP + thrust::set_difference_by_key(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::set_difference_by_key(thrust::cuda::par.on(dev_ctx.stream()), +#endif + A_key_ptr + start, + A_key_ptr + stop, + B_key_ptr, + B_key_ptr + x.nnz(), + A_val.data() + start, + B_val.data(), + key_result_start, + val_result_start); + } + + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + key_result.data(), + end.first, + key_result.data() + rulebook_len); + int len = 0; + phi::backends::gpu::GpuMemcpyAsync(&len, + key_result.data() + rulebook_len, + sizeof(int), +#ifdef PADDLE_WITH_HIP + hipMemcpyDeviceToHost, +#else + cudaMemcpyDeviceToHost, +#endif + dev_ctx.stream()); + dev_ctx.Wait(); + // set the diff value = -1, and update counter + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len, 1); + SetFlagAndUpdateCounterKernel<<>>(val_result.data(), + len, + rulebook_len, + kernel_size, + rulebook_ptr, + counter_ptr); +// remove -1 +#ifdef PADDLE_WITH_HIP + int* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), +#else + int* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), +#endif + rulebook_ptr, + rulebook_ptr + 3 * rulebook_len, + -1); + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + rulebook_ptr, last, key_result.data() + rulebook_len); + phi::backends::gpu::GpuMemcpyAsync(&rulebook_len, + key_result.data() + rulebook_len, + sizeof(int), +#ifdef PADDLE_WITH_HIP + hipMemcpyDeviceToHost, +#else + cudaMemcpyDeviceToHost, +#endif + dev_ctx.stream()); + dev_ctx.Wait(); + rulebook_len /= 3; + } + #ifdef PADDLE_WITH_HIP thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), #else @@ -274,23 +450,14 @@ int ProductRuleBook(const Context& dev_ctx, cudaMemcpyDeviceToHost, dev_ctx.stream()); #endif - dev_ctx.Wait(); - int rulebook_len = - (*h_counter)[kernel_size - 1] + (*h_offsets)[kernel_size - 1]; rulebook->Resize({rulebook_rows, rulebook_len}); // 3. sorted or merge the out index out_index->ResizeAndAllocate({rulebook_len}); unique_value->ResizeAndAllocate({rulebook_len}); unique_key->ResizeAndAllocate({rulebook_len}); - dev_ctx.Alloc( - out_index, out_index->dtype(), sizeof(int) * out_index->numel()); int* out_index_ptr = out_index->data(); - dev_ctx.Alloc( - unique_value, unique_value->dtype(), sizeof(int) * unique_value->numel()); int* unique_value_ptr = unique_value->data(); - dev_ctx.Alloc( - unique_key, unique_key->dtype(), sizeof(int) * unique_key->numel()); int* unique_key_ptr = unique_key->data(); int* new_end = SortedAndUniqueIndex(dev_ctx, @@ -364,6 +531,7 @@ void Conv3dKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, const int groups, + const bool subm, SparseCooTensor* out, DenseTensor* rulebook) { // update padding and dilation @@ -389,20 +557,28 @@ void Conv3dKernel(const Context& dev_ctx, DataType::INT32, {kernel_size}, DataLayout::NCHW); DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta)); - DenseTensor out_index = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); - DenseTensor unique_key = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); - DenseTensor unique_value = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {1}, DataLayout::NCHW)); + DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW); + DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); + DenseTensor unique_key = phi::Empty(dev_ctx, std::move(index_meta)); + DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); + + std::vector subm_paddings(paddings), subm_strides(strides); + if (subm) { + auto kernel_dims = kernel.dims(); + for (int i = 0; i < paddings.size(); i++) { + subm_paddings[i] = kernel_dims[i] / 2; + subm_strides[i] = 1; + } + } int n = ProductRuleBook(dev_ctx, x, kernel, - paddings, + subm_paddings, dilations, - strides, + subm_strides, out_dims, + subm, rulebook, &counter_per_kernel, &offsets_per_kernel, @@ -428,6 +604,8 @@ void Conv3dKernel(const Context& dev_ctx, phi::Empty(dev_ctx, std::move(out_features_meta)); T* in_features_ptr = in_features.data(); T* out_features_ptr = out_features.data(); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, &out_features, static_cast(0.0f)); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1); diff --git a/paddle/phi/tests/api/test_sparse_conv_api.cc b/paddle/phi/tests/api/test_sparse_conv_api.cc index 76cb01d8a8b98b070d89ba4a3887275f00e228f3..7c4aa164259071667e3d90994759c05454f407ff 100644 --- a/paddle/phi/tests/api/test_sparse_conv_api.cc +++ b/paddle/phi/tests/api/test_sparse_conv_api.cc @@ -78,7 +78,7 @@ void TestConv3dBase(const std::vector& indices, if (!std::is_same::value) { auto outs = paddle::experimental::sparse::conv3d( - x, weight, paddings, dilations, strides, 1); + x, weight, paddings, dilations, strides, 1, false); auto out = std::dynamic_pointer_cast( std::get<0>(outs).impl()); diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index c1a8b853b32e38cb32e2081727e102164ffddb08..37a69a176c6e1ded81a8449da3c571442bd94e78 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -64,7 +64,8 @@ void TestConv3dBase(const std::vector& indices, const float diff = 1e-3, const bool backward = false, const std::vector features_grad = {}, - const std::vector kernel_grad = {}) { + const std::vector kernel_grad = {}, + const bool subm = false) { phi::CPUContext dev_ctx_cpu; dev_ctx_cpu.SetAllocator( paddle::memory::allocation::AllocatorFacade::Instance() @@ -114,6 +115,7 @@ void TestConv3dBase(const std::vector& indices, dilations, strides, 1, + subm, &rulebook); ASSERT_EQ(correct_out_dims.size(), out.dims().size()); @@ -138,7 +140,8 @@ void TestConv3dBase(const std::vector& indices, paddings, dilations, strides, - 1); + 1, + subm); f_verify(grads[0].data(), features_grad); f_verify(grads[1].data(), kernel_grad); } @@ -191,6 +194,7 @@ void TestConv3dBase(const std::vector& indices, dilations, strides, 1, + subm, &d_rulebook); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); @@ -235,7 +239,8 @@ void TestConv3dBase(const std::vector& indices, paddings, dilations, strides, - 1); + 1, + subm); DenseTensor h_features_grad = phi::Empty( dev_ctx_cpu, DenseTensorMeta(grads[0].dtype(), grads[0].dims(), grads[0].layout())); @@ -266,7 +271,8 @@ void TestConv3d(const std::vector& indices, const float diff = 1e-3, const bool backward = false, const std::vector features_grad = {}, - const std::vector kernel_grad = {}) { + const std::vector kernel_grad = {}, + const bool subm = false) { // test float TestConv3dBase(indices, features, @@ -283,7 +289,8 @@ void TestConv3d(const std::vector& indices, diff, backward, features_grad, - kernel_grad); + kernel_grad, + subm); // test double TestConv3dBase(indices, cast(features), @@ -300,7 +307,8 @@ void TestConv3d(const std::vector& indices, diff, backward, cast(features_grad), - cast(kernel_grad)); + cast(kernel_grad), + subm); } TEST(DEV_API, sparse_conv3d) { @@ -661,5 +669,101 @@ TEST(DEV_API, sparse_conv3d_backward) { kernel_grad); } +TEST(DEV_API, sparse_conv2d_subm) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 1, 4, 5, in_channels}; + DDim kernel_dims = {1, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 1, 4, 5, out_channels}; + std::vector paddings = {0, 1, 1}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 4; + std::vector indices_flatten = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 3, 3, 2, 2, 3}; + + std::vector features = {0.8854, 0.6505, -0.1999, 0.3583}; + // 3*3*3=27 + std::vector kernel = { + 0.9364, 0.9460, 0.6564, 0.7999, 0.2013, 0.3812, 0.5474, 0.1016, 0.3368}; + + std::vector out_indices_flatten = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 3, 3, 2, 2, 3}; + + std::vector out_features = {0.1782, 0.2313, 0.7117, 0.5214}; + + std::vector features_grad = {0.0359, 1.2080, 0.5838, 0.4541}; + std::vector kernel_grad = { + 0.3391, 0.4630, 0.0000, -0.1042, 0.3528, 0.2550, 0.0000, -0.0462, 0.0829}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations, + 1e-3, + true, + features_grad, + kernel_grad, + true); +} + +TEST(DEV_API, sparse_conv3d_subm) { + const int in_channels = 1; + const int out_channels = 1; + DDim x_dims = {1, 4, 4, 5, in_channels}; + DDim kernel_dims = {3, 3, 3, in_channels, out_channels}; + DDim out_dims = {1, 4, 4, 5, out_channels}; + std::vector paddings = {1, 1, 1}; + std::vector strides = {1, 1, 1}; + std::vector dilations = {1, 1, 1}; + + const int non_zero_num = 3; + std::vector indices_flatten = {0, 0, 0, 1, 3, 3, 2, 0, 2, 0, 3, 1}; + + std::vector features = {-0.9578, 0.1572, 0.1036}; + // 3*3*3=27 + std::vector kernel = { + 0.1367, 0.4534, 0.2138, 0.8264, 0.7534, 0.3270, 0.2880, 0.1562, 0.7770, + 0.6902, 0.1981, 0.1369, 0.6582, 0.7582, 0.5640, 0.8894, 0.7350, 0.1845, + 0.6892, 0.3654, 0.6076, 0.0326, 0.8412, 0.5289, 0.9824, 0.8235, 0.9802}; + + std::vector out_indices_flatten = {0, 0, 0, 1, 3, 3, 2, 0, 2, 0, 3, 1}; + + std::vector out_features = {-0.7262, 0.1192, 0.0785}; + + std::vector features_grad = {-0.5506, 0.0904, 0.0595}; + std::vector kernel_grad = { + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.7224, 0.0000, 0.0000, 0.0000, 0.0000, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}; + + TestConv3d(indices_flatten, + features, + x_dims, + kernel, + kernel_dims, + out_indices_flatten, + out_features, + out_dims, + non_zero_num, + paddings, + strides, + dilations, + 1e-3, + true, + features_grad, + kernel_grad, + true); +} + } // namespace tests } // namespace phi diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml index 2f233a2df357df478c96bed2c40e28e8e972f660..9c859022e8ad1d910ecf44426b2850496e793cee 100644 --- a/python/paddle/utils/code_gen/sparse_api.yaml +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -1,5 +1,5 @@ - api : conv3d - args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) + args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) kernel : func : sparse_conv3d diff --git a/python/paddle/utils/code_gen/sparse_bw_api.yaml b/python/paddle/utils/code_gen/sparse_bw_api.yaml index 8c9f02ebb3198670fffa4ddd80d14798b6fe78a9..6532f103cbf86288ffc739656440dc378d48eb2d 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api.yaml +++ b/python/paddle/utils/code_gen/sparse_bw_api.yaml @@ -1,6 +1,6 @@ - backward_api : conv3d_grad - forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) - args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups) + forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) + args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) output : Tensor(x_grad@DenseTensor), Tensor(kernel_grad@DenseTensor) kernel : - func : sparse_conv_grad + func : sparse_conv3d_grad