未验证 提交 56b50c97 编写于 作者: Z Zhen Wang 提交者: GitHub

Add allclose_op (#23335)

* Add allclose Op, and its function is analogous to numpy.allclose. It returns True if two tensors are elementwise equal within a tolerance.
上级 948c57d8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/allclose_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "The first input tensor to compare.");
AddInput("Other", "The second input tensor to compare.");
AddOutput("Out", "The output tensor of allclose op.");
AddAttr<float>("rtol", "The relative tolerance. Default: :math:`1e-5` .")
.SetDefault(1e-5);
AddAttr<float>("atol", "The absolute tolerance. Default: :math:`1e-8` .")
.SetDefault(1e-8);
AddAttr<bool>("equal_nan",
"If :math:`True` , then two :math:`NaNs` will be "
"compared as equal. Default: :math:`False` .")
.SetDefault(false);
AddComment(R"DOC(
This operator checks if all :math:`input` and :math:`other` satisfy the condition:
:math:`\left| input - other \right| \leq atol + rtol \times \left| other \right|`
elementwise, for all elements of :math:`input` and :math:`other`. The behaviour of this
operator is analogous to :math:`numpy.allclose`, namely that it returns :math:`True` if
two tensors are elementwise equal within a tolerance.
)DOC");
}
};
class AllcloseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
platform::errors::NotFound(
"Input(Input) of allclose op should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Other"), true,
platform::errors::NotFound(
"Input(Other) of allclose op should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"The output(Out) of allclose op must not be null."));
auto input_dim = ctx->GetInputDim("Input");
auto other_dim = ctx->GetInputDim("Other");
PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(),
platform::errors::PreconditionNotMet(
"Input(Input) and Input(Other) must have the same "
"dimension size."));
int n = input_dim.size();
bool is_runtime = ctx->IsRuntime();
for (int i = 0; i < n; i++) {
if (is_runtime) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
} else {
if (!(input_dim[i] < 0 || other_dim[i] < 0)) {
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
platform::errors::PreconditionNotMet(
"The value at dim %d of Input(Input) is not "
"equal to the Input(Other): %ld != %ld.",
i, input_dim[i], other_dim[i]));
}
}
}
ctx->SetOutputDim("Out", framework::make_ddim({1}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
class AllcloseOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, framework::proto::VarType::BOOL);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
allclose, ops::AllcloseOp, ops::AllcloseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllcloseOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel<CPU, float>,
ops::AllcloseKernel<CPU, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#define EIGEN_USE_GPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/allclose_op.h"
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(allclose, ops::AllcloseKernel<CUDA, float>,
ops::AllcloseKernel<CUDA, double>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class AllcloseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// get attrs
float rtol = ctx.Attr<float>("rtol");
float atol = ctx.Attr<float>("atol");
bool equal_nan = ctx.Attr<bool>("equal_nan");
// get input/output
auto* input = ctx.Input<Tensor>("Input");
auto* other = ctx.Input<Tensor>("Other");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<bool>(ctx.GetPlace());
// get place
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto input_v = framework::EigenVector<T>::Flatten(*input);
auto other_v = framework::EigenVector<T>::Flatten(*other);
auto out_v = framework::EigenScalar<bool>::From(*out);
auto left = (input_v - other_v).abs();
auto right = static_cast<T>(atol) + static_cast<T>(rtol) * other_v.abs();
auto compare_res = left <= right;
if (equal_nan) {
auto input_nan = input_v.isnan();
auto other_nan = other_v.isnan();
out_v.device(place) =
(input_nan == other_nan).all() && (compare_res != input_nan).all();
} else {
out_v.device(place) = compare_res.all();
}
}
};
} // namespace operators
} // namespace paddle
......@@ -76,7 +76,7 @@ from .tensor.logic import equal #DEFINE_ALIAS
# from .tensor.logic import not_equal #DEFINE_ALIAS
# from .tensor.logic import reduce_all #DEFINE_ALIAS
# from .tensor.logic import reduce_any #DEFINE_ALIAS
# from .tensor.logic import allclose #DEFINE_ALIAS
from .tensor.logic import allclose #DEFINE_ALIAS
# from .tensor.logic import elementwise_equal #DEFINE_ALIAS
# from .tensor.logic import isnan #DEFINE_ALIAS
# from .tensor..tensor import Tensor #DEFINE_ALIAS
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import unittest
import numpy as np
class TestAllcloseLayer(unittest.TestCase):
def allclose_check(self, use_cuda):
a = fluid.data(name="a", shape=[2], dtype='float32')
b = fluid.data(name="b", shape=[2], dtype='float32')
result = paddle.allclose(
a, b, rtol=1e-05, atol=1e-08, equal_nan=False, name="ignore_nan")
result_nan = paddle.allclose(
a, b, rtol=1e-05, atol=1e-08, equal_nan=True, name="equal_nan")
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
x = np.array([10000., 1e-07]).astype("float32")
y = np.array([10000.1, 1e-08]).astype("float32")
result_v, result_nan_v = exe.run(feed={'a': x,
'b': y},
fetch_list=[result, result_nan])
self.assertEqual(result_v[0], False)
self.assertEqual(result_nan_v[0], False)
x = np.array([10000., 1e-08]).astype("float32")
y = np.array([10000.1, 1e-09]).astype("float32")
result_v, result_nan_v = exe.run(feed={'a': x,
'b': y},
fetch_list=[result, result_nan])
self.assertEqual(result_v[0], True)
self.assertEqual(result_nan_v[0], True)
x = np.array([1.0, float('nan')]).astype("float32")
y = np.array([1.0, float('nan')]).astype("float32")
result_v, result_nan_v = exe.run(feed={'a': x,
'b': y},
fetch_list=[result, result_nan])
self.assertEqual(result_v[0], False)
self.assertEqual(result_nan_v[0], True)
def test_allclose_cpu(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.allclose_check(use_cuda=False)
def test_allclose_gpu(self):
if fluid.core.is_compiled_with_cuda():
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.allclose_check(use_cuda=True)
def test_dygraph_mode(self):
x_1 = np.array([10000., 1e-07]).astype("float32")
y_1 = np.array([10000.1, 1e-08]).astype("float32")
x_2 = np.array([10000., 1e-08]).astype("float32")
y_2 = np.array([10000.1, 1e-09]).astype("float32")
x_3 = np.array([1.0, float('nan')]).astype("float32")
y_3 = np.array([1.0, float('nan')]).astype("float32")
with fluid.dygraph.guard():
x_v_1 = fluid.dygraph.to_variable(x_1)
y_v_1 = fluid.dygraph.to_variable(y_1)
ret_1 = paddle.allclose(
x_v_1,
y_v_1,
rtol=1e-05,
atol=1e-08,
equal_nan=False,
name='test_1')
self.assertEqual(ret_1.numpy()[0], False)
ret_1 = paddle.allclose(
x_v_1,
y_v_1,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
name='test_2')
self.assertEqual(ret_1.numpy()[0], False)
x_v_2 = fluid.dygraph.to_variable(x_2)
y_v_2 = fluid.dygraph.to_variable(y_2)
ret_2 = paddle.allclose(
x_v_2,
y_v_2,
rtol=1e-05,
atol=1e-08,
equal_nan=False,
name='test_3')
self.assertEqual(ret_2.numpy()[0], True)
ret_2 = paddle.allclose(
x_v_2,
y_v_2,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
name='test_4')
self.assertEqual(ret_2.numpy()[0], True)
x_v_3 = fluid.dygraph.to_variable(x_3)
y_v_3 = fluid.dygraph.to_variable(y_3)
ret_3 = paddle.allclose(
x_v_3,
y_v_3,
rtol=1e-05,
atol=1e-08,
equal_nan=False,
name='test_5')
self.assertEqual(ret_3.numpy()[0], False)
ret_3 = paddle.allclose(
x_v_3,
y_v_3,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
name='test_6')
self.assertEqual(ret_3.numpy()[0], True)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestAllcloseOp(OpTest):
def set_args(self):
self.input = np.array([10000., 1e-07]).astype("float32")
self.other = np.array([10000.1, 1e-08]).astype("float32")
self.rtol = 1e-05
self.atol = 1e-08
self.equal_nan = False
def setUp(self):
self.set_args()
self.op_type = "allclose"
self.inputs = {'Input': self.input, 'Other': self.other}
self.attrs = {
'rtol': self.rtol,
'atol': self.atol,
'equal_nan': self.equal_nan
}
self.outputs = {
'Out': np.array([
np.allclose(
self.inputs['Input'],
self.inputs['Other'],
rtol=self.rtol,
atol=self.atol,
equal_nan=self.equal_nan)
])
}
def test_check_output(self):
self.check_output()
class TestAllcloseOpSmallNum(TestAllcloseOp):
def set_args(self):
self.input = np.array([10000., 1e-08]).astype("float32")
self.other = np.array([10000.1, 1e-09]).astype("float32")
self.rtol = 1e-05
self.atol = 1e-08
self.equal_nan = False
class TestAllcloseOpNanFalse(TestAllcloseOp):
def set_args(self):
self.input = np.array([1.0, float('nan')]).astype("float32")
self.other = np.array([1.0, float('nan')]).astype("float32")
self.rtol = 1e-05
self.atol = 1e-08
self.equal_nan = False
class TestAllcloseOpNanTrue(TestAllcloseOp):
def set_args(self):
self.input = np.array([1.0, float('nan')]).astype("float32")
self.other = np.array([1.0, float('nan')]).astype("float32")
self.rtol = 1e-05
self.atol = 1e-08
self.equal_nan = True
if __name__ == "__main__":
unittest.main()
......@@ -53,7 +53,7 @@ from .logic import equal #DEFINE_ALIAS
# from .logic import not_equal #DEFINE_ALIAS
# from .logic import reduce_all #DEFINE_ALIAS
# from .logic import reduce_any #DEFINE_ALIAS
# from .logic import allclose #DEFINE_ALIAS
from .logic import allclose #DEFINE_ALIAS
# from .logic import elementwise_equal #DEFINE_ALIAS
# from .logic import isnan #DEFINE_ALIAS
# from . import Tensor #DEFINE_ALIAS
......
......@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.common_ops_import import *
import paddle.fluid as fluid
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_type
from ..fluid.layers.layer_function_generator import templatedoc
# TODO: define logic functions of a tensor
__all__ = [
......@@ -31,7 +32,7 @@ __all__ = [
# 'not_equal',
# 'reduce_all',
# 'reduce_any',
# 'allclose',
'allclose',
# 'elementwise_equal',
# 'isnan'
]
......@@ -102,3 +103,86 @@ def equal(x, y, axis=-1, name=None):
attrs=attrs,
outputs={'Out': [out]})
return out
@templatedoc()
def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
"""
${comment}
Args:
input(inputtype):{input_comment}.
other(othertype):{other_comment}.
rtol(rtoltype,optional):{rtol_comment}.
atol(atoltype,optional):{atol_comment}.
equal_nan(equalnantype,optional):{equal_nan_comment}.
name(STR, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
Returns:
${out_comment}.
Return Type:
${out_type}
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
use_cuda = fluid.core.is_compiled_with_cuda()
a = fluid.data(name="a", shape=[2], dtype='float32')
b = fluid.data(name="b", shape=[2], dtype='float32')
result = paddle.allclose(a, b, rtol=1e-05, atol=1e-08,
equal_nan=False, name="ignore_nan")
result_nan = paddle.allclose(a, b, rtol=1e-05, atol=1e-08,
equal_nan=True, name="equal_nan")
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
x = np.array([10000., 1e-07]).astype("float32")
y = np.array([10000.1, 1e-08]).astype("float32")
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y},
fetch_list=[result, result_nan])
print(result_v, result_nan_v)
# Output: (array([False]), array([False]))
x = np.array([10000., 1e-08]).astype("float32")
y = np.array([10000.1, 1e-09]).astype("float32")
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y},
fetch_list=[result, result_nan])
print(result_v, result_nan_v)
# Output: (array([ True]), array([ True]))
x = np.array([1.0, float('nan')]).astype("float32")
y = np.array([1.0, float('nan')]).astype("float32")
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y},
fetch_list=[result, result_nan])
print(result_v, result_nan_v)
# Output: (array([False]), array([ True]))
"""
check_type(rtol, 'rtol', float, 'allclose')
check_type(atol, 'atol', float, 'allclose')
check_type(equal_nan, 'equal_nan', bool, 'allclose')
helper = LayerHelper("allclose", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'Input': input, 'Other': other}
outputs = {'Out': out}
attrs = {'rtol': rtol, 'atol': atol, 'equal_nan': equal_nan}
helper.append_op(
type='allclose', inputs=inputs, outputs=outputs, attrs=attrs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册