未验证 提交 5ebf4078 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

add input type and dtype check for squeeze (#20100)

* Add input check and refine error message

* Refine test case and comments

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 0313b98a
......@@ -35,14 +35,19 @@ class SqueezeOp : public framework::OperatorWithKernel {
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(), 6,
"Invalid dimnesions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit).");
"ShapeError: the dimensions of Input(X) "
"should be in the range of [1, 6] (Eigen limit)."
"But received X's dimensions = %d, X's shape=[%s].",
x_dims.size(), x_dims);
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
"The squeeze axis should be less than input "
"tensor's rank.");
PADDLE_ENFORCE_LT(
a, x_dims.size(),
"ShapeError: The squeeze axis should be less than input "
"tensor's dimensions. But received axis = %d, input "
"tensor's dimensions = %d, input tensor's shape = [%s].",
a, x_dims.size(), x_dims);
}
auto out_dims = GetOutputShape(axes, x_dims);
......@@ -73,9 +78,10 @@ class SqueezeOp : public framework::OperatorWithKernel {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE_GE(current, 0,
"Invalid axis, the negative axis is out of range.");
"Invalid axis, the axis should >= 0."
"Current axis is:%d, input tensor's shape = [%s].",
current, in_dims);
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
......@@ -171,14 +177,19 @@ class Squeeze2Op : public framework::OperatorWithKernel {
const auto &x_dims = ctx->GetInputDim("X");
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(), 6,
"Invalid dimnesions, the rank of Input(X) "
"should be in the range of [1, 6] (Eigen limit).");
"ShapeError: the dimensions of Input(X) "
"should be in the range of [1, 6] (Eigen limit)."
"But received X's dimensions = %d, X's shape = [%s].",
x_dims.size(), x_dims);
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
for (int a : axes) {
PADDLE_ENFORCE_LT(a, x_dims.size(),
"The squeeze axis should be less than input "
"tensor's rank.");
PADDLE_ENFORCE_LT(
a, x_dims.size(),
"ShapeError: The squeeze axis should be less than input "
"tensor's dimensions. But received axis = %d, input "
"tensor's dimensions = %d, input tensor's shape = [%s].",
a, x_dims.size(), x_dims);
}
auto out_dims = SqueezeOp::GetOutputShape(axes, x_dims);
......
......@@ -61,13 +61,17 @@ class SqueezeKernel : public framework::OpKernel<T> {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// Check current index, the upper limit has beed checked in line 36.
PADDLE_ENFORCE_GE(current, 0,
"Invalid axis, the negative axis is out of range.");
"Invalid axis, the axis should >= 0."
"Current axis is:%d, input tensor's shape = [%s].",
current, in_dims);
PADDLE_ENFORCE_EQ(in_dims[current], 1,
"Invalid axis index, the axis that will be squeezed "
"should be equal to 1.");
"should be equal to 1. But current axis = %d,"
"input tensor's shape = [%s].",
in_dims[current], in_dims);
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
......
......@@ -8078,6 +8078,23 @@ def squeeze(input, axes, name=None):
assert not in_dygraph_mode(), (
"squeeze layer is not supported in dygraph mode yet.")
helper = LayerHelper("squeeze", **locals())
if not isinstance(input, Variable):
raise TypeError(
"The type of 'input' in squeeze must be Variable, but received %s" %
(type(input)))
if convert_dtype(input.dtype
) not in ['float32', 'float64', 'int8', 'int32', 'int64']:
raise TypeError(
"The data type of 'input' in squeeze must be float32, float64, int8, int32,"
"int64, but received %s." % (convert_dtype(input.dtype)))
if not isinstance(axes, list):
raise TypeError(
"The type of 'axes' in squeeze must be list, but received %s" %
(type(axes)))
out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from op_test import OpTest
......@@ -68,5 +70,20 @@ class TestSqueezeOp3(TestSqueezeOp):
self.new_shape = (3, 5, 1, 4)
class TestSqueezeOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of softmax_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.squeeze, x1)
# The input axes of squeeze must be list.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.squeeze, x2, axes=0)
# The input dtype of squeeze not support float16.
x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16")
self.assertRaises(TypeError, fluid.layers.squeeze, x3, axes=0)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册