From 32c463797d0de1ec2f20aaa299aa29f9d4bc74df Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Sun, 27 Sep 2020 16:53:32 +0800 Subject: [PATCH] add unit test --- .../fluid/tests/unittests/test_initializer.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index 4c76af616f4..92135b113a0 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -17,6 +17,7 @@ from __future__ import print_function import numpy as np import unittest +import paddle import paddle.fluid as fluid import paddle.fluid.framework as framework import paddle.fluid.initializer as initializer @@ -31,6 +32,14 @@ def check_cast_op(op): op.attr('out_dtype') == VarDesc.VarType.FP16 +def output_hist(out): + hist, _ = np.histogram(out, range=(-1, 1)) + hist = hist.astype("float32") + hist /= float(out.size) + prob = 0.1 * np.ones((10)) + return hist, prob + + class TestConstantInitializer(unittest.TestCase): def test_constant_initializer_default_value(self, dtype="float32"): """Test the constant initializer with default value @@ -583,5 +592,31 @@ class TestSetGlobalInitializer(unittest.TestCase): fluid.set_global_initializer(None) +class TestUniformInitializerDygraph(unittest.TestCase): + def test_uniform_initializer(self, dtype="float32"): + """ + In dygraph mode, we can use initializer directly to initialize a tensor. + """ + paddle.disable_static() + + tensor = paddle.zeros([1024, 1024]) + tensor.stop_gradient = False + self.assertTrue(np.allclose(np.zeros((1024, 1024)), tensor.numpy())) + + uniform_ = paddle.nn.initializer.Uniform() + uniform_(tensor) + + self.assertEqual(tensor.stop_gradient, + False) # stop_gradient is not changed + + hist, prob = output_hist(tensor.numpy()) + + self.assertTrue( + np.allclose( + hist, prob, rtol=0, atol=1e-3), "hist: " + str(hist)) + + paddle.enable_static() + + if __name__ == '__main__': unittest.main() -- GitLab