未验证 提交 3ea6899f 编写于 作者: W wangchaochaohu 提交者: GitHub

API(fluid.layers.array_read/array_write) error message enhancement (#23568)

上级 eda7ff05
...@@ -1314,7 +1314,15 @@ def array_write(x, i, array=None): ...@@ -1314,7 +1314,15 @@ def array_write(x, i, array=None):
array.append(x) array.append(x)
return array return array
check_variable_and_dtype(i, 'i', ['int64'], 'array_write')
check_type(x, 'x', (Variable), 'array_write')
helper = LayerHelper('array_write', **locals()) helper = LayerHelper('array_write', **locals())
if array is not None:
if not isinstance(
array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError(
"array should be tensor array vairable in array_write Op")
if array is None: if array is None:
array = helper.create_variable( array = helper.create_variable(
name="{0}.out".format(helper.name), name="{0}.out".format(helper.name),
...@@ -1686,6 +1694,7 @@ def array_read(array, i): ...@@ -1686,6 +1694,7 @@ def array_read(array, i):
i = i.numpy()[0] i = i.numpy()[0]
return array[i] return array[i]
check_variable_and_dtype(i, 'i', ['int64'], 'array_read')
helper = LayerHelper('array_read', **locals()) helper = LayerHelper('array_read', **locals())
if not isinstance( if not isinstance(
array, array,
......
...@@ -21,6 +21,7 @@ import paddle.fluid.layers as layers ...@@ -21,6 +21,7 @@ import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.fluid import compiler, Program, program_guard
import numpy import numpy
...@@ -125,5 +126,19 @@ class TestArrayReadWrite(unittest.TestCase): ...@@ -125,5 +126,19 @@ class TestArrayReadWrite(unittest.TestCase):
self.assertAlmostEqual(1.0, g_out_sum_dygraph, delta=0.1) self.assertAlmostEqual(1.0, g_out_sum_dygraph, delta=0.1)
class TestArrayReadWriteOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
#for ci coverage
x1 = numpy.random.randn(2, 4).astype('int32')
x2 = fluid.layers.fill_constant(shape=[1], dtype='int32', value=1)
x3 = numpy.random.randn(2, 4).astype('int32')
self.assertRaises(
TypeError, fluid.layers.array_read, array=x1, i=x2)
self.assertRaises(
TypeError, fluid.layers.array_write, array=x1, i=x2, out=x3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册