未验证 提交 d2200e97 编写于 作者: A andyjpaddle 提交者: GitHub

Add isclose op (#37135)

* add isclose op, test=develop

* add isclose op, test=develop

* add isclose api, test=develop

* rm useless code

* rm useless code

* update python api of isclose

* add some unittest of isclose op, test=develop
上级 65c242ed
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/isclose_op.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename T>
struct GetTensorValue<platform::CPUDeviceContext, T> {
T operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor& tensor) const {
return *(tensor.data<T>());
}
};
template <typename T>
struct IscloseFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& other,
const double rtol, const double atol, bool equal_nan,
framework::Tensor* output) {
auto* in_a = in.data<T>();
auto* in_b = other.data<T>();
auto* out_data = output->mutable_data<bool>(ctx.GetPlace());
auto num = in.numel();
// *out_data = true;
for (int i = 0; i < num; i++) {
out_data[i] = true;
}
for (int i = 0; i < num; i++) {
const T a = in_a[i], b = in_b[i];
bool val;
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
// *out_data &= val;
out_data[i] = val;
}
}
};
class IscloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"The input tensor, it's data type should be float32, float64.");
AddInput("Other",
"The input tensor, it's data type should be float32, float64.");
AddInput("Rtol", "The relative tolerance.").AsDispensable();
AddInput("Atol", "The absolute tolerance.").AsDispensable();
AddOutput("Out", "The output tensor, it's data type is bool.");
AddAttr<std::string>("rtol",
"The relative tolerance. Default: :math:`1e-5` .")
.SetDefault("1e-5");
AddAttr<std::string>("atol",
"The absolute tolerance. Default: :math:`1e-8` .")
.SetDefault("1e-8");
AddAttr<bool>("equal_nan",
"If :math:`True` , then two :math:`NaNs` will be "
"compared as equal. Default: :math:`False` .")
.SetDefault(false);
AddComment(R"DOC(
This operator checks if all :math:`x` and :math:`y` satisfy the condition:
.. math::
\left| x - y \right| \leq atol + rtol \times \left| y \right|
elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this
operator is analogous to :math:`numpy.isclose`, namely that it returns :math:`True` if
two tensors are elementwise equal within a tolerance.
)DOC");
}
};
class IscloseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Isclose");
OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Isclose");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Isclose");
auto input_dim = ctx->GetInputDim("Input");
auto other_dim = ctx->GetInputDim("Other");
PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(),
platform::errors::PreconditionNotMet(
"Input(Input) and Input(Other) must have the same "
"dimension size."));
int n = input_dim.size();
bool is_runtime = ctx->IsRuntime();
for (int i = 0; i < n; i++) {
if (is_runtime) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
} else {
if (!(input_dim[i] < 0 || other_dim[i] < 0)) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
}
}
}
ctx->SetOutputDim("Out", input_dim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
class IscloseOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
isclose, ops::IscloseOp, ops::IscloseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::IscloseOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(isclose, ops::IscloseKernel<CPU, float>,
ops::IscloseKernel<CPU, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/isclose_op.h"
namespace paddle {
namespace operators {
template <typename T>
struct GetTensorValue<platform::CUDADeviceContext, T> {
T operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& tensor) const {
const T* data = tensor.data<T>();
T value;
const auto gpu_place =
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T),
dev_ctx.stream());
return value;
}
};
template <typename T>
__global__ void IscloseCUDAKernel(const T* in_data, const T* other_data,
const double rtol, const double atol,
bool equal_nan, int num, bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const T a = in_data[i], b = other_data[i];
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
out_data[i] = val;
// if (!val) *out_data = false;
}
}
template <typename T>
struct IscloseFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& in, const framework::Tensor& other,
const double rtol, const double atol, bool equal_nan,
framework::Tensor* output) {
int num = in.numel();
const T* in_data = in.data<T>();
const T* other_data = other.data<T>();
bool* out_data = output->mutable_data<bool>(dev_ctx.GetPlace());
int block = 1024;
int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
#ifdef PADDLE_WITH_HIP
hipMemset(out_data, true, num * sizeof(bool));
#else
cudaMemset(out_data, true, num * sizeof(bool));
#endif
IscloseCUDAKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, other_data, rtol, atol, equal_nan, num, out_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(isclose, ops::IscloseKernel<CUDA, float>,
ops::IscloseKernel<CUDA, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct GetTensorValue {
T operator()(const platform::DeviceContext& ctx,
const framework::Tensor& tensor) const;
};
template <typename DeviceContext, typename T>
struct IscloseFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& other, const float rtol,
const float atol, bool equal_nan, framework::Tensor* output);
};
template <typename DeviceContext, typename T>
class IscloseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// get attrs
bool equal_nan = ctx.Attr<bool>("equal_nan");
// get input/output
const auto* input = ctx.Input<Tensor>("Input");
const auto* other = ctx.Input<Tensor>("Other");
auto* out = ctx.Output<Tensor>("Out");
double rtol_v = std::stod(ctx.Attr<std::string>("rtol"));
double atol_v = std::stod(ctx.Attr<std::string>("atol"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
GetTensorValue<DeviceContext, double> get_tensor_value;
if (ctx.HasInput("Rtol")) {
const auto* rtol = ctx.Input<Tensor>("Rtol");
PADDLE_ENFORCE_EQ(
rtol->numel(), 1,
platform::errors::InvalidArgument(
"Input(Rtol) size must be 1, but get %d.", rtol->numel()));
PADDLE_ENFORCE_EQ(rtol->type(), framework::proto::VarType::FP64,
platform::errors::InvalidArgument(
"Input(Rtol) type must be double, but get %s.",
framework::DataTypeToString(rtol->type())));
rtol_v = get_tensor_value(dev_ctx, *rtol);
}
if (ctx.HasInput("Atol")) {
const auto* atol = ctx.Input<Tensor>("Atol");
PADDLE_ENFORCE_EQ(
atol->numel(), 1,
platform::errors::InvalidArgument(
"Input(Atol) size must be 1, but get %d", atol->numel()));
PADDLE_ENFORCE_EQ(atol->type(), framework::proto::VarType::FP64,
platform::errors::InvalidArgument(
"Input(Atol) type must be double, but get %s",
framework::DataTypeToString(atol->type())));
atol_v = get_tensor_value(dev_ctx, *atol);
}
IscloseFunctor<DeviceContext, T>()(dev_ctx, *input, *other, rtol_v, atol_v,
equal_nan, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -117,6 +117,7 @@ from .tensor.logic import bitwise_or # noqa: F401
from .tensor.logic import bitwise_xor # noqa: F401
from .tensor.logic import not_equal # noqa: F401
from .tensor.logic import allclose # noqa: F401
from .tensor.logic import isclose # noqa: F401
from .tensor.logic import equal_all # noqa: F401
from .tensor.logic import is_tensor # noqa: F401
from .tensor.manipulation import cast # noqa: F401
......@@ -322,6 +323,7 @@ __all__ = [ # noqa
'complex128',
'addmm',
'allclose',
'isclose',
't',
'add',
'subtract',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
class TestIscloseOp(OpTest):
def set_args(self):
self.input = np.array([10000., 1e-07]).astype("float32")
self.other = np.array([10000.1, 1e-08]).astype("float32")
self.rtol = np.array([1e-05]).astype("float64")
self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = False
def setUp(self):
paddle.enable_static()
self.set_args()
self.op_type = "isclose"
self.inputs = {
'Input': self.input,
'Other': self.other,
"Rtol": self.rtol,
"Atol": self.atol
}
self.attrs = {'equal_nan': self.equal_nan}
self.outputs = {
'Out': np.array([
np.isclose(
self.inputs['Input'],
self.inputs['Other'],
rtol=self.rtol,
atol=self.atol,
equal_nan=self.equal_nan)
])
}
def test_check_output(self):
self.check_output()
class TestIscloseOpException(TestIscloseOp):
def test_check_output(self):
def test_rtol_num():
self.inputs['Rtol'] = np.array([1e-05, 1e-05]).astype("float64")
self.inputs['Atol'] = np.array([1e-08]).astype("float64")
self.check_output()
self.assertRaises(ValueError, test_rtol_num)
def test_rtol_type():
self.inputs['Rtol'] = np.array([5]).astype("int32")
self.inputs['Atol'] = np.array([1e-08]).astype("float64")
self.check_output()
self.assertRaises(ValueError, test_rtol_type)
def test_atol_num():
self.inputs['Rtol'] = np.array([1e-05]).astype("float64")
self.inputs['Atol'] = np.array([1e-08, 1e-08]).astype("float64")
self.check_output()
self.assertRaises(ValueError, test_atol_num)
def test_atol_type():
self.inputs['Rtol'] = np.array([1e-05]).astype("float64")
self.inputs['Atol'] = np.array([8]).astype("int32")
self.check_output()
self.assertRaises(ValueError, test_atol_type)
class TestIscloseOpSmallNum(TestIscloseOp):
def set_args(self):
self.input = np.array([10000., 1e-08]).astype("float32")
self.other = np.array([10000.1, 1e-09]).astype("float32")
self.rtol = np.array([1e-05]).astype("float64")
self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = False
class TestIscloseOpNanFalse(TestIscloseOp):
def set_args(self):
self.input = np.array([1.0, float('nan')]).astype("float32")
self.other = np.array([1.0, float('nan')]).astype("float32")
self.rtol = np.array([1e-05]).astype("float64")
self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = False
class TestIscloseOpNanTrue(TestIscloseOp):
def set_args(self):
self.input = np.array([1.0, float('nan')]).astype("float32")
self.other = np.array([1.0, float('nan')]).astype("float32")
self.rtol = np.array([1e-05]).astype("float64")
self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = True
class TestIscloseStatic(unittest.TestCase):
def test_api_case(self):
paddle.enable_static()
x_data = np.random.rand(10, 10)
y_data = np.random.rand(10, 10)
places = [paddle.fluid.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.fluid.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
result = paddle.isclose(x, y)
exe = paddle.fluid.Executor(place)
fetches = exe.run(paddle.fluid.default_main_program(),
feed={"x": x_data,
"y": y_data},
fetch_list=[result])
expected_out = np.isclose(x_data, y_data)
self.assertTrue((fetches[0] == expected_out).all(), True)
class TestIscloseDygraph(unittest.TestCase):
def test_api_case(self):
places = [paddle.CPUPlace()]
if paddle.fluid.core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
paddle.disable_static()
x_data = np.random.rand(10, 10)
y_data = np.random.rand(10, 10)
x = paddle.to_tensor(x_data, place=place)
y = paddle.to_tensor(y_data, place=place)
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08)
expected_out = np.isclose(x_data, y_data, rtol=1e-05, atol=1e-08)
self.assertTrue((out.numpy() == expected_out).all(), True)
paddle.enable_static()
class TestIscloseError(unittest.TestCase):
def test_input_dtype(self):
paddle.enable_static()
def test_x_dtype():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
result = paddle.isclose(x, y)
self.assertRaises(TypeError, test_x_dtype)
def test_y_dtype():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='int32')
result = paddle.isclose(x, y)
self.assertRaises(TypeError, test_y_dtype)
def test_attr(self):
paddle.enable_static()
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float64')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
def test_rtol():
result = paddle.isclose(x, y, rtol=True)
self.assertRaises(TypeError, test_rtol)
def test_atol():
result = paddle.isclose(x, y, rtol=True)
self.assertRaises(TypeError, test_atol)
def test_equal_nan():
result = paddle.isclose(x, y, equal_nan=1)
self.assertRaises(TypeError, test_equal_nan)
class TestIscloseOpFloat32(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float32")
self.other = np.array([10]).astype("float32")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
class TestIscloseOpFloat64(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float64")
self.other = np.array([10]).astype("float64")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
class TestIscloseOpLargeDimInput(TestIscloseOp):
def set_args(self):
self.input = np.array(np.zeros([2048, 1024])).astype("float64")
self.other = np.array(np.zeros([2048, 1024])).astype("float64")
self.input[-1][-1] = 100
self.rtol = np.array([1e-05]).astype("float64")
self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = False
if __name__ == "__main__":
unittest.main()
......@@ -72,6 +72,7 @@ from .logic import bitwise_xor # noqa: F401
from .logic import bitwise_not # noqa: F401
from .logic import not_equal # noqa: F401
from .logic import allclose # noqa: F401
from .logic import isclose # noqa: F401
from .logic import equal_all # noqa: F401
from .logic import is_tensor # noqa: F401
from .manipulation import cast # noqa: F401
......@@ -331,6 +332,7 @@ tensor_method_func = [ #noqa
'logical_xor',
'not_equal',
'allclose',
'isclose',
'is_tensor',
'cast',
'concat',
......
......@@ -583,3 +583,77 @@ def bitwise_not(x, out=None, name=None):
return _bitwise_op(
op_name="bitwise_not", x=x, y=None, name=name, out=out, binary_op=False)
@templatedoc()
def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
"""
${comment}
Args:
x(Tensor): ${input_comment}.
y(Tensor): ${other_comment}.
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): ${equal_nan_comment}.
name (str, optional): Name for the operation. For more information, please
refer to :ref:`api_guide_Name`. Default: None.
Returns:
Tensor: ${out_comment}.
Raises:
TypeError: The data type of ``x`` must be one of float32, float64.
TypeError: The data type of ``y`` must be one of float32, float64.
TypeError: The type of ``rtol`` must be float.
TypeError: The type of ``atol`` must be float.
TypeError: The type of ``equal_nan`` must be bool.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([10000., 1e-07])
y = paddle.to_tensor([10000.1, 1e-08])
result1 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08,
equal_nan=False, name="ignore_nan")
np_result1 = result1.numpy()
# [True, False]
result2 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08,
equal_nan=True, name="equal_nan")
np_result2 = result2.numpy()
# [True, False]
x = paddle.to_tensor([1.0, float('nan')])
y = paddle.to_tensor([1.0, float('nan')])
result1 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08,
equal_nan=False, name="ignore_nan")
np_result1 = result1.numpy()
# [True, False]
result2 = paddle.isclose(x, y, rtol=1e-05, atol=1e-08,
equal_nan=True, name="equal_nan")
np_result2 = result2.numpy()
# [True, True]
"""
if in_dygraph_mode():
return _C_ops.isclose(x, y, 'rtol',
str(rtol), 'atol',
str(atol), 'equal_nan', equal_nan)
check_variable_and_dtype(x, "input", ['float32', 'float64'], 'isclose')
check_variable_and_dtype(y, "input", ['float32', 'float64'], 'isclose')
check_type(rtol, 'rtol', float, 'isclose')
check_type(atol, 'atol', float, 'isclose')
check_type(equal_nan, 'equal_nan', bool, 'isclose')
helper = LayerHelper("isclose", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'Input': x, 'Other': y}
outputs = {'Out': out}
attrs = {'rtol': str(rtol), 'atol': str(atol), 'equal_nan': equal_nan}
helper.append_op(
type='isclose', inputs=inputs, outputs=outputs, attrs=attrs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册