提交 32c46379 编写于 作者: Z zhiqiu

add unit test

上级 e11e41a6
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import paddle.fluid.initializer as initializer import paddle.fluid.initializer as initializer
...@@ -31,6 +32,14 @@ def check_cast_op(op): ...@@ -31,6 +32,14 @@ def check_cast_op(op):
op.attr('out_dtype') == VarDesc.VarType.FP16 op.attr('out_dtype') == VarDesc.VarType.FP16
def output_hist(out):
hist, _ = np.histogram(out, range=(-1, 1))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestConstantInitializer(unittest.TestCase): class TestConstantInitializer(unittest.TestCase):
def test_constant_initializer_default_value(self, dtype="float32"): def test_constant_initializer_default_value(self, dtype="float32"):
"""Test the constant initializer with default value """Test the constant initializer with default value
...@@ -583,5 +592,31 @@ class TestSetGlobalInitializer(unittest.TestCase): ...@@ -583,5 +592,31 @@ class TestSetGlobalInitializer(unittest.TestCase):
fluid.set_global_initializer(None) fluid.set_global_initializer(None)
class TestUniformInitializerDygraph(unittest.TestCase):
def test_uniform_initializer(self, dtype="float32"):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
paddle.disable_static()
tensor = paddle.zeros([1024, 1024])
tensor.stop_gradient = False
self.assertTrue(np.allclose(np.zeros((1024, 1024)), tensor.numpy()))
uniform_ = paddle.nn.initializer.Uniform()
uniform_(tensor)
self.assertEqual(tensor.stop_gradient,
False) # stop_gradient is not changed
hist, prob = output_hist(tensor.numpy())
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=1e-3), "hist: " + str(hist))
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册