未验证 提交 896bda0c 编写于 作者: H HappyAngel 提交者: GitHub

python API(get_tensor_from_selected_rows and tensor_array_to_tensor)error...

python API(get_tensor_from_selected_rows and tensor_array_to_tensor)error message enhance, test=develop (#23636)
上级 9b3086cf
...@@ -12586,6 +12586,11 @@ def get_tensor_from_selected_rows(x, name=None): ...@@ -12586,6 +12586,11 @@ def get_tensor_from_selected_rows(x, name=None):
out = fluid.layers.get_tensor_from_selected_rows(input) out = fluid.layers.get_tensor_from_selected_rows(input)
""" """
check_type(x, 'x', Variable, 'get_tensor_from_selected_rows')
if x.type != core.VarDesc.VarType.SELECTED_ROWS:
raise TypeError(
"The type of 'x' in get_tensor_from_selected_rows must be SELECTED_ROWS."
)
helper = LayerHelper('get_tensor_from_selected_rows', **locals()) helper = LayerHelper('get_tensor_from_selected_rows', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
......
...@@ -440,6 +440,11 @@ def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False): ...@@ -440,6 +440,11 @@ def tensor_array_to_tensor(input, axis=1, name=None, use_stack=False):
numpy.array(list(map(lambda x: int(x.shape[axis]), input)))) numpy.array(list(map(lambda x: int(x.shape[axis]), input))))
return res, sizes return res, sizes
check_type(input, 'input', (list, Variable), 'tensor_array_to_tensor')
if isinstance(input, list):
for i, input_x in enumerate(input):
check_type(input_x, 'input[' + str(i) + ']', Variable,
'tensor_array_to_tensor')
helper = LayerHelper('tensor_array_to_tensor', **locals()) helper = LayerHelper('tensor_array_to_tensor', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
out_index = helper.create_variable_for_type_inference(dtype="int32") out_index = helper.create_variable_for_type_inference(dtype="int32")
......
...@@ -17,7 +17,28 @@ from __future__ import print_function ...@@ -17,7 +17,28 @@ from __future__ import print_function
import unittest import unittest
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
import paddle.fluid as fluid
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.fluid import Program, program_guard
class TestGetTensorFromSelectedRowsError(unittest.TestCase):
"""get_tensor_from_selected_rows error message enhance"""
def test_errors(self):
with program_guard(Program()):
x_var = fluid.data('X', [2, 3])
x_data = np.random.random((2, 4)).astype("float32")
def test_Variable():
fluid.layers.get_tensor_from_selected_rows(x=x_data)
self.assertRaises(TypeError, test_Variable)
def test_SELECTED_ROWS():
fluid.layers.get_tensor_from_selected_rows(x=x_var)
self.assertRaises(TypeError, test_SELECTED_ROWS)
class TestGetTensorFromSelectedRows(unittest.TestCase): class TestGetTensorFromSelectedRows(unittest.TestCase):
......
...@@ -20,6 +20,25 @@ import paddle.fluid as fluid ...@@ -20,6 +20,25 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid import Program, program_guard
class TestTensorArrayToTensorError(unittest.TestCase):
"""Tensor_array_to_tensor error message enhance"""
def test_errors(self):
with program_guard(Program()):
input_data = numpy.random.random((2, 4)).astype("float32")
def test_Variable():
fluid.layers.tensor_array_to_tensor(input=input_data)
self.assertRaises(TypeError, test_Variable)
def test_list_Variable():
fluid.layers.tensor_array_to_tensor(input=[input_data])
self.assertRaises(TypeError, test_list_Variable)
class TestLoDTensorArrayConcat(unittest.TestCase): class TestLoDTensorArrayConcat(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册