提交 32c46379 编写于 作者: Z zhiqiu

add unit test

上级 e11e41a6
......@@ -17,6 +17,7 @@ from __future__ import print_function
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.initializer as initializer
......@@ -31,6 +32,14 @@ def check_cast_op(op):
op.attr('out_dtype') == VarDesc.VarType.FP16
def output_hist(out):
hist, _ = np.histogram(out, range=(-1, 1))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestConstantInitializer(unittest.TestCase):
def test_constant_initializer_default_value(self, dtype="float32"):
"""Test the constant initializer with default value
......@@ -583,5 +592,31 @@ class TestSetGlobalInitializer(unittest.TestCase):
fluid.set_global_initializer(None)
class TestUniformInitializerDygraph(unittest.TestCase):
def test_uniform_initializer(self, dtype="float32"):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
paddle.disable_static()
tensor = paddle.zeros([1024, 1024])
tensor.stop_gradient = False
self.assertTrue(np.allclose(np.zeros((1024, 1024)), tensor.numpy()))
uniform_ = paddle.nn.initializer.Uniform()
uniform_(tensor)
self.assertEqual(tensor.stop_gradient,
False) # stop_gradient is not changed
hist, prob = output_hist(tensor.numpy())
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=1e-3), "hist: " + str(hist))
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册