未验证 提交 dd3ae023 编写于 作者: K kinghuin 提交者: GitHub

optimize compare and logical ops error info, add test case for this ops

* optimize compare and logical ops error info
* add out and cond dtype test
上级 c4979136
......@@ -80,14 +80,16 @@ class CompareOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* context) const override {
OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"), "%s operator must have input X",
comment.type);
PADDLE_ENFORCE(context->HasInput("Y"), "%s operator must have input Y",
comment.type);
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
OP_INOUT_CHECK(context->HasInput("Y"), "Output", "Y", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(),
"The size of dim_y should not be greater than dim_x's.");
platform::errors::InvalidArgument(
"The size of dim_y should not be greater than "
"dim_x's, but received dim_y: %d > dim_x: %d",
dim_y.size(), dim_x.size()));
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
......
......@@ -79,10 +79,7 @@ class UnaryLogicalOp : public LogicalOp {
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE_EQ(
context->HasInput("X"), true,
platform::errors::NotFound("Input(X) of %s operator must not be null",
comment.type));
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
......@@ -96,10 +93,8 @@ class BinaryLogicalOp : public LogicalOp {
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
PADDLE_ENFORCE_EQ(context->HasInput("X"), true,
"Input(X) of %s operator must not be null", comment.type);
PADDLE_ENFORCE_EQ(context->HasInput("Y"), true,
"Input(Y) of %s operator must not be null", comment.type);
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
......@@ -107,10 +102,11 @@ class BinaryLogicalOp : public LogicalOp {
int product_y = framework::product(dim_y);
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
if (check) {
PADDLE_ENFORCE_EQ(
product_x, product_y,
"The number of elements in X and Y should be same, %d != %d",
product_x, product_y);
PADDLE_ENFORCE_EQ(product_x, product_y,
platform::errors::InvalidArgument(
"The number of elements in X and Y should be same, "
"but received %d != %d",
product_x, product_y));
}
context->SetOutputDim("Out", context->GetInputDim("X"));
......
......@@ -1436,6 +1436,15 @@ def less_than(x, y, force_cpu=None, cond=None):
result_value, = exe.run(fluid.default_main_program(), feed={'x':x_i, 'y':y_i}, fetch_list=[result])
print(result_value) # [[True, False], [False, False]]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"less_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"less_than")
if cond is not None:
check_type(cond, "cond", Variable, "less_than")
if force_cpu != None:
check_type(force_cpu, "force_cpu", bool, "less_than")
helper = LayerHelper("less_than", **locals())
if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool')
......@@ -1480,6 +1489,14 @@ def less_equal(x, y, cond=None):
out1 = label<= limit #out1=[True, False]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"less_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"less_equal")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"less_equal")
helper = LayerHelper("less_equal", **locals())
if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool')
......@@ -1521,6 +1538,14 @@ def greater_than(x, y, cond=None):
out = fluid.layers.greater_than(x=label, y=limit) #out=[False, True]
out1 = label > limit #out1=[False, True]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"greater_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"greater_than")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"greater_than")
helper = LayerHelper("greater_than", **locals())
if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool')
......@@ -1564,6 +1589,14 @@ def greater_equal(x, y, cond=None):
out_1 = label >= limit #out1=[True, False]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"greater_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"greater_equal")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"greater_equal")
helper = LayerHelper("greater_equal", **locals())
if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool')
......@@ -1607,6 +1640,14 @@ def equal(x, y, cond=None):
out1 = fluid.layers.equal(x=label,y=limit) #out1=[True, False]
out2 = fluid.layers.equal(x=label_cond,y=limit, cond=out_cond) #out2=[False, True] out_cond=[False, True]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"equal")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"equal")
helper = LayerHelper("equal", **locals())
if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool')
......@@ -1641,6 +1682,14 @@ def not_equal(x, y, cond=None):
limit = fluid.layers.fill_constant(shape=[1], value=1, dtype='int64')
out = fluid.layers.not_equal(x=label, y=limit)
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"not_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"not_equal")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"not_equal")
helper = LayerHelper("not_equal", **locals())
if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool')
......
......@@ -11364,6 +11364,12 @@ Examples:
def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
check_variable_and_dtype(x, "x", ["bool"], op_name)
if y is not None:
check_variable_and_dtype(y, "y", ["bool"], op_name)
if out is not None:
check_variable_and_dtype(out, "out", [convert_dtype(x.dtype)], op_name)
helper = LayerHelper(op_name, **locals())
if binary_op:
......
......@@ -36,6 +36,26 @@ def create_test_class(op_type, typename, callback):
def test_output(self):
self.check_output()
def test_errors(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[2], dtype='int32')
y = fluid.layers.data(name='y', shape=[2], dtype='int32')
a = fluid.layers.data(name='a', shape=[2], dtype='int16')
b = fluid.layers.data(name='b', shape=[2], dtype='int64')
if self.op_type == "less_than":
self.assertRaises(
TypeError,
fluid.layers.less_than,
x=x,
y=y,
force_cpu=1)
op = eval("fluid.layers.%s" % self.op_type)
self.assertRaises(TypeError, op, x=x, y=y, cond=1)
if self.op_type != "less_than":
self.assertRaises(TypeError, op, x=x, y=y, cond=b)
self.assertRaises(TypeError, op, x=x, y=a)
self.assertRaises(TypeError, op, x=a, y=y)
cls_name = "{0}_{1}".format(op_type, typename)
Cls.__name__ = cls_name
globals()[cls_name] = Cls
......
......@@ -17,6 +17,8 @@ from __future__ import print_function
import op_test
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def create_test_class(op_type, callback, binary_op=True):
......@@ -38,6 +40,22 @@ def create_test_class(op_type, callback, binary_op=True):
def test_output(self):
self.check_output()
def test_error(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[2], dtype='bool')
y = fluid.layers.data(name='y', shape=[2], dtype='bool')
a = fluid.layers.data(name='a', shape=[2], dtype='int32')
op = eval("fluid.layers.%s" % self.op_type)
if self.op_type != "logical_not":
self.assertRaises(TypeError, op, x=x, y=y, out=1)
self.assertRaises(TypeError, op, x=x, y=a)
self.assertRaises(TypeError, op, x=a, y=y)
self.assertRaises(TypeError, op, x=x, y=y, out=a)
else:
self.assertRaises(TypeError, op, x=x, out=1)
self.assertRaises(TypeError, op, x=x, out=a)
self.assertRaises(TypeError, op, x=a)
Cls.__name__ = op_type
globals()[op_type] = Cls
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册