未验证 提交 81e8fd4a 编写于 作者: W wangchaochaohu 提交者: GitHub

API(fluid.layers.array_length) error message enhancement (#23547)

上级 1b8fe70e
......@@ -60,8 +60,9 @@ CPU and the length of LoDTensorArray should be used as control variables.
class LoDArrayLengthInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInput("X"));
PADDLE_ENFORCE(context->HasOutput("Out"));
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "LDArrayLength");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out",
"LoDArrayLength");
context->SetOutputDim("Out", {1});
}
};
......
......@@ -1782,12 +1782,19 @@ def array_length(array):
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables.
"""
if in_dygraph_mode():
assert isinstance(
array,
list), "The 'array' in array_write must be a list in dygraph mode"
return len(array)
if not isinstance(
array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError(
"array should be tensor array vairable in array_length Op")
helper = LayerHelper('array_length', **locals())
tmp = helper.create_variable_for_type_inference(dtype='int64')
tmp.stop_gradient = True
......
......@@ -18,6 +18,8 @@ import unittest
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
import numpy
......@@ -33,5 +35,14 @@ class TestLoDArrayLength(unittest.TestCase):
self.assertEqual(11, result[0])
class TestLoDArrayLengthOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
#for ci coverage
x1 = numpy.random.randn(2, 4).astype('int32')
self.assertRaises(TypeError, fluid.layers.array_length, array=x1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册