未验证 提交 4f25604e 编写于 作者: U umiswing 提交者: GitHub

[Sparse] Support sparse conv 2d. (#54158)

上级 0df9e4ce
......@@ -28,11 +28,11 @@
namespace phi {
namespace sparse {
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
template <typename T, typename IntT, typename Context>
void ReshapeCooCPUKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
// TODO(OccupyMars2025): Currently, reshape is only applicable to sparse dims
int64_t x_nnz = x.nnz();
......@@ -48,15 +48,15 @@ void ReshapeCooKernel(const Context& dev_ctx,
for (int i = 0; i < out_dims.size() - x.dense_dim(); ++i) {
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
DenseTensor out_indices = Empty<IntT, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of indices
const DenseTensor& x_indices = x.indices();
const auto* x_indices_data = x_indices.data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
const auto* x_indices_data = x_indices.data<IntT>();
auto* out_indices_data = out_indices.data<IntT>();
const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
......@@ -75,6 +75,17 @@ void ReshapeCooKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "ReshapeCooCPUKernel", ([&] {
ReshapeCooCPUKernel<T, data_t, Context>(dev_ctx, x, shape, out);
}));
}
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
......
......@@ -26,30 +26,33 @@
namespace phi {
namespace sparse {
__global__ void ReshapeCooCudaKernel(const int64_t* x_indices_data,
template <typename IntT>
__global__ void ReshapeCooCudaKernel(const IntT* x_indices_data,
const int num_x_sparse_part_dims,
const int num_out_sparse_part_dims,
const int64_t x_nnz,
const int64_t* x_sparse_part_strides,
const int64_t* out_sparse_part_strides,
int64_t* out_indices_data) {
IntT* out_indices_data) {
CUDA_KERNEL_LOOP_TYPE(j, x_nnz, int64_t) {
int64_t location = 0;
IntT location = 0;
for (int i = 0; i < num_x_sparse_part_dims; ++i) {
location += x_indices_data[i * x_nnz + j] * x_sparse_part_strides[i];
location += x_indices_data[i * x_nnz + j] *
static_cast<IntT>(x_sparse_part_strides[i]);
}
for (int i = 0; i < num_out_sparse_part_dims; ++i) {
out_indices_data[i * x_nnz + j] = location / out_sparse_part_strides[i];
location %= out_sparse_part_strides[i];
out_indices_data[i * x_nnz + j] =
location / static_cast<IntT>(out_sparse_part_strides[i]);
location %= static_cast<IntT>(out_sparse_part_strides[i]);
}
}
}
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
template <typename T, typename IntT, typename Context>
void ReshapeCooGPUKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
int64_t x_nnz = x.nnz();
std::vector<int> new_shape(shape.GetData().begin(), shape.GetData().end());
phi::DDim out_dims = x.dims().reshape(new_shape);
......@@ -63,14 +66,14 @@ void ReshapeCooKernel(const Context& dev_ctx,
out_sparse_part_dims.push_back(out_dims[i]);
}
DenseTensor out_indices = Empty<int64_t, Context>(
DenseTensor out_indices = Empty<IntT, Context>(
dev_ctx, {static_cast<int64_t>(out_sparse_part_dims.size()), x_nnz});
DenseTensor out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of out indices
const auto* x_indices_data = x.indices().data<int64_t>();
auto* out_indices_data = out_indices.data<int64_t>();
const auto* x_indices_data = x.indices().data<IntT>();
auto* out_indices_data = out_indices.data<IntT>();
const phi::DDim& x_sparse_part_strides =
phi::stride(phi::make_ddim(x_sparse_part_dims));
const phi::DDim& out_sparse_part_strides =
......@@ -119,6 +122,17 @@ void ReshapeCooKernel(const Context& dev_ctx,
out_indices_data);
}
template <typename T, typename Context>
void ReshapeCooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const phi::IntArray& shape,
SparseCooTensor* out) {
PD_VISIT_BASE_INTEGRAL_TYPES(
x.indices().dtype(), "ReshapeCooGPUKernel", ([&] {
ReshapeCooGPUKernel<T, data_t, Context>(dev_ctx, x, shape, out);
}));
}
// just copy from paddle\phi\kernels\sparse\cpu\reshape_kernel.cc
template <typename T, typename Context>
void ReshapeCsrKernel(const Context& dev_ctx,
......
......@@ -20,6 +20,8 @@ from .layer.activation import Softmax
from .layer.activation import ReLU6
from .layer.activation import LeakyReLU
from .layer.conv import Conv3D
from .layer.conv import Conv2D
from .layer.conv import SubmConv2D
from .layer.conv import SubmConv3D
from .layer.pooling import MaxPool3D
......@@ -30,7 +32,9 @@ __all__ = [
'Softmax',
'BatchNorm',
'SyncBatchNorm',
'Conv2D',
'Conv3D',
'SubmConv2D',
'SubmConv3D',
'MaxPool3D',
]
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .conv import conv2d # noqa: F401
from .conv import conv3d # noqa: F401
from .conv import subm_conv2d # noqa: F401
from .conv import subm_conv3d # noqa: F401
from .transformer import attention # noqa: F401
from .pooling import max_pool3d # noqa: F401
......@@ -22,7 +24,9 @@ from .activation import leaky_relu # noqa: F401
from .activation import softmax # noqa: F401
__all__ = [
'conv2d',
'conv3d',
'subm_conv2d',
'subm_conv3d',
'max_pool3d',
'relu',
......
......@@ -14,12 +14,14 @@
__all__ = []
import paddle
from paddle import _C_ops, in_dynamic_mode
from paddle.fluid.layer_helper import LayerHelper
from paddle.nn.functional.conv import _update_padding_nd
from paddle.utils import convert_to_list
from ...binary import add
from ...unary import reshape
def _conv3d(
......@@ -115,6 +117,100 @@ def _conv3d(
return pre_bias
def _conv2d(
x,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
subm=False,
key=None,
data_format="NHWC",
name=None,
):
assert groups == 1, "Currently, only support groups=1"
dims = 2
# Currently, only support 'NHWC'
if data_format not in ["NHWC"]:
raise ValueError(
"Attr(data_format) should be 'NHWC'. Received "
"Attr(data_format): {}.".format(data_format)
)
if len(x.shape) != 4:
raise ValueError(
"Input x should be 4D tensor, but received x with the shape of {}".format(
x.shape
)
)
channel_last = data_format == "NHWC"
n_dim = 0
channel_dim = -1 if channel_last else 1
h_dim = 1 if channel_last else 2
w_dim = 2 if channel_last else -1
if len(x.shape) != 4:
raise ValueError(
"Input x should be 4D tensor, but received x with the shape of {}".format(
x.shape
)
)
n = x.shape[n_dim]
d = 1
h = x.shape[h_dim]
w = x.shape[w_dim]
num_channels = x.shape[channel_dim]
if num_channels < 0:
raise ValueError(
"The channel dimension of the input({}) should be defined. "
"Received: {}.".format(x.shape, num_channels)
)
padding, padding_algorithm = _update_padding_nd(padding, channel_last, dims)
stride = convert_to_list(stride, dims, 'stride')
dilation = convert_to_list(dilation, dims, 'dilation')
padding.insert(0, 0)
stride.insert(0, 1)
dilation.insert(0, 1)
x = reshape(x, [n, d, h, w, num_channels])
h_filter = weight.shape[0]
w_filter = weight.shape[1]
c_filter = weight.shape[2]
m_filter = weight.shape[3]
weight = paddle.reshape(weight, [d, h_filter, w_filter, c_filter, m_filter])
if in_dynamic_mode():
pre_bias = _C_ops.sparse_conv3d(
x,
weight,
padding,
dilation,
stride,
groups,
subm,
key if key is not None else "",
)
x = reshape(x, [n, h, w, -1])
weight = paddle.reshape(
weight, [h_filter, w_filter, c_filter, m_filter]
)
n_out = pre_bias.shape[0]
h_out = pre_bias.shape[2]
w_out = pre_bias.shape[3]
channels_out = pre_bias.shape[4]
pre_bias = reshape(pre_bias, [n_out, h_out, w_out, channels_out])
if bias is not None:
return add(pre_bias, bias)
else:
return pre_bias
else:
raise ValueError("Only support dynamic_mode now.")
def conv3d(
x,
weight,
......@@ -331,3 +427,216 @@ def subm_conv3d(
data_format,
name,
)
def conv2d(
x,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
data_format="NHWC",
name=None,
):
r"""
The sparse convolution2d functional calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input(Input) and
Output(Output) are multidimensional SparseCooTensors with a shape of
:math:`[N, H, W, C]` . Where N is batch size, C is the number of
channels, H is the height of the feature,
and W is the width of the feature. If bias attribution is provided,
bias is added to the output of the convolution.
For each input :math:`X`, the equation is:
.. math::
Out = \sigma (W \ast X + b)
In the above equation:
* :math:`X`: Input value, a tensor with NHWC format.
* :math:`W`: Filter value, a tensor with HWCM format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 1-D tensor with shape [M].
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Args:
x (Tensor): The input is 4-D SparseCooTensor with shape [N, H, W, C], the data
type of input is float16 or float32 or float64.
weight (Tensor): The convolution kernel, a Tensor with shape [kH, kW, C/g, M],
where M is the number of filters(output channels), g is the number of groups,
kD, kH, kW are the filter's height and width respectively.
bias (Tensor, optional): The bias, a Tensor of shape [M].
stride (int|list|tuple, optional): The stride size. It means the stride in convolution. If stride is a
list/tuple, it must contain two integers, (stride_height, stride_width).
Otherwise, stride_height = stride_width = stride. Default: stride = 1.
padding (string|int|list|tuple, optional): The padding size. It means the number of zero-paddings
on both sides for each dimension. If `padding` is a string, either 'VALID' or
'SAME' which is the padding algorithm. If padding size is a tuple or list,
it could be in three forms: `[pad_height, pad_width]` or
`[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`,
when `data_format` is `"NHWC"`, `padding` can be in the form
`[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0.
dilation (int|list|tuple, optional): The dilation size. It means the spacing between the kernel points.
If dilation is a list/tuple, it must contain two integers, (dilation_height,
dilation_width). Otherwise, dilation_height = dilation_width = dilation.
Default: dilation = 1.
groups (int, optional): The groups number of the Conv2D Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1. Currently, only support groups=1.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NHWC"`.
The default is `"NHWC"`. When it is `"NHWC"`, the data is stored in the order of:
`[batch_size, input_height, input_width, input_channels]`.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
A SparseCooTensor representing the conv2d, whose data type is the same with input.
Examples:
.. code-block:: python
import paddle
indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 3, 4, 1]
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True)
weight = paddle.randn((3, 3, 1, 1), dtype='float32')
y = paddle.sparse.nn.functional.conv2d(sparse_x, weight)
print(y.shape)
# (1, 1, 2, 1)
"""
return _conv2d(
x,
weight,
bias,
stride,
padding,
dilation,
groups,
False,
None,
data_format,
name,
)
def subm_conv2d(
x,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
data_format="NHWC",
key=None,
name=None,
):
r"""
The sparse submanifold convolution2d functional calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input(Input) and
Output(Output) are multidimensional SparseCooTensors with a shape of
:math:`[N, H, W, C]` . Where N is batch size, C is the number of
channels, H is the height of the feature,
and W is the width of the feature. If bias attribution is provided,
bias is added to the output of the convolution.
For each input :math:`X`, the equation is:
.. math::
Out = \sigma (W \ast X + b)
In the above equation:
* :math:`X`: Input value, a tensor with NHWC format.
* :math:`W`: Filter value, a tensor with HWCM format.
* :math:`\\ast`: Submanifold Convolution operation, refer to the paper: https://arxiv.org/abs/1706.01307.
* :math:`b`: Bias value, a 1-D tensor with shape [M].
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Args:
x (Tensor): The input is 4-D SparseCooTensor with shape [N, H, W, C], the data
type of input is float16 or float32 or float64.
weight (Tensor): The convolution kernel, a Tensor with shape [kH, kW, C/g, M],
where M is the number of filters(output channels), g is the number of groups,
kD, kH, kW are the filter's height and width respectively.
bias (Tensor, optional): The bias, a Tensor of shape [M].
stride (int|list|tuple, optional): The stride size. It means the stride in convolution. If stride is a
list/tuple, it must contain two integers, (stride_height, stride_width).
Otherwise, stride_height = stride_width = stride. Default: stride = 1.
padding (string|int|list|tuple, optional): The padding size. It means the number of zero-paddings
on both sides for each dimension. If `padding` is a string, either 'VALID' or
'SAME' which is the padding algorithm. If padding size is a tuple or list,
it could be in three forms: `[pad_height, pad_width]` or
`[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`,
when `data_format` is `"NHWC"`, `padding` can be in the form
`[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Default: padding = 0.
dilation (int|list|tuple, optional): The dilation size. It means the spacing between the kernel points.
If dilation is a list/tuple, it must contain two integers, (dilation_height,
dilation_width). Otherwise, dilation_height = dilation_width = dilation.
Default: dilation = 1.
groups (int, optional): The groups number of the Conv2D Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1. Currently, only support groups=1.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. An optional string from: `"NHWC"`.
The default is `"NHWC"`. When it is `"NHWC"`, the data is stored in the order of:
`[batch_size, input_height, input_width, input_channels]`.
key(str, optional): the key is used to save or use the same rulebook,
the definition and role of rulebook refers to
https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The
default value is None.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
A SparseCooTensor representing the conv2d, whose data type is the same with input.
Examples:
.. code-block:: python
import paddle
indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 3, 4, 1]
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True)
weight = paddle.randn((3, 3, 1, 1), dtype='float32')
y = paddle.sparse.nn.functional.subm_conv2d(sparse_x, weight)
print(y.shape)
# (1, 3, 4, 1)
"""
return _conv2d(
x,
weight,
bias,
stride,
padding,
dilation,
groups,
True,
key,
data_format,
name,
)
......@@ -130,6 +130,112 @@ class _Conv3D(Layer):
return main_str.format(**self.__dict__)
class _Conv2D(Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
subm=False,
key=None,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
data_format="NHWC",
):
super().__init__()
assert (
weight_attr is not False
), "weight_attr should not be False in Conv."
self._param_attr = weight_attr
self._bias_attr = bias_attr
self._groups = groups
self._in_channels = in_channels
self._out_channels = out_channels
self._data_format = data_format
self._subm = subm
self._key = key
assert (
padding_mode == 'zeros'
), "Currently, only support padding_mode='zeros'"
assert groups == 1, "Currently, only support groups=1"
valid_format = {'NHWC'}
if data_format not in valid_format:
raise ValueError(
"data_format must be one of {}, but got data_format='{}'".format(
valid_format, data_format
)
)
channel_last = data_format == "NHWC"
dims = 2
self._stride = convert_to_list(stride, dims, 'stride')
self._dilation = convert_to_list(dilation, dims, 'dilation')
self._kernel_size = convert_to_list(kernel_size, dims, 'kernel_size')
self._padding = padding
self._padding_mode = padding_mode
self._updated_padding, self._padding_algorithm = _update_padding_nd(
padding, channel_last, dims
)
# the sparse conv restricts the shape is [H, W, in_channels, out_channels]
filter_shape = self._kernel_size + [
self._in_channels,
self._out_channels,
]
def _get_default_param_initializer():
filter_elem_num = np.prod(self._kernel_size) * self._in_channels
std = (2.0 / filter_elem_num) ** 0.5
return Normal(0.0, std)
self.weight = self.create_parameter(
shape=filter_shape,
attr=self._param_attr,
default_initializer=_get_default_param_initializer(),
)
self.bias = self.create_parameter(
attr=self._bias_attr, shape=[self._out_channels], is_bias=True
)
def forward(self, x):
out = F.conv._conv2d(
x,
self.weight,
bias=self.bias,
stride=self._stride,
padding=self._updated_padding,
dilation=self._dilation,
groups=self._groups,
subm=self._subm,
key=self._key,
data_format=self._data_format,
)
return out
def extra_repr(self):
main_str = '{_in_channels}, {_out_channels}, kernel_size={_kernel_size}'
if self._stride != [1] * len(self._stride):
main_str += ', stride={_stride}'
if self._padding != 0:
main_str += ', padding={_padding}'
if self._padding_mode != 'zeros':
main_str += ', padding_mode={_padding_mode}'
if self._dilation != [1] * len(self._dilation):
main_str += ', dilation={_dilation}'
if self._groups != 1:
main_str += ', groups={_groups}'
main_str += ', data_format={_data_format}'
return main_str.format(**self.__dict__)
class Conv3D(_Conv3D):
r"""
**Sparse Convlution3d Layer**
......@@ -265,6 +371,141 @@ class Conv3D(_Conv3D):
)
class Conv2D(_Conv2D):
r"""
**Sparse Convlution2d Layer**
The Sparse convolution2d layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input(Input) and
Output(Output) are multidimensional SparseCooTensors with a shape of
:math:`[N, H, W, C]` . Where N is batch size, C is the number of
channels, H is the height of the feature,
and W is the width of the feature. If bias attribution is provided,
bias is added to the output of the convolution.
For each input :math:`X`, the equation is:
.. math::
Out = W \ast X + b
In the above equation:
* :math:`X`: Input value, a tensor with NHWC format.
* :math:`W`: Filter value, a tensor with HWCM format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 1-D tensor with shape [M].
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Parameters:
in_channels(int): The number of input channels in the input image.
out_channels(int): The number of output channels produced by the convolution.
kernel_size(int|list|tuple): The size of the convolving kernel.
stride(int|list|tuple, optional): The stride size. If stride is a list/tuple, it must
contain three integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. The default value is 1.
padding(int|str|tuple|list, optional): The padding size. Padding coule be in one of the following forms.
1. a string in ['valid', 'same'].
2. an int, which means each spartial dimension(height, width) is zero paded by size of `padding`
3. a list[int] or tuple[int] whose length is the number of spartial dimensions, which contains the amount of padding on each side for each spartial dimension. It has the form [pad_d1, pad_d2, ...].
4. a list[int] or tuple[int] whose length is 2 * number of spartial dimensions. It has the form [pad_before, pad_after, pad_before, pad_after, ...] for all spartial dimensions.
5. a list or tuple of pairs of ints. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...].
Note that, the batch dimension and channel dimension are also included. Each pair of integers correspond to the amount of padding for a dimension of the input. Padding in batch dimension and channel dimension should be [0, 0] or (0, 0).
The default value is 0.
dilation(int|list|tuple, optional): The dilation size. If dilation is a list/tuple, it must
contain three integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. The default value is 1.
groups(int, optional): The groups number of the Conv2D Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. The default value is 1, currently, only support groups=1.
padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``.
weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If it is set to None, the parameter
is initialized with :math:`Normal(0.0, std)`, and the :math:`std` is
:math:`(\frac{2.0 }{filter\_elem\_num})^{0.5}`. The default value is None.
bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. The default value is None.
data_format(str, optional): Data format that specifies the layout of input.
It can be "NCHW" or "NHWC". Currently, only support "NHWC".
Attribute:
**weight** (Parameter): the learnable weights of filters of this layer.
**bias** (Parameter): the learnable bias of this layer.
Shape:
- x: :math:`(N, H_{in}, W_{in}, C_{in})`
- weight: :math:`(K_{h}, K_{w}, C_{in}, C_{out})`
- bias: :math:`(C_{out})`
- output: :math:`(N, H_{out}, W_{out}, C_{out})`
Where
.. math::
H_{out}&= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (kernel\_size[0] - 1) + 1))}{strides[0]} + 1
W_{out}&= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (kernel\_size[1] - 1) + 1))}{strides[1]} + 1
Examples:
.. code-block:: python
import paddle
indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 3, 4, 1]
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True)
conv = paddle.sparse.nn.Conv2D(1, 1, (3, 3))
y = conv(sparse_x)
print(y.shape)
# (1, 1, 2, 1)
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
data_format="NHWC",
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
subm=False,
key=None,
padding_mode=padding_mode,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format,
)
class SubmConv3D(_Conv3D):
r"""
**Submanifold Sparse Convlution3d Layer**
......@@ -403,3 +644,143 @@ class SubmConv3D(_Conv3D):
bias_attr=bias_attr,
data_format=data_format,
)
class SubmConv2D(_Conv2D):
r"""
**Submanifold Sparse Convlution2d Layer**
The submanifold sparse convolution2d layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input(Input) and
Output(Output) are multidimensional SparseCooTensors with a shape of
:math:`[N, H, W, C]` . Where N is batch size, C is the number of
channels, H is the height of the feature,
and W is the width of the feature. If bias attribution is provided,
bias is added to the output of the convolution.
For each input :math:`X`, the equation is:
.. math::
Out = W \ast X + b
In the above equation:
* :math:`X`: Input value, a tensor with NDHWC format.
* :math:`W`: Filter value, a tensor with DHWCM format.
* :math:`\\ast`: Submanifold Convolution operation, refer to the paper: https://arxiv.org/abs/1706.01307.
* :math:`b`: Bias value, a 1-D tensor with shape [M].
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Parameters:
in_channels(int): The number of input channels in the input image.
out_channels(int): The number of output channels produced by the convolution.
kernel_size(int|list|tuple): The size of the convolving kernel.
stride(int|list|tuple, optional): The stride size. If stride is a list/tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. The default value is 1.
padding(int|str|tuple|list, optional): The padding size. Padding coule be in one of the following forms.
1. a string in ['valid', 'same'].
2. an int, which means each spartial dimension(depth, height, width) is zero paded by size of `padding`
3. a list[int] or tuple[int] whose length is the number of spartial dimensions, which contains the amount of padding on each side for each spartial dimension. It has the form [pad_d1, pad_d2, ...].
4. a list[int] or tuple[int] whose length is 2 * number of spartial dimensions. It has the form [pad_before, pad_after, pad_before, pad_after, ...] for all spartial dimensions.
5. a list or tuple of pairs of ints. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...].
Note that, the batch dimension and channel dimension are also included. Each pair of integers correspond to the amount of padding for a dimension of the input. Padding in batch dimension and channel dimension should be [0, 0] or (0, 0).
The default value is 0.
dilation(int|list|tuple, optional): The dilation size. If dilation is a list/tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. The default value is 1.
groups(int, optional): The groups number of the Conv2D Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. The default value is 1.
padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``.
key(str, optional): the key is used to save or use the same rulebook,
the definition and role of rulebook refers to
https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The
default value is None.
weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If it is set to None, the parameter
is initialized with :math:`Normal(0.0, std)`, and the :math:`std` is
:math:`(\frac{2.0 }{filter\_elem\_num})^{0.5}`. The default value is None.
bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. The default value is None.
data_format(str, optional): Data format that specifies the layout of input.
It can be "NCHW" or "NHWC". Currently, only support "NHWC".
Attribute:
**weight** (Parameter): the learnable weights of filters of this layer.
**bias** (Parameter): the learnable bias of this layer.
Shape:
- x: :math:`(N, H_{in}, W_{in}, C_{in})`
- weight: :math:`(K_{h}, K_{w}, C_{in}, C_{out})`
- bias: :math:`(C_{out})`
- output: :math:`(N, H_{out}, W_{out}, C_{out})`
Where
.. math::
H_{out}&= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (kernel\_size[0] - 1) + 1))}{strides[0]} + 1
W_{out}&= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (kernel\_size[1] - 1) + 1))}{strides[1]} + 1
Examples:
.. code-block:: python
import paddle
indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
dense_shape = [1, 3, 4, 1]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape, stop_gradient=True)
subm_conv = paddle.sparse.nn.SubmConv2D(1, 1, (3, 3))
y = subm_conv(sparse_x)
print(y.shape)
# (1, 3, 4, 1)
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros',
key=None,
weight_attr=None,
bias_attr=None,
data_format="NHWC",
):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
subm=True,
key=key,
padding_mode=padding_mode,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format,
)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import unittest
import numpy as np
......@@ -20,8 +21,47 @@ import paddle
from paddle import sparse
from paddle.fluid import core
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO
)
logger = logging.getLogger(__name__)
class TestSparseConv(unittest.TestCase):
def test_conv2d(self):
kernel = [[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]]
dense_kernel = paddle.to_tensor(
kernel, dtype='float32', stop_gradient=False
)
dense_kernel = paddle.reshape(dense_kernel, [3, 3, 1, 1])
paddings = [0, 0]
strides = [1, 1]
dilations = [1, 1]
bias = [1]
indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [1, 2, 3, 4]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 3, 4, 1]
correct_out_values = [[5], [11]]
sparse_input = core.eager.sparse_coo_tensor(
indices, values, dense_shape, False
)
out = paddle.sparse.nn.functional.conv2d(
sparse_input,
dense_kernel,
bias=paddle.to_tensor(bias, dtype='float32'),
stride=strides,
padding=paddings,
dilation=dilations,
groups=1,
data_format="NHWC",
)
out.backward(out)
out = paddle.sparse.coalesce(out)
np.testing.assert_array_equal(correct_out_values, out.values().numpy())
def test_conv3d(self):
kernel = [[[[[1], [1], [1]], [[1], [1], [1]], [[1], [1], [1]]]]]
dense_kernel = paddle.to_tensor(
......@@ -56,6 +96,23 @@ class TestSparseConv(unittest.TestCase):
out = paddle.sparse.coalesce(out)
assert np.array_equal(correct_out_values, out.values().numpy())
def test_subm_conv2d(self):
indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 3, 4, 1]
sparse_x = paddle.sparse.sparse_coo_tensor(
indices, values, dense_shape, stop_gradient=True
)
weight = paddle.randn((1, 3, 3, 1), dtype='float32')
y = paddle.sparse.nn.functional.subm_conv2d(
sparse_x, weight, key='subm_conv'
)
np.testing.assert_array_equal(
sparse_x.indices().numpy(), y.indices().numpy()
)
def test_subm_conv3d(self):
indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
......@@ -71,6 +128,30 @@ class TestSparseConv(unittest.TestCase):
)
assert np.array_equal(sparse_x.indices().numpy(), y.indices().numpy())
def test_Conv2D(self):
# (3, non_zero_num), 3-D:(N, H, W)
indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
# (non_zero_num, C)
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 3, 4, 1]
correct_out_values = [[4], [10]]
sparse_input = paddle.sparse.sparse_coo_tensor(
indices, values, dense_shape, False
)
sparse_conv2d = paddle.sparse.nn.Conv2D(
1, 1, (3, 3), data_format='NHWC'
)
sparse_out = sparse_conv2d(sparse_input)
# test errors
with self.assertRaises(ValueError):
# Currently, only support data_format='NDHWC'
conv2d = paddle.sparse.nn.SubmConv2D(
1, 1, (3, 3), data_format='NCHW', key='subm_conv'
)
def test_Conv3D(self):
# (4, non_zero_num), 4-D:(N, D, H, W)
indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
......@@ -95,6 +176,34 @@ class TestSparseConv(unittest.TestCase):
1, 1, (1, 3, 3), data_format='NCDHW', key='subm_conv'
)
def test_SubmConv2D(self):
indices = [[0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 3, 4, 1]
correct_out_values = [[4], [10]]
sparse_input = paddle.sparse.sparse_coo_tensor(
indices, values, dense_shape, False
)
subm_conv2d = paddle.sparse.nn.SubmConv2D(
1, 1, (3, 3), data_format='NHWC', key='subm_conv'
)
# test extra_repr
logger.info(subm_conv2d.extra_repr())
sparse_out = subm_conv2d(sparse_input)
# the output shape of subm_conv is same as input shape
np.testing.assert_array_equal(indices, sparse_out.indices().numpy())
# test errors
with self.assertRaises(ValueError):
# Currently, only support data_format='NHWC'
conv2d = paddle.sparse.nn.SubmConv2D(
1, 1, (3, 3), data_format='NCHW', key='subm_conv'
)
def test_SubmConv3D(self):
indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [[1], [2], [3], [4]]
......@@ -123,6 +232,45 @@ class TestSparseConv(unittest.TestCase):
1, 1, (1, 3, 3), data_format='NCDHW', key='subm_conv'
)
def test_Conv2D_bias(self):
paddle.seed(0)
shape = [1, 4, 4, 3]
x = paddle.randn(shape)
sp_x = x.to_sparse_coo(3)
conv2d = paddle.nn.Conv2D(3, 2, 3, data_format='NHWC')
sp_conv2d = paddle.sparse.nn.Conv2D(3, 2, 3, data_format='NHWC')
sp_conv2d.weight.set_value(
paddle.to_tensor(conv2d.weight.numpy().transpose(2, 3, 1, 0))
)
sp_conv2d.bias.set_value(paddle.to_tensor(conv2d.bias.numpy()))
x.stop_gradient = False
out = conv2d(x)
loss = out.mean()
loss.backward()
sp_x.stop_gradient = False
sp_out = sp_conv2d(sp_x)
dense_out = sp_out.to_dense()
sp_loss = dense_out.mean()
sp_loss.backward()
np.testing.assert_allclose(
out.numpy(), dense_out.numpy(), atol=1e-3, rtol=1e-3
)
np.testing.assert_allclose(
conv2d.weight.grad.numpy().transpose(2, 3, 1, 0),
sp_conv2d.weight.grad.numpy(),
atol=1e-3,
rtol=1e-3,
)
np.testing.assert_allclose(
conv2d.bias.grad.numpy(),
sp_conv2d.bias.grad.numpy(),
atol=1e-5,
rtol=1e-5,
)
def test_Conv3D_bias(self):
paddle.seed(0)
shape = [1, 4, 4, 4, 3]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册