未验证 提交 91f7b4e0 编写于 作者: L Leo Chen 提交者: GitHub

refine as_lodtensor, test=develop (#25286)

* refine as_lodtensor, test=develop

* fix test, test=develop

* add unittest, test=develop

* handle nested_list, test=develop

* handle nested_list, test=develop
上级 0459af5c
......@@ -404,29 +404,35 @@ def _as_lodtensor(data, place, dtype=None):
>>> ...
Args:
data(numpy.ndarray): a instance of array
data(numpy.ndarray|list|tuple|scalar): a instance of array, scalar, list or tuple
data(core.Place): the place of created tensor
dtype(core.VarDesc.VarType): the expected data type of created tensor
dtype(core.VarDesc.VarType|str): the expected data type of created tensor
Returns:
LoDTensor
"""
if isinstance(data, list):
raise RuntimeError("Some of your feed data hold LoD information. \
They can not be completely cast from a list of Python \
ndarray to LoDTensor. Please convert data to LoDTensor \
directly before feeding the data.\
")
#NOTE(zhiqiu): convert python builtin ,like float and int, to numpy array
#NOTE(zhiqiu): convert python builtin, like float, int, and list, to numpy ndarray
if not isinstance(data, np.ndarray):
if np.isscalar(data):
assert dtype is not None, 'dtype should be given when casting python scalar to tensor'
assert dtype is not None, 'The dtype should be given when feed data is not np.ndarray'
dtype = convert_dtype(dtype) if isinstance(
dtype, core.VarDesc.VarType) else dtype
if np.isscalar(data):
data = np.array([data]).astype(dtype)
elif isinstance(data, (list, tuple)):
data = np.array(data)
if data.dtype == np.object:
raise TypeError(
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"Please consider using 'fluid.create_lod_tensor' to convert it to a LoD-Tensor."
)
data = data.astype(dtype)
else:
raise TypeError(
"Convert data of type {} to Tensor is not supported".format(
type(data)))
# single tensor case
# convert numpy.ndarray to tensor
tensor = core.LoDTensor()
tensor.set(data, place)
return tensor
......
......@@ -84,6 +84,28 @@ class TestExecutor(unittest.TestCase):
self.assertEqual(_lr._dtype(), fluid.core.VarDesc.VarType.FP32)
self.assertEqual(type(a), int)
def test_program_feed_list(self):
main_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
with fluid.program_guard(main_program, startup_program):
with fluid.scope_guard(scope):
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
lr, cost = self.net()
exe.run(startup_program)
train_data = [[1.0], [2.0], [3.0], [4.0]]
y_true = [[2.0], [4.0], [6.0], [8.0]]
a = 0
_lr, _ = exe.run(feed={'x': train_data,
'y': y_true,
'lr': a},
fetch_list=[lr, cost],
return_numpy=False)
self.assertEqual(_lr._dtype(), lr.dtype)
self.assertEqual(_lr._dtype(), fluid.core.VarDesc.VarType.FP32)
self.assertEqual(type(y_true), list)
def test_compiled_program_feed_scalar(self):
main_program = fluid.Program()
startup_program = fluid.Program()
......@@ -125,10 +147,32 @@ class TestAsLodTensor(unittest.TestCase):
fluid.core.VarDesc.VarType.FP64)
self.assertEqual(tensor._dtype(), fluid.core.VarDesc.VarType.FP64)
def test_as_lodtensor_error(self):
def test_as_lodtensor_assertion_error(self):
cpu = fluid.CPUPlace()
self.assertRaises(AssertionError, fluid.executor._as_lodtensor, 1, cpu)
def test_as_lodtensor_type_error(self):
cpu = fluid.CPUPlace()
self.assertRaises(TypeError, fluid.executor._as_lodtensor, {"a": 1},
cpu, fluid.core.VarDesc.VarType.INT32)
def test_as_lodtensor_list(self):
cpu = fluid.CPUPlace()
tensor = fluid.executor._as_lodtensor([1, 2], cpu,
fluid.core.VarDesc.VarType.FP64)
self.assertEqual(tensor._dtype(), fluid.core.VarDesc.VarType.FP64)
def test_as_lodtensor_tuple(self):
cpu = fluid.CPUPlace()
tensor = fluid.executor._as_lodtensor((1, 2), cpu,
fluid.core.VarDesc.VarType.FP64)
self.assertEqual(tensor._dtype(), fluid.core.VarDesc.VarType.FP64)
def test_as_lodtensor_nested_list(self):
cpu = fluid.CPUPlace()
self.assertRaises(TypeError, fluid.executor._as_lodtensor,
[[1], [1, 2]], cpu, fluid.core.VarDesc.VarType.INT32)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册