diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index 216685f0f719184c40dd7abe321e08efb665ca62..34337db558c8af94f28274ab9d5ee6ffe127b537 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -139,16 +139,16 @@ void Conv3dGradCPUKernel(const CPUContext& dev_ctx, T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels; const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels; - T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * out_channels; + T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels; T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels; // call gemm: d_kernel = transpose(x) * out_grad // (in_channels, n) * (n, out_channels) blas.GEMM(CblasTrans, CblasNoTrans, - M, - N, K, + N, + M, static_cast(1), tmp_in_ptr, tmp_out_grad_ptr, diff --git a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc index c920f3c46128737425614ced21331215fd244e08..d133464ab853c681c52bb98faa215899429033c4 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc @@ -50,16 +50,19 @@ void Conv3dCPUKernel(const CPUContext& dev_ctx, kernel_sizes[i] = kernel_dims[i]; } - phi::funcs::sparse::GetOutShape( - x_dims, kernel_sizes, paddings, dilations, strides, &out_dims); - const int in_channels = kernel_dims[3]; - const int out_channels = kernel_dims[4]; - std::vector subm_paddings(paddings), subm_strides(strides); if (subm) { + // the out shape of subm_conv is same as input shape + // reset the padding=kernel_size/2 and strides=1 phi::funcs::sparse::ResetSubmKernelSizeAndStrides( kernel.dims(), &subm_paddings, &subm_strides); } + + phi::funcs::sparse::GetOutShape( + x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims); + const int in_channels = kernel_dims[3]; + const int out_channels = kernel_dims[4]; + // Second algorithm: // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // 1. product rulebook diff --git a/paddle/phi/kernels/sparse/gpu/convolution.cu.h b/paddle/phi/kernels/sparse/gpu/convolution.cu.h index 1bceb767b670857fabc2577b161085355af43131..2396a5975de4e85c932b43b40a51cf6b03427aa5 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution.cu.h +++ b/paddle/phi/kernels/sparse/gpu/convolution.cu.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -22,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/primitive/compute_primitives.h" @@ -143,35 +145,6 @@ inline IntT* SortedAndUniqueIndex(const Context& dev_ctx, return new_end.first; } -template -__global__ void SetFlagAndUpdateCounterKernel(const int* indexs, - const int n, - const int rulebook_len, - const int kernel_size, - T* rulebook_ptr, - int* counter_ptr) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - extern __shared__ int cache_count[]; // kernel_size - for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { - cache_count[i] = 0; - } - __syncthreads(); - - for (int i = tid; i < n; i += gridDim.x * blockDim.x) { - int index = indexs[i]; - T kernel_index = rulebook_ptr[index]; - rulebook_ptr[index + rulebook_len] = -1; - rulebook_ptr[index + 2 * rulebook_len] = -1; - rulebook_ptr[index] = -1; - atomicAdd(&cache_count[kernel_index], 1); - } - __syncthreads(); - - for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { - atomicSub(&counter_ptr[i], cache_count[i]); - } -} - /** * @brief: update the out index and indices * unique_keys: save the index of the output feature list @@ -221,6 +194,42 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) { } } +template +__global__ void UpdateOutIndexAndCounterAfterLowerBound( + const IntT* x_indexs, + const IntT* bound_out, + const int rulebook_len, + const int kernel_size, + const int64_t non_zero_num, + IntT* rulebook_ptr, + IntT* out_indexs, + int* counter_ptr) { + extern __shared__ int cache_count[]; + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + cache_count[i] = 0; + } + __syncthreads(); + + CUDA_KERNEL_LOOP_TYPE(i, rulebook_len, int64_t) { + int j = bound_out[i]; + if (j >= 0 && j < non_zero_num && out_indexs[i] == x_indexs[j]) { + out_indexs[i] = j; + } else { + // mask this position will be remove + int kernel_index = rulebook_ptr[i]; + rulebook_ptr[i + rulebook_len] = -1; + rulebook_ptr[i + 2 * rulebook_len] = -1; + rulebook_ptr[i] = -1; + atomicAdd(&cache_count[kernel_index], 1); + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + atomicSub(&counter_ptr[i], cache_count[i]); + } +} + /** * @brief product rulebook * for input_i in x_indices: @@ -338,7 +347,6 @@ int ProductRuleBook(const Context& dev_ctx, SparseCooTensor* out, std::vector* h_counter, std::vector* h_offsets) { - // TODO(zhangkaihuo): use PD_VISIT_INTEGRAL_TYPES for secondary dispatch auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); const int64_t non_zero_num = x.nnz(); const auto& non_zero_indices = x.non_zero_indices(); @@ -362,7 +370,6 @@ int ProductRuleBook(const Context& dev_ctx, Dims4D d_paddings(1, paddings[2], paddings[1], paddings[0]); Dims4D d_strides(1, strides[2], strides[1], strides[0]); Dims4D d_dilations(1, dilations[2], dilations[1], dilations[0]); - // 1. product rule book phi::funcs::SetConstant set_zero; set_zero(dev_ctx, counter_per_kernel, 0); @@ -408,8 +415,8 @@ int ProductRuleBook(const Context& dev_ctx, cudaMemcpyDeviceToHost, #endif dev_ctx.stream()); - rulebook_len /= 3; dev_ctx.Wait(); + rulebook_len /= 3; if (subm) { // At present, hashtable is not used to map the input and output indexes. @@ -417,96 +424,41 @@ int ProductRuleBook(const Context& dev_ctx, // convolution, // and then the intermediate output index is subtracted from the input index // to obain the rulebook. - // get difference - IntT* A_key_ptr = rulebook_ptr + 2 * rulebook_len; - IntT* B_key_ptr = in_indexs.data(); - DenseTensorMeta val_meta(DataType::INT32, {rulebook_len}, DataLayout::NCHW); - DenseTensor A_val = phi::Empty(dev_ctx, std::move(val_meta)); - DenseTensor B_val = phi::Empty( - dev_ctx, DenseTensorMeta(DataType::INT32, {x.nnz()}, DataLayout::NCHW)); - phi::IndexKernel>( - dev_ctx, &A_val, kps::IdentityFunctor()); - phi::IndexKernel>( - dev_ctx, &B_val, kps::IdentityFunctor()); - DenseTensor key_result = phi::Empty( - dev_ctx, - DenseTensorMeta(indices_dtype, {rulebook_len + 1}, DataLayout::NCHW)); - DenseTensor val_result = phi::Empty(dev_ctx, std::move(val_meta)); - -#ifdef PADDLE_WITH_HIP - thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), -#else - thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), -#endif - counter_ptr, - counter_ptr + kernel_size, - offsets_ptr); - std::vector offsets(kernel_size, 0); - // TODO(zhangkaihuo): used unified memcpy interface - phi::backends::gpu::GpuMemcpyAsync(offsets.data(), - offsets_ptr, - kernel_size * sizeof(int), -#ifdef PADDLE_WITH_HIP - hipMemcpyDeviceToHost, -#else - cudaMemcpyDeviceToHost, -#endif - dev_ctx.stream()); - dev_ctx.Wait(); - - thrust::pair end; - // Because set_diff does not support duplicate data, set_diff is performed - // separately for each segment of data. - // TODO(zhangkaihuo): Using hashtable here may get better performance, - // further tests ared needed. - for (int i = 0; i < kernel_size; i++) { - int start = offsets[i]; - int stop = i == kernel_size - 1 ? rulebook_len : offsets[i + 1]; - IntT* key_result_start = (i == 0 ? key_result.data() : end.first); - int* val_result_start = i == 0 ? val_result.data() : end.second; - end = -#ifdef PADDLE_WITH_HIP - thrust::set_difference_by_key(thrust::hip::par.on(dev_ctx.stream()), -#else - thrust::set_difference_by_key(thrust::cuda::par.on(dev_ctx.stream()), -#endif - A_key_ptr + start, - A_key_ptr + stop, - B_key_ptr, - B_key_ptr + x.nnz(), - A_val.data() + start, - B_val.data(), - key_result_start, - val_result_start); - } - DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( - key_result.data(), - end.first, - key_result.data() + rulebook_len); - IntT len = 0; - phi::backends::gpu::GpuMemcpyAsync(&len, - key_result.data() + rulebook_len, - sizeof(IntT), + // call lower_bound to get the real index of out_index + const IntT* in_indexs_ptr = in_indexs.data(); + IntT* out_indexs_ptr = rulebook_ptr + 2 * rulebook_len; + DenseTensor bound = phi::Empty( + dev_ctx, + DenseTensorMeta( + indices_dtype, {static_cast(rulebook_len)}, DataLayout::NCHW)); + IntT* bound_ptr = bound.data(); #ifdef PADDLE_WITH_HIP - hipMemcpyDeviceToHost, + thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()), #else - cudaMemcpyDeviceToHost, + thrust::lower_bound(thrust::cuda::par.on(dev_ctx.stream()), #endif - dev_ctx.stream()); - dev_ctx.Wait(); - // set the diff value = -1, and update counter - auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len, 1); - SetFlagAndUpdateCounterKernel<<>>( - val_result.data(), - len, + in_indexs_ptr, + in_indexs_ptr + in_indexs.numel(), + out_indexs_ptr, + out_indexs_ptr + rulebook_len, + bound_ptr); + + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + + UpdateOutIndexAndCounterAfterLowerBound<<>>( + in_indexs_ptr, + bound.data(), rulebook_len, kernel_size, + x.nnz(), rulebook_ptr, + out_indexs_ptr, counter_ptr); + // remove -1 #ifdef PADDLE_WITH_HIP IntT* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), @@ -517,9 +469,9 @@ int ProductRuleBook(const Context& dev_ctx, rulebook_ptr + 3 * rulebook_len, -1); DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( - rulebook_ptr, last, key_result.data() + rulebook_len); + rulebook_ptr, last, bound_ptr); phi::backends::gpu::GpuMemcpyAsync(&rulebook_len, - key_result.data() + rulebook_len, + bound_ptr, sizeof(IntT), #ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, @@ -540,102 +492,111 @@ int ProductRuleBook(const Context& dev_ctx, counter_ptr + kernel_size, offsets_ptr); -#ifdef PADDLE_WITH_HIP phi::backends::gpu::GpuMemcpyAsync(&(*h_counter)[0], counter_ptr, kernel_size * sizeof(int), +#ifdef PADDLE_WITH_HIP hipMemcpyDeviceToHost, - dev_ctx.stream()); - phi::backends::gpu::GpuMemcpyAsync(&(*h_offsets)[0], - offsets_ptr, - kernel_size * sizeof(int), - hipMemcpyDeviceToHost, - dev_ctx.stream()); #else - phi::backends::gpu::GpuMemcpyAsync(&(*h_counter)[0], - counter_ptr, - kernel_size * sizeof(int), cudaMemcpyDeviceToHost, +#endif dev_ctx.stream()); + phi::backends::gpu::GpuMemcpyAsync(&(*h_offsets)[0], offsets_ptr, kernel_size * sizeof(int), +#ifdef PADDLE_WITH_HIP + hipMemcpyDeviceToHost, +#else cudaMemcpyDeviceToHost, - dev_ctx.stream()); #endif + dev_ctx.stream()); + rulebook->Resize({rulebook_rows, static_cast(rulebook_len)}); - // 3. sorted or merge the out index - out_index->ResizeAndAllocate({static_cast(rulebook_len)}); - unique_value->ResizeAndAllocate({static_cast(rulebook_len)}); - DenseTensor unique_key = phi::Empty( - dev_ctx, - DenseTensorMeta(paddle::experimental::CppTypeToDataType::Type(), - {static_cast(rulebook_len)}, - DataLayout::NCHW)); - int* out_index_ptr = out_index->data(); - int* unique_value_ptr = unique_value->data(); - IntT* unique_key_ptr = unique_key.data(); - - IntT* new_end = - SortedAndUniqueIndex(dev_ctx, - rulebook_ptr + 2 * rulebook_len, - rulebook_len, - out_index, - &unique_key, - unique_value); - // thrust::distance doesn't support stream parameters - // const int out_non_zero_num = thrust::distance(unique_key_ptr, - // new_end.first); - DistanceKernel<<<1, 1>>>( - unique_key_ptr, - new_end, - rulebook_ptr + rulebook_rows * rulebook_cols - 1); - IntT out_non_zero_num = 0; + if (!subm) { + // 3. sorted or merge the out index + out_index->ResizeAndAllocate({static_cast(rulebook_len)}); + unique_value->ResizeAndAllocate({static_cast(rulebook_len)}); + DenseTensor unique_key = phi::Empty( + dev_ctx, + DenseTensorMeta( + indices_dtype, {static_cast(rulebook_len)}, DataLayout::NCHW)); + int* out_index_ptr = out_index->data(); + int* unique_value_ptr = unique_value->data(); + IntT* unique_key_ptr = unique_key.data(); + + IntT* new_end = + SortedAndUniqueIndex(dev_ctx, + rulebook_ptr + 2 * rulebook_len, + rulebook_len, + out_index, + &unique_key, + unique_value); + // thrust::distance doesn't support stream parameters + // const int out_non_zero_num = thrust::distance(unique_key_ptr, + // new_end.first); + DistanceKernel<<<1, 1, 0, dev_ctx.stream()>>>( + unique_key_ptr, + new_end, + rulebook_ptr + rulebook_rows * rulebook_cols - 1); + IntT out_non_zero_num = 0; #ifdef PADDLE_WITH_HIP - phi::backends::gpu::GpuMemcpyAsync( - &out_non_zero_num, - rulebook_ptr + rulebook_rows * rulebook_cols - 1, - sizeof(IntT), - hipMemcpyDeviceToHost, - dev_ctx.stream()); + phi::backends::gpu::GpuMemcpyAsync( + &out_non_zero_num, + rulebook_ptr + rulebook_rows * rulebook_cols - 1, + sizeof(IntT), + hipMemcpyDeviceToHost, + dev_ctx.stream()); #else - phi::backends::gpu::GpuMemcpyAsync( - &out_non_zero_num, - rulebook_ptr + rulebook_rows * rulebook_cols - 1, - sizeof(IntT), - cudaMemcpyDeviceToHost, - dev_ctx.stream()); + phi::backends::gpu::GpuMemcpyAsync( + &out_non_zero_num, + rulebook_ptr + rulebook_rows * rulebook_cols - 1, + sizeof(IntT), + cudaMemcpyDeviceToHost, + dev_ctx.stream()); #endif - dev_ctx.Wait(); + dev_ctx.Wait(); - // 5. update out_indices and rulebook by unique_value_ptr - const int64_t sparse_dim = 4; - DenseTensorMeta indices_meta( - indices_dtype, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); - DenseTensorMeta values_meta(x.dtype(), - {out_non_zero_num, kernel_sizes[4]}, - x.non_zero_elements().layout()); - phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); - phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); - - IntT* out_indices_ptr = out_indices.data(); - - config = - phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_non_zero_num, 1); - UpdateIndexKernel<<>>( - unique_key_ptr, - unique_value_ptr, - out_index_ptr, - out_non_zero_num, - rulebook_len, - d_out_dims, - out_indices_ptr, - rulebook_ptr + 2 * rulebook_len); - out->SetMember(out_indices, out_values, out_dims, true); + // 5. update out_indices and rulebook by unique_value_ptr + const int64_t sparse_dim = 4; + DenseTensorMeta indices_meta( + indices_dtype, {sparse_dim, out_non_zero_num}, DataLayout::NCHW); + DenseTensorMeta values_meta(x.dtype(), + {out_non_zero_num, kernel_sizes[4]}, + x.non_zero_elements().layout()); + phi::DenseTensor out_indices = phi::Empty(dev_ctx, std::move(indices_meta)); + phi::DenseTensor out_values = phi::Empty(dev_ctx, std::move(values_meta)); + + IntT* out_indices_ptr = out_indices.data(); + + config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_non_zero_num, 1); + UpdateIndexKernel<<>>( + unique_key_ptr, + unique_value_ptr, + out_index_ptr, + out_non_zero_num, + rulebook_len, + d_out_dims, + out_indices_ptr, + rulebook_ptr + 2 * rulebook_len); + out->SetMember(out_indices, out_values, out_dims, true); + } else { + DenseTensor out_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor out_values = + phi::Empty(dev_ctx, + DenseTensorMeta(x.dtype(), + {x.nnz(), kernel_sizes[4]}, + x.non_zero_elements().layout())); + phi::Copy( + dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), false, &out_indices); + out->SetMember(out_indices, out_values, out_dims, true); + } return rulebook_len; } diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index 6c37f759923c33e96a3e164ec2ac0a704d670ee3..ed9579fcd5b672e37c6aa6cbbb1a5c5835a342fa 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -171,16 +171,16 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx, T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels; const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels; - T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * out_channels; + T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels; T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels; // call gemm: d_kernel = transpose(x) * out_grad // (in_channels, n) * (n, out_channels) blas.GEMM(CblasTrans, CblasNoTrans, - M, - N, K, + N, + M, static_cast(1), tmp_in_ptr, tmp_out_grad_ptr, diff --git a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu index 83f19ce5785df4ba11513e9fe2d0220505ca0f6d..93da65dc0f7d8c630d1348b6ed1c31c7372973f6 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" @@ -45,8 +46,17 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx, for (int i = 0; i < kernel_dims.size(); i++) { kernel_sizes[i] = kernel_dims[i]; } + + std::vector subm_paddings(paddings), subm_strides(strides); + if (subm) { + // the out shape of subm_conv is same as input shape + // reset the padding=kernel_size/2 and strides=1 + phi::funcs::sparse::ResetSubmKernelSizeAndStrides( + kernel.dims(), &subm_paddings, &subm_strides); + } + phi::funcs::sparse::GetOutShape( - x_dims, kernel_sizes, paddings, dilations, strides, &out_dims); + x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims); const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; std::vector offsets(kernel_size + 1), h_counter(kernel_size); @@ -64,11 +74,6 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx, DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); - std::vector subm_paddings(paddings), subm_strides(strides); - if (subm) { - phi::funcs::sparse::ResetSubmKernelSizeAndStrides( - kernel.dims(), &subm_paddings, &subm_strides); - } int n = ProductRuleBook(dev_ctx, x, kernel_sizes, @@ -147,18 +152,34 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx, } // 4. scatter - config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, out->nnz() * out_channels, 1); - ScatterKernel<<>>(out_features_ptr, - unique_value.data(), - out_index.data(), - out->nnz(), - n, - out_channels, - out_values_ptr); + if (subm) { + set_zero(dev_ctx, out_values, static_cast(0.0f)); + config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * out_channels, 1); + phi::funcs::ScatterCUDAKernel<<>>( + out_features_ptr, + rulebook_ptr + 2 * n, + out_values_ptr, + n, + out_channels, + false); + } else { + config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, out->nnz() * out_channels, 1); + ScatterKernel<<>>(out_features_ptr, + unique_value.data(), + out_index.data(), + out->nnz(), + n, + out_channels, + out_values_ptr); + } } /** * x: (N, D, H, W, C) diff --git a/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py b/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py index d5a61423e9c440c8f2bb80a06a57676caf368d19..42f628c8fb1fd93271cc4bb2e7c80f7c1569a11d 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py @@ -40,14 +40,76 @@ class TestSparseConv(unittest.TestCase): correct_out_values = [[4], [10]] sparse_input = core.eager.sparse_coo_tensor(indices, values, dense_shape, False) - out = _C_ops.final_state_sparse_conv3d(sparse_input, dense_kernel, - paddings, dilations, strides, - 1, False) + out = paddle.sparse.functional.conv3d( + sparse_input, + dense_kernel, + bias=None, + stride=strides, + padding=paddings, + dilation=dilations, + groups=1, + data_format="NDHWC") out.backward(out) - #At present, only backward can be verified to work normally - #TODO(zhangkaihuo): compare the result with dense conv - print(sparse_input.grad.values()) assert np.array_equal(correct_out_values, out.values().numpy()) + def test_subm_conv3d(self): + with _test_eager_guard(): + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + sparse_x = paddle.sparse.sparse_coo_tensor( + indices, values, dense_shape, stop_gradient=True) + weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32') + y = paddle.sparse.functional.subm_conv3d(sparse_x, weight) + assert np.array_equal(sparse_x.indices().numpy(), + y.indices().numpy()) + + def test_Conv3D(self): + with _test_eager_guard(): + #(4, non_zero_num), 4-D:(N, D, H, W) + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + #(non_zero_num, C) + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + correct_out_values = [[4], [10]] + sparse_input = paddle.sparse.sparse_coo_tensor(indices, values, + dense_shape, False) + + sparse_conv3d = paddle.sparse.Conv3D( + 1, 1, (1, 3, 3), data_format='NDHWC') + sparse_out = sparse_conv3d(sparse_input) + #test errors + with self.assertRaises(ValueError): + #Currently, only support data_format='NDHWC' + conv3d = paddle.sparse.SubmConv3D( + 1, 1, (1, 3, 3), data_format='NCDHW') + + def test_SubmConv3D(self): + with _test_eager_guard(): + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + correct_out_values = [[4], [10]] + sparse_input = paddle.sparse.sparse_coo_tensor(indices, values, + dense_shape, False) + + subm_conv3d = paddle.sparse.SubmConv3D( + 1, 1, (1, 3, 3), data_format='NDHWC') + # test extra_repr + print(subm_conv3d.extra_repr()) + + sparse_out = subm_conv3d(sparse_input) + # the output shape of subm_conv is same as input shape + assert np.array_equal(indices, sparse_out.indices().numpy()) -#TODO: Add more test case + #test errors + with self.assertRaises(ValueError): + #Currently, only support data_format='NDHWC' + conv3d = paddle.sparse.SubmConv3D( + 1, 1, (1, 3, 3), data_format='NCDHW') diff --git a/python/paddle/sparse/__init__.py b/python/paddle/sparse/__init__.py index aff9625469ef2ab1350ba04a6775da37eb12cce3..5e716d69379ed4275191e9f43b1947a6b3ea47f3 100644 --- a/python/paddle/sparse/__init__.py +++ b/python/paddle/sparse/__init__.py @@ -15,5 +15,9 @@ from .creation import sparse_coo_tensor from .creation import sparse_csr_tensor from .layer.activation import ReLU +from .layer.conv import Conv3D +from .layer.conv import SubmConv3D -__all__ = ['sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU'] +__all__ = [ + 'sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU', 'Conv3D', 'SubmConv3D' +] diff --git a/python/paddle/sparse/functional/__init__.py b/python/paddle/sparse/functional/__init__.py index f4c5b33a5a7eaa99c2158dfaf2d49cf4df912099..93c3ccda4a6136eedee9fb79b9a9fe4c8a86c7b2 100644 --- a/python/paddle/sparse/functional/__init__.py +++ b/python/paddle/sparse/functional/__init__.py @@ -13,5 +13,7 @@ # limitations under the License. from .activation import relu # noqa: F401 +from .conv import conv3d # noqa: F401 +from .conv import subm_conv3d # noqa: F401 -__all__ = ['relu'] +__all__ = ['relu', 'conv3d', 'subm_conv3d'] diff --git a/python/paddle/sparse/functional/conv.py b/python/paddle/sparse/functional/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c0e5c914ccb0a3779375a33efd815f3d977741 --- /dev/null +++ b/python/paddle/sparse/functional/conv.py @@ -0,0 +1,294 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [] + +from paddle import _C_ops, in_dynamic_mode +from ...fluid.layers.utils import convert_to_list +from paddle.nn.functional.conv import _update_padding_nd + + +def _conv3d(x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + subm=False, + data_format="NDHWC", + name=None): + assert in_dynamic_mode(), "Currently, only support dynamic mode" + assert bias == None, "Currently, sparse_conv3d does not support bias" + assert groups == 1, "Currently, only support groups=1" + + dims = 3 + + # Currently, only support 'NDHWC' + if data_format not in ["NDHWC"]: + raise ValueError("Attr(data_format) should be 'NDHWC'. Received " + "Attr(data_format): {}.".format(data_format)) + if len(x.shape) != 5: + raise ValueError( + "Input x should be 5D tensor, but received x with the shape of {}". + format(x.shape)) + + channel_last = (data_format == "NDHWC") + channel_dim = -1 if channel_last else 1 + if len(x.shape) != 5: + raise ValueError( + "Input x should be 5D tensor, but received x with the shape of {}". + format(x.shape)) + num_channels = x.shape[channel_dim] + if num_channels < 0: + raise ValueError( + "The channel dimension of the input({}) should be defined. " + "Received: {}.".format(x.shape, num_channels)) + + padding, padding_algorithm = _update_padding_nd(padding, channel_last, dims) + stride = convert_to_list(stride, dims, 'stride') + dilation = convert_to_list(dilation, dims, 'dilation') + op_type = "conv3d" + + return _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation, + stride, groups, subm) + + +def conv3d(x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + data_format="NDHWC", + name=None): + r""" + + The sparse convolution3d functional calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input(Input) and + Output(Output) are multidimensional SparseCooTensors with a shape of + :math:`[N, D, H, W, C]` . Where N is batch size, C is the number of + channels, D is the depth of the feature, H is the height of the feature, + and W is the width of the feature. If bias attribution is provided, + bias is added to the output of the convolution. + + For each input :math:`X`, the equation is: + + .. math:: + + Out = \sigma (W \ast X + b) + + In the above equation: + + * :math:`X`: Input value, a tensor with NCDHW or NDHWC format. + * :math:`W`: Filter value, a tensor with MCDHW format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 1-D tensor with shape [M]. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Example: + + - Input: + + Input shape: :math:`(N, D_{in}, H_{in}, W_{in}, C_{in})` + + Filter shape: :math:`(D_f, H_f, W_f, C_{in}, C_{out})` + + - Output: + Output shape: :math:`(N, D_{out}, H_{out}, W_{out}, C_{out})` + + Where + + .. math:: + + D_{out}&= \\frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{strides[0]} + 1 \\\\ + H_{out}&= \\frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{strides[1]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1 + + Args: + x (Tensor): The input is 5-D SparseCooTensor with shape [N, D, H, W, C], the data + type of input is float16 or float32 or float64. + weight (Tensor): The convolution kernel, a Tensor with shape [kD, kH, kW, C/g, M], + where M is the number of filters(output channels), g is the number of groups, + kD, kH, kW are the filter's depth, height and width respectively. + bias (Tensor, optional): The bias, a Tensor of shape [M, ], currently, only support bias is None. + stride (int|list|tuple): The stride size. It means the stride in convolution. If stride is a + list/tuple, it must contain three integers, (stride_depth, stride_height, stride_width). + Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1. + padding (string|int|list|tuple): The padding size. It means the number of zero-paddings + on both sides for each dimension. If `padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If padding size is a tuple or list, + it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `"NCDHW"`, `padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NDHWC"`, `padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Default: padding = 0. + dilation (int|list|tuple): The dilation size. It means the spacing between the kernel points. + If dilation is a list/tuple, it must contain three integers, (dilation_depth, dilation_height, + dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation. + Default: dilation = 1. + groups (int): The groups number of the Conv3D Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: groups=1. Currently, only support groups=1. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`. + The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of: + `[batch_size, input_depth, input_height, input_width, input_channels]`. + name(str|None): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + A SparseCooTensor representing the conv3d, whose data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + with _test_eager_guard(): + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True) + weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32') + y = paddle.sparse.functional.conv3d(sparse_x, weight) + print(y.shape) + # (1, 1, 1, 2, 1) + """ + return _conv3d(x, weight, bias, stride, padding, dilation, groups, False, + data_format, name) + + +def subm_conv3d(x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + data_format="NDHWC", + name=None): + r""" + + The sparse submanifold convolution3d functional calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input(Input) and + Output(Output) are multidimensional SparseCooTensors with a shape of + :math:`[N, D, H, W, C]` . Where N is batch size, C is the number of + channels, D is the depth of the feature, H is the height of the feature, + and W is the width of the feature. If bias attribution is provided, + bias is added to the output of the convolution. + + For each input :math:`X`, the equation is: + + .. math:: + + Out = W \ast X + b + + In the above equation: + + * :math:`X`: Input value, a tensor with NCDHW or NDHWC format. + * :math:`W`: Filter value, a tensor with DHWCM format. + * :math:`\\ast`: Submanifold Convolution operation, refer to the paper: https://arxiv.org/abs/1706.01307. + * :math:`b`: Bias value, a 1-D tensor with shape [M]. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Example: + + - Input: + + Input shape: :math:`(N, D_{in}, H_{in}, W_{in}, C_{in})` + + Filter shape: :math:`(D_f, H_f, W_f, C_{in}, C_{out})` + + - Output: + Output shape: :math:`(N, D_{out}, H_{out}, W_{out}, C_{out})` + + Where + + .. math:: + + D_{out}&= \\frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{strides[0]} + 1 \\\\ + H_{out}&= \\frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{strides[1]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{strides[2]} + 1 + + Args: + x (Tensor): The input is 5-D SparseCooTensor with shape [N, D, H, W, C], the data + type of input is float16 or float32 or float64. + weight (Tensor): The convolution kernel, a Tensor with shape [kD, kH, kW, C/g, M], + where M is the number of filters(output channels), g is the number of groups, + kD, kH, kW are the filter's depth, height and width respectively. + bias (Tensor, optional): The bias, a Tensor of shape [M, ], currently, only support bias is None. + stride (int|list|tuple): The stride size. It means the stride in convolution. If stride is a + list/tuple, it must contain three integers, (stride_depth, stride_height, stride_width). + Otherwise, stride_depth = stride_height = stride_width = stride. Default: stride = 1. + padding (string|int|list|tuple): The padding size. It means the number of zero-paddings + on both sides for each dimension. If `padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If padding size is a tuple or list, + it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `"NCDHW"`, `padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NHWC"`, `padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Default: padding = 0. + dilation (int|list|tuple): The dilation size. It means the spacing between the kernel points. + If dilation is a list/tuple, it must contain three integers, (dilation_depth, dilation_height, + dilation_width). Otherwise, dilation_depth = dilation_height = dilation_width = dilation. + Default: dilation = 1. + groups (int): The groups number of the Conv3D Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Currently, only support groups=1. + data_format (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`. + The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of: + `[batch_size, input_depth, input_height, input_width, input_channels]`. + name(str|None): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + A SparseCooTensor representing the conv3d, whose data type is + the same with input. + + Examples: + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + with _test_eager_guard(): + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True) + weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32') + y = paddle.sparse.functional.subm_conv3d(sparse_x, weight) + print(y.shape) + #(1, 1, 3, 4, 1) + """ + return _conv3d(x, weight, bias, stride, padding, dilation, groups, True, + data_format, name) diff --git a/python/paddle/sparse/layer/__init__.py b/python/paddle/sparse/layer/__init__.py index 66abce260b6f7dc0b98df0c10a1dafd988b2ad5c..a0f9d068e677c183e9c004285d78e1941555318b 100644 --- a/python/paddle/sparse/layer/__init__.py +++ b/python/paddle/sparse/layer/__init__.py @@ -13,5 +13,7 @@ # limitations under the License. from .activation import ReLU +from .conv import Conv3D +from .conv import SubmConv3D __all__ = [] diff --git a/python/paddle/sparse/layer/conv.py b/python/paddle/sparse/layer/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ff421a06a1344888c4c6440d50b3df03388b8607 --- /dev/null +++ b/python/paddle/sparse/layer/conv.py @@ -0,0 +1,380 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from .. import functional as F +from paddle.nn import Layer +from paddle.nn.initializer import Normal +from ..functional.conv import _update_padding_nd +from ...fluid.layers import utils + +__all__ = [] + + +class _Conv3D(Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + subm=False, + padding_mode='zeros', + weight_attr=None, + bias_attr=None, + data_format="NDHWC"): + super(_Conv3D, self).__init__() + assert weight_attr is not False, "weight_attr should not be False in Conv." + self._param_attr = weight_attr + self._bias_attr = bias_attr + self._groups = groups + self._in_channels = in_channels + self._out_channels = out_channels + self._data_format = data_format + self._subm = subm + + assert padding_mode == 'zeros', "Currently, only support padding_mode='zeros'" + assert groups == 1, "Currently, only support groups=1" + + valid_format = {'NDHWC'} + if data_format not in valid_format: + raise ValueError( + "data_format must be one of {}, but got data_format='{}'". + format(valid_format, data_format)) + + channel_last = data_format == "NDHWC" + + dims = 3 + self._stride = utils.convert_to_list(stride, dims, 'stride') + self._dilation = utils.convert_to_list(dilation, dims, 'dilation') + self._kernel_size = utils.convert_to_list(kernel_size, dims, + 'kernel_size') + self._padding = padding + self._padding_mode = padding_mode + self._updated_padding, self._padding_algorithm = _update_padding_nd( + padding, channel_last, dims) + + # the sparse conv restricts the shape is [D, H, W, in_channels, out_channels] + filter_shape = self._kernel_size + [ + self._in_channels, self._out_channels + ] + + def _get_default_param_initializer(): + filter_elem_num = np.prod(self._kernel_size) * self._in_channels + std = (2.0 / filter_elem_num)**0.5 + return Normal(0.0, std) + + self.weight = self.create_parameter( + shape=filter_shape, + attr=self._param_attr, + default_initializer=_get_default_param_initializer()) + #self.bias = self.create_parameter( + # attr=self._bias_attr, shape=[self._out_channels], is_bias=True) + self.bias = None + + def forward(self, x): + out = F.conv._conv3d( + x, + self.weight, + bias=self.bias, + stride=self._stride, + padding=self._updated_padding, + dilation=self._dilation, + groups=self._groups, + subm=self._subm, + data_format=self._data_format) + return out + + def extra_repr(self): + main_str = '{_in_channels}, {_out_channels}, kernel_size={_kernel_size}' + if self._stride != [1] * len(self._stride): + main_str += ', stride={_stride}' + if self._padding != 0: + main_str += ', padding={_padding}' + if self._padding_mode != 'zeros': + main_str += ', padding_mode={_padding_mode}' + if self._dilation != [1] * len(self._dilation): + main_str += ', dilation={_dilation}' + if self._groups != 1: + main_str += ', groups={_groups}' + main_str += ', data_format={_data_format}' + return main_str.format(**self.__dict__) + + +class Conv3D(_Conv3D): + r""" + **Sparse Convlution3d Layer** + The Sparse convolution3d layer calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input(Input) and + Output(Output) are multidimensional SparseCooTensors with a shape of + :math:`[N, D, H, W, C]` . Where N is batch size, C is the number of + channels, D is the depth of the feature, H is the height of the feature, + and W is the width of the feature. If bias attribution is provided, + bias is added to the output of the convolution. + For each input :math:`X`, the equation is: + + .. math:: + + Out = W \ast X + b + + In the above equation: + + * :math:`X`: Input value, a tensor with NDHWC format. + * :math:`W`: Filter value, a tensor with DHWCM format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 1-D tensor with shape [M]. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Parameters: + in_channels(int): The number of input channels in the input image. + out_channels(int): The number of output channels produced by the convolution. + kernel_size(int|list|tuple, optional): The size of the convolving kernel. + stride(int|list|tuple, optional): The stride size. If stride is a list/tuple, it must + contain three integers, (stride_D, stride_H, stride_W). Otherwise, the + stride_D = stride_H = stride_W = stride. The default value is 1. + padding(int|str|tuple|list, optional): The padding size. Padding coule be in one of the following forms. + 1. a string in ['valid', 'same']. + 2. an int, which means each spartial dimension(depth, height, width) is zero paded by size of `padding` + 3. a list[int] or tuple[int] whose length is the number of spartial dimensions, which contains the amount of padding on each side for each spartial dimension. It has the form [pad_d1, pad_d2, ...]. + 4. a list[int] or tuple[int] whose length is 2 * number of spartial dimensions. It has the form [pad_before, pad_after, pad_before, pad_after, ...] for all spartial dimensions. + 5. a list or tuple of pairs of ints. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension are also included. Each pair of integers correspond to the amount of padding for a dimension of the input. Padding in batch dimension and channel dimension should be [0, 0] or (0, 0). + The default value is 0. + dilation(int|list|tuple, optional): The dilation size. If dilation is a list/tuple, it must + contain three integers, (dilation_D, dilation_H, dilation_W). Otherwise, the + dilation_D = dilation_H = dilation_W = dilation. The default value is 1. + groups(int, optional): The groups number of the Conv3D Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. The default value is 1, currently, only support groups=1. + padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``. + weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights + of conv3d. If it is set to None or one attribute of ParamAttr, conv3d + will create ParamAttr as param_attr. If it is set to None, the parameter + is initialized with :math:`Normal(0.0, std)`, and the :math:`std` is + :math:`(\frac{2.0 }{filter\_elem\_num})^{0.5}`. The default value is None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of conv3d. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv3d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. The default value is None. + data_format(str, optional): Data format that specifies the layout of input. + It can be "NCDHW" or "NDHWC". Currently, only support "NCDHW". + + Attribute: + + **weight** (Parameter): the learnable weights of filters of this layer. + + **bias** (Parameter): the learnable bias of this layer. + + Shape: + + - x: :math:`(N, D_{in}, H_{in}, W_{in}, C_{in})` + + - weight: :math:`(K_{d}, K_{h}, K_{w}, C_{in}, C_{out})` + + - bias: :math:`(C_{out})` + + - output: :math:`(N, D_{out}, H_{out}, W_{out}, C_{out})` + + Where + + .. math:: + + D_{out}&= \frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (kernel\_size[0] - 1) + 1))}{strides[0]} + 1 + + H_{out}&= \frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (kernel\_size[1] - 1) + 1))}{strides[1]} + 1 + + W_{out}&= \frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (kernel\_size[2] - 1) + 1))}{strides[2]} + 1 + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + with _test_eager_guard(): + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + dense_shape = [1, 1, 3, 4, 1] + sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True) + conv = paddle.sparse.Conv3D(1, 1, (1, 3, 3)) + y = conv(sparse_x) + print(y.shape) + # (1, 1, 1, 2, 1) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode='zeros', + weight_attr=None, + bias_attr=None, + data_format="NDHWC"): + super(Conv3D, self).__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + subm=False, + padding_mode=padding_mode, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format) + + +class SubmConv3D(_Conv3D): + r""" + **Sparse Submanifold Convlution3d Layer** + The Sparse submanifold convolution3d layer calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input(Input) and + Output(Output) are multidimensional SparseCooTensors with a shape of + :math:`[N, D, H, W, C]` . Where N is batch size, C is the number of + channels, D is the depth of the feature, H is the height of the feature, + and W is the width of the feature. If bias attribution is provided, + bias is added to the output of the convolution. + For each input :math:`X`, the equation is: + + .. math:: + + Out =(W \ast X + b + + In the above equation: + + * :math:`X`: Input value, a tensor with NDHWC format. + * :math:`W`: Filter value, a tensor with DHWCM format. + * :math:`\\ast`: Submanifold Convolution operation, refer to the paper: https://arxiv.org/abs/1706.01307. + * :math:`b`: Bias value, a 1-D tensor with shape [M]. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Parameters: + in_channels(int): The number of input channels in the input image. + out_channels(int): The number of output channels produced by the convolution. + kernel_size(int|list|tuple, optional): The size of the convolving kernel. + stride(int|list|tuple, optional): The stride size. If stride is a list/tuple, it must + contain three integers, (stride_D, stride_H, stride_W). Otherwise, the + stride_D = stride_H = stride_W = stride. The default value is 1. + padding(int|str|tuple|list, optional): The padding size. Padding coule be in one of the following forms. + 1. a string in ['valid', 'same']. + 2. an int, which means each spartial dimension(depth, height, width) is zero paded by size of `padding` + 3. a list[int] or tuple[int] whose length is the number of spartial dimensions, which contains the amount of padding on each side for each spartial dimension. It has the form [pad_d1, pad_d2, ...]. + 4. a list[int] or tuple[int] whose length is 2 * number of spartial dimensions. It has the form [pad_before, pad_after, pad_before, pad_after, ...] for all spartial dimensions. + 5. a list or tuple of pairs of ints. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension are also included. Each pair of integers correspond to the amount of padding for a dimension of the input. Padding in batch dimension and channel dimension should be [0, 0] or (0, 0). + The default value is 0. + dilation(int|list|tuple, optional): The dilation size. If dilation is a list/tuple, it must + contain three integers, (dilation_D, dilation_H, dilation_W). Otherwise, the + dilation_D = dilation_H = dilation_W = dilation. The default value is 1. + groups(int, optional): The groups number of the Conv3D Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. The default value is 1. + padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``. + weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights + of conv3d. If it is set to None or one attribute of ParamAttr, conv3d + will create ParamAttr as param_attr. If it is set to None, the parameter + is initialized with :math:`Normal(0.0, std)`, and the :math:`std` is + :math:`(\frac{2.0 }{filter\_elem\_num})^{0.5}`. The default value is None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of conv3d. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv3d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. The default value is None. + data_format(str, optional): Data format that specifies the layout of input. + It can be "NCDHW" or "NDHWC". Currently, only support "NCDHW". + + Attribute: + + **weight** (Parameter): the learnable weights of filters of this layer. + + **bias** (Parameter): the learnable bias of this layer. + + Shape: + + - x: :math:`(N, D_{in}, H_{in}, W_{in}, C_{in})` + + - weight: :math:`(K_{d}, K_{h}, K_{w}, C_{in}, C_{out})` + + - bias: :math:`(C_{out})` + + - output: :math:`(N, D_{out}, H_{out}, W_{out}, C_{out})` + + Where + + .. math:: + + D_{out}&= \frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (kernel\_size[0] - 1) + 1))}{strides[0]} + 1 + + H_{out}&= \frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (kernel\_size[1] - 1) + 1))}{strides[1]} + 1 + + W_{out}&= \frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (kernel\_size[2] - 1) + 1))}{strides[2]} + 1 + + Examples: + + .. code-block:: python + + import paddle + from paddle.fluid.framework import _test_eager_guard + + with _test_eager_guard(): + indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]] + values = [[1], [2], [3], [4]] + dense_shape = [1, 1, 3, 4, 1] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor(values, dtype='float32') + sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True) + subm_conv = paddle.sparse.SubmConv3D(1, 1, (1, 3, 3)) + y = subm_conv(sparse_x) + print(y.shape) + # (1, 1, 3, 4, 1) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode='zeros', + weight_attr=None, + bias_attr=None, + data_format="NDHWC"): + super(SubmConv3D, self).__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + subm=True, + padding_mode=padding_mode, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format)