提交 237385cf 编写于 作者: X xuwei06

Correctly handle int values for assign_value_op

上级 7306aab6
...@@ -148,10 +148,15 @@ def assign(input, output): ...@@ -148,10 +148,15 @@ def assign(input, output):
dtype = convert_np_dtype_to_dtype_(input.dtype) dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == DataType.FP32: if dtype == DataType.FP32:
value_name = "fp32_values" value_name = "fp32_values"
values = [float(v) for v in input.flat]
elif dtype == DataType.INT32: elif dtype == DataType.INT32:
value_name = "int32_values" value_name = "int32_values"
values = [int(v) for v in input.flat]
else: else:
raise ValueError("Unsupported dtype %s", input.dtype) raise ValueError("Unsupported dtype %s", input.dtype)
if input.size > 1024 * 1024:
raise ValueError("The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it")
helper.append_op( helper.append_op(
type='assign_value', type='assign_value',
...@@ -159,7 +164,7 @@ def assign(input, output): ...@@ -159,7 +164,7 @@ def assign(input, output):
attrs={ attrs={
'dtype': dtype, 'dtype': dtype,
'shape': list(input.shape), 'shape': list(input.shape),
value_name: [float(v) for v in input.flat] value_name: values
}) })
else: else:
raise ValueError("Wrong type for assign input: %s" % type(input)) raise ValueError("Wrong type for assign input: %s" % type(input))
......
...@@ -22,16 +22,18 @@ class TestAssignValueOp(op_test.OpTest): ...@@ -22,16 +22,18 @@ class TestAssignValueOp(op_test.OpTest):
self.check_output() self.check_output()
def test_assign(self): def test_assign(self):
val = numpy.random.random(size=(2, 5)).astype(numpy.float32) val = (
-100 + 200 * numpy.random.random(size=(2, 5))).astype(numpy.int32)
x = layers.create_tensor(dtype="float32") x = layers.create_tensor(dtype="float32")
layers.assign(input=val, output=x) layers.assign(input=val, output=x)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
fetched_x = exe.run(fluid.default_main_program(), fetched_x = exe.run(fluid.default_main_program(),
feed={}, feed={},
fetch_list=[x]) fetch_list=[x])[0]
self.assertTrue( self.assertTrue(
numpy.allclose(fetched_x, val), numpy.array_equal(fetched_x, val),
"fetch_x=%s val=%s" % (fetched_x, val)) "fetch_x=%s val=%s" % (fetched_x, val))
self.assertEqual(fetched_x.dtype, val.dtype)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册