From 595a71979578ee0485afd80b0f39c2fe68684765 Mon Sep 17 00:00:00 2001 From: wawltor Date: Thu, 30 Jul 2020 23:23:13 +0800 Subject: [PATCH] Update the api for the compare_ops Update the code for the compare_ops, update the api and doc --- cmake/operators.cmake | 2 +- .../operators/controlflow/CMakeLists.txt | 2 +- ...compare_reduce_op.cc => compare_all_op.cc} | 67 ++-- ...compare_reduce_op.cu => compare_all_op.cu} | 57 ++-- .../{compare_reduce_op.h => compare_all_op.h} | 0 python/paddle/__init__.py | 2 +- python/paddle/fluid/layers/control_flow.py | 24 +- .../fluid/tests/unittests/test_compare_op.py | 48 ++- .../tests/unittests/test_compare_reduce_op.py | 70 +---- python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/logic.py | 285 +++++++++++++----- 11 files changed, 358 insertions(+), 201 deletions(-) rename paddle/fluid/operators/controlflow/{compare_reduce_op.cc => compare_all_op.cc} (75%) rename paddle/fluid/operators/controlflow/{compare_reduce_op.cu => compare_all_op.cu} (66%) rename paddle/fluid/operators/controlflow/{compare_reduce_op.h => compare_all_op.h} (100%) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 961c1b554a5..5b03cbf8c7f 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -114,7 +114,7 @@ function(op_library TARGET) endif() # Define operators that don't need pybind here. - foreach(manual_pybind_op "compare_reduce_op" "compare_op" "logical_op" "nccl_op" + foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index e1742b03ab7..680abc5ddff 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -9,4 +9,4 @@ cc_test(conditional_block_op_test SRCS conditional_block_op_test.cc DEPS conditi target_link_libraries(conditional_block_infer_op conditional_block_op) -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_reduce);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") diff --git a/paddle/fluid/operators/controlflow/compare_reduce_op.cc b/paddle/fluid/operators/controlflow/compare_all_op.cc similarity index 75% rename from paddle/fluid/operators/controlflow/compare_reduce_op.cc rename to paddle/fluid/operators/controlflow/compare_all_op.cc index 316b46b02ce..adacf70f5e1 100644 --- a/paddle/fluid/operators/controlflow/compare_reduce_op.cc +++ b/paddle/fluid/operators/controlflow/compare_all_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/controlflow/compare_reduce_op.h" +#include "paddle/fluid/operators/controlflow/compare_all_op.h" #include #include "paddle/fluid/framework/op_registry.h" @@ -30,38 +30,44 @@ class CompareReduceOpKernel auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* z = context.Output("Out"); - int axis = context.Attr("axis"); + bool shape_same = true; Tensor tmp; framework::DDim x_dims = x->dims(); framework::DDim y_dims = y->dims(); - int max_dim = std::max(x_dims.size(), y_dims.size()); - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector tmp_dims_array(max_dim); - GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), - y_dims_array.data(), tmp_dims_array.data(), max_dim, - axis); - tmp.mutable_data(framework::make_ddim(tmp_dims_array), - context.GetPlace()); - - if (x->numel() == 1 && y->numel() == 1) { - bool* z_data = tmp.mutable_data(context.GetPlace()); - z_data[0] = Functor()(x->data()[0], y->data()[0]); + + // judge the two inputs shape is same, if not same, just return false + if (x_dims.size() != y_dims.size()) { + shape_same = false; } else { - ElementwiseComputeEx( - context, x, y, axis, Functor(), &tmp); + for (auto i = 0; i < x_dims.size(); i++) { + if (x_dims[i] != y_dims[i]) { + shape_same = false; + break; + } + } } - // Reduce by 'logical and' operator - z->mutable_data(context.GetPlace()); - auto ipt = framework::EigenVector::Flatten(tmp); - auto out = framework::EigenScalar::From(*z); - auto& place = *context.template device_context() - .eigen_device(); - auto reduce_dim = Eigen::array({{0}}); - out.device(place) = ipt.all(reduce_dim); + bool* z_data = z->mutable_data(context.GetPlace()); + if (!shape_same) { + z_data[0] = false; + } else { + tmp.mutable_data(x_dims, context.GetPlace()); + if (x->numel() == 1 && y->numel() == 1) { + bool* z_data = tmp.mutable_data(context.GetPlace()); + z_data[0] = Functor()(x->data()[0], y->data()[0]); + } else { + ElementwiseComputeEx( + context, x, y, 0, Functor(), &tmp); + } + auto ipt = framework::EigenVector::Flatten(tmp); + auto out = framework::EigenScalar::From(*z); + auto& place = + *context.template device_context() + .eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + out.device(place) = ipt.all(reduce_dim); + } } }; @@ -74,11 +80,6 @@ class CompareReduceOpProtoMaker : public framework::OpProtoAndCheckerMaker { comment.type)); AddInput("Y", string::Sprintf("the right hand operand of %s operator", comment.type)); - AddAttr( - "axis", - "The start dimension index for broadcasting Y onto X. [default -1]") - .SetDefault(-1) - .EqualGreaterThan(-1); AddOutput("Out", string::Sprintf( "tensor with a bool element. If all " "element %s, the Out tensor is [True], else [False]", @@ -144,7 +145,7 @@ class CompareReduceOp : public framework::OperatorWithKernel { ::paddle::platform::CPUDeviceContext, functor>, \ ::paddle::operators::CompareReduceOpKernel< \ ::paddle::platform::CPUDeviceContext, functor>); -REGISTER_COMPARE_REDUCE_OP(equal_reduce, "X == Y"); +REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y"); -REGISTER_COMPARE_REDUCE_CPU_KERNEL(equal_reduce, +REGISTER_COMPARE_REDUCE_CPU_KERNEL(equal_all, paddle::operators::EqualReduceFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_reduce_op.cu b/paddle/fluid/operators/controlflow/compare_all_op.cu similarity index 66% rename from paddle/fluid/operators/controlflow/compare_reduce_op.cu rename to paddle/fluid/operators/controlflow/compare_all_op.cu index 3adac0d9664..e3c920f78c4 100644 --- a/paddle/fluid/operators/controlflow/compare_reduce_op.cu +++ b/paddle/fluid/operators/controlflow/compare_all_op.cu @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/controlflow/compare_reduce_op.h" +#include +#include "paddle/fluid/operators/controlflow/compare_all_op.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" namespace paddle { namespace operators { @@ -43,31 +44,41 @@ class CompareReduceOpKernel auto* x = context.Input("X"); auto* y = context.Input("Y"); auto* z = context.Output("Out"); - int axis = context.Attr("axis"); + bool shape_same = true; Tensor tmp; framework::DDim x_dims = x->dims(); framework::DDim y_dims = y->dims(); - int max_dim = std::max(x_dims.size(), y_dims.size()); - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector tmp_dims_array(max_dim); - GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), - y_dims_array.data(), tmp_dims_array.data(), max_dim, - axis); - tmp.mutable_data(framework::make_ddim(tmp_dims_array), - context.GetPlace()); - ElementwiseComputeEx(context, x, y, axis, - Functor(), &tmp); - // Reduce by 'bitwise and' operator - std::vector reduce_dims; - reduce_dims.resize(tmp.dims().size()); - for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i; - auto stream = context.cuda_device_context().stream(); - TensorReduce>( - tmp, z, reduce_dims, true, BitwiseAdd(), IdentityFunctor(), - stream); + + if (x_dims.size() != y_dims.size()) { + shape_same = false; + } else { + for (auto i = 0; i < x_dims.size(); i++) { + if (x_dims[i] != y_dims[i]) { + shape_same = false; + break; + } + } + } + + bool* z_data = z->mutable_data(context.GetPlace()); + if (!shape_same) { + thrust::device_ptr z_dev_ptr(z_data); + thrust::fill(z_dev_ptr, z_dev_ptr + 1, false); + return; + } else { + tmp.mutable_data(x_dims, context.GetPlace()); + ElementwiseComputeEx(context, x, y, 0, + Functor(), &tmp); + // Reduce by 'bitwise and' operator + std::vector reduce_dims; + reduce_dims.resize(tmp.dims().size()); + for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i; + auto stream = context.cuda_device_context().stream(); + TensorReduce>( + tmp, z, reduce_dims, true, BitwiseAdd(), IdentityFunctor(), + stream); + } } }; @@ -84,5 +95,5 @@ class CompareReduceOpKernel paddle::platform::CUDADeviceContext, functor>, \ paddle::operators::CompareReduceOpKernel< \ paddle::platform::CUDADeviceContext, functor>); -REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_reduce, +REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all, paddle::operators::EqualReduceFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_reduce_op.h b/paddle/fluid/operators/controlflow/compare_all_op.h similarity index 100% rename from paddle/fluid/operators/controlflow/compare_reduce_op.h rename to paddle/fluid/operators/controlflow/compare_all_op.h diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 0d572599a66..6cc986c61e1 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -98,7 +98,7 @@ from .tensor.logic import not_equal #DEFINE_ALIAS from .tensor.logic import reduce_all #DEFINE_ALIAS from .tensor.logic import reduce_any #DEFINE_ALIAS from .tensor.logic import allclose #DEFINE_ALIAS -from .tensor.logic import elementwise_equal #DEFINE_ALIAS +from .tensor.logic import equal_all #DEFINE_ALIAS # from .tensor.logic import isnan #DEFINE_ALIAS from .tensor.manipulation import cast #DEFINE_ALIAS from .tensor.manipulation import concat #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 294912cd453..b179d006262 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -1580,7 +1580,7 @@ def create_array(dtype): @templatedoc() -def less_than(x, y, force_cpu=None, cond=None): +def less_than(x, y, force_cpu=None, cond=None, name=None): """ :alias_main: paddle.less_than :alias: paddle.less_than,paddle.tensor.less_than,paddle.tensor.logic.less_than @@ -1595,6 +1595,8 @@ def less_than(x, y, force_cpu=None, cond=None): cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *less_than*. if cond is None, a new Varibale will be created to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: ${out_comment}. @@ -1649,7 +1651,7 @@ def less_than(x, y, force_cpu=None, cond=None): @templatedoc() -def less_equal(x, y, cond=None): +def less_equal(x, y, cond=None, name=None): """ :alias_main: paddle.less_equal :alias: paddle.less_equal,paddle.tensor.less_equal,paddle.tensor.logic.less_equal @@ -1662,6 +1664,8 @@ def less_equal(x, y, cond=None): y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *less_equal*. if cond is None, a new Varibale will be created to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`. @@ -1701,7 +1705,7 @@ def less_equal(x, y, cond=None): @templatedoc() -def greater_than(x, y, cond=None): +def greater_than(x, y, cond=None, name=None): """ :alias_main: paddle.greater_than :alias: paddle.greater_than,paddle.tensor.greater_than,paddle.tensor.logic.greater_than @@ -1714,6 +1718,8 @@ def greater_than(x, y, cond=None): y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *greater_than*. if cond is None, a new Varibale will be created to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x` . @@ -1752,7 +1758,7 @@ def greater_than(x, y, cond=None): @templatedoc() -def greater_equal(x, y, cond=None): +def greater_equal(x, y, cond=None, name=None): """ :alias_main: paddle.greater_equal :alias: paddle.greater_equal,paddle.tensor.greater_equal,paddle.tensor.logic.greater_equal @@ -1765,6 +1771,8 @@ def greater_equal(x, y, cond=None): y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *greater_equal*. if cond is None, a new Varibale will be created to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`. @@ -1804,7 +1812,7 @@ def greater_equal(x, y, cond=None): return cond -def equal(x, y, cond=None): +def equal(x, y, cond=None, name=None): """ This layer returns the truth value of :math:`x == y` elementwise. @@ -1814,6 +1822,8 @@ def equal(x, y, cond=None): cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *equal*. if cond is None, a new Varibale will be created to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: Variable: output Tensor, it's shape is the same as the input's Tensor, @@ -1849,7 +1859,7 @@ def equal(x, y, cond=None): return cond -def not_equal(x, y, cond=None): +def not_equal(x, y, cond=None, name=None): """ :alias_main: paddle.not_equal :alias: paddle.not_equal,paddle.tensor.not_equal,paddle.tensor.logic.not_equal @@ -1862,6 +1872,8 @@ def not_equal(x, y, cond=None): y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *not_equal*. if cond is None, a new Varibale will be created to store the result. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`. diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index 9d4a9082b54..ef687ff75c6 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -20,6 +20,7 @@ import numpy import numpy as np import paddle import paddle.fluid as fluid +import paddle.fluid.core as core from paddle.fluid import Program, program_guard @@ -67,6 +68,49 @@ for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) +def create_paddle_case(op_type, callback): + class PaddleCls(unittest.TestCase): + def setUp(self): + self.op_type = op_type + self.input_x = np.array([1, 2, 3, 4]) + self.input_y = np.array([1, 3, 2, 4]) + self.real_result = callback(self.input_x, self.input_y) + + def test_api(self): + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[4], dtype='int64') + y = fluid.layers.data(name='y', shape=[4], dtype='int64') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + place = fluid.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = fluid.Executor(place) + res, = exe.run(feed={"x": self.input_x, + "y": self.input_y}, + fetch_list=[out]) + self.assertEqual((res == self.real_result).all(), True) + + def test_attr_name(self): + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[4], dtype='int32') + y = fluid.layers.data(name='y', shape=[4], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x=x, y=y, name="name_%s" % (self.op_type)) + self.assertEqual("name_%s" % (self.op_type) in out.name, True) + + cls_name = "TestCase_{}".format(op_type) + PaddleCls.__name__ = cls_name + globals()[cls_name] = PaddleCls + + +create_paddle_case('less_equal', lambda _a, _b: _a <= _b) +create_paddle_case('greater_than', lambda _a, _b: _a > _b) +create_paddle_case('greater_equal', lambda _a, _b: _a >= _b) +create_paddle_case('equal', lambda _a, _b: _a == _b) +create_paddle_case('not_equal', lambda _a, _b: _a != _b) + + class TestCompareOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): @@ -82,7 +126,7 @@ class API_TestElementwise_Equal(unittest.TestCase): with fluid.program_guard(fluid.Program(), fluid.Program()): label = fluid.layers.assign(np.array([3, 3], dtype="int32")) limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) - out = paddle.elementwise_equal(x=label, y=limit) + out = paddle.equal(x=label, y=limit) place = fluid.CPUPlace() exe = fluid.Executor(place) res, = exe.run(fetch_list=[out]) @@ -91,7 +135,7 @@ class API_TestElementwise_Equal(unittest.TestCase): with fluid.program_guard(fluid.Program(), fluid.Program()): label = fluid.layers.assign(np.array([3, 3], dtype="int32")) limit = fluid.layers.assign(np.array([3, 3], dtype="int32")) - out = paddle.elementwise_equal(x=label, y=limit) + out = paddle.equal(x=label, y=limit) place = fluid.CPUPlace() exe = fluid.Executor(place) res, = exe.run(fetch_list=[out]) diff --git a/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py b/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py index d14ff1a4e25..67fe5c81ddc 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py @@ -22,30 +22,29 @@ import paddle.fluid as fluid from paddle.fluid import Program, program_guard -def create_test_broadcast_class(op_type, args, callback): +def create_test_not_equal_class(op_type, typename, callback): class Cls(op_test.OpTest): def setUp(self): - x = np.random.random(size=args['x_size']).astype('int32') - y = np.random.random(size=args['y_size']).astype('int32') + x = np.random.random(size=(10, 7)).astype(typename) + y = np.random.random(size=(10, 7)).astype(typename) z = callback(x, y) self.inputs = {'X': x, 'Y': y} self.outputs = {'Out': z} self.op_type = op_type - self.axis = args['axis'] def test_output(self): self.check_output() - cls_name = "{0}_{1}".format(op_type, 'broadcast') + cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_equal_all') Cls.__name__ = cls_name globals()[cls_name] = Cls -def create_test_not_equal_class(op_type, typename, callback): +def create_test_not_shape_equal_class(op_type, typename, callback): class Cls(op_test.OpTest): def setUp(self): x = np.random.random(size=(10, 7)).astype(typename) - y = np.random.random(size=(10, 7)).astype(typename) + y = np.random.random(size=(10)).astype(typename) z = callback(x, y) self.inputs = {'X': x, 'Y': y} self.outputs = {'Out': z} @@ -54,7 +53,7 @@ def create_test_not_equal_class(op_type, typename, callback): def test_output(self): self.check_output() - cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_equal') + cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_shape_equal_all') Cls.__name__ = cls_name globals()[cls_name] = Cls @@ -71,7 +70,7 @@ def create_test_equal_class(op_type, typename, callback): def test_output(self): self.check_output() - cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal') + cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all') Cls.__name__ = cls_name globals()[cls_name] = Cls @@ -88,7 +87,7 @@ def create_test_dim1_class(op_type, typename, callback): def test_output(self): self.check_output() - cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal') + cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all') Cls.__name__ = cls_name globals()[cls_name] = Cls @@ -96,59 +95,16 @@ def create_test_dim1_class(op_type, typename, callback): np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y)) for _type_name in {'float32', 'float64', 'int32', 'int64'}: - create_test_not_equal_class('equal_reduce', _type_name, np_equal) - create_test_equal_class('equal_reduce', _type_name, np_equal) - create_test_dim1_class('equal_reduce', _type_name, np_equal) - -broadcast_args = [{ - 'x_size': (100, 2, 3), - 'y_size': (100), - 'axis': 0 -}, { - 'x_size': (2, 100, 3), - 'y_size': (100), - 'axis': 1 -}, { - 'x_size': (2, 3, 100), - 'y_size': (1, 1), - 'axis': -1 -}, { - 'x_size': (2, 10, 12, 3), - 'y_size': (10, 12), - 'axis': 1 -}, { - 'x_size': (100, 2, 3, 4), - 'y_size': (100, 1), - 'axis': 0 -}, { - 'x_size': (10, 3, 12), - 'y_size': (10, 1, 12), - 'axis': -1 -}, { - 'x_size': (2, 12, 3, 5), - 'y_size': (2, 12, 1, 5), - 'axis': -1 -}, { - 'x_size': (2, 12, 3, 5), - 'y_size': (3, 5), - 'axis': 2 -}] - - -def np_broadcast_equal(_x, _y): - res = np.all(np.equal(_x, _y)) - return np.array(res) - - -for args in broadcast_args: - create_test_broadcast_class('equal_reduce', args, np_broadcast_equal) + create_test_not_equal_class('equal_all', _type_name, np_equal) + create_test_equal_class('equal_all', _type_name, np_equal) + create_test_dim1_class('equal_all', _type_name, np_equal) class TestEqualReduceAPI(unittest.TestCase): def test_name(self): x = fluid.layers.assign(np.array([3, 4], dtype="int32")) y = fluid.layers.assign(np.array([3, 4], dtype="int32")) - out = paddle.equal(x, y, name='equal_res') + out = paddle.equal_all(x, y, name='equal_res') assert 'equal_res' in out.name diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 8ffe9613995..21cae803716 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -71,7 +71,7 @@ from .logic import not_equal #DEFINE_ALIAS from .logic import reduce_all #DEFINE_ALIAS from .logic import reduce_any #DEFINE_ALIAS from .logic import allclose #DEFINE_ALIAS -from .logic import elementwise_equal #DEFINE_ALIAS +from .logic import equal_all #DEFINE_ALIAS # from .logic import isnan #DEFINE_ALIAS from .manipulation import cast #DEFINE_ALIAS from .manipulation import concat #DEFINE_ALIAS diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 9de2622e7c6..936022dd73b 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -15,24 +15,21 @@ from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_type from ..fluid.layers.layer_function_generator import templatedoc +from .. import fluid # TODO: define logic functions of a tensor -from ..fluid.layers import greater_equal #DEFINE_ALIAS -from ..fluid.layers import greater_than #DEFINE_ALIAS from ..fluid.layers import is_empty #DEFINE_ALIAS from ..fluid.layers import isfinite #DEFINE_ALIAS -from ..fluid.layers import less_equal #DEFINE_ALIAS -from ..fluid.layers import less_than #DEFINE_ALIAS from ..fluid.layers import logical_and #DEFINE_ALIAS from ..fluid.layers import logical_not #DEFINE_ALIAS from ..fluid.layers import logical_or #DEFINE_ALIAS from ..fluid.layers import logical_xor #DEFINE_ALIAS -from ..fluid.layers import not_equal #DEFINE_ALIAS from ..fluid.layers import reduce_all #DEFINE_ALIAS from ..fluid.layers import reduce_any #DEFINE_ALIAS __all__ = [ 'equal', + 'equal_all', 'greater_equal', 'greater_than', 'is_empty', @@ -47,78 +44,50 @@ __all__ = [ 'reduce_all', 'reduce_any', 'allclose', - 'elementwise_equal', # 'isnan' ] -def equal(x, y, axis=-1, name=None): +def equal_all(x, y, name=None): """ - :alias_main: paddle.equal - :alias: paddle.equal,paddle.tensor.equal,paddle.tensor.logic.equal + :alias_main: paddle.equal_all + :alias: paddle.equal_all,paddle.tensor.equal_all,paddle.tensor.logic.equal_all This OP returns the truth value of :math:`x == y`. True if two inputs have the same elements, False otherwise. - **NOTICE**: The output of this OP has no gradient, and this OP supports broadcasting by :attr:`axis`. + **NOTICE**: The output of this OP has no gradient. Args: - x(Variable): Tensor, data type is float32, float64, int32, int64. - y(Variable): Tensor, data type is float32, float64, int32, int64. - axis(int32, optional): If X.dimension != Y.dimension, Y.dimension - must be a subsequence of x.dimension. And axis is the start - dimension index for broadcasting Y onto X. For more detail, - please refer to OP:`elementwise_add`. - name(str, optional): Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name`.Default: None. + x(Tensor): Tensor, data type is float32, float64, int32, int64. + y(Tensor): Tensor, data type is float32, float64, int32, int64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: output Tensor, data type is bool, value is [False] or [True]. + Tensor: output Tensor, data type is bool, value is [False] or [True]. Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle import numpy as np - - label = fluid.layers.assign(np.array([3, 4], dtype="int32")) - label_1 = fluid.layers.assign(np.array([1, 2], dtype="int32")) - limit = fluid.layers.assign(np.array([3, 4], dtype="int32")) - out1 = paddle.equal(x=label, y=limit) #out1=[True] - out2 = paddle.equal(x=label_1, y=limit) #out2=[False] - - .. code-block:: python - - import paddle.fluid as fluid import paddle - import numpy as np - - def gen_data(): - return { - "x": np.ones((2, 3, 4, 5)).astype('float32'), - "y": np.zeros((3, 4)).astype('float32') - } - - x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32') - y = fluid.data(name="y", shape=[3,4], dtype='float32') - out = paddle.equal(x, y, axis=1) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - res = exe.run(feed=gen_data(), - fetch_list=[out]) - print(res[0]) #[False] + import paddle.imperative as imperative + + paddle.enable_imperative() + x = imperative.to_variable(np.array([1, 2, 3])) + y = imperative.to_variable(np.array([1, 2, 3])) + z = imperative.to_variable(np.array([1, 4, 3])) + result1 = paddle.equal_all(x, y) + print(result1.numpy()) # result1 = [True ] + result2 = paddle.equal_all(x, z) + print(result2.numpy()) # result2 = [False ] """ - helper = LayerHelper("equal_reduce", **locals()) + + helper = LayerHelper("equal_all", **locals()) out = helper.create_variable_for_type_inference(dtype='bool') - attrs = {} - attrs['axis'] = axis helper.append_op( - type='equal_reduce', - inputs={'X': [x], - 'Y': [y]}, - attrs=attrs, - outputs={'Out': [out]}) + type='equal_all', inputs={'X': [x], + 'Y': [y]}, outputs={'Out': [out]}) return out @@ -208,41 +177,205 @@ def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): return out -def elementwise_equal(x, y, name=None): +@templatedoc() +def equal(x, y, name=None): """ - :alias_main: paddle.elementwise_equal - :alias: paddle.elementwise_equal,paddle.tensor.elementwise_equal,paddle.tensor.logic.elementwise_equal + :alias_main: paddle.equal + :alias: paddle.equal,paddle.tensor.equal,paddle.tensor.logic.equal This layer returns the truth value of :math:`x == y` elementwise. + **NOTICE**: The output of this OP has no gradient. Args: - x(Variable): Tensor, data type is float32, float64, int32, int64. - y(Variable): Tensor, data type is float32, float64, int32, int64. + x(Tensor): Tensor, data type is float32, float64, int32, int64. + y(Tensor): Tensor, data type is float32, float64, int32, int64. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: output Tensor, it's shape is the same as the input's Tensor, + Tensor: output Tensor, it's shape is the same as the input's Tensor, and the data type is bool. The result of this op is stop_gradient. Examples: .. code-block:: python - import paddle - import paddle.fluid as fluid import numpy as np - label = fluid.layers.assign(np.array([3, 3], dtype="int32")) - limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) - out1 = paddle.elementwise_equal(x=label, y=limit) #out1=[True, False] + import paddle + import paddle.imperative as imperative + + paddle.enable_imperative() + x = imperative.to_variable(np.array([1, 2, 3])) + y = imperative.to_variable(np.array([1, 3, 2])) + result1 = paddle.equal(x, y) + print(result1.numpy()) # result1 = [True False False] """ - helper = LayerHelper("elementwise_equal", **locals()) - out = helper.create_variable_for_type_inference(dtype='bool') - out.stop_gradient = True + out = fluid.layers.equal(x, y, name=name, cond=None) + return out - helper.append_op( - type='equal', - inputs={'X': [x], - 'Y': [y]}, - outputs={'Out': [out]}, - attrs={'force_cpu': False}) + +@templatedoc() +def greater_equal(x, y, name=None): + """ + :alias_main: paddle.greater_equal + :alias: paddle.greater_equal,paddle.tensor.greater_equal,paddle.tensor.logic.greater_equal + + This OP returns the truth value of :math:`x >= y` elementwise, which is equivalent function to the overloaded operator `>=`. + **NOTICE**: The output of this OP has no gradient. + + Args: + x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + Returns: + Tensor, the output data type is bool: The tensor storing the output, the output shape is same as input :attr:`x`. + + Examples: + .. code-block:: python + import numpy as np + import paddle + import paddle.imperative as imperative + + paddle.enable_imperative() + x = imperative.to_variable(np.array([1, 2, 3])) + y = imperative.to_variable(np.array([1, 3, 2])) + result1 = paddle.greater_equal(x, y) + print(result1.numpy()) # result1 = [True False True] + """ + out = fluid.layers.greater_equal(x, y, name=name, cond=None) + return out + + +@templatedoc() +def greater_than(x, y, name=None): + """ + :alias_main: paddle.greater_than + :alias: paddle.greater_than,paddle.tensor.greater_than,paddle.tensor.logic.greater_than + + This OP returns the truth value of :math:`x > y` elementwise, which is equivalent function to the overloaded operator `>`. + **NOTICE**: The output of this OP has no gradient. + + Args: + x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + Returns: + Tensor, the output data type is bool: The tensor storing the output, the output shape is same as input :attr:`x` . + + Examples: + .. code-block:: python + import numpy as np + import paddle + import paddle.imperative as imperative + + paddle.enable_imperative() + x = imperative.to_variable(np.array([1, 2, 3])) + y = imperative.to_variable(np.array([1, 3, 2])) + result1 = paddle.greater_than(x, y) + print(result1.numpy()) # result1 = [False False True] + """ + out = fluid.layers.greater_than(x, y, name=name, cond=None) + return out + + +@templatedoc() +def less_equal(x, y, name=None): + """ + :alias_main: paddle.less_equal + :alias: paddle.less_equal,paddle.tensor.less_equal,paddle.tensor.logic.less_equal + + This OP returns the truth value of :math:`x <= y` elementwise, which is equivalent function to the overloaded operator `<=`. + **NOTICE**: The output of this OP has no gradient. + + Args: + x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the output data type is bool: The tensor storing the output, the output shape is same as input :attr:`x`. + + Examples: + .. code-block:: python + import numpy as np + import paddle + import paddle.imperative as imperative + + paddle.enable_imperative() + x = imperative.to_variable(np.array([1, 2, 3])) + y = imperative.to_variable(np.array([1, 3, 2])) + result1 = paddle.less_equal(x, y) + print(result1.numpy()) # result1 = [True True False] + """ + out = fluid.layers.less_equal(x, y, name=name, cond=None) + return out + + +@templatedoc() +def less_than(x, y, name=None): + """ + :alias_main: paddle.less_than + :alias: paddle.less_than,paddle.tensor.less_than,paddle.tensor.logic.less_than + + This OP returns the truth value of :math:`x < y` elementwise, which is equivalent function to the overloaded operator `<`. + **NOTICE**: The output of this OP has no gradient. + + Args: + x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the output data type is bool: The tensor storing the output, the output shape is same as input :attr:`x`. + + Examples: + .. code-block:: python + import numpy as np + import paddle + import paddle.imperative as imperative + + paddle.enable_imperative() + x = imperative.to_variable(np.array([1, 2, 3])) + y = imperative.to_variable(np.array([1, 3, 2])) + result1 = paddle.less_than(x, y) + print(result1.numpy()) # result1 = [False True False] + """ + out = fluid.layers.less_than(x, y, force_cpu=False, name=name, cond=None) + return out + + +@templatedoc() +def not_equal(x, y, name=None): + """ + :alias_main: paddle.not_equal + :alias: paddle.not_equal,paddle.tensor.not_equal,paddle.tensor.logic.not_equal + + This OP returns the truth value of :math:`x != y` elementwise, which is equivalent function to the overloaded operator `!=`. + **NOTICE**: The output of this OP has no gradient. + + Args: + x(Tensor): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + y(Tensor): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the output data type is bool: The tensor storing the output, the output shape is same as input :attr:`x`. + + Examples: + .. code-block:: python + import numpy as np + import paddle + import paddle.imperative as imperative + + paddle.enable_imperative() + x = imperative.to_variable(np.array([1, 2, 3])) + y = imperative.to_variable(np.array([1, 3, 2])) + result1 = paddle.not_equal(x, y) + print(result1.numpy()) # result1 = [False True True] + """ + out = fluid.layers.not_equal(x, y, name=name, cond=None) return out -- GitLab