未验证 提交 81e8fd4a 编写于 作者: W wangchaochaohu 提交者: GitHub

API(fluid.layers.array_length) error message enhancement (#23547)

上级 1b8fe70e
...@@ -60,8 +60,9 @@ CPU and the length of LoDTensorArray should be used as control variables. ...@@ -60,8 +60,9 @@ CPU and the length of LoDTensorArray should be used as control variables.
class LoDArrayLengthInferShape : public framework::InferShapeBase { class LoDArrayLengthInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X")); OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "LDArrayLength");
PADDLE_ENFORCE(context->HasOutput("Out")); OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out",
"LoDArrayLength");
context->SetOutputDim("Out", {1}); context->SetOutputDim("Out", {1});
} }
}; };
......
...@@ -1782,12 +1782,19 @@ def array_length(array): ...@@ -1782,12 +1782,19 @@ def array_length(array):
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux, # so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables. # and '__int64' on Windows. They both represent 64-bit integer variables.
""" """
if in_dygraph_mode(): if in_dygraph_mode():
assert isinstance( assert isinstance(
array, array,
list), "The 'array' in array_write must be a list in dygraph mode" list), "The 'array' in array_write must be a list in dygraph mode"
return len(array) return len(array)
if not isinstance(
array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError(
"array should be tensor array vairable in array_length Op")
helper = LayerHelper('array_length', **locals()) helper = LayerHelper('array_length', **locals())
tmp = helper.create_variable_for_type_inference(dtype='int64') tmp = helper.create_variable_for_type_inference(dtype='int64')
tmp.stop_gradient = True tmp.stop_gradient = True
......
...@@ -18,6 +18,8 @@ import unittest ...@@ -18,6 +18,8 @@ import unittest
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
import numpy import numpy
...@@ -33,5 +35,14 @@ class TestLoDArrayLength(unittest.TestCase): ...@@ -33,5 +35,14 @@ class TestLoDArrayLength(unittest.TestCase):
self.assertEqual(11, result[0]) self.assertEqual(11, result[0])
class TestLoDArrayLengthOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
#for ci coverage
x1 = numpy.random.randn(2, 4).astype('int32')
self.assertRaises(TypeError, fluid.layers.array_length, array=x1)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册