From 8032d57e17b54df2f56e4718ab36d88726f8539a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 24 May 2023 18:38:01 +0800 Subject: [PATCH] [Sparse]sparse BatchNorm support 2D input (#53893) --- .../tests/unittests/test_sparse_norm_op.py | 33 +++++++++++++++++++ python/paddle/sparse/nn/layer/norm.py | 8 +++-- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py b/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py index b86a915f239..c17a252ee75 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py @@ -88,6 +88,39 @@ class TestSparseBatchNorm(unittest.TestCase): assert np.allclose(dense_out.numpy(), batch_norm_out.values().numpy()) # [1, 6, 6, 6, 3] + def check(self, shape): + np.random.seed(0) + data = np.random.uniform(-0.01, 0.01, shape).astype("float32") + x = paddle.to_tensor(data) + x.stop_gradient = False + dim = len(shape) + data_format = "NHWC" if dim == 4 else "NDHWC" + if dim == 4: + bn = paddle.nn.BatchNorm2D(shape[-1], data_format=data_format) + else: + bn = paddle.nn.BatchNorm3D(shape[-1], data_format=data_format) + y = bn(x) + y.backward() + + sp_x = paddle.to_tensor(data).to_sparse_coo(dim - 1) + sp_x.stop_gradient = False + sp_bn = paddle.sparse.nn.BatchNorm(shape[-1], data_format=data_format) + sp_y = sp_bn(sp_x) + sp_y.backward() + + np.testing.assert_allclose( + y.numpy(), sp_y.to_dense().numpy(), rtol=1e-5 + ) + np.testing.assert_allclose( + x.grad.numpy(), sp_x.grad.to_dense().numpy(), rtol=1e-5 + ) + + def test_nd(self): + # 2D + self.check([2, 8, 8, 3]) + # 3D + self.check([2, 8, 8, 3, 4]) + class TestSyncBatchNorm(unittest.TestCase): def test_sync_batch_norm(self): diff --git a/python/paddle/sparse/nn/layer/norm.py b/python/paddle/sparse/nn/layer/norm.py index 30e69560428..dc8c2713f45 100644 --- a/python/paddle/sparse/nn/layer/norm.py +++ b/python/paddle/sparse/nn/layer/norm.py @@ -73,7 +73,7 @@ class BatchNorm(paddle.nn.BatchNorm1D): name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: A SparseCooTensor with layout = 'NDHWC'. + - x: A SparseCooTensor with layout = 'NDHWC' or 'NHWC'. - output: SparseCooTensor with same shape as input x. Returns: @@ -119,8 +119,10 @@ class BatchNorm(paddle.nn.BatchNorm1D): ) def _check_data_format(self, input): - if input != "NDHWC": - raise ValueError('sparse BatchNorm only support layout of "NDHWC"') + if input not in ["NDHWC", "NHWC"]: + raise ValueError( + 'sparse BatchNorm only support layout of "NDHWC" and "NHWC"' + ) def forward(self, input): self._check_data_format(self._data_format) -- GitLab