diff --git a/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py b/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py index b86a915f239a4c61a6936a823c084fa90120308b..c17a252ee75a6d7713f6ee0085c6267d0d032fb6 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_norm_op.py @@ -88,6 +88,39 @@ class TestSparseBatchNorm(unittest.TestCase): assert np.allclose(dense_out.numpy(), batch_norm_out.values().numpy()) # [1, 6, 6, 6, 3] + def check(self, shape): + np.random.seed(0) + data = np.random.uniform(-0.01, 0.01, shape).astype("float32") + x = paddle.to_tensor(data) + x.stop_gradient = False + dim = len(shape) + data_format = "NHWC" if dim == 4 else "NDHWC" + if dim == 4: + bn = paddle.nn.BatchNorm2D(shape[-1], data_format=data_format) + else: + bn = paddle.nn.BatchNorm3D(shape[-1], data_format=data_format) + y = bn(x) + y.backward() + + sp_x = paddle.to_tensor(data).to_sparse_coo(dim - 1) + sp_x.stop_gradient = False + sp_bn = paddle.sparse.nn.BatchNorm(shape[-1], data_format=data_format) + sp_y = sp_bn(sp_x) + sp_y.backward() + + np.testing.assert_allclose( + y.numpy(), sp_y.to_dense().numpy(), rtol=1e-5 + ) + np.testing.assert_allclose( + x.grad.numpy(), sp_x.grad.to_dense().numpy(), rtol=1e-5 + ) + + def test_nd(self): + # 2D + self.check([2, 8, 8, 3]) + # 3D + self.check([2, 8, 8, 3, 4]) + class TestSyncBatchNorm(unittest.TestCase): def test_sync_batch_norm(self): diff --git a/python/paddle/sparse/nn/layer/norm.py b/python/paddle/sparse/nn/layer/norm.py index 30e69560428306e68ca1c91a7987ed23516b0452..dc8c2713f45a4ca6d9a1b99d09d252d87ed6437b 100644 --- a/python/paddle/sparse/nn/layer/norm.py +++ b/python/paddle/sparse/nn/layer/norm.py @@ -73,7 +73,7 @@ class BatchNorm(paddle.nn.BatchNorm1D): name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: - - x: A SparseCooTensor with layout = 'NDHWC'. + - x: A SparseCooTensor with layout = 'NDHWC' or 'NHWC'. - output: SparseCooTensor with same shape as input x. Returns: @@ -119,8 +119,10 @@ class BatchNorm(paddle.nn.BatchNorm1D): ) def _check_data_format(self, input): - if input != "NDHWC": - raise ValueError('sparse BatchNorm only support layout of "NDHWC"') + if input not in ["NDHWC", "NHWC"]: + raise ValueError( + 'sparse BatchNorm only support layout of "NDHWC" and "NHWC"' + ) def forward(self, input): self._check_data_format(self._data_format)