未验证 提交 8032d57e 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse]sparse BatchNorm support 2D input (#53893)

上级 f55f9d79
...@@ -88,6 +88,39 @@ class TestSparseBatchNorm(unittest.TestCase): ...@@ -88,6 +88,39 @@ class TestSparseBatchNorm(unittest.TestCase):
assert np.allclose(dense_out.numpy(), batch_norm_out.values().numpy()) assert np.allclose(dense_out.numpy(), batch_norm_out.values().numpy())
# [1, 6, 6, 6, 3] # [1, 6, 6, 6, 3]
def check(self, shape):
np.random.seed(0)
data = np.random.uniform(-0.01, 0.01, shape).astype("float32")
x = paddle.to_tensor(data)
x.stop_gradient = False
dim = len(shape)
data_format = "NHWC" if dim == 4 else "NDHWC"
if dim == 4:
bn = paddle.nn.BatchNorm2D(shape[-1], data_format=data_format)
else:
bn = paddle.nn.BatchNorm3D(shape[-1], data_format=data_format)
y = bn(x)
y.backward()
sp_x = paddle.to_tensor(data).to_sparse_coo(dim - 1)
sp_x.stop_gradient = False
sp_bn = paddle.sparse.nn.BatchNorm(shape[-1], data_format=data_format)
sp_y = sp_bn(sp_x)
sp_y.backward()
np.testing.assert_allclose(
y.numpy(), sp_y.to_dense().numpy(), rtol=1e-5
)
np.testing.assert_allclose(
x.grad.numpy(), sp_x.grad.to_dense().numpy(), rtol=1e-5
)
def test_nd(self):
# 2D
self.check([2, 8, 8, 3])
# 3D
self.check([2, 8, 8, 3, 4])
class TestSyncBatchNorm(unittest.TestCase): class TestSyncBatchNorm(unittest.TestCase):
def test_sync_batch_norm(self): def test_sync_batch_norm(self):
......
...@@ -73,7 +73,7 @@ class BatchNorm(paddle.nn.BatchNorm1D): ...@@ -73,7 +73,7 @@ class BatchNorm(paddle.nn.BatchNorm1D):
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
Shape: Shape:
- x: A SparseCooTensor with layout = 'NDHWC'. - x: A SparseCooTensor with layout = 'NDHWC' or 'NHWC'.
- output: SparseCooTensor with same shape as input x. - output: SparseCooTensor with same shape as input x.
Returns: Returns:
...@@ -119,8 +119,10 @@ class BatchNorm(paddle.nn.BatchNorm1D): ...@@ -119,8 +119,10 @@ class BatchNorm(paddle.nn.BatchNorm1D):
) )
def _check_data_format(self, input): def _check_data_format(self, input):
if input != "NDHWC": if input not in ["NDHWC", "NHWC"]:
raise ValueError('sparse BatchNorm only support layout of "NDHWC"') raise ValueError(
'sparse BatchNorm only support layout of "NDHWC" and "NHWC"'
)
def forward(self, input): def forward(self, input):
self._check_data_format(self._data_format) self._check_data_format(self._data_format)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册