未验证 提交 3e8708bc 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fix test and doc (#44735)

* fix test and doc
上级 cd94be61
...@@ -18,7 +18,6 @@ import numpy as np ...@@ -18,7 +18,6 @@ import numpy as np
import paddle import paddle
from paddle.incubate.sparse import nn from paddle.incubate.sparse import nn
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import _test_eager_guard
import copy import copy
...@@ -26,79 +25,83 @@ class TestSparseBatchNorm(unittest.TestCase): ...@@ -26,79 +25,83 @@ class TestSparseBatchNorm(unittest.TestCase):
def test(self): def test(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
with _test_eager_guard(): paddle.seed(0)
paddle.seed(0) channels = 4
channels = 4 shape = [2, 3, 6, 6, channels]
shape = [2, 3, 6, 6, channels] #there is no zero in dense_x
#there is no zero in dense_x dense_x = paddle.randn(shape)
dense_x = paddle.randn(shape) dense_x.stop_gradient = False
dense_x.stop_gradient = False
batch_norm = paddle.nn.BatchNorm3D(channels, data_format="NDHWC")
batch_norm = paddle.nn.BatchNorm3D(channels, data_format="NDHWC") dense_y = batch_norm(dense_x)
dense_y = batch_norm(dense_x) dense_y.backward(dense_y)
dense_y.backward(dense_y)
sparse_dim = 4
sparse_dim = 4 dense_x2 = copy.deepcopy(dense_x)
dense_x2 = copy.deepcopy(dense_x) dense_x2.stop_gradient = False
dense_x2.stop_gradient = False sparse_x = dense_x2.to_sparse_coo(sparse_dim)
sparse_x = dense_x2.to_sparse_coo(sparse_dim) sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels) # set same params
# set same params sparse_batch_norm._mean.set_value(batch_norm._mean)
sparse_batch_norm._mean.set_value(batch_norm._mean) sparse_batch_norm._variance.set_value(batch_norm._variance)
sparse_batch_norm._variance.set_value(batch_norm._variance) sparse_batch_norm.weight.set_value(batch_norm.weight)
sparse_batch_norm.weight.set_value(batch_norm.weight)
sparse_y = sparse_batch_norm(sparse_x)
sparse_y = sparse_batch_norm(sparse_x) # compare the result with dense batch_norm
# compare the result with dense batch_norm assert np.allclose(dense_y.flatten().numpy(),
assert np.allclose(dense_y.flatten().numpy(), sparse_y.values().flatten().numpy(),
sparse_y.values().flatten().numpy(), atol=1e-5,
atol=1e-5, rtol=1e-5)
rtol=1e-5)
# test backward
# test backward sparse_y.backward(sparse_y)
sparse_y.backward(sparse_y) assert np.allclose(dense_x.grad.flatten().numpy(),
assert np.allclose(dense_x.grad.flatten().numpy(), sparse_x.grad.values().flatten().numpy(),
sparse_x.grad.values().flatten().numpy(), atol=1e-5,
atol=1e-5, rtol=1e-5)
rtol=1e-5)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_error_layout(self): def test_error_layout(self):
with _test_eager_guard(): with self.assertRaises(ValueError):
with self.assertRaises(ValueError): shape = [2, 3, 6, 6, 3]
shape = [2, 3, 6, 6, 3] x = paddle.randn(shape)
x = paddle.randn(shape) sparse_x = x.to_sparse_coo(4)
sparse_x = x.to_sparse_coo(4) sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm(
sparse_batch_norm = paddle.incubate.sparse.nn.BatchNorm( 3, data_format='NCDHW')
3, data_format='NCDHW') sparse_batch_norm(sparse_x)
sparse_batch_norm(sparse_x)
def test2(self): def test2(self):
with _test_eager_guard(): paddle.seed(123)
paddle.seed(123) channels = 3
channels = 3 x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32')
x_data = paddle.randn((1, 6, 6, 6, channels)).astype('float32') dense_x = paddle.to_tensor(x_data)
dense_x = paddle.to_tensor(x_data) sparse_x = dense_x.to_sparse_coo(4)
sparse_x = dense_x.to_sparse_coo(4) batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels)
batch_norm = paddle.incubate.sparse.nn.BatchNorm(channels) batch_norm_out = batch_norm(sparse_x)
batch_norm_out = batch_norm(sparse_x) dense_bn = paddle.nn.BatchNorm1D(channels)
print(batch_norm_out.shape) dense_x = dense_x.reshape((-1, dense_x.shape[-1]))
# [1, 6, 6, 6, 3] dense_out = dense_bn(dense_x)
assert np.allclose(dense_out.numpy(), batch_norm_out.values().numpy())
# [1, 6, 6, 6, 3]
class TestSyncBatchNorm(unittest.TestCase): class TestSyncBatchNorm(unittest.TestCase):
def test_sync_batch_norm(self): def test_sync_batch_norm(self):
with _test_eager_guard(): x = np.array([[[[0.3, 0.4], [0.3, 0.07]],
x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
[[0.83, 0.37], [0.18, 0.93]]]]).astype('float32') x = paddle.to_tensor(x)
x = paddle.to_tensor(x) sparse_x = x.to_sparse_coo(len(x.shape) - 1)
x = x.to_sparse_coo(len(x.shape) - 1)
if paddle.is_compiled_with_cuda():
if paddle.is_compiled_with_cuda(): sparse_sync_bn = nn.SyncBatchNorm(2)
sync_batch_norm = nn.SyncBatchNorm(2) sparse_hidden = sparse_sync_bn(sparse_x)
hidden1 = sync_batch_norm(x)
print(hidden1) dense_sync_bn = paddle.nn.SyncBatchNorm(2)
x = x.reshape((-1, x.shape[-1]))
dense_hidden = dense_sync_bn(x)
assert np.allclose(sparse_hidden.values().numpy(),
dense_hidden.numpy())
def test_convert(self): def test_convert(self):
base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5), base_model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5),
......
...@@ -229,6 +229,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): ...@@ -229,6 +229,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):
Shapes: Shapes:
input: Tensor that the dimension from 2 to 5. input: Tensor that the dimension from 2 to 5.
output: Tensor with the same shape as input. output: Tensor with the same shape as input.
Examples: Examples:
...@@ -278,7 +279,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): ...@@ -278,7 +279,7 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):
@classmethod @classmethod
def convert_sync_batchnorm(cls, layer): def convert_sync_batchnorm(cls, layer):
""" r"""
Helper function to convert :class: `paddle.incubate.sparse.nn.BatchNorm` layers in the model to :class: `paddle.incubate.sparse.nn.SyncBatchNorm` layers. Helper function to convert :class: `paddle.incubate.sparse.nn.BatchNorm` layers in the model to :class: `paddle.incubate.sparse.nn.SyncBatchNorm` layers.
Parameters: Parameters:
...@@ -290,13 +291,14 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm): ...@@ -290,13 +291,14 @@ class SyncBatchNorm(paddle.nn.SyncBatchNorm):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.incubate.sparse.nn as nn import paddle.incubate.sparse.nn as nn
model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5)) model = paddle.nn.Sequential(nn.Conv3D(3, 5, 3), nn.BatchNorm(5))
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model) sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
""" """
layer_output = layer layer_output = layer
if isinstance(layer, _BatchNormBase): if isinstance(layer, _BatchNormBase):
if layer._weight_attr != None and not isinstance( if layer._weight_attr != None and not isinstance(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册