提交 db61de41 编写于 作者: Z zhupengyang 提交者: Tao Luo

[cherry-pick] python input check for concat op (#20601)

* add input type and dtype check, enhance shape error message for concat_op (#20101)

* add input type and dtype check, enhance shape error message for concat_op
test=develop

* enhance shape check
test=develop

* improve coverage

test=develop

* enhance input type chec for concat (#20584)

test=develop
上级 fed1263c
......@@ -41,8 +41,9 @@ class ConcatOp : public framework::OperatorWithKernel {
static_cast<int64_t>(ins[0].size()));
const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0.");
PADDLE_ENFORCE_GT(n, 0,
"ShapeError: Input tensors count should > 0. But "
"recevied inputs' length is 0.");
if (n == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}
......@@ -66,9 +67,14 @@ class ConcatOp : public framework::OperatorWithKernel {
ctx->IsRuntime() || (out_dims[j] > 0 && ins[i][j] > 0);
if (check_shape) {
// check all shape in run time
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
"Input tensors should have the same "
"elements except the specify axis.");
PADDLE_ENFORCE_EQ(
out_dims[j], ins[i][j],
"ShapeError: Input tensors should have same "
"dimensions(or specific dimension = -1) except the axis. "
"But recevied axis = %s, input[0]'s shape = "
"[%s], input[%s]'s shape = [%s], the \"%s\" "
"dimension of input[%s] is unexpected",
axis, ins[0], i, ins[j], j, i);
}
}
}
......
......@@ -23,6 +23,8 @@ from ..core import VarDesc
from .layer_function_generator import templatedoc
from ..data_feeder import convert_dtype
import numpy
import warnings
from ..data_feeder import convert_dtype
__all__ = [
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
......@@ -247,6 +249,26 @@ def concat(input, axis=0, name=None):
# [14 15 16]]
"""
helper = LayerHelper('concat', **locals())
if not isinstance(input, list):
warnings.warn(
"The type of input in concat should be list, but received %s." %
(type(input)))
input = [input]
for x in input:
if not isinstance(x, Variable):
raise TypeError(
"The type of x in 'input' in concat must be Variable, but received %s."
% (type(x)))
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of x in 'input' in concat only support float16 on GPU now."
)
if convert_dtype(x.dtype) not in [
'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The data type of x in 'input' in concat must be float16(only support on GPU), float32, float64, int32, int64, but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
type='concat',
......
......@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestConcatOp(OpTest):
......@@ -112,5 +114,27 @@ create_test_fp16(TestConcatOp3)
create_test_fp16(TestConcatOp4)
create_test_fp16(TestConcatOp5)
class TestConcatOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of concat_op should be list.
x1 = fluid.layers.data(shape=[4], dtype='int32', name='x1')
fluid.layers.concat(x1)
# The item in input must be Variable.
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.concat, [x2])
# The input dtype of concat_op must be float16(only support on GPU), float32, float64, int32, int64.
x4 = fluid.layers.data(shape=[4], dtype='uint8', name='x4')
x5 = fluid.layers.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
fluid.layers.concat([x6, x7])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册