From 237385cf414bf2e176a52c46e650e37a2cfc40a7 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Wed, 10 Jan 2018 11:39:03 -0800 Subject: [PATCH] Correctly handle int values for assign_value_op --- python/paddle/v2/fluid/layers/tensor.py | 7 ++++++- python/paddle/v2/fluid/tests/test_assign_value_op.py | 8 +++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/fluid/layers/tensor.py b/python/paddle/v2/fluid/layers/tensor.py index 639f8b03ede..57668a7983b 100644 --- a/python/paddle/v2/fluid/layers/tensor.py +++ b/python/paddle/v2/fluid/layers/tensor.py @@ -148,10 +148,15 @@ def assign(input, output): dtype = convert_np_dtype_to_dtype_(input.dtype) if dtype == DataType.FP32: value_name = "fp32_values" + values = [float(v) for v in input.flat] elif dtype == DataType.INT32: value_name = "int32_values" + values = [int(v) for v in input.flat] else: raise ValueError("Unsupported dtype %s", input.dtype) + if input.size > 1024 * 1024: + raise ValueError("The size of input is too big. Please consider " + "saving it to file and 'load_op' to load it") helper.append_op( type='assign_value', @@ -159,7 +164,7 @@ def assign(input, output): attrs={ 'dtype': dtype, 'shape': list(input.shape), - value_name: [float(v) for v in input.flat] + value_name: values }) else: raise ValueError("Wrong type for assign input: %s" % type(input)) diff --git a/python/paddle/v2/fluid/tests/test_assign_value_op.py b/python/paddle/v2/fluid/tests/test_assign_value_op.py index c3f3f87839a..51b99d09182 100644 --- a/python/paddle/v2/fluid/tests/test_assign_value_op.py +++ b/python/paddle/v2/fluid/tests/test_assign_value_op.py @@ -22,16 +22,18 @@ class TestAssignValueOp(op_test.OpTest): self.check_output() def test_assign(self): - val = numpy.random.random(size=(2, 5)).astype(numpy.float32) + val = ( + -100 + 200 * numpy.random.random(size=(2, 5))).astype(numpy.int32) x = layers.create_tensor(dtype="float32") layers.assign(input=val, output=x) exe = fluid.Executor(fluid.CPUPlace()) fetched_x = exe.run(fluid.default_main_program(), feed={}, - fetch_list=[x]) + fetch_list=[x])[0] self.assertTrue( - numpy.allclose(fetched_x, val), + numpy.array_equal(fetched_x, val), "fetch_x=%s val=%s" % (fetched_x, val)) + self.assertEqual(fetched_x.dtype, val.dtype) if __name__ == '__main__': -- GitLab