未验证 提交 dd3ae023 编写于 作者: K kinghuin 提交者: GitHub

optimize compare and logical ops error info, add test case for this ops

* optimize compare and logical ops error info
* add out and cond dtype test
上级 c4979136
...@@ -80,14 +80,16 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -80,14 +80,16 @@ class CompareOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* context) const override { void InferShape(framework::InferShapeContext* context) const override {
OpComment comment; OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"), "%s operator must have input X", OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
comment.type); OP_INOUT_CHECK(context->HasInput("Y"), "Output", "Y", comment.type);
PADDLE_ENFORCE(context->HasInput("Y"), "%s operator must have input Y",
comment.type);
auto dim_x = context->GetInputDim("X"); auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y"); auto dim_y = context->GetInputDim("Y");
PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(), PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(),
"The size of dim_y should not be greater than dim_x's."); platform::errors::InvalidArgument(
"The size of dim_y should not be greater than "
"dim_x's, but received dim_y: %d > dim_x: %d",
dim_y.size(), dim_x.size()));
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
......
...@@ -79,10 +79,7 @@ class UnaryLogicalOp : public LogicalOp { ...@@ -79,10 +79,7 @@ class UnaryLogicalOp : public LogicalOp {
protected: protected:
void InferShape(framework::InferShapeContext *context) const override { void InferShape(framework::InferShapeContext *context) const override {
OpComment comment; OpComment comment;
PADDLE_ENFORCE_EQ( OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
context->HasInput("X"), true,
platform::errors::NotFound("Input(X) of %s operator must not be null",
comment.type));
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
...@@ -96,10 +93,8 @@ class BinaryLogicalOp : public LogicalOp { ...@@ -96,10 +93,8 @@ class BinaryLogicalOp : public LogicalOp {
protected: protected:
void InferShape(framework::InferShapeContext *context) const override { void InferShape(framework::InferShapeContext *context) const override {
OpComment comment; OpComment comment;
PADDLE_ENFORCE_EQ(context->HasInput("X"), true, OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
"Input(X) of %s operator must not be null", comment.type); OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type);
PADDLE_ENFORCE_EQ(context->HasInput("Y"), true,
"Input(Y) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X"); auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y"); auto dim_y = context->GetInputDim("Y");
...@@ -107,10 +102,11 @@ class BinaryLogicalOp : public LogicalOp { ...@@ -107,10 +102,11 @@ class BinaryLogicalOp : public LogicalOp {
int product_y = framework::product(dim_y); int product_y = framework::product(dim_y);
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0); bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
if (check) { if (check) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(product_x, product_y,
product_x, product_y, platform::errors::InvalidArgument(
"The number of elements in X and Y should be same, %d != %d", "The number of elements in X and Y should be same, "
product_x, product_y); "but received %d != %d",
product_x, product_y));
} }
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
......
...@@ -1436,6 +1436,15 @@ def less_than(x, y, force_cpu=None, cond=None): ...@@ -1436,6 +1436,15 @@ def less_than(x, y, force_cpu=None, cond=None):
result_value, = exe.run(fluid.default_main_program(), feed={'x':x_i, 'y':y_i}, fetch_list=[result]) result_value, = exe.run(fluid.default_main_program(), feed={'x':x_i, 'y':y_i}, fetch_list=[result])
print(result_value) # [[True, False], [False, False]] print(result_value) # [[True, False], [False, False]]
""" """
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"less_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"less_than")
if cond is not None:
check_type(cond, "cond", Variable, "less_than")
if force_cpu != None:
check_type(force_cpu, "force_cpu", bool, "less_than")
helper = LayerHelper("less_than", **locals()) helper = LayerHelper("less_than", **locals())
if cond is None: if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
...@@ -1480,6 +1489,14 @@ def less_equal(x, y, cond=None): ...@@ -1480,6 +1489,14 @@ def less_equal(x, y, cond=None):
out1 = label<= limit #out1=[True, False] out1 = label<= limit #out1=[True, False]
""" """
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"less_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"less_equal")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"less_equal")
helper = LayerHelper("less_equal", **locals()) helper = LayerHelper("less_equal", **locals())
if cond is None: if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
...@@ -1521,6 +1538,14 @@ def greater_than(x, y, cond=None): ...@@ -1521,6 +1538,14 @@ def greater_than(x, y, cond=None):
out = fluid.layers.greater_than(x=label, y=limit) #out=[False, True] out = fluid.layers.greater_than(x=label, y=limit) #out=[False, True]
out1 = label > limit #out1=[False, True] out1 = label > limit #out1=[False, True]
""" """
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"greater_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"greater_than")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"greater_than")
helper = LayerHelper("greater_than", **locals()) helper = LayerHelper("greater_than", **locals())
if cond is None: if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
...@@ -1564,6 +1589,14 @@ def greater_equal(x, y, cond=None): ...@@ -1564,6 +1589,14 @@ def greater_equal(x, y, cond=None):
out_1 = label >= limit #out1=[True, False] out_1 = label >= limit #out1=[True, False]
""" """
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"greater_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"greater_equal")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"greater_equal")
helper = LayerHelper("greater_equal", **locals()) helper = LayerHelper("greater_equal", **locals())
if cond is None: if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
...@@ -1607,6 +1640,14 @@ def equal(x, y, cond=None): ...@@ -1607,6 +1640,14 @@ def equal(x, y, cond=None):
out1 = fluid.layers.equal(x=label,y=limit) #out1=[True, False] out1 = fluid.layers.equal(x=label,y=limit) #out1=[True, False]
out2 = fluid.layers.equal(x=label_cond,y=limit, cond=out_cond) #out2=[False, True] out_cond=[False, True] out2 = fluid.layers.equal(x=label_cond,y=limit, cond=out_cond) #out2=[False, True] out_cond=[False, True]
""" """
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"equal")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"equal")
helper = LayerHelper("equal", **locals()) helper = LayerHelper("equal", **locals())
if cond is None: if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
...@@ -1641,6 +1682,14 @@ def not_equal(x, y, cond=None): ...@@ -1641,6 +1682,14 @@ def not_equal(x, y, cond=None):
limit = fluid.layers.fill_constant(shape=[1], value=1, dtype='int64') limit = fluid.layers.fill_constant(shape=[1], value=1, dtype='int64')
out = fluid.layers.not_equal(x=label, y=limit) out = fluid.layers.not_equal(x=label, y=limit)
""" """
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"not_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"not_equal")
if cond is not None:
check_variable_and_dtype(cond, "cond", [convert_dtype(x.dtype)],
"not_equal")
helper = LayerHelper("not_equal", **locals()) helper = LayerHelper("not_equal", **locals())
if cond is None: if cond is None:
cond = helper.create_variable_for_type_inference(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
......
...@@ -11364,6 +11364,12 @@ Examples: ...@@ -11364,6 +11364,12 @@ Examples:
def _logical_op(op_name, x, y, out=None, name=None, binary_op=True): def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
check_variable_and_dtype(x, "x", ["bool"], op_name)
if y is not None:
check_variable_and_dtype(y, "y", ["bool"], op_name)
if out is not None:
check_variable_and_dtype(out, "out", [convert_dtype(x.dtype)], op_name)
helper = LayerHelper(op_name, **locals()) helper = LayerHelper(op_name, **locals())
if binary_op: if binary_op:
......
...@@ -36,6 +36,26 @@ def create_test_class(op_type, typename, callback): ...@@ -36,6 +36,26 @@ def create_test_class(op_type, typename, callback):
def test_output(self): def test_output(self):
self.check_output() self.check_output()
def test_errors(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[2], dtype='int32')
y = fluid.layers.data(name='y', shape=[2], dtype='int32')
a = fluid.layers.data(name='a', shape=[2], dtype='int16')
b = fluid.layers.data(name='b', shape=[2], dtype='int64')
if self.op_type == "less_than":
self.assertRaises(
TypeError,
fluid.layers.less_than,
x=x,
y=y,
force_cpu=1)
op = eval("fluid.layers.%s" % self.op_type)
self.assertRaises(TypeError, op, x=x, y=y, cond=1)
if self.op_type != "less_than":
self.assertRaises(TypeError, op, x=x, y=y, cond=b)
self.assertRaises(TypeError, op, x=x, y=a)
self.assertRaises(TypeError, op, x=a, y=y)
cls_name = "{0}_{1}".format(op_type, typename) cls_name = "{0}_{1}".format(op_type, typename)
Cls.__name__ = cls_name Cls.__name__ = cls_name
globals()[cls_name] = Cls globals()[cls_name] = Cls
......
...@@ -17,6 +17,8 @@ from __future__ import print_function ...@@ -17,6 +17,8 @@ from __future__ import print_function
import op_test import op_test
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def create_test_class(op_type, callback, binary_op=True): def create_test_class(op_type, callback, binary_op=True):
...@@ -38,6 +40,22 @@ def create_test_class(op_type, callback, binary_op=True): ...@@ -38,6 +40,22 @@ def create_test_class(op_type, callback, binary_op=True):
def test_output(self): def test_output(self):
self.check_output() self.check_output()
def test_error(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[2], dtype='bool')
y = fluid.layers.data(name='y', shape=[2], dtype='bool')
a = fluid.layers.data(name='a', shape=[2], dtype='int32')
op = eval("fluid.layers.%s" % self.op_type)
if self.op_type != "logical_not":
self.assertRaises(TypeError, op, x=x, y=y, out=1)
self.assertRaises(TypeError, op, x=x, y=a)
self.assertRaises(TypeError, op, x=a, y=y)
self.assertRaises(TypeError, op, x=x, y=y, out=a)
else:
self.assertRaises(TypeError, op, x=x, out=1)
self.assertRaises(TypeError, op, x=x, out=a)
self.assertRaises(TypeError, op, x=a)
Cls.__name__ = op_type Cls.__name__ = op_type
globals()[op_type] = Cls globals()[op_type] = Cls
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册