未验证 提交 d98e1182 编写于 作者: D danleifeng 提交者: GitHub

fix check and error message for flatten hash is_empty op (#24434)

fix check info for flatten hash is_empty op; test=develop
上级 30efee33
......@@ -29,17 +29,17 @@ class FlattenOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input (X) of Flatten op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output (Output) of Flatten op should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Flatten");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Flatten");
const auto &axis = ctx->Attrs().Get<int>("axis");
const auto &in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(axis, 0,
"The axis should be greater than or equal to 0.");
platform::errors::InvalidArgument(
"The axis should be greater than or equal to 0."));
PADDLE_ENFORCE_LE(
axis, in_dims.size(),
"The axis should be less than or equal to input tensor's rank.");
platform::errors::InvalidArgument(
"The axis should be less than or equal to input tensor's rank."));
const auto &out_dims = GetOutputShape(axis, in_dims);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
......@@ -161,17 +161,17 @@ class Flatten2Op : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input (X) of Flatten op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output (Output) of Flatten op should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Flatten2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Flatten2");
const auto &axis = ctx->Attrs().Get<int>("axis");
const auto &in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(axis, 0,
"The axis should be greater than or equal to 0.");
platform::errors::InvalidArgument(
"The axis should be greater than or equal to 0."));
PADDLE_ENFORCE_LE(
axis, in_dims.size(),
"The axis should be less than or equal to input tensor's rank.");
platform::errors::InvalidArgument(
"The axis should be less than or equal to input tensor's rank"));
const auto &out_dims = FlattenOp::GetOutputShape(axis, in_dims);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
......@@ -181,8 +181,7 @@ class Flatten2Op : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out");
}
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
"Output (XShape) of Flatten op should not be null.");
OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Flatten2");
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
......@@ -223,10 +222,10 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true,
"Input(XShape) shouldn't be null.");
PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true,
"Input(Out@GRAD) shouldn't be null.");
OP_INOUT_CHECK(context->HasInput("XShape"), "Input", "XShape",
"Flatten2Grad");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "Flatten2Grad");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
......
......@@ -26,14 +26,13 @@ class HashOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of HashOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of HashOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Hash");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Hash");
auto dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dims.size(), 2UL,
"The input of hash_op's dimensions must be 2");
platform::errors::InvalidArgument(
"The input of hash_op's dimensions must be 2"));
std::vector<int64_t> out_dims;
int num_hash = ctx->Attrs().Get<int>("num_hash");
HashOutputSize(dims, out_dims, num_hash);
......
......@@ -25,10 +25,8 @@ class IsEmptyOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of IsEmptyOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of IsEmptyOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "IsEmpty");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "IsEmpty");
ctx->SetOutputDim("Out", {1});
}
......
......@@ -26,7 +26,7 @@ import numpy
import warnings
import six
from functools import reduce, partial
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ... import compat as cpt
from ..backward import _infer_var_data_type_shape_
......@@ -3725,15 +3725,15 @@ def is_empty(x, cond=None):
# fluid.layers.is_empty(x=input, cond=res)
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'is_empty')
check_type(cond, 'cond', (Variable, type(None)), 'is_empty')
helper = LayerHelper("is_empty", **locals())
if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True
elif not isinstance(cond, Variable):
raise TypeError("cond takes a variable")
elif cond.dtype != 'bool':
raise TypeError("The data type of cond must be bool")
else:
check_dtype(cond.dtype, 'cond', ['bool'], 'is_empty')
helper.append_op(
type='is_empty', inputs={'X': [x]}, outputs={'Out': [cond]})
return cond
......@@ -9628,6 +9628,8 @@ def flatten(x, axis=1, name=None):
out = fluid.layers.flatten(x=x, axis=2)
# out shape is [16, 3]
"""
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int8', 'int32', 'int64'], 'flatten')
helper = LayerHelper('flatten', **locals())
if not (isinstance(x, Variable)):
......@@ -12466,6 +12468,9 @@ def hash(input, hash_size, num_hash=1, name=None):
# [386]
# [901]]]
"""
check_variable_and_dtype(input, 'input', ['int32', 'int64'], 'hash')
check_type(hash_size, 'hash_size', ['int32', 'int64'], 'hash')
check_type(num_hash, 'num_hash', ['int32', 'int64'], 'hash')
helper = LayerHelper('hash', **locals())
out = helper.create_variable_for_type_inference(
helper.input_dtype(), stop_gradient=True)
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from op_test import OpTest
......@@ -69,5 +69,25 @@ class TestFlattenOpSixDims(TestFlattenOp):
self.new_shape = (36, 16)
class TestFlatten2OpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data = np.random.random((3, 2, 4, 5)).astype("float64")
def test_Variable():
# the input type must be Variable
fluid.layers.flatten(input_data, axis=1)
self.assertRaises(TypeError, test_Variable)
def test_type():
# dtype must be float32, float64, int8, int32, int64.
x2 = fluid.layers.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16')
fluid.layers.flatten(x2, axis=1)
self.assertRaises(TypeError, test_type)
if __name__ == "__main__":
unittest.main()
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
class TestHashOp(OpTest):
......@@ -102,5 +103,41 @@ class TestHashOp3(TestHashOp):
self.check_output()
class TestHashOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data = np.random.randint(0, 10, (8, 1)).astype("int32")
def test_Variable():
# the input type must be Variable
fluid.layers.hash(input=input_data, hash_size=2**32)
self.assertRaises(TypeError, test_Variable)
def test_type():
# dtype must be int32, int64.
x2 = fluid.layers.data(
name='x2', shape=[1], dtype="float32", lod_level=1)
fluid.layers.hash(input=x2, hash_size=2**32)
self.assertRaises(TypeError, test_type)
def test_hash_size_type():
# hash_size dtype must be int32, int64.
x3 = fluid.layers.data(
name='x3', shape=[1], dtype="int32", lod_level=1)
fluid.layers.hash(input=x3, hash_size=1024.5)
self.assertRaises(TypeError, test_hash_size_type)
def test_num_hash_type():
# num_hash dtype must be int32, int64.
x4 = fluid.layers.data(
name='x4', shape=[1], dtype="int32", lod_level=1)
fluid.layers.hash(input=x4, hash_size=2**32, num_hash=2.5)
self.assertRaises(TypeError, test_num_hash_type)
if __name__ == "__main__":
unittest.main()
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
class TestEmpty(OpTest):
......@@ -36,5 +37,42 @@ class TestNotEmpty(TestEmpty):
self.outputs = {'Out': np.array([True])}
class TestIsEmptyOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data = np.random.random((3, 2)).astype("float64")
def test_Variable():
# the input type must be Variable
fluid.layers.is_empty(x=input_data)
self.assertRaises(TypeError, test_Variable)
def test_cond_Variable():
# cond type must be Variable or None
x2 = fluid.layers.data(name="x2", shape=[3, 2], dtype="float32")
cond_data = np.random.random((3, 2)).astype("float32")
fluid.layers.is_empty(x=x2, cond=cond_data)
self.assertRaises(TypeError, test_cond_Variable)
def test_type():
# dtype must be float32, float64, int32, int64
x3 = fluid.layers.data(
name="x3", shape=[4, 32, 32], dtype="bool")
res = fluid.layers.is_empty(x=x3)
self.assertRaises(TypeError, test_type)
def test_cond_type():
# cond dtype must be bool.
x4 = fluid.layers.data(name="x4", shape=[3, 2], dtype="float32")
cond = fluid.layers.data(
name="cond", shape=[1], dtype="float32")
fluid.layers.is_empty(x=x4, cond=cond)
self.assertRaises(TypeError, test_cond_type)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册