From 56b50c97f877be2501709191c1e324b049e067e1 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Sat, 4 Apr 2020 11:47:06 +0800 Subject: [PATCH] Add allclose_op (#23335) * Add allclose Op, and its function is analogous to numpy.allclose. It returns True if two tensors are elementwise equal within a tolerance. --- paddle/fluid/operators/allclose_op.cc | 123 +++++++++++++++ paddle/fluid/operators/allclose_op.cu | 24 +++ paddle/fluid/operators/allclose_op.h | 61 ++++++++ python/paddle/__init__.py | 2 +- .../tests/unittests/test_allclose_layer.py | 140 ++++++++++++++++++ .../fluid/tests/unittests/test_allclose_op.py | 80 ++++++++++ python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/logic.py | 90 ++++++++++- 8 files changed, 517 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/allclose_op.cc create mode 100644 paddle/fluid/operators/allclose_op.cu create mode 100644 paddle/fluid/operators/allclose_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_allclose_layer.py create mode 100644 python/paddle/fluid/tests/unittests/test_allclose_op.py diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc new file mode 100644 index 00000000000..dcfb4104c31 --- /dev/null +++ b/paddle/fluid/operators/allclose_op.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/allclose_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "The first input tensor to compare."); + AddInput("Other", "The second input tensor to compare."); + AddOutput("Out", "The output tensor of allclose op."); + + AddAttr("rtol", "The relative tolerance. Default: :math:`1e-5` .") + .SetDefault(1e-5); + AddAttr("atol", "The absolute tolerance. Default: :math:`1e-8` .") + .SetDefault(1e-8); + AddAttr("equal_nan", + "If :math:`True` , then two :math:`NaNs` will be " + "compared as equal. Default: :math:`False` .") + .SetDefault(false); + + AddComment(R"DOC( +This operator checks if all :math:`input` and :math:`other` satisfy the condition: + +:math:`\left| input - other \right| \leq atol + rtol \times \left| other \right|` + +elementwise, for all elements of :math:`input` and :math:`other`. The behaviour of this +operator is analogous to :math:`numpy.allclose`, namely that it returns :math:`True` if +two tensors are elementwise equal within a tolerance. +)DOC"); + } +}; + +class AllcloseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, + platform::errors::NotFound( + "Input(Input) of allclose op should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Other"), true, + platform::errors::NotFound( + "Input(Other) of allclose op should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::NotFound( + "The output(Out) of allclose op must not be null.")); + + auto input_dim = ctx->GetInputDim("Input"); + auto other_dim = ctx->GetInputDim("Other"); + PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(), + platform::errors::PreconditionNotMet( + "Input(Input) and Input(Other) must have the same " + "dimension size.")); + int n = input_dim.size(); + bool is_runtime = ctx->IsRuntime(); + for (int i = 0; i < n; i++) { + if (is_runtime) { + PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], + platform::errors::PreconditionNotMet( + "The value at dim %d of Input(Input) is not " + "equal to the Input(Other): %ld != %ld.", + i, input_dim[i], other_dim[i])); + } else { + if (!(input_dim[i] < 0 || other_dim[i] < 0)) { + PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], + platform::errors::PreconditionNotMet( + "The value at dim %d of Input(Input) is not " + "equal to the Input(Other): %ld != %ld.", + i, input_dim[i], other_dim[i])); + } + } + } + + ctx->SetOutputDim("Out", framework::make_ddim({1})); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class AllcloseOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto out_var_name = ctx->Output("Out").front(); + ctx->SetDataType(out_var_name, framework::proto::VarType::BOOL); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR( + allclose, ops::AllcloseOp, ops::AllcloseOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::AllcloseOpVarTypeInference); +REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel, + ops::AllcloseKernel); diff --git a/paddle/fluid/operators/allclose_op.cu b/paddle/fluid/operators/allclose_op.cu new file mode 100644 index 00000000000..aaca4e5b122 --- /dev/null +++ b/paddle/fluid/operators/allclose_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define EIGEN_USE_GPU + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/allclose_op.h" + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(allclose, ops::AllcloseKernel, + ops::AllcloseKernel); diff --git a/paddle/fluid/operators/allclose_op.h b/paddle/fluid/operators/allclose_op.h new file mode 100644 index 00000000000..51893c087ce --- /dev/null +++ b/paddle/fluid/operators/allclose_op.h @@ -0,0 +1,61 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class AllcloseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // get attrs + float rtol = ctx.Attr("rtol"); + float atol = ctx.Attr("atol"); + bool equal_nan = ctx.Attr("equal_nan"); + // get input/output + auto* input = ctx.Input("Input"); + auto* other = ctx.Input("Other"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + // get place + auto& place = *ctx.template device_context().eigen_device(); + + auto input_v = framework::EigenVector::Flatten(*input); + auto other_v = framework::EigenVector::Flatten(*other); + auto out_v = framework::EigenScalar::From(*out); + + auto left = (input_v - other_v).abs(); + auto right = static_cast(atol) + static_cast(rtol) * other_v.abs(); + auto compare_res = left <= right; + + if (equal_nan) { + auto input_nan = input_v.isnan(); + auto other_nan = other_v.isnan(); + out_v.device(place) = + (input_nan == other_nan).all() && (compare_res != input_nan).all(); + } else { + out_v.device(place) = compare_res.all(); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index c8deb79c85a..24db112564b 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -76,7 +76,7 @@ from .tensor.logic import equal #DEFINE_ALIAS # from .tensor.logic import not_equal #DEFINE_ALIAS # from .tensor.logic import reduce_all #DEFINE_ALIAS # from .tensor.logic import reduce_any #DEFINE_ALIAS -# from .tensor.logic import allclose #DEFINE_ALIAS +from .tensor.logic import allclose #DEFINE_ALIAS # from .tensor.logic import elementwise_equal #DEFINE_ALIAS # from .tensor.logic import isnan #DEFINE_ALIAS # from .tensor..tensor import Tensor #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_allclose_layer.py b/python/paddle/fluid/tests/unittests/test_allclose_layer.py new file mode 100644 index 00000000000..60fd157d2e7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_allclose_layer.py @@ -0,0 +1,140 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import unittest +import numpy as np + + +class TestAllcloseLayer(unittest.TestCase): + def allclose_check(self, use_cuda): + a = fluid.data(name="a", shape=[2], dtype='float32') + b = fluid.data(name="b", shape=[2], dtype='float32') + + result = paddle.allclose( + a, b, rtol=1e-05, atol=1e-08, equal_nan=False, name="ignore_nan") + result_nan = paddle.allclose( + a, b, rtol=1e-05, atol=1e-08, equal_nan=True, name="equal_nan") + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + x = np.array([10000., 1e-07]).astype("float32") + y = np.array([10000.1, 1e-08]).astype("float32") + result_v, result_nan_v = exe.run(feed={'a': x, + 'b': y}, + fetch_list=[result, result_nan]) + self.assertEqual(result_v[0], False) + self.assertEqual(result_nan_v[0], False) + + x = np.array([10000., 1e-08]).astype("float32") + y = np.array([10000.1, 1e-09]).astype("float32") + result_v, result_nan_v = exe.run(feed={'a': x, + 'b': y}, + fetch_list=[result, result_nan]) + self.assertEqual(result_v[0], True) + self.assertEqual(result_nan_v[0], True) + + x = np.array([1.0, float('nan')]).astype("float32") + y = np.array([1.0, float('nan')]).astype("float32") + result_v, result_nan_v = exe.run(feed={'a': x, + 'b': y}, + fetch_list=[result, result_nan]) + self.assertEqual(result_v[0], False) + self.assertEqual(result_nan_v[0], True) + + def test_allclose_cpu(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.allclose_check(use_cuda=False) + + def test_allclose_gpu(self): + if fluid.core.is_compiled_with_cuda(): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.allclose_check(use_cuda=True) + + def test_dygraph_mode(self): + x_1 = np.array([10000., 1e-07]).astype("float32") + y_1 = np.array([10000.1, 1e-08]).astype("float32") + x_2 = np.array([10000., 1e-08]).astype("float32") + y_2 = np.array([10000.1, 1e-09]).astype("float32") + x_3 = np.array([1.0, float('nan')]).astype("float32") + y_3 = np.array([1.0, float('nan')]).astype("float32") + + with fluid.dygraph.guard(): + x_v_1 = fluid.dygraph.to_variable(x_1) + y_v_1 = fluid.dygraph.to_variable(y_1) + ret_1 = paddle.allclose( + x_v_1, + y_v_1, + rtol=1e-05, + atol=1e-08, + equal_nan=False, + name='test_1') + self.assertEqual(ret_1.numpy()[0], False) + ret_1 = paddle.allclose( + x_v_1, + y_v_1, + rtol=1e-05, + atol=1e-08, + equal_nan=True, + name='test_2') + self.assertEqual(ret_1.numpy()[0], False) + x_v_2 = fluid.dygraph.to_variable(x_2) + y_v_2 = fluid.dygraph.to_variable(y_2) + ret_2 = paddle.allclose( + x_v_2, + y_v_2, + rtol=1e-05, + atol=1e-08, + equal_nan=False, + name='test_3') + self.assertEqual(ret_2.numpy()[0], True) + ret_2 = paddle.allclose( + x_v_2, + y_v_2, + rtol=1e-05, + atol=1e-08, + equal_nan=True, + name='test_4') + self.assertEqual(ret_2.numpy()[0], True) + x_v_3 = fluid.dygraph.to_variable(x_3) + y_v_3 = fluid.dygraph.to_variable(y_3) + ret_3 = paddle.allclose( + x_v_3, + y_v_3, + rtol=1e-05, + atol=1e-08, + equal_nan=False, + name='test_5') + self.assertEqual(ret_3.numpy()[0], False) + ret_3 = paddle.allclose( + x_v_3, + y_v_3, + rtol=1e-05, + atol=1e-08, + equal_nan=True, + name='test_6') + self.assertEqual(ret_3.numpy()[0], True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_allclose_op.py b/python/paddle/fluid/tests/unittests/test_allclose_op.py new file mode 100644 index 00000000000..5b5ed264188 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_allclose_op.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestAllcloseOp(OpTest): + def set_args(self): + self.input = np.array([10000., 1e-07]).astype("float32") + self.other = np.array([10000.1, 1e-08]).astype("float32") + self.rtol = 1e-05 + self.atol = 1e-08 + self.equal_nan = False + + def setUp(self): + self.set_args() + self.op_type = "allclose" + self.inputs = {'Input': self.input, 'Other': self.other} + self.attrs = { + 'rtol': self.rtol, + 'atol': self.atol, + 'equal_nan': self.equal_nan + } + self.outputs = { + 'Out': np.array([ + np.allclose( + self.inputs['Input'], + self.inputs['Other'], + rtol=self.rtol, + atol=self.atol, + equal_nan=self.equal_nan) + ]) + } + + def test_check_output(self): + self.check_output() + + +class TestAllcloseOpSmallNum(TestAllcloseOp): + def set_args(self): + self.input = np.array([10000., 1e-08]).astype("float32") + self.other = np.array([10000.1, 1e-09]).astype("float32") + self.rtol = 1e-05 + self.atol = 1e-08 + self.equal_nan = False + + +class TestAllcloseOpNanFalse(TestAllcloseOp): + def set_args(self): + self.input = np.array([1.0, float('nan')]).astype("float32") + self.other = np.array([1.0, float('nan')]).astype("float32") + self.rtol = 1e-05 + self.atol = 1e-08 + self.equal_nan = False + + +class TestAllcloseOpNanTrue(TestAllcloseOp): + def set_args(self): + self.input = np.array([1.0, float('nan')]).astype("float32") + self.other = np.array([1.0, float('nan')]).astype("float32") + self.rtol = 1e-05 + self.atol = 1e-08 + self.equal_nan = True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 36b7868dba6..65758e0a921 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -53,7 +53,7 @@ from .logic import equal #DEFINE_ALIAS # from .logic import not_equal #DEFINE_ALIAS # from .logic import reduce_all #DEFINE_ALIAS # from .logic import reduce_any #DEFINE_ALIAS -# from .logic import allclose #DEFINE_ALIAS +from .logic import allclose #DEFINE_ALIAS # from .logic import elementwise_equal #DEFINE_ALIAS # from .logic import isnan #DEFINE_ALIAS # from . import Tensor #DEFINE_ALIAS diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 13a4c5cdc69..71fa2d52427 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.common_ops_import import * -import paddle.fluid as fluid +from ..fluid.layer_helper import LayerHelper +from ..fluid.data_feeder import check_type +from ..fluid.layers.layer_function_generator import templatedoc # TODO: define logic functions of a tensor __all__ = [ @@ -31,7 +32,7 @@ __all__ = [ # 'not_equal', # 'reduce_all', # 'reduce_any', - # 'allclose', + 'allclose', # 'elementwise_equal', # 'isnan' ] @@ -102,3 +103,86 @@ def equal(x, y, axis=-1, name=None): attrs=attrs, outputs={'Out': [out]}) return out + + +@templatedoc() +def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + """ + ${comment} + + Args: + input(inputtype):{input_comment}. + other(othertype):{other_comment}. + rtol(rtoltype,optional):{rtol_comment}. + atol(atoltype,optional):{atol_comment}. + equal_nan(equalnantype,optional):{equal_nan_comment}. + name(STR, optional): The default value is None. + Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + ${out_comment}. + + Return Type: + ${out_type} + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + use_cuda = fluid.core.is_compiled_with_cuda() + + a = fluid.data(name="a", shape=[2], dtype='float32') + b = fluid.data(name="b", shape=[2], dtype='float32') + + result = paddle.allclose(a, b, rtol=1e-05, atol=1e-08, + equal_nan=False, name="ignore_nan") + result_nan = paddle.allclose(a, b, rtol=1e-05, atol=1e-08, + equal_nan=True, name="equal_nan") + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + x = np.array([10000., 1e-07]).astype("float32") + y = np.array([10000.1, 1e-08]).astype("float32") + result_v, result_nan_v = exe.run( + feed={'a': x, 'b': y}, + fetch_list=[result, result_nan]) + print(result_v, result_nan_v) + # Output: (array([False]), array([False])) + + x = np.array([10000., 1e-08]).astype("float32") + y = np.array([10000.1, 1e-09]).astype("float32") + result_v, result_nan_v = exe.run( + feed={'a': x, 'b': y}, + fetch_list=[result, result_nan]) + print(result_v, result_nan_v) + # Output: (array([ True]), array([ True])) + + x = np.array([1.0, float('nan')]).astype("float32") + y = np.array([1.0, float('nan')]).astype("float32") + result_v, result_nan_v = exe.run( + feed={'a': x, 'b': y}, + fetch_list=[result, result_nan]) + print(result_v, result_nan_v) + # Output: (array([False]), array([ True])) + """ + + check_type(rtol, 'rtol', float, 'allclose') + check_type(atol, 'atol', float, 'allclose') + check_type(equal_nan, 'equal_nan', bool, 'allclose') + + helper = LayerHelper("allclose", **locals()) + out = helper.create_variable_for_type_inference(dtype='bool') + + inputs = {'Input': input, 'Other': other} + outputs = {'Out': out} + attrs = {'rtol': rtol, 'atol': atol, 'equal_nan': equal_nan} + helper.append_op( + type='allclose', inputs=inputs, outputs=outputs, attrs=attrs) + + return out -- GitLab