未验证 提交 1f8de080 编写于 作者: W wawltor 提交者: GitHub

add the support for the bool in compare ops

add the support for the bool in compare ops
上级 606939de
......@@ -135,15 +135,17 @@ class CompareReduceOp : public framework::OperatorWithKernel {
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \
REGISTER_OP_CPU_KERNEL( \
op_type, ::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<int>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<int64_t>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<float>>, \
::paddle::operators::CompareReduceOpKernel< \
#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \
REGISTER_OP_CPU_KERNEL( \
op_type, ::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<bool>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<int>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<int64_t>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<float>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y");
......
......@@ -85,15 +85,18 @@ class CompareReduceOpKernel
} // namespace operators
} // namespace paddle
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int64_t>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<float>>, \
paddle::operators::CompareReduceOpKernel< \
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<bool>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int64_t>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<float>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all,
paddle::operators::EqualReduceFunctor);
......@@ -82,6 +82,7 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<bool>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
......
......@@ -98,6 +98,9 @@ class CompareOpKernel
#define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \
REGISTER_OP_##dev##_KERNEL(op_type, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<bool>, inverse_functor<bool>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int>, inverse_functor<int>>, \
......
......@@ -155,6 +155,38 @@ def create_paddle_case(op_type, callback):
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_bool_api_4(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[3, 1], dtype='bool')
y = paddle.static.data(name='y', shape=[3, 1], dtype='bool')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
input_x = np.array([True, False, True]).astype(np.bool)
input_y = np.array([True, True, False]).astype(np.bool)
real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_bool_broadcast_api_4(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[3, 1], dtype='bool')
y = paddle.static.data(name='y', shape=[1], dtype='bool')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
input_x = np.array([True, False, True]).astype(np.bool)
input_y = np.array([True]).astype(np.bool)
real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_attr_name(self):
paddle.enable_static()
with program_guard(Program(), Program()):
......
......@@ -92,9 +92,28 @@ def create_test_dim1_class(op_type, typename, callback):
globals()[cls_name] = Cls
def create_test_dim1_class(op_type, typename, callback):
class Cls(op_test.OpTest):
def setUp(self):
x = y = np.random.random(size=(1)).astype(typename)
x = np.array([True, False, True]).astype(typename)
x = np.array([False, False, True]).astype(typename)
z = callback(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z}
self.op_type = op_type
def test_output(self):
self.check_output()
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all')
Cls.__name__ = cls_name
globals()[cls_name] = Cls
np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y))
for _type_name in {'float32', 'float64', 'int32', 'int64'}:
for _type_name in {'float32', 'float64', 'int32', 'int64', 'bool'}:
create_test_not_equal_class('equal_all', _type_name, np_equal)
create_test_equal_class('equal_all', _type_name, np_equal)
create_test_dim1_class('equal_all', _type_name, np_equal)
......@@ -107,6 +126,14 @@ class TestEqualReduceAPI(unittest.TestCase):
out = paddle.equal_all(x, y, name='equal_res')
assert 'equal_res' in out.name
def test_dynamic_api(self):
paddle.disable_static()
x = paddle.ones(shape=[10, 10], dtype="int32")
y = paddle.ones(shape=[10, 10], dtype="int32")
out = paddle.equal_all(x, y)
assert out.numpy()[0] == True
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
......@@ -38,8 +38,8 @@ def equal_all(x, y, name=None):
**NOTICE**: The output of this OP has no gradient.
Args:
x(Tensor): Tensor, data type is float32, float64, int32, int64.
y(Tensor): Tensor, data type is float32, float64, int32, int64.
x(Tensor): Tensor, data type is bool, float32, float64, int32, int64.
y(Tensor): Tensor, data type is bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -59,6 +59,8 @@ def equal_all(x, y, name=None):
result2 = paddle.equal_all(x, z)
print(result2) # result2 = [False ]
"""
if in_dygraph_mode():
return core.ops.equal_all(x, y)
helper = LayerHelper("equal_all", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
......@@ -152,8 +154,8 @@ def equal(x, y, name=None):
**NOTICE**: The output of this OP has no gradient.
Args:
x(Tensor): Tensor, data type is float32, float64, int32, int64.
y(Tensor): Tensor, data type is float32, float64, int32, int64.
x(Tensor): Tensor, data type is bool, float32, float64, int32, int64.
y(Tensor): Tensor, data type is bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -174,10 +176,10 @@ def equal(x, y, name=None):
if in_dygraph_mode():
return core.ops.equal(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"equal")
check_variable_and_dtype(
x, "x", ["bool", "float32", "float64", "int32", "int64"], "equal")
check_variable_and_dtype(
y, "y", ["bool", "float32", "float64", "int32", "int64"], "equal")
helper = LayerHelper("equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True
......@@ -196,8 +198,8 @@ def greater_equal(x, y, name=None):
**NOTICE**: The output of this OP has no gradient.
Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -216,9 +218,11 @@ def greater_equal(x, y, name=None):
if in_dygraph_mode():
return core.ops.greater_equal(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
check_variable_and_dtype(x, "x",
["bool", "float32", "float64", "int32", "int64"],
"greater_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
check_variable_and_dtype(y, "y",
["bool", "float32", "float64", "int32", "int64"],
"greater_equal")
helper = LayerHelper("greater_equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
......@@ -240,8 +244,8 @@ def greater_than(x, y, name=None):
**NOTICE**: The output of this OP has no gradient.
Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -260,9 +264,11 @@ def greater_than(x, y, name=None):
if in_dygraph_mode():
return core.ops.greater_than(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
check_variable_and_dtype(x, "x",
["bool", "float32", "float64", "int32", "int64"],
"greater_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
check_variable_and_dtype(y, "y",
["bool", "float32", "float64", "int32", "int64"],
"greater_than")
helper = LayerHelper("greater_than", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
......@@ -284,8 +290,8 @@ def less_equal(x, y, name=None):
**NOTICE**: The output of this OP has no gradient.
Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -305,10 +311,10 @@ def less_equal(x, y, name=None):
if in_dygraph_mode():
return core.ops.less_equal(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"less_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"less_equal")
check_variable_and_dtype(
x, "x", ["bool", "float32", "float64", "int32", "int64"], "less_equal")
check_variable_and_dtype(
y, "y", ["bool", "float32", "float64", "int32", "int64"], "less_equal")
helper = LayerHelper("less_equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True
......@@ -327,8 +333,8 @@ def less_than(x, y, name=None):
**NOTICE**: The output of this OP has no gradient.
Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -348,10 +354,10 @@ def less_than(x, y, name=None):
if in_dygraph_mode():
return core.ops.less_than(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"less_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"less_than")
check_variable_and_dtype(
x, "x", ["bool", "float32", "float64", "int32", "int64"], "less_than")
check_variable_and_dtype(
y, "y", ["bool", "float32", "float64", "int32", "int64"], "less_than")
helper = LayerHelper("less_than", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True
......@@ -370,8 +376,8 @@ def not_equal(x, y, name=None):
**NOTICE**: The output of this OP has no gradient.
Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -391,10 +397,10 @@ def not_equal(x, y, name=None):
if in_dygraph_mode():
return core.ops.not_equal(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"not_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"not_equal")
check_variable_and_dtype(
x, "x", ["bool", "float32", "float64", "int32", "int64"], "not_equal")
check_variable_and_dtype(
y, "y", ["bool", "float32", "float64", "int32", "int64"], "not_equal")
helper = LayerHelper("not_equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册