未验证 提交 18e9aafb 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add Sparse MaxPool3D (#42130)

上级 c7302f96
...@@ -256,9 +256,11 @@ void SparseCooToDenseKernel(const Context& dev_ctx, ...@@ -256,9 +256,11 @@ void SparseCooToDenseKernel(const Context& dev_ctx,
} }
const int64_t dense_dim = values.dims().size() - 1; const int64_t dense_dim = values.dims().size() - 1;
const auto place = dev_ctx.GetPlace();
const T* x_data = values.data<T>(); const T* x_data = values.data<T>();
T* out_data = out->mutable_data<T>(place); *out = phi::Empty(
dev_ctx,
DenseTensorMeta(x.dtype(), x.dims(), x.non_zero_elements().layout()));
T* out_data = out->data<T>();
int64_t base_offset = 1; int64_t base_offset = 1;
for (int64_t i = 0; i < dense_dim; i++) { for (int64_t i = 0; i < dense_dim; i++) {
base_offset *= dense_dims[sparse_dim + i]; base_offset *= dense_dims[sparse_dim + i];
......
...@@ -104,7 +104,7 @@ void MaxPoolGPUKernel(const GPUContext& dev_ctx, ...@@ -104,7 +104,7 @@ void MaxPoolGPUKernel(const GPUContext& dev_ctx,
#endif #endif
out_features_ptr, out_features_ptr,
out_features_ptr + out->non_zero_elements().numel(), out_features_ptr + out->non_zero_elements().numel(),
static_cast<T>(-FLT_MAX)); static_cast<T>(0));
// TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster // TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (counter[i] <= 0) { if (counter[i] <= 0) {
......
...@@ -503,7 +503,10 @@ void SparseCooToDenseKernel(const Context& dev_ctx, ...@@ -503,7 +503,10 @@ void SparseCooToDenseKernel(const Context& dev_ctx,
const auto place = dev_ctx.GetPlace(); const auto place = dev_ctx.GetPlace();
const T* x_data = values.data<T>(); const T* x_data = values.data<T>();
T* out_data = out->mutable_data<T>(place); *out = phi::Empty(dev_ctx,
phi::DenseTensorMeta(
x.dtype(), x.dims(), x.non_zero_elements().layout()));
T* out_data = out->data<T>();
int64_t base_offset = 1; int64_t base_offset = 1;
for (int64_t i = 0; i < dense_dim; i++) { for (int64_t i = 0; i < dense_dim; i++) {
base_offset *= dense_dims[sparse_dim + i]; base_offset *= dense_dims[sparse_dim + i];
......
...@@ -110,7 +110,7 @@ void SparseCooToDenseKernel(const Context& dev_ctx, ...@@ -110,7 +110,7 @@ void SparseCooToDenseKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor SparseCooToDense(const Context& dev_ctx, const SparseCooTensor& x) { DenseTensor SparseCooToDense(const Context& dev_ctx, const SparseCooTensor& x) {
DenseTensorMeta meta(x.dtype(), x.dims(), x.layout()); DenseTensorMeta meta(x.dtype(), x.dims(), x.non_zero_elements().layout());
DenseTensor dense = phi::Empty(dev_ctx, std::move(meta)); DenseTensor dense = phi::Empty(dev_ctx, std::move(meta));
SparseCooToDenseKernel<T, Context>(dev_ctx, x, &dense); SparseCooToDenseKernel<T, Context>(dev_ctx, x, &dense);
return dense; return dense;
...@@ -129,7 +129,7 @@ void SparseCsrToDenseKernel(const Context& dev_ctx, ...@@ -129,7 +129,7 @@ void SparseCsrToDenseKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor SparseCsrToDense(const Context& dev_ctx, const SparseCsrTensor& x) { DenseTensor SparseCsrToDense(const Context& dev_ctx, const SparseCsrTensor& x) {
DenseTensorMeta meta(x.dtype(), x.dims(), x.layout()); DenseTensorMeta meta(x.dtype(), x.dims(), x.non_zero_elements().layout());
DenseTensor dense = phi::Empty(dev_ctx, std::move(meta)); DenseTensor dense = phi::Empty(dev_ctx, std::move(meta));
SparseCsrToDenseKernel<T, Context>(dev_ctx, x, &dense); SparseCsrToDenseKernel<T, Context>(dev_ctx, x, &dense);
return dense; return dense;
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle import _C_ops
from paddle.fluid.framework import _test_eager_guard
class TestMaxPool3DFunc(unittest.TestCase):
def setInput(self):
paddle.seed(0)
self.dense_x = paddle.randn((1, 4, 4, 4, 4))
def setKernelSize(self):
self.kernel_sizes = [3, 3, 3]
def setStride(self):
self.strides = [1, 1, 1]
def setPadding(self):
self.paddings = [0, 0, 0]
def setUp(self):
self.setInput()
self.setKernelSize()
self.setStride()
self.setPadding()
def test(self):
with _test_eager_guard():
self.setUp()
sparse_x = self.dense_x.to_sparse_coo(4)
out = paddle.sparse.functional.max_pool3d(
sparse_x,
self.kernel_sizes,
stride=self.strides,
padding=self.paddings)
out = out.to_dense()
dense_out = paddle.nn.functional.max_pool3d(
self.dense_x,
self.kernel_sizes,
stride=self.strides,
padding=self.paddings,
data_format='NDHWC')
#compare with dense
assert np.allclose(dense_out.flatten().numpy(),
out.flatten().numpy())
class TestStride(TestMaxPool3DFunc):
def setStride(self):
self.strides = 1
class TestPadding(TestMaxPool3DFunc):
def setPadding(self):
self.paddings = 1
def setInput(self):
self.dense_x = paddle.randn((1, 5, 6, 8, 3))
class TestKernelSize(TestMaxPool3DFunc):
def setKernelSize(self):
self.kernel_sizes = [5, 5, 5]
def setInput(self):
paddle.seed(0)
self.dense_x = paddle.randn((1, 6, 9, 6, 3))
class TestInput(TestMaxPool3DFunc):
def setInput(self):
paddle.seed(0)
self.dense_x = paddle.randn((2, 6, 7, 9, 3))
dropout = paddle.nn.Dropout(0.8)
self.dense_x = dropout(self.dense_x)
class TestMaxPool3DAPI(unittest.TestCase):
def test(self):
with _test_eager_guard():
dense_x = paddle.randn((2, 3, 6, 6, 3))
sparse_x = dense_x.to_sparse_coo(4)
max_pool3d = paddle.sparse.MaxPool3D(
kernel_size=3, data_format='NDHWC')
out = max_pool3d(sparse_x)
out = out.to_dense()
dense_out = paddle.nn.functional.max_pool3d(
dense_x, 3, data_format='NDHWC')
assert np.allclose(dense_out.numpy(), out.numpy())
if __name__ == "__main__":
unittest.main()
...@@ -20,7 +20,9 @@ from .layer.norm import BatchNorm ...@@ -20,7 +20,9 @@ from .layer.norm import BatchNorm
from .layer.conv import Conv3D from .layer.conv import Conv3D
from .layer.conv import SubmConv3D from .layer.conv import SubmConv3D
from .layer.pooling import MaxPool3D
__all__ = [ __all__ = [
'sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU', 'Conv3D', 'SubmConv3D', 'sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU', 'Conv3D', 'SubmConv3D',
'BatchNorm' 'BatchNorm', 'MaxPool3D'
] ]
...@@ -15,5 +15,6 @@ ...@@ -15,5 +15,6 @@
from .activation import relu # noqa: F401 from .activation import relu # noqa: F401
from .conv import conv3d # noqa: F401 from .conv import conv3d # noqa: F401
from .conv import subm_conv3d # noqa: F401 from .conv import subm_conv3d # noqa: F401
from .pooling import max_pool3d # noqa: F401
__all__ = ['relu', 'conv3d', 'subm_conv3d'] __all__ = ['relu', 'conv3d', 'subm_conv3d', 'max_pool3d']
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...fluid.layers import utils
from paddle import _C_ops, in_dynamic_mode
from paddle.nn.functional.pooling import _update_padding_nd
__all__ = []
def max_pool3d(x,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
data_format="NDHWC",
name=None):
"""
Implements sparse max pooling 3d operation.
See more details in :ref:`api_sparse_pooling_MaxPool3d` .
Args:
x (Tensor): The input SparseCooTensor of pooling operator, which is a 5-D tensor with
shape [N, D, H, W, C]. The format of input tensor `"NDHWC"`, where N represents batch size, C represents the number of channels, D, H and W represent the depth, height and width of the feature respectively.
kernel_size (int|list|tuple): The pool kernel size. If the kernel size
is a tuple or list, it must contain three integers,
(kernel_size_Depth, kernel_size_Height, kernel_size_Width).
Otherwise, the pool kernel size will be the cube of an int.
stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list,
it must contain three integers, [stride_Depth, stride_Height, stride_Width).
Otherwise, the pool stride size will be a cube of an int.
padding (string|int|list|tuple): The padding size. Padding could be in one of the following forms.
1. A string in ['valid', 'same'].
2. An int, which means the feature map is zero padded by size of `padding` on every sides.
3. A list[int] or tuple(int) whose length is 3, [pad_depth, pad_height, pad_weight] whose value means the padding size of each dimension.
4. A list[int] or tuple(int) whose length is 6. [pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right] whose value means the padding size of each side.
5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0).
The default value is 0.
ceil_mode (bool): ${ceil_mode_comment}
data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`.
The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_depth, input_height, input_width]`. Currently only support `"NDHWC"` .
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: The output tensor of pooling result. The data type is same as input tensor.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.randn((1, 4, 4, 4, 3))
sparse_x = dense_x.to_sparse_coo(4)
kernel_sizes = [3, 3, 3]
paddings = [0, 0, 0]
strides = [1, 1, 1]
out = paddle.sparse.functional.max_pool3d(sparse_x, kernel_sizes, stride=strides, padding=paddings)
#[1, 2, 2, 2, 3]
"""
assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
assert x.is_sparse_coo(
), "Currently, sparse.relu only support the input of SparseCooTensor"
assert data_format == 'NDHWC', "Currently, sparse.max_pool3d only support data format of 'NDHWC'"
kernel_size = utils.convert_to_list(kernel_size, 3, 'pool_size')
if stride is None:
stride = kernel_size
else:
stride = utils.convert_to_list(stride, 3, 'pool_stride')
channel_last = True
padding, padding_algorithm = _update_padding_nd(
padding, 3, channel_last=channel_last, ceil_mode=ceil_mode)
#TODO(zkh2016): remove the dependency on dilation from the backend
dilation = [1, 1, 1]
return _C_ops.final_state_sparse_maxpool(x, kernel_size, padding, dilation,
stride)
...@@ -16,5 +16,6 @@ from .activation import ReLU ...@@ -16,5 +16,6 @@ from .activation import ReLU
from .norm import BatchNorm from .norm import BatchNorm
from .conv import Conv3D from .conv import Conv3D
from .conv import SubmConv3D from .conv import SubmConv3D
from .pooling import MaxPool3D
__all__ = [] __all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.nn import Layer
from .. import functional as F
class MaxPool3D(Layer):
"""
This operation applies 3D max pooling over input features based on the sparse input,
and kernel_size, stride, padding parameters. Input(X) and Output(Out) are
in NDHWC format, where N is batch size, C is the number of channels,
H is the height of the feature, D is the depth of the feature, and W is the width of the feature.
Parameters:
kernel_size(int|list|tuple): The pool kernel size. If the kernel size
is a tuple or list, it must contain three integers,
(kernel_size_Depth, kernel_size_Height, kernel_size_Width).
Otherwise, the pool kernel size will be the cube of an int.
stride(int|list|tuple, optional): The pool stride size. If pool stride size is a tuple or list,
it must contain three integers, [stride_Depth, stride_Height, stride_Width).
Otherwise, the pool stride size will be a cube of an int.
Default None, then stride will be equal to the kernel_size.
padding(str|int|list|tuple, optional): The padding size. Padding could be in one of the following forms.
1. A string in ['valid', 'same'].
2. An int, which means the feature map is zero padded by size of `padding` on every sides.
3. A list[int] or tuple(int) whose length is 3, [pad_depth, pad_height, pad_weight] whose value means the padding size of each dimension.
4. A list[int] or tuple(int) whose length is \6. [pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right] whose value means the padding size of each side.
5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0).
The default value is 0.
ceil_mode(bool, optional): ${ceil_mode_comment}
return_mask(bool, optional): Whether to return the max indices along with the outputs.
data_format(str, optional): The data format of the input and output data. An optional string from: `"NCDHW"`,
`"NDHWC"`. The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_depth, input_height, input_width]`. Currently, only support "NDHWC".
name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`.
Usually name is no need to set and None by default.
Returns:
A callable object of MaxPool3D.
Shape:
- x(Tensor): The input SparseCooTensor of max pool3d operator, which is a 5-D tensor.
The data type can be float32, float64.
- output(Tensor): The output tensor of max pool3d operator, which is a 5-D tensor.
The data type is same as input x.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.randn((2, 3, 6, 6, 3))
sparse_x = dense_x.to_sparse_coo(4)
max_pool3d = paddle.sparse.MaxPool3D(
kernel_size=3, data_format='NDHWC')
out = max_pool3d(sparse_x)
#shape=[2, 1, 2, 2, 3]
"""
def __init__(self,
kernel_size,
stride=None,
padding=0,
return_mask=False,
ceil_mode=False,
data_format="NDHWC",
name=None):
super(MaxPool3D, self).__init__()
self.ksize = kernel_size
self.stride = stride
self.padding = padding
self.return_mask = return_mask
self.ceil_mode = ceil_mode
self.data_format = data_format
self.name = name
def forward(self, x):
return F.max_pool3d(
x,
kernel_size=self.ksize,
stride=self.stride,
padding=self.padding,
ceil_mode=self.ceil_mode,
data_format=self.data_format,
name=self.name)
def extra_repr(self):
return 'kernel_size={ksize}, stride={stride}, padding={padding}'.format(
**self.__dict__)
...@@ -65,3 +65,12 @@ ...@@ -65,3 +65,12 @@
args : (Tensor x) args : (Tensor x)
output : Tensor(out@SparseCsrTensor) output : Tensor(out@SparseCsrTensor)
invoke : to_sparse_csr_impl(x) invoke : to_sparse_csr_impl(x)
- api: maxpool
args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
output : Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
kernel :
func : sparse_maxpool
layout : x
intermediate : rulebook
backward : sparse_maxpool_grad
...@@ -32,6 +32,13 @@ ...@@ -32,6 +32,13 @@
output : Tensor(x_grad@DenseTensor) output : Tensor(x_grad@DenseTensor)
invoke : to_dense_impl(out_grad) invoke : to_dense_impl(out_grad)
- backward_api : sparse_maxpool_grad
forward : sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
output : Tensor(x_grad@SparseCooTensor)
kernel :
func : sparse_maxpool_grad
- backward_api : sparse_relu_grad - backward_api : sparse_relu_grad
forward : sparse_relu(Tensor x) -> Tensor(out@SparseCooTensor) forward : sparse_relu(Tensor x) -> Tensor(out@SparseCooTensor)
args : (Tensor x, Tensor out_grad) args : (Tensor x, Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册