From 0aecf7c70e52e99bf7decda820f18039b3f373e6 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 22 Jan 2019 10:46:48 +0800 Subject: [PATCH] add TestNumpyArrayInitializer --- .../fluid/tests/unittests/test_initializer.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index ab7183f88df..2e70175d439 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -420,5 +420,25 @@ class TestMSRAInitializer(unittest.TestCase): self.assertEqual(init_op.type, 'assign_value') +class TestNumpyArrayInitializer(unittest.TestCase): + def test_numpy_array_initializer(self): + """Test the numpy array initializer with supplied arguments + """ + import numpy + program = framework.Program() + block = program.global_block() + for _ in range(2): + np_array = numpy.array([1, 2, 3, 4]).astype('float32') + block.create_parameter( + dtype=np_array.dtype, + shape=np_array.shape, + lod_level=0, + name="param", + initializer=initializer.NumpyArrayInitializer(np_array)) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'assign_value') + + if __name__ == '__main__': unittest.main() -- GitLab