未验证 提交 8a6456db 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add Sparse BatchNorm and fix two bugs (#42013)

上级 281a5be7
......@@ -44,7 +44,7 @@ void CoalescedCPUKernel(const CPUContext& dev_ctx,
const T* x_values_ptr = x_values.data<T>();
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];
std::map<IntT, std::vector<int64_t>> indices_to_index;
for (uint64_t i = 0; i < x_indexs.size(); i++) {
......
......@@ -125,7 +125,7 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
T* out_ptr = out->data<T>();
memset(out_ptr, static_cast<T>(0), out->numel() * sizeof(T));
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];
const T* in_ptr = x.non_zero_elements().data<T>();
// TODO(zhangkaihuo): multithreading can be used for acceleration
for (uint64_t i = 0; i < mask_indexs.size(); i++) {
......
......@@ -76,7 +76,7 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
// 2. get the address of each non-zero values
const T* x_values_ptr = x_values.data<T>();
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];
DenseTensor values_indexs = phi::Empty(
dev_ctx, DenseTensorMeta(DataType::INT32, {nnz}, DataLayout::NCHW));
int* values_indexs_ptr = values_indexs.data<int>();
......
......@@ -231,7 +231,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
T* out_ptr = out->data<T>();
const int64_t stride =
x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim;
x.dims().size() == sparse_dim ? 1 : x.non_zero_elements().dims()[1];
SparseMaskCopyKernel<<<config.block_per_grid,
config.thread_per_block,
......
......@@ -31,19 +31,21 @@ class TestSparseConv(unittest.TestCase):
paddings = [0, 0, 0]
strides = [1, 1, 1]
dilations = [1, 1, 1]
bias = [1]
indices = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 2], [1, 3, 2, 3]]
values = [1, 2, 3, 4]
indices = paddle.to_tensor(indices, dtype='int32')
values = paddle.to_tensor(values, dtype='float32')
dense_shape = [1, 1, 3, 4, 1]
correct_out_values = [[4], [10]]
correct_out_values = [[5], [11]]
sparse_input = core.eager.sparse_coo_tensor(indices, values,
dense_shape, False)
out = paddle.sparse.functional.conv3d(
sparse_input,
dense_kernel,
bias=None,
bias=paddle.to_tensor(
bias, dtype='float32'),
stride=strides,
padding=paddings,
dilation=dilations,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.fluid.framework import _test_eager_guard
import copy
class TestSparseBatchNorm(unittest.TestCase):
def test(self):
with _test_eager_guard():
paddle.seed(0)
channels = 4
shape = [2, 3, 6, 6, channels]
#there is no zero in dense_x
dense_x = paddle.randn(shape)
dense_x.stop_gradient = False
batch_norm = paddle.nn.BatchNorm3D(channels, data_format="NDHWC")
dense_y = batch_norm(dense_x)
dense_y.backward(dense_y)
sparse_dim = 4
dense_x2 = copy.deepcopy(dense_x)
dense_x2.stop_gradient = False
sparse_x = dense_x2.to_sparse_coo(sparse_dim)
sparse_batch_norm = paddle.sparse.BatchNorm(channels)
# set same params
sparse_batch_norm._mean.set_value(batch_norm._mean)
sparse_batch_norm._variance.set_value(batch_norm._variance)
sparse_batch_norm.weight.set_value(batch_norm.weight)
sparse_y = sparse_batch_norm(sparse_x)
# compare the result with dense batch_norm
assert np.allclose(
dense_y.flatten().numpy(),
sparse_y.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
# test backward
sparse_y.backward(sparse_y)
assert np.allclose(
dense_x.grad.flatten().numpy(),
sparse_x.grad.values().flatten().numpy(),
atol=1e-5,
rtol=1e-5)
def test_error_layout(self):
with _test_eager_guard():
with self.assertRaises(ValueError):
shape = [2, 3, 6, 6, 3]
x = paddle.randn(shape)
sparse_x = x.to_sparse_coo(4)
sparse_batch_norm = paddle.sparse.BatchNorm(
3, data_format='NCDHW')
sparse_batch_norm(sparse_x)
def test2(self):
with _test_eager_guard():
paddle.seed(123)
channels = 3
x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32')
dense_x = paddle.to_tensor(x_data)
sparse_x = dense_x.to_sparse_coo(4)
batch_norm = paddle.sparse.BatchNorm(channels)
batch_norm_out = batch_norm(sparse_x)
print(batch_norm_out.shape)
# [1, 6, 6, 6, 3]
if __name__ == "__main__":
unittest.main()
......@@ -208,6 +208,20 @@ class TestSparseConvert(unittest.TestCase):
# test coo_values_grad
values_tensor.backward(paddle.to_tensor(out_grad))
assert np.array_equal(out_grad, sparse_x.grad.values().numpy())
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0],
[5.0, 5.0]]
sparse_x = paddle.sparse.sparse_coo_tensor(
paddle.to_tensor(indices),
paddle.to_tensor(values),
shape=[3, 4, 2],
stop_gradient=False)
values_tensor = sparse_x.values()
out_grad = [[2.0, 2.0], [3.0, 3.0], [5.0, 5.0], [8.0, 8.0],
[9.0, 9.0]]
# test coo_values_grad
values_tensor.backward(paddle.to_tensor(out_grad))
assert np.array_equal(out_grad, sparse_x.grad.values().numpy())
def test_sparse_coo_tensor_grad(self):
with _test_eager_guard():
......@@ -233,6 +247,21 @@ class TestSparseConvert(unittest.TestCase):
assert np.array_equal(correct_values_grad,
values.grad.numpy())
# test the non-zero values is a vector
values = [[1, 1], [2, 2]]
values = paddle.to_tensor(
values, dtype='float32', stop_gradient=False)
sparse_x = paddle.sparse.sparse_coo_tensor(
indices, values, shape=[2, 2, 2], stop_gradient=False)
grad_values = [[2, 2], [3, 3]]
grad_values = paddle.to_tensor(grad_values, dtype='float32')
sparse_out_grad = paddle.sparse.sparse_coo_tensor(
grad_indices, grad_values, shape=[2, 2, 2])
sparse_x.backward(sparse_out_grad)
correct_values_grad = [[0, 0], [3, 3]]
assert np.array_equal(correct_values_grad,
values.grad.numpy())
def test_sparse_coo_tensor_sorted(self):
with _test_eager_guard():
for device in devices:
......@@ -252,6 +281,16 @@ class TestSparseConvert(unittest.TestCase):
assert np.array_equal(values_sorted,
sparse_x.values().numpy())
# test the non-zero values is a vector
values = [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]
values = paddle.to_tensor(values, dtype='float32')
sparse_x = paddle.sparse.sparse_coo_tensor(indices, values)
values_sorted = [[5.0, 5.0], [1.0, 1.0]]
assert np.array_equal(indices_sorted,
sparse_x.indices().numpy())
assert np.array_equal(values_sorted,
sparse_x.values().numpy())
class TestCooError(unittest.TestCase):
def test_small_shape(self):
......
......@@ -15,9 +15,12 @@
from .creation import sparse_coo_tensor
from .creation import sparse_csr_tensor
from .layer.activation import ReLU
from .layer.norm import BatchNorm
from .layer.conv import Conv3D
from .layer.conv import SubmConv3D
__all__ = [
'sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU', 'Conv3D', 'SubmConv3D'
'sparse_coo_tensor', 'sparse_csr_tensor', 'ReLU', 'Conv3D', 'SubmConv3D',
'BatchNorm'
]
......@@ -20,6 +20,8 @@ from ..tensor import to_tensor
from ..tensor import max
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
import numpy as np
__all__ = [
'sparse_coo_tensor',
'sparse_csr_tensor',
......@@ -33,11 +35,14 @@ def _handle_dtype(data, dtype):
return data
def _infer_dense_shape(indices):
def _infer_dense_shape(indices, values):
assert len(indices.shape) == 2
lens = max(indices, axis=1)
lens = lens + 1
return list(lens.numpy())
lens = lens.numpy()
if len(values.shape) > 1:
lens = np.append(lens, values.shape[1:])
return list(lens)
def _get_place(place):
......@@ -106,7 +111,7 @@ def sparse_coo_tensor(indices,
with _test_eager_guard():
indices = [[0, 1, 2], [1, 2, 0]]
values = [1.0, 2.0, 3.0]
dense_shape = [2, 3]
dense_shape = [3, 3]
coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
# print(coo)
# Tensor(shape=[2, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
......@@ -145,7 +150,8 @@ def sparse_coo_tensor(indices,
values = _handle_dtype(values, dtype)
values.stop_gradient = stop_gradient
min_shape = _infer_dense_shape(indices)
min_shape = _infer_dense_shape(indices, values)
if shape is None:
shape = min_shape
else:
......
......@@ -16,6 +16,8 @@ __all__ = []
from paddle import _C_ops, in_dynamic_mode
from ...fluid.layers.utils import convert_to_list
from ...fluid.layers.nn import elementwise_add
from .. import sparse_coo_tensor
from paddle.nn.functional.conv import _update_padding_nd
......@@ -30,7 +32,6 @@ def _conv3d(x,
data_format="NDHWC",
name=None):
assert in_dynamic_mode(), "Currently, only support dynamic mode"
assert bias == None, "Currently, sparse_conv3d does not support bias"
assert groups == 1, "Currently, only support groups=1"
dims = 3
......@@ -61,8 +62,18 @@ def _conv3d(x,
dilation = convert_to_list(dilation, dims, 'dilation')
op_type = "conv3d"
return _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation,
stride, groups, subm)
pre_bias = _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation,
stride, groups, subm)
if bias is not None:
values = pre_bias.values()
add_bias = elementwise_add(values, bias, axis=1)
return sparse_coo_tensor(
pre_bias.indices(),
add_bias,
shape=pre_bias.shape,
stop_gradient=pre_bias.stop_gradient)
else:
return pre_bias
def conv3d(x,
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from .activation import ReLU
from .norm import BatchNorm
from .conv import Conv3D
from .conv import SubmConv3D
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import warnings
class BatchNorm(paddle.nn.BatchNorm1D):
r"""
Applies Batch Normalization over a SparseCooTensor as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
When use_global_stats = False, the :math:`\mu_{\beta}`
and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch.
Calculated as follows:
.. math::
\mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &//\
\ mini-batch\ mean \\
\sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \
\mu_{\beta})^2 \qquad &//\ mini-batch\ variance \\
When use_global_stats = True, the :math:`\mu_{\beta}`
and :math:`\sigma_{\beta}^{2}` are not the statistics of one mini-batch.
They are global or running statistics (moving_mean and moving_variance). It usually got from the
pre-trained model. Calculated as follows:
.. math::
moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global \ mean \\
moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global \ variance \\
The normalization function formula is as follows:
.. math::
\hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
- :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
- :math:`\gamma` : trainable proportional parameter
- :math:`\beta` : trainable deviation parameter
Parameters:
num_features(int): Indicate the number of channels of the input ``Tensor``.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale`
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as weight_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the weight_attr is not set, the parameter is initialized with Xavier. Default: None.
bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of batch_norm.
If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC". Defalut "NCL".
use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
Shape:
- x: A SparseCooTensor with layout = 'NDHWC'.
- output: SparseCooTensor with same shape as input x.
Returns:
None.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
paddle.seed(123)
channels = 3
x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32')
dense_x = paddle.to_tensor(x_data)
sparse_x = dense_x.to_sparse_coo(4)
batch_norm = paddle.sparse.BatchNorm(channels)
batch_norm_out = batch_norm(sparse_x)
print(batch_norm_out.shape)
# [1, 6, 6, 6, 3]
"""
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NDHWC',
use_global_stats=None,
name=None):
super(BatchNorm, self).__init__(
num_features,
momentum=momentum,
epsilon=epsilon,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format,
use_global_stats=use_global_stats,
name=name)
def _check_data_format(self, input):
if input != "NDHWC":
raise ValueError('sparse BatchNorm only support layout of "NDHWC"')
def forward(self, input):
values = input.values()
self._check_data_format(self._data_format)
if len(values.shape) != 2:
raise ValueError('expected 2D input.values() (got {}D)'.format(
len(values.shape)))
if self.training:
warnings.warn(
"When training, we now always track global mean and variance.")
batch_norm_out = paddle.nn.functional.batch_norm(
values,
self._mean,
self._variance,
weight=self.weight,
bias=self.bias,
training=self.training,
momentum=self._momentum,
epsilon=self._epsilon,
data_format='NC',
use_global_stats=self._use_global_stats)
return paddle.sparse.sparse_coo_tensor(
input.indices(),
batch_norm_out,
shape=input.shape,
stop_gradient=input.stop_gradient)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册