From d2200e974f8a7c32a4c4ef39fee7ad4204b58bcb Mon Sep 17 00:00:00 2001 From: andyjpaddle <87074272+andyjpaddle@users.noreply.github.com> Date: Mon, 22 Nov 2021 19:43:40 +0800 Subject: [PATCH] Add isclose op (#37135) * add isclose op, test=develop * add isclose op, test=develop * add isclose api, test=develop * rm useless code * rm useless code * update python api of isclose * add some unittest of isclose op, test=develop --- paddle/fluid/operators/isclose_op.cc | 165 +++++++++++++ paddle/fluid/operators/isclose_op.cu | 86 +++++++ paddle/fluid/operators/isclose_op.h | 87 +++++++ python/paddle/__init__.py | 2 + .../fluid/tests/unittests/test_isclose_op.py | 225 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/logic.py | 74 ++++++ 7 files changed, 641 insertions(+) create mode 100644 paddle/fluid/operators/isclose_op.cc create mode 100644 paddle/fluid/operators/isclose_op.cu create mode 100644 paddle/fluid/operators/isclose_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_isclose_op.py diff --git a/paddle/fluid/operators/isclose_op.cc b/paddle/fluid/operators/isclose_op.cc new file mode 100644 index 00000000000..0ae7a9fa02f --- /dev/null +++ b/paddle/fluid/operators/isclose_op.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/isclose_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +template +struct GetTensorValue { + T operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor& tensor) const { + return *(tensor.data()); + } +}; + +template +struct IscloseFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& other, + const double rtol, const double atol, bool equal_nan, + framework::Tensor* output) { + auto* in_a = in.data(); + auto* in_b = other.data(); + auto* out_data = output->mutable_data(ctx.GetPlace()); + auto num = in.numel(); + // *out_data = true; + for (int i = 0; i < num; i++) { + out_data[i] = true; + } + for (int i = 0; i < num; i++) { + const T a = in_a[i], b = in_b[i]; + bool val; + if (std::isnan(a) || std::isnan(b)) { + val = equal_nan && std::isnan(a) == std::isnan(b); + } else { + T left = (a > b ? a - b : b - a); + T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + T diff = (left > right ? left - right : right - left); + val = a == b || left <= right || diff <= 1e-15; + } + // *out_data &= val; + out_data[i] = val; + } + } +}; + +class IscloseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "The input tensor, it's data type should be float32, float64."); + AddInput("Other", + "The input tensor, it's data type should be float32, float64."); + AddInput("Rtol", "The relative tolerance.").AsDispensable(); + AddInput("Atol", "The absolute tolerance.").AsDispensable(); + AddOutput("Out", "The output tensor, it's data type is bool."); + AddAttr("rtol", + "The relative tolerance. Default: :math:`1e-5` .") + .SetDefault("1e-5"); + AddAttr("atol", + "The absolute tolerance. Default: :math:`1e-8` .") + .SetDefault("1e-8"); + AddAttr("equal_nan", + "If :math:`True` , then two :math:`NaNs` will be " + "compared as equal. Default: :math:`False` .") + .SetDefault(false); + + AddComment(R"DOC( +This operator checks if all :math:`x` and :math:`y` satisfy the condition: + +.. math:: + \left| x - y \right| \leq atol + rtol \times \left| y \right| + +elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this +operator is analogous to :math:`numpy.isclose`, namely that it returns :math:`True` if +two tensors are elementwise equal within a tolerance. +)DOC"); + } +}; + +class IscloseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Isclose"); + OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Isclose"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Isclose"); + + auto input_dim = ctx->GetInputDim("Input"); + auto other_dim = ctx->GetInputDim("Other"); + PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(), + platform::errors::PreconditionNotMet( + "Input(Input) and Input(Other) must have the same " + "dimension size.")); + int n = input_dim.size(); + bool is_runtime = ctx->IsRuntime(); + for (int i = 0; i < n; i++) { + if (is_runtime) { + PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], + platform::errors::PreconditionNotMet( + "The value at dim %d of Input(Input) is not " + "equal to the Input(Other): %ld != %ld.", + i, input_dim[i], other_dim[i])); + } else { + if (!(input_dim[i] < 0 || other_dim[i] < 0)) { + PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], + platform::errors::PreconditionNotMet( + "The value at dim %d of Input(Input) is not " + "equal to the Input(Other): %ld != %ld.", + i, input_dim[i], other_dim[i])); + } + } + } + + ctx->SetOutputDim("Out", input_dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); + } +}; + +class IscloseOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override { + ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR( + isclose, ops::IscloseOp, ops::IscloseOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::IscloseOpVarTypeInference); +REGISTER_OP_CPU_KERNEL(isclose, ops::IscloseKernel, + ops::IscloseKernel); diff --git a/paddle/fluid/operators/isclose_op.cu b/paddle/fluid/operators/isclose_op.cu new file mode 100644 index 00000000000..77295414eb9 --- /dev/null +++ b/paddle/fluid/operators/isclose_op.cu @@ -0,0 +1,86 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/isclose_op.h" + +namespace paddle { +namespace operators { + +template +struct GetTensorValue { + T operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor& tensor) const { + const T* data = tensor.data(); + T value; + const auto gpu_place = + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()); + memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T), + dev_ctx.stream()); + return value; + } +}; + +template +__global__ void IscloseCUDAKernel(const T* in_data, const T* other_data, + const double rtol, const double atol, + bool equal_nan, int num, bool* out_data) { + unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; + bool val; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + const T a = in_data[i], b = other_data[i]; + if (isnan(a) || isnan(b)) { + val = equal_nan && isnan(a) == isnan(b); + } else { + T left = (a > b ? a - b : b - a); + T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + T diff = (left > right ? left - right : right - left); + val = a == b || left <= right || diff <= 1e-15; + } + out_data[i] = val; + // if (!val) *out_data = false; + } +} + +template +struct IscloseFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor& in, const framework::Tensor& other, + const double rtol, const double atol, bool equal_nan, + framework::Tensor* output) { + int num = in.numel(); + const T* in_data = in.data(); + const T* other_data = other.data(); + bool* out_data = output->mutable_data(dev_ctx.GetPlace()); + int block = 1024; + int grid = (block - 1 + num) / block; + grid = (grid > block) ? block : grid; +#ifdef PADDLE_WITH_HIP + hipMemset(out_data, true, num * sizeof(bool)); +#else + cudaMemset(out_data, true, num * sizeof(bool)); +#endif + IscloseCUDAKernel<<>>( + in_data, other_data, rtol, atol, equal_nan, num, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; +REGISTER_OP_CUDA_KERNEL(isclose, ops::IscloseKernel, + ops::IscloseKernel); diff --git a/paddle/fluid/operators/isclose_op.h b/paddle/fluid/operators/isclose_op.h new file mode 100644 index 00000000000..7f5052c1e66 --- /dev/null +++ b/paddle/fluid/operators/isclose_op.h @@ -0,0 +1,87 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +struct GetTensorValue { + T operator()(const platform::DeviceContext& ctx, + const framework::Tensor& tensor) const; +}; + +template +struct IscloseFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& other, const float rtol, + const float atol, bool equal_nan, framework::Tensor* output); +}; + +template +class IscloseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // get attrs + bool equal_nan = ctx.Attr("equal_nan"); + // get input/output + const auto* input = ctx.Input("Input"); + const auto* other = ctx.Input("Other"); + auto* out = ctx.Output("Out"); + + double rtol_v = std::stod(ctx.Attr("rtol")); + double atol_v = std::stod(ctx.Attr("atol")); + + auto& dev_ctx = ctx.template device_context(); + GetTensorValue get_tensor_value; + if (ctx.HasInput("Rtol")) { + const auto* rtol = ctx.Input("Rtol"); + PADDLE_ENFORCE_EQ( + rtol->numel(), 1, + platform::errors::InvalidArgument( + "Input(Rtol) size must be 1, but get %d.", rtol->numel())); + PADDLE_ENFORCE_EQ(rtol->type(), framework::proto::VarType::FP64, + platform::errors::InvalidArgument( + "Input(Rtol) type must be double, but get %s.", + framework::DataTypeToString(rtol->type()))); + rtol_v = get_tensor_value(dev_ctx, *rtol); + } + if (ctx.HasInput("Atol")) { + const auto* atol = ctx.Input("Atol"); + PADDLE_ENFORCE_EQ( + atol->numel(), 1, + platform::errors::InvalidArgument( + "Input(Atol) size must be 1, but get %d", atol->numel())); + PADDLE_ENFORCE_EQ(atol->type(), framework::proto::VarType::FP64, + platform::errors::InvalidArgument( + "Input(Atol) type must be double, but get %s", + framework::DataTypeToString(atol->type()))); + atol_v = get_tensor_value(dev_ctx, *atol); + } + + IscloseFunctor()(dev_ctx, *input, *other, rtol_v, atol_v, + equal_nan, out); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e46e347d517..5823cf460ee 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -117,6 +117,7 @@ from .tensor.logic import bitwise_or # noqa: F401 from .tensor.logic import bitwise_xor # noqa: F401 from .tensor.logic import not_equal # noqa: F401 from .tensor.logic import allclose # noqa: F401 +from .tensor.logic import isclose # noqa: F401 from .tensor.logic import equal_all # noqa: F401 from .tensor.logic import is_tensor # noqa: F401 from .tensor.manipulation import cast # noqa: F401 @@ -322,6 +323,7 @@ __all__ = [ # noqa 'complex128', 'addmm', 'allclose', + 'isclose', 't', 'add', 'subtract', diff --git a/python/paddle/fluid/tests/unittests/test_isclose_op.py b/python/paddle/fluid/tests/unittests/test_isclose_op.py new file mode 100644 index 00000000000..aa39284d113 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_isclose_op.py @@ -0,0 +1,225 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle + + +class TestIscloseOp(OpTest): + def set_args(self): + self.input = np.array([10000., 1e-07]).astype("float32") + self.other = np.array([10000.1, 1e-08]).astype("float32") + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") + self.equal_nan = False + + def setUp(self): + paddle.enable_static() + self.set_args() + self.op_type = "isclose" + self.inputs = { + 'Input': self.input, + 'Other': self.other, + "Rtol": self.rtol, + "Atol": self.atol + } + self.attrs = {'equal_nan': self.equal_nan} + self.outputs = { + 'Out': np.array([ + np.isclose( + self.inputs['Input'], + self.inputs['Other'], + rtol=self.rtol, + atol=self.atol, + equal_nan=self.equal_nan) + ]) + } + + def test_check_output(self): + self.check_output() + + +class TestIscloseOpException(TestIscloseOp): + def test_check_output(self): + def test_rtol_num(): + self.inputs['Rtol'] = np.array([1e-05, 1e-05]).astype("float64") + self.inputs['Atol'] = np.array([1e-08]).astype("float64") + self.check_output() + + self.assertRaises(ValueError, test_rtol_num) + + def test_rtol_type(): + self.inputs['Rtol'] = np.array([5]).astype("int32") + self.inputs['Atol'] = np.array([1e-08]).astype("float64") + self.check_output() + + self.assertRaises(ValueError, test_rtol_type) + + def test_atol_num(): + self.inputs['Rtol'] = np.array([1e-05]).astype("float64") + self.inputs['Atol'] = np.array([1e-08, 1e-08]).astype("float64") + self.check_output() + + self.assertRaises(ValueError, test_atol_num) + + def test_atol_type(): + self.inputs['Rtol'] = np.array([1e-05]).astype("float64") + self.inputs['Atol'] = np.array([8]).astype("int32") + self.check_output() + + self.assertRaises(ValueError, test_atol_type) + + +class TestIscloseOpSmallNum(TestIscloseOp): + def set_args(self): + self.input = np.array([10000., 1e-08]).astype("float32") + self.other = np.array([10000.1, 1e-09]).astype("float32") + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") + self.equal_nan = False + + +class TestIscloseOpNanFalse(TestIscloseOp): + def set_args(self): + self.input = np.array([1.0, float('nan')]).astype("float32") + self.other = np.array([1.0, float('nan')]).astype("float32") + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") + self.equal_nan = False + + +class TestIscloseOpNanTrue(TestIscloseOp): + def set_args(self): + self.input = np.array([1.0, float('nan')]).astype("float32") + self.other = np.array([1.0, float('nan')]).astype("float32") + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") + self.equal_nan = True + + +class TestIscloseStatic(unittest.TestCase): + def test_api_case(self): + paddle.enable_static() + x_data = np.random.rand(10, 10) + y_data = np.random.rand(10, 10) + places = [paddle.fluid.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.fluid.CUDAPlace(0)) + for place in places: + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64') + y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') + result = paddle.isclose(x, y) + exe = paddle.fluid.Executor(place) + fetches = exe.run(paddle.fluid.default_main_program(), + feed={"x": x_data, + "y": y_data}, + fetch_list=[result]) + expected_out = np.isclose(x_data, y_data) + self.assertTrue((fetches[0] == expected_out).all(), True) + + +class TestIscloseDygraph(unittest.TestCase): + def test_api_case(self): + places = [paddle.CPUPlace()] + if paddle.fluid.core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for place in places: + paddle.disable_static() + x_data = np.random.rand(10, 10) + y_data = np.random.rand(10, 10) + x = paddle.to_tensor(x_data, place=place) + y = paddle.to_tensor(y_data, place=place) + out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08) + expected_out = np.isclose(x_data, y_data, rtol=1e-05, atol=1e-08) + self.assertTrue((out.numpy() == expected_out).all(), True) + paddle.enable_static() + + +class TestIscloseError(unittest.TestCase): + def test_input_dtype(self): + paddle.enable_static() + + def test_x_dtype(): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') + y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') + result = paddle.isclose(x, y) + + self.assertRaises(TypeError, test_x_dtype) + + def test_y_dtype(): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64') + y = paddle.fluid.data(name='y', shape=[10, 10], dtype='int32') + result = paddle.isclose(x, y) + + self.assertRaises(TypeError, test_y_dtype) + + def test_attr(self): + paddle.enable_static() + x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64') + y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') + + def test_rtol(): + result = paddle.isclose(x, y, rtol=True) + + self.assertRaises(TypeError, test_rtol) + + def test_atol(): + result = paddle.isclose(x, y, rtol=True) + + self.assertRaises(TypeError, test_atol) + + def test_equal_nan(): + result = paddle.isclose(x, y, equal_nan=1) + + self.assertRaises(TypeError, test_equal_nan) + + +class TestIscloseOpFloat32(TestIscloseOp): + def set_args(self): + self.input = np.array([10.1]).astype("float32") + self.other = np.array([10]).astype("float32") + self.rtol = np.array([0.01]).astype("float64") + self.atol = np.array([0]).astype("float64") + self.equal_nan = False + + +class TestIscloseOpFloat64(TestIscloseOp): + def set_args(self): + self.input = np.array([10.1]).astype("float64") + self.other = np.array([10]).astype("float64") + self.rtol = np.array([0.01]).astype("float64") + self.atol = np.array([0]).astype("float64") + self.equal_nan = False + + +class TestIscloseOpLargeDimInput(TestIscloseOp): + def set_args(self): + self.input = np.array(np.zeros([2048, 1024])).astype("float64") + self.other = np.array(np.zeros([2048, 1024])).astype("float64") + self.input[-1][-1] = 100 + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") + self.equal_nan = False + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index cb7b2928d02..21d1dd1793b 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -72,6 +72,7 @@ from .logic import bitwise_xor # noqa: F401 from .logic import bitwise_not # noqa: F401 from .logic import not_equal # noqa: F401 from .logic import allclose # noqa: F401 +from .logic import isclose # noqa: F401 from .logic import equal_all # noqa: F401 from .logic import is_tensor # noqa: F401 from .manipulation import cast # noqa: F401 @@ -331,6 +332,7 @@ tensor_method_func = [ #noqa 'logical_xor', 'not_equal', 'allclose', + 'isclose', 'is_tensor', 'cast', 'concat', diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index f944813f8ed..a9ec4891182 100755 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -583,3 +583,77 @@ def bitwise_not(x, out=None, name=None): return _bitwise_op( op_name="bitwise_not", x=x, y=None, name=name, out=out, binary_op=False) + + +@templatedoc() +def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): + """ + ${comment} + + Args: + x(Tensor): ${input_comment}. + y(Tensor): ${other_comment}. + rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . + atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . + equal_nan(equalnantype, optional): ${equal_nan_comment}. + name (str, optional): Name for the operation. For more information, please + refer to :ref:`api_guide_Name`. Default: None. + + Returns: + Tensor: ${out_comment}. + + Raises: + TypeError: The data type of ``x`` must be one of float32, float64. + TypeError: The data type of ``y`` must be one of float32, float64. + TypeError: The type of ``rtol`` must be float. + TypeError: The type of ``atol`` must be float. + TypeError: The type of ``equal_nan`` must be bool. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([10000., 1e-07]) + y = paddle.to_tensor([10000.1, 1e-08]) + result1 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08, + equal_nan=False, name="ignore_nan") + np_result1 = result1.numpy() + # [True, False] + result2 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08, + equal_nan=True, name="equal_nan") + np_result2 = result2.numpy() + # [True, False] + + x = paddle.to_tensor([1.0, float('nan')]) + y = paddle.to_tensor([1.0, float('nan')]) + result1 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08, + equal_nan=False, name="ignore_nan") + np_result1 = result1.numpy() + # [True, False] + result2 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08, + equal_nan=True, name="equal_nan") + np_result2 = result2.numpy() + # [True, True] + """ + + if in_dygraph_mode(): + return _C_ops.isclose(x, y, 'rtol', + str(rtol), 'atol', + str(atol), 'equal_nan', equal_nan) + + check_variable_and_dtype(x, "input", ['float32', 'float64'], 'isclose') + check_variable_and_dtype(y, "input", ['float32', 'float64'], 'isclose') + check_type(rtol, 'rtol', float, 'isclose') + check_type(atol, 'atol', float, 'isclose') + check_type(equal_nan, 'equal_nan', bool, 'isclose') + + helper = LayerHelper("isclose", **locals()) + out = helper.create_variable_for_type_inference(dtype='bool') + + inputs = {'Input': x, 'Other': y} + outputs = {'Out': out} + attrs = {'rtol': str(rtol), 'atol': str(atol), 'equal_nan': equal_nan} + helper.append_op( + type='isclose', inputs=inputs, outputs=outputs, attrs=attrs) + return out -- GitLab