提交 237385cf 编写于 作者: X xuwei06

Correctly handle int values for assign_value_op

上级 7306aab6
......@@ -148,10 +148,15 @@ def assign(input, output):
dtype = convert_np_dtype_to_dtype_(input.dtype)
if dtype == DataType.FP32:
value_name = "fp32_values"
values = [float(v) for v in input.flat]
elif dtype == DataType.INT32:
value_name = "int32_values"
values = [int(v) for v in input.flat]
else:
raise ValueError("Unsupported dtype %s", input.dtype)
if input.size > 1024 * 1024:
raise ValueError("The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it")
helper.append_op(
type='assign_value',
......@@ -159,7 +164,7 @@ def assign(input, output):
attrs={
'dtype': dtype,
'shape': list(input.shape),
value_name: [float(v) for v in input.flat]
value_name: values
})
else:
raise ValueError("Wrong type for assign input: %s" % type(input))
......
......@@ -22,16 +22,18 @@ class TestAssignValueOp(op_test.OpTest):
self.check_output()
def test_assign(self):
val = numpy.random.random(size=(2, 5)).astype(numpy.float32)
val = (
-100 + 200 * numpy.random.random(size=(2, 5))).astype(numpy.int32)
x = layers.create_tensor(dtype="float32")
layers.assign(input=val, output=x)
exe = fluid.Executor(fluid.CPUPlace())
fetched_x = exe.run(fluid.default_main_program(),
feed={},
fetch_list=[x])
fetch_list=[x])[0]
self.assertTrue(
numpy.allclose(fetched_x, val),
numpy.array_equal(fetched_x, val),
"fetch_x=%s val=%s" % (fetched_x, val))
self.assertEqual(fetched_x.dtype, val.dtype)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册