提交 99d533d0 编写于 作者: Q Qiao Longfei

update TestNumpyArrayInitializer test=develop

上级 0aecf7c7
......@@ -734,7 +734,7 @@ class NumpyArrayInitializer(Initializer):
outputs={'Out': var},
attrs={
'dtype': dtype,
'shape': list(input.shape),
'shape': list(self._value.shape),
value_name: values
},
stop_gradient=True)
......
......@@ -427,8 +427,8 @@ class TestNumpyArrayInitializer(unittest.TestCase):
import numpy
program = framework.Program()
block = program.global_block()
np_array = numpy.random.random((10000)).astype("float32")
for _ in range(2):
np_array = numpy.array([1, 2, 3, 4]).astype('float32')
block.create_parameter(
dtype=np_array.dtype,
shape=np_array.shape,
......@@ -438,6 +438,7 @@ class TestNumpyArrayInitializer(unittest.TestCase):
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'assign_value')
assert (init_op.attr('fp32_values') == np_array).all()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册