未验证 提交 8de89e67 编写于 作者: Z Zhou Wei 提交者: GitHub

support tuple/list init for VarBase (#25231)

* support tuple/list init for VarBase,test=develop

* fix doc of fluid.dygraph.to_variable,test=develop

* fix doc of fluid.dygraph.to_variable,test=develop
上级 f07b25d8
......@@ -23,6 +23,7 @@ from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar
from .tracer import Tracer
import logging
import objgraph
from ..data_feeder import convert_dtype
__all__ = [
'no_grad',
......@@ -539,28 +540,34 @@ def grad(outputs,
@framework.dygraph_only
def to_variable(value, name=None, zero_copy=None):
def to_variable(value, name=None, zero_copy=None, dtype=None):
"""
:api_attr: imperative
The API will create a ``Variable`` or ``ComplexVariable`` object from
numpy\.ndarray, Variable or ComplexVariable object.
tuple, list, numpy\.ndarray, Variable or ComplexVariable object.
Parameters:
value(ndarray|Variable|Tensor|ComplexVariable): The numpy\.ndarray, Variable
Tensor or ComplexVariable object that needs to be converted, it can be
multi-dimension, and the data type is one of numpy\.{float16,
float32, float64, int16, int32, int64, uint8, uint16, complex64,
complex128}.
value(tuple|list|ndarray|Variable|Tensor|ComplexVariable): Initial data.
Can be a list, tuple, NumPy ndarray, Variable, Tensor, ComplexVariable.
The shape can be multi-dimensional. The data type is one of
numpy\.{float16, float32, float64, int16, int32, int64,
uint8, uint16, complex64, complex128}.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name` .
zero_copy(bool, optional): Whether to share memory with the input numpy
array. This parameter only works with CPUPlace and will be set to
True when it is None. Default: None.
dtype(str, optional): The desired data type of returned ``Variable`` .
Can be 'bool' , 'float16' , 'float32' , 'float64' , 'int8' , 'int16' ,
'int32' , 'int64' , 'uint8' . Default: None.
Returns:
Variable or ComplexVariable: If ``value`` is a numpy\.ndarray object, return ``Tensor`` created from the specified numpy\.ndarray object, which has same data type and shape with ``value``. If ``value`` is a Variable or ComplexVariable object, just return ``value``.
Variable or ComplexVariable: If ``value`` is a tuple/list/numpy\.ndarray object,
return ``Tensor`` created from the corresponding numpy\.ndarray object, which has
same data type and shape with ``value``. If ``value`` is a Variable or ComplexVariable
object, just return ``value``.
Examples:
......@@ -582,17 +589,41 @@ def to_variable(value, name=None, zero_copy=None):
z = fluid.dygraph.to_variable(c)
z.numpy() # array([2.+1.j, 2.+0.j])
z.dtype # 'complex128'
y = fluid.dygraph.to_variable([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
y.shape # [3L, 2L]
y = fluid.dygraph.to_variable(((0.1, 1.2), (2.2, 3.1), (4.9, 5.2)), dtype='int32')
y.shape # [3L, 2L]
"""
if isinstance(value, np.ndarray):
assert framework.in_dygraph_mode(
), "to_variable could only be called in dygraph mode"
support_type = (list, tuple, np.ndarray, core.VarBase, framework.Variable,
framework.ComplexVariable, core.Tensor, core.LoDTensor)
if not isinstance(value, support_type):
raise TypeError(
"The type of 'value' in fluid.dygraph.to_variable must be %s, but received %s."
% (support_type, type(value)))
if isinstance(value, (core.VarBase, framework.Variable,
framework.ComplexVariable)):
return value
elif isinstance(value, (core.Tensor, core.LoDTensor)):
return core.VarBase(value)
else:
if isinstance(framework._current_expected_place(),
framework.core.CPUPlace):
if zero_copy is None:
zero_copy = True
else:
assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
zero_copy = False
if not isinstance(value, np.ndarray):
value = np.array(value)
if dtype is not None:
dtype = convert_dtype(dtype)
if value.dtype != dtype:
value = value.astype(dtype)
if np.iscomplexobj(value):
if not name:
name = framework.unique_name.generate('_generated_var')
......@@ -617,12 +648,3 @@ def to_variable(value, name=None, zero_copy=None):
zero_copy=zero_copy,
name=name if name else '')
return py_var
elif isinstance(value, (core.VarBase, framework.Variable,
framework.ComplexVariable)):
return value
elif isinstance(value, (core.Tensor, core.LoDTensor)):
return core.VarBase(value)
else:
raise TypeError(
"The type of input value is invalid, expected type is 'ndarray', "
"'Variable' or 'ComplexVariable', but received %s." % type(value))
......@@ -47,6 +47,24 @@ class TestVarBase(unittest.TestCase):
linear = fluid.dygraph.Linear(32, 64)
var = linear._helper.to_variable("test", name="abc")
def test_list_to_variable(self):
with fluid.dygraph.guard():
array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]]
var = fluid.dygraph.to_variable(array, dtype='int32')
self.assertTrue(np.array_equal(var.numpy(), array))
self.assertEqual(var.shape, [2, 3, 2])
self.assertEqual(var.dtype, core.VarDesc.VarType.INT32)
self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR)
def test_tuple_to_variable(self):
with fluid.dygraph.guard():
array = (((1, 2), (1, 2), (1, 2)), ((1, 2), (1, 2), (1, 2)))
var = fluid.dygraph.to_variable(array, dtype='float32')
self.assertTrue(np.array_equal(var.numpy(), array))
self.assertEqual(var.shape, [2, 3, 2])
self.assertEqual(var.dtype, core.VarDesc.VarType.FP32)
self.assertEqual(var.type, core.VarDesc.VarType.LOD_TENSOR)
def test_tensor_to_variable(self):
with fluid.dygraph.guard():
t = fluid.Tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册