提交 5e65c753 编写于 作者: Z zhupengyang 提交者: Tao Luo

enhance input type chec for concat (#20584)

test=develop
上级 443f604c
......@@ -249,10 +249,15 @@ def concat(input, axis=0, name=None):
# [14 15 16]]
"""
helper = LayerHelper('concat', **locals())
if not isinstance(input, list):
warnings.warn(
"The type of input in concat should be list, but received %s." %
(type(input)))
input = [input]
for x in input:
if not isinstance(x, Variable):
raise TypeError(
"The type of x in 'input' in concat must be Variable, but received %s"
"The type of x in 'input' in concat must be Variable, but received %s."
% (type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
......
......@@ -118,13 +118,22 @@ create_test_fp16(TestConcatOp5)
class TestConcatOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of concat_op must be Variable.
x1 = fluid.create_lod_tensor(
# The input type of concat_op should be list.
x1 = fluid.layers.data(shape=[4], dtype='int32', name='x1')
fluid.layers.concat(x1)
# The item in input must be Variable.
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.concat, x1)
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.concat, [x2])
# The input dtype of concat_op must be float16(only support on GPU), float32, float64, int32, int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='uint8')
self.assertRaises(TypeError, fluid.layers.concat, x2)
x4 = fluid.layers.data(shape=[4], dtype='uint8', name='x4')
x5 = fluid.layers.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
fluid.layers.concat([x6, x7])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册