提交 0aecf7c7 编写于 作者: Q Qiao Longfei

add TestNumpyArrayInitializer

上级 a1326cf3
...@@ -420,5 +420,25 @@ class TestMSRAInitializer(unittest.TestCase): ...@@ -420,5 +420,25 @@ class TestMSRAInitializer(unittest.TestCase):
self.assertEqual(init_op.type, 'assign_value') self.assertEqual(init_op.type, 'assign_value')
class TestNumpyArrayInitializer(unittest.TestCase):
def test_numpy_array_initializer(self):
"""Test the numpy array initializer with supplied arguments
"""
import numpy
program = framework.Program()
block = program.global_block()
for _ in range(2):
np_array = numpy.array([1, 2, 3, 4]).astype('float32')
block.create_parameter(
dtype=np_array.dtype,
shape=np_array.shape,
lod_level=0,
name="param",
initializer=initializer.NumpyArrayInitializer(np_array))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'assign_value')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册