未验证 提交 91f7b4e0 编写于 作者: L Leo Chen 提交者: GitHub

refine as_lodtensor, test=develop (#25286)

* refine as_lodtensor, test=develop

* fix test, test=develop

* add unittest, test=develop

* handle nested_list, test=develop

* handle nested_list, test=develop
上级 0459af5c
......@@ -404,29 +404,35 @@ def _as_lodtensor(data, place, dtype=None):
>>> ...
Args:
data(numpy.ndarray): a instance of array
data(numpy.ndarray|list|tuple|scalar): a instance of array, scalar, list or tuple
data(core.Place): the place of created tensor
dtype(core.VarDesc.VarType): the expected data type of created tensor
dtype(core.VarDesc.VarType|str): the expected data type of created tensor
Returns:
LoDTensor
"""
if isinstance(data, list):
raise RuntimeError("Some of your feed data hold LoD information. \
They can not be completely cast from a list of Python \
ndarray to LoDTensor. Please convert data to LoDTensor \
directly before feeding the data.\
")
#NOTE(zhiqiu): convert python builtin ,like float and int, to numpy array
#NOTE(zhiqiu): convert python builtin, like float, int, and list, to numpy ndarray
if not isinstance(data, np.ndarray):
assert dtype is not None, 'The dtype should be given when feed data is not np.ndarray'
dtype = convert_dtype(dtype) if isinstance(
dtype, core.VarDesc.VarType) else dtype
if np.isscalar(data):
assert dtype is not None, 'dtype should be given when casting python scalar to tensor'
dtype = convert_dtype(dtype) if isinstance(
dtype, core.VarDesc.VarType) else dtype
data = np.array([data]).astype(dtype)
elif isinstance(data, (list, tuple)):
data = np.array(data)
if data.dtype == np.object:
raise TypeError(
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"Please consider using 'fluid.create_lod_tensor' to convert it to a LoD-Tensor."
)
data = data.astype(dtype)
else:
raise TypeError(
"Convert data of type {} to Tensor is not supported".format(
type(data)))
# single tensor case
# convert numpy.ndarray to tensor
tensor = core.LoDTensor()
tensor.set(data, place)
return tensor
......
......@@ -84,6 +84,28 @@ class TestExecutor(unittest.TestCase):
self.assertEqual(_lr._dtype(), fluid.core.VarDesc.VarType.FP32)
self.assertEqual(type(a), int)
def test_program_feed_list(self):
main_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
with fluid.program_guard(main_program, startup_program):
with fluid.scope_guard(scope):
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
lr, cost = self.net()
exe.run(startup_program)
train_data = [[1.0], [2.0], [3.0], [4.0]]
y_true = [[2.0], [4.0], [6.0], [8.0]]
a = 0
_lr, _ = exe.run(feed={'x': train_data,
'y': y_true,
'lr': a},
fetch_list=[lr, cost],
return_numpy=False)
self.assertEqual(_lr._dtype(), lr.dtype)
self.assertEqual(_lr._dtype(), fluid.core.VarDesc.VarType.FP32)
self.assertEqual(type(y_true), list)
def test_compiled_program_feed_scalar(self):
main_program = fluid.Program()
startup_program = fluid.Program()
......@@ -125,10 +147,32 @@ class TestAsLodTensor(unittest.TestCase):
fluid.core.VarDesc.VarType.FP64)
self.assertEqual(tensor._dtype(), fluid.core.VarDesc.VarType.FP64)
def test_as_lodtensor_error(self):
def test_as_lodtensor_assertion_error(self):
cpu = fluid.CPUPlace()
self.assertRaises(AssertionError, fluid.executor._as_lodtensor, 1, cpu)
def test_as_lodtensor_type_error(self):
cpu = fluid.CPUPlace()
self.assertRaises(TypeError, fluid.executor._as_lodtensor, {"a": 1},
cpu, fluid.core.VarDesc.VarType.INT32)
def test_as_lodtensor_list(self):
cpu = fluid.CPUPlace()
tensor = fluid.executor._as_lodtensor([1, 2], cpu,
fluid.core.VarDesc.VarType.FP64)
self.assertEqual(tensor._dtype(), fluid.core.VarDesc.VarType.FP64)
def test_as_lodtensor_tuple(self):
cpu = fluid.CPUPlace()
tensor = fluid.executor._as_lodtensor((1, 2), cpu,
fluid.core.VarDesc.VarType.FP64)
self.assertEqual(tensor._dtype(), fluid.core.VarDesc.VarType.FP64)
def test_as_lodtensor_nested_list(self):
cpu = fluid.CPUPlace()
self.assertRaises(TypeError, fluid.executor._as_lodtensor,
[[1], [1, 2]], cpu, fluid.core.VarDesc.VarType.INT32)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册