diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 76c351835893b1f2f80341e2c13b104dbdf737ea..7d972cbbd09b95e5d7476837cb3f3318526deed8 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -23,6 +23,7 @@ from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar from .tracer import Tracer import logging import objgraph +from ..data_feeder import convert_dtype __all__ = [ 'no_grad', @@ -539,28 +540,34 @@ def grad(outputs, @framework.dygraph_only -def to_variable(value, name=None, zero_copy=None): +def to_variable(value, name=None, zero_copy=None, dtype=None): """ :api_attr: imperative The API will create a ``Variable`` or ``ComplexVariable`` object from - numpy\.ndarray, Variable or ComplexVariable object. + tuple, list, numpy\.ndarray, Variable or ComplexVariable object. Parameters: - value(ndarray|Variable|Tensor|ComplexVariable): The numpy\.ndarray, Variable - Tensor or ComplexVariable object that needs to be converted, it can be - multi-dimension, and the data type is one of numpy\.{float16, - float32, float64, int16, int32, int64, uint8, uint16, complex64, - complex128}. + value(tuple|list|ndarray|Variable|Tensor|ComplexVariable): Initial data. + Can be a list, tuple, NumPy ndarray, Variable, Tensor, ComplexVariable. + The shape can be multi-dimensional. The data type is one of + numpy\.{float16, float32, float64, int16, int32, int64, + uint8, uint16, complex64, complex128}. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . zero_copy(bool, optional): Whether to share memory with the input numpy array. This parameter only works with CPUPlace and will be set to True when it is None. Default: None. + dtype(str, optional): The desired data type of returned ``Variable`` . + Can be 'bool' , 'float16' , 'float32' , 'float64' , 'int8' , 'int16' , + 'int32' , 'int64' , 'uint8' . Default: None. Returns: - Variable or ComplexVariable: If ``value`` is a numpy\.ndarray object, return ``Tensor`` created from the specified numpy\.ndarray object, which has same data type and shape with ``value``. If ``value`` is a Variable or ComplexVariable object, just return ``value``. + Variable or ComplexVariable: If ``value`` is a tuple/list/numpy\.ndarray object, + return ``Tensor`` created from the corresponding numpy\.ndarray object, which has + same data type and shape with ``value``. If ``value`` is a Variable or ComplexVariable + object, just return ``value``. Examples: @@ -582,17 +589,41 @@ def to_variable(value, name=None, zero_copy=None): z = fluid.dygraph.to_variable(c) z.numpy() # array([2.+1.j, 2.+0.j]) z.dtype # 'complex128' + + y = fluid.dygraph.to_variable([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]) + y.shape # [3L, 2L] + + y = fluid.dygraph.to_variable(((0.1, 1.2), (2.2, 3.1), (4.9, 5.2)), dtype='int32') + y.shape # [3L, 2L] + """ - if isinstance(value, np.ndarray): - assert framework.in_dygraph_mode( - ), "to_variable could only be called in dygraph mode" + support_type = (list, tuple, np.ndarray, core.VarBase, framework.Variable, + framework.ComplexVariable, core.Tensor, core.LoDTensor) + if not isinstance(value, support_type): + raise TypeError( + "The type of 'value' in fluid.dygraph.to_variable must be %s, but received %s." + % (support_type, type(value))) + if isinstance(value, (core.VarBase, framework.Variable, + framework.ComplexVariable)): + return value + elif isinstance(value, (core.Tensor, core.LoDTensor)): + return core.VarBase(value) + else: if isinstance(framework._current_expected_place(), framework.core.CPUPlace): if zero_copy is None: zero_copy = True else: assert not zero_copy, "zero_copy mode can only be used with CPUPlace" - zero_copy = False + + if not isinstance(value, np.ndarray): + value = np.array(value) + + if dtype is not None: + dtype = convert_dtype(dtype) + if value.dtype != dtype: + value = value.astype(dtype) + if np.iscomplexobj(value): if not name: name = framework.unique_name.generate('_generated_var') @@ -617,12 +648,3 @@ def to_variable(value, name=None, zero_copy=None): zero_copy=zero_copy, name=name if name else '') return py_var - elif isinstance(value, (core.VarBase, framework.Variable, - framework.ComplexVariable)): - return value - elif isinstance(value, (core.Tensor, core.LoDTensor)): - return core.VarBase(value) - else: - raise TypeError( - "The type of input value is invalid, expected type is 'ndarray', " - "'Variable' or 'ComplexVariable', but received %s." % type(value)) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index fcf8bc46f592395275b2a82861c3ba7929a236f7..ea81fcb17c2c967f814a8b67c8c0efef2ae2e9bf 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -47,6 +47,24 @@ class TestVarBase(unittest.TestCase): linear = fluid.dygraph.Linear(32, 64) var = linear._helper.to_variable("test", name="abc") + def test_list_to_variable(self): + with fluid.dygraph.guard(): + array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]] + var = fluid.dygraph.to_variable(array, dtype='int32') + self.assertTrue(np.array_equal(var.numpy(), array)) + self.assertEqual(var.shape, [2, 3, 2]) + self.assertEqual(var.dtype, core.VarDesc.VarType.INT32) + self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR) + + def test_tuple_to_variable(self): + with fluid.dygraph.guard(): + array = (((1, 2), (1, 2), (1, 2)), ((1, 2), (1, 2), (1, 2))) + var = fluid.dygraph.to_variable(array, dtype='float32') + self.assertTrue(np.array_equal(var.numpy(), array)) + self.assertEqual(var.shape, [2, 3, 2]) + self.assertEqual(var.dtype, core.VarDesc.VarType.FP32) + self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR) + def test_tensor_to_variable(self): with fluid.dygraph.guard(): t = fluid.Tensor()