未验证 提交 1f8de080 编写于 作者: W wawltor 提交者: GitHub

add the support for the bool in compare ops

add the support for the bool in compare ops
上级 606939de
...@@ -135,15 +135,17 @@ class CompareReduceOp : public framework::OperatorWithKernel { ...@@ -135,15 +135,17 @@ class CompareReduceOp : public framework::OperatorWithKernel {
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \ ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); ::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \ #define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
op_type, ::paddle::operators::CompareReduceOpKernel< \ op_type, ::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<int>>, \ ::paddle::platform::CPUDeviceContext, functor<bool>>, \
::paddle::operators::CompareReduceOpKernel< \ ::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<int64_t>>, \ ::paddle::platform::CPUDeviceContext, functor<int>>, \
::paddle::operators::CompareReduceOpKernel< \ ::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<float>>, \ ::paddle::platform::CPUDeviceContext, functor<int64_t>>, \
::paddle::operators::CompareReduceOpKernel< \ ::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<float>>, \
::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<double>>); ::paddle::platform::CPUDeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y"); REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y");
......
...@@ -85,15 +85,18 @@ class CompareReduceOpKernel ...@@ -85,15 +85,18 @@ class CompareReduceOpKernel
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ #define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
op_type, paddle::operators::CompareReduceOpKernel< \ op_type, paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int>>, \ paddle::platform::CUDADeviceContext, functor<bool>>, \
paddle::operators::CompareReduceOpKernel< \ paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<int64_t>>, \ paddle::platform::CUDADeviceContext, functor<int>>, \
paddle::operators::CompareReduceOpKernel< \ paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<float>>, \ paddle::platform::CUDADeviceContext, functor<int64_t>>, \
paddle::operators::CompareReduceOpKernel< \ paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<float>>, \
paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<double>>); paddle::platform::CUDADeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all,
paddle::operators::EqualReduceFunctor); paddle::operators::EqualReduceFunctor);
...@@ -82,6 +82,7 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor> ...@@ -82,6 +82,7 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \ #define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
op_type, \ op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<bool>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \ ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
......
...@@ -98,6 +98,9 @@ class CompareOpKernel ...@@ -98,6 +98,9 @@ class CompareOpKernel
#define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \ #define REGISTER_COMPARE_KERNEL(op_type, dev, functor, inverse_functor) \
REGISTER_OP_##dev##_KERNEL(op_type, \ REGISTER_OP_##dev##_KERNEL(op_type, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<bool>, inverse_functor<bool>>, \
::paddle::operators::CompareOpKernel< \ ::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \ ::paddle::platform::dev##DeviceContext, \
functor<int>, inverse_functor<int>>, \ functor<int>, inverse_functor<int>>, \
......
...@@ -155,6 +155,38 @@ def create_paddle_case(op_type, callback): ...@@ -155,6 +155,38 @@ def create_paddle_case(op_type, callback):
fetch_list=[out]) fetch_list=[out])
self.assertEqual((res == real_result).all(), True) self.assertEqual((res == real_result).all(), True)
def test_bool_api_4(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[3, 1], dtype='bool')
y = paddle.static.data(name='y', shape=[3, 1], dtype='bool')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
input_x = np.array([True, False, True]).astype(np.bool)
input_y = np.array([True, True, False]).astype(np.bool)
real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_bool_broadcast_api_4(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[3, 1], dtype='bool')
y = paddle.static.data(name='y', shape=[1], dtype='bool')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = paddle.static.Executor(self.place)
input_x = np.array([True, False, True]).astype(np.bool)
input_y = np.array([True]).astype(np.bool)
real_result = callback(input_x, input_y)
res, = exe.run(feed={"x": input_x,
"y": input_y},
fetch_list=[out])
self.assertEqual((res == real_result).all(), True)
def test_attr_name(self): def test_attr_name(self):
paddle.enable_static() paddle.enable_static()
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -92,9 +92,28 @@ def create_test_dim1_class(op_type, typename, callback): ...@@ -92,9 +92,28 @@ def create_test_dim1_class(op_type, typename, callback):
globals()[cls_name] = Cls globals()[cls_name] = Cls
def create_test_dim1_class(op_type, typename, callback):
class Cls(op_test.OpTest):
def setUp(self):
x = y = np.random.random(size=(1)).astype(typename)
x = np.array([True, False, True]).astype(typename)
x = np.array([False, False, True]).astype(typename)
z = callback(x, y)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z}
self.op_type = op_type
def test_output(self):
self.check_output()
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all')
Cls.__name__ = cls_name
globals()[cls_name] = Cls
np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y)) np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y))
for _type_name in {'float32', 'float64', 'int32', 'int64'}: for _type_name in {'float32', 'float64', 'int32', 'int64', 'bool'}:
create_test_not_equal_class('equal_all', _type_name, np_equal) create_test_not_equal_class('equal_all', _type_name, np_equal)
create_test_equal_class('equal_all', _type_name, np_equal) create_test_equal_class('equal_all', _type_name, np_equal)
create_test_dim1_class('equal_all', _type_name, np_equal) create_test_dim1_class('equal_all', _type_name, np_equal)
...@@ -107,6 +126,14 @@ class TestEqualReduceAPI(unittest.TestCase): ...@@ -107,6 +126,14 @@ class TestEqualReduceAPI(unittest.TestCase):
out = paddle.equal_all(x, y, name='equal_res') out = paddle.equal_all(x, y, name='equal_res')
assert 'equal_res' in out.name assert 'equal_res' in out.name
def test_dynamic_api(self):
paddle.disable_static()
x = paddle.ones(shape=[10, 10], dtype="int32")
y = paddle.ones(shape=[10, 10], dtype="int32")
out = paddle.equal_all(x, y)
assert out.numpy()[0] == True
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -38,8 +38,8 @@ def equal_all(x, y, name=None): ...@@ -38,8 +38,8 @@ def equal_all(x, y, name=None):
**NOTICE**: The output of this OP has no gradient. **NOTICE**: The output of this OP has no gradient.
Args: Args:
x(Tensor): Tensor, data type is float32, float64, int32, int64. x(Tensor): Tensor, data type is bool, float32, float64, int32, int64.
y(Tensor): Tensor, data type is float32, float64, int32, int64. y(Tensor): Tensor, data type is bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
...@@ -59,6 +59,8 @@ def equal_all(x, y, name=None): ...@@ -59,6 +59,8 @@ def equal_all(x, y, name=None):
result2 = paddle.equal_all(x, z) result2 = paddle.equal_all(x, z)
print(result2) # result2 = [False ] print(result2) # result2 = [False ]
""" """
if in_dygraph_mode():
return core.ops.equal_all(x, y)
helper = LayerHelper("equal_all", **locals()) helper = LayerHelper("equal_all", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
...@@ -152,8 +154,8 @@ def equal(x, y, name=None): ...@@ -152,8 +154,8 @@ def equal(x, y, name=None):
**NOTICE**: The output of this OP has no gradient. **NOTICE**: The output of this OP has no gradient.
Args: Args:
x(Tensor): Tensor, data type is float32, float64, int32, int64. x(Tensor): Tensor, data type is bool, float32, float64, int32, int64.
y(Tensor): Tensor, data type is float32, float64, int32, int64. y(Tensor): Tensor, data type is bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
...@@ -174,10 +176,10 @@ def equal(x, y, name=None): ...@@ -174,10 +176,10 @@ def equal(x, y, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.equal(x, y) return core.ops.equal(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"equal") x, "x", ["bool", "float32", "float64", "int32", "int64"], "equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"equal") y, "y", ["bool", "float32", "float64", "int32", "int64"], "equal")
helper = LayerHelper("equal", **locals()) helper = LayerHelper("equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True out.stop_gradient = True
...@@ -196,8 +198,8 @@ def greater_equal(x, y, name=None): ...@@ -196,8 +198,8 @@ def greater_equal(x, y, name=None):
**NOTICE**: The output of this OP has no gradient. **NOTICE**: The output of this OP has no gradient.
Args: Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -216,9 +218,11 @@ def greater_equal(x, y, name=None): ...@@ -216,9 +218,11 @@ def greater_equal(x, y, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.greater_equal(x, y) return core.ops.greater_equal(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(x, "x",
["bool", "float32", "float64", "int32", "int64"],
"greater_equal") "greater_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(y, "y",
["bool", "float32", "float64", "int32", "int64"],
"greater_equal") "greater_equal")
helper = LayerHelper("greater_equal", **locals()) helper = LayerHelper("greater_equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
...@@ -240,8 +244,8 @@ def greater_than(x, y, name=None): ...@@ -240,8 +244,8 @@ def greater_than(x, y, name=None):
**NOTICE**: The output of this OP has no gradient. **NOTICE**: The output of this OP has no gradient.
Args: Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -260,9 +264,11 @@ def greater_than(x, y, name=None): ...@@ -260,9 +264,11 @@ def greater_than(x, y, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.greater_than(x, y) return core.ops.greater_than(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(x, "x",
["bool", "float32", "float64", "int32", "int64"],
"greater_than") "greater_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(y, "y",
["bool", "float32", "float64", "int32", "int64"],
"greater_than") "greater_than")
helper = LayerHelper("greater_than", **locals()) helper = LayerHelper("greater_than", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
...@@ -284,8 +290,8 @@ def less_equal(x, y, name=None): ...@@ -284,8 +290,8 @@ def less_equal(x, y, name=None):
**NOTICE**: The output of this OP has no gradient. **NOTICE**: The output of this OP has no gradient.
Args: Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
...@@ -305,10 +311,10 @@ def less_equal(x, y, name=None): ...@@ -305,10 +311,10 @@ def less_equal(x, y, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.less_equal(x, y) return core.ops.less_equal(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"less_equal") x, "x", ["bool", "float32", "float64", "int32", "int64"], "less_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"less_equal") y, "y", ["bool", "float32", "float64", "int32", "int64"], "less_equal")
helper = LayerHelper("less_equal", **locals()) helper = LayerHelper("less_equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True out.stop_gradient = True
...@@ -327,8 +333,8 @@ def less_than(x, y, name=None): ...@@ -327,8 +333,8 @@ def less_than(x, y, name=None):
**NOTICE**: The output of this OP has no gradient. **NOTICE**: The output of this OP has no gradient.
Args: Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
...@@ -348,10 +354,10 @@ def less_than(x, y, name=None): ...@@ -348,10 +354,10 @@ def less_than(x, y, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.less_than(x, y) return core.ops.less_than(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"less_than") x, "x", ["bool", "float32", "float64", "int32", "int64"], "less_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"less_than") y, "y", ["bool", "float32", "float64", "int32", "int64"], "less_than")
helper = LayerHelper("less_than", **locals()) helper = LayerHelper("less_than", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True out.stop_gradient = True
...@@ -370,8 +376,8 @@ def not_equal(x, y, name=None): ...@@ -370,8 +376,8 @@ def not_equal(x, y, name=None):
**NOTICE**: The output of this OP has no gradient. **NOTICE**: The output of this OP has no gradient.
Args: Args:
x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. x(Tensor): First input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Tensor): Second input to compare which is N-D tensor. The input data type should be bool, float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`. user to set this property. For more information, please refer to :ref:`api_guide_Name`.
...@@ -391,10 +397,10 @@ def not_equal(x, y, name=None): ...@@ -391,10 +397,10 @@ def not_equal(x, y, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.not_equal(x, y) return core.ops.not_equal(x, y)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"not_equal") x, "x", ["bool", "float32", "float64", "int32", "int64"], "not_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"], check_variable_and_dtype(
"not_equal") y, "y", ["bool", "float32", "float64", "int32", "int64"], "not_equal")
helper = LayerHelper("not_equal", **locals()) helper = LayerHelper("not_equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True out.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册