diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc index cd83443f0522f7453da556bc63cf280b4d6a6a61..e452d3c21b8e0b892d528f5f830ce14e5debafa6 100644 --- a/paddle/fluid/operators/allclose_op.cc +++ b/paddle/fluid/operators/allclose_op.cc @@ -14,7 +14,9 @@ #include "paddle/fluid/operators/allclose_op.h" #include +#include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" @@ -63,9 +65,15 @@ class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker { "The input tensor, it's data type should be float32, float64."); AddInput("Other", "The input tensor, it's data type should be float32, float64."); - AddInput("Rtol", "The relative tolerance."); - AddInput("Atol", "The absolute tolerance."); + AddInput("Rtol", "The relative tolerance.").AsDispensable(); + AddInput("Atol", "The absolute tolerance.").AsDispensable(); AddOutput("Out", "The output tensor, it's data type is bool."); + AddAttr("rtol", + "The relative tolerance. Default: :math:`1e-5` .") + .SetDefault("1e-5"); + AddAttr("atol", + "The absolute tolerance. Default: :math:`1e-8` .") + .SetDefault("1e-8"); AddAttr("equal_nan", "If :math:`True` , then two :math:`NaNs` will be " "compared as equal. Default: :math:`False` .") @@ -91,8 +99,6 @@ class AllcloseOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Allclose"); OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Allclose"); - OP_INOUT_CHECK(ctx->HasInput("Rtol"), "Input", "Rtol", "Allclose"); - OP_INOUT_CHECK(ctx->HasInput("Atol"), "Input", "Atol", "Allclose"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Allclose"); auto input_dim = ctx->GetInputDim("Input"); @@ -153,3 +159,16 @@ REGISTER_OPERATOR( ops::AllcloseOpVarTypeInference); REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel, ops::AllcloseKernel); + +REGISTER_OP_VERSION(allclose) + .AddCheckpoint( + R"ROC( + Upgrade allclose add 2 attributes [atol, rtol]. + )ROC", + paddle::framework::compatible::OpVersionDesc() + .NewAttr("rtol", + "(string) The relative tolerance. Default: :math:`1e-5` .", + std::string("1e-5")) + .NewAttr("atol", + "(string) The absolute tolerance. Default: :math:`1e-8` .", + std::string("1e-8"))); diff --git a/paddle/fluid/operators/allclose_op.h b/paddle/fluid/operators/allclose_op.h index a08ddca9eb679237b43697c45e1414f4ecbffa0c..b5683a1d9a93c338c2839b38141510d61cb8231f 100644 --- a/paddle/fluid/operators/allclose_op.h +++ b/paddle/fluid/operators/allclose_op.h @@ -14,6 +14,8 @@ #pragma once +#include +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/place.h" @@ -44,14 +46,38 @@ class AllcloseKernel : public framework::OpKernel { // get input/output const auto* input = ctx.Input("Input"); const auto* other = ctx.Input("Other"); - const auto* rtol = ctx.Input("Rtol"); - const auto* atol = ctx.Input("Atol"); auto* out = ctx.Output("Out"); - auto& dev_ctx = ctx.template device_context(); + double rtol_v = std::stod(ctx.Attr("rtol")); + double atol_v = std::stod(ctx.Attr("atol")); + + auto& dev_ctx = ctx.template device_context(); GetTensorValue get_tensor_value; - double rtol_v = get_tensor_value(dev_ctx, *rtol); - double atol_v = get_tensor_value(dev_ctx, *atol); + if (ctx.HasInput("Rtol")) { + const auto* rtol = ctx.Input("Rtol"); + PADDLE_ENFORCE_EQ( + rtol->numel(), 1, + platform::errors::InvalidArgument( + "Input(Rtol) size must be 1, but get %d.", rtol->numel())); + PADDLE_ENFORCE_EQ(rtol->type(), framework::proto::VarType::FP64, + platform::errors::InvalidArgument( + "Input(Rtol) type must be double, but get %s.", + framework::DataTypeToString(rtol->type()))); + rtol_v = get_tensor_value(dev_ctx, *rtol); + } + if (ctx.HasInput("Atol")) { + const auto* atol = ctx.Input("Atol"); + PADDLE_ENFORCE_EQ( + atol->numel(), 1, + platform::errors::InvalidArgument( + "Input(Atol) size must be 1, but get %d", atol->numel())); + PADDLE_ENFORCE_EQ(atol->type(), framework::proto::VarType::FP64, + platform::errors::InvalidArgument( + "Input(Atol) type must be double, but get %s", + framework::DataTypeToString(atol->type()))); + atol_v = get_tensor_value(dev_ctx, *atol); + } + AllcloseFunctor()(dev_ctx, *input, *other, rtol_v, atol_v, equal_nan, out); } diff --git a/python/paddle/fluid/tests/unittests/test_allclose_layer.py b/python/paddle/fluid/tests/unittests/test_allclose_layer.py index 60fd157d2e74cc2dac375b71bb08049d5759c0e6..c376a5c95c3935630ade6f35adaba51059618ebd 100644 --- a/python/paddle/fluid/tests/unittests/test_allclose_layer.py +++ b/python/paddle/fluid/tests/unittests/test_allclose_layer.py @@ -19,57 +19,81 @@ import numpy as np class TestAllcloseLayer(unittest.TestCase): - def allclose_check(self, use_cuda): - a = fluid.data(name="a", shape=[2], dtype='float32') - b = fluid.data(name="b", shape=[2], dtype='float32') + def allclose_check(self, use_cuda, dtype='float32'): + a = fluid.data(name="a", shape=[2], dtype=dtype) + b = fluid.data(name="b", shape=[2], dtype=dtype) result = paddle.allclose( a, b, rtol=1e-05, atol=1e-08, equal_nan=False, name="ignore_nan") result_nan = paddle.allclose( a, b, rtol=1e-05, atol=1e-08, equal_nan=True, name="equal_nan") + result_corner = paddle.allclose( + a, b, rtol=0.01, atol=0.0, name="corner_case") place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - x = np.array([10000., 1e-07]).astype("float32") - y = np.array([10000.1, 1e-08]).astype("float32") + x = np.array([10000., 1e-07]).astype(dtype) + y = np.array([10000.1, 1e-08]).astype(dtype) result_v, result_nan_v = exe.run(feed={'a': x, 'b': y}, fetch_list=[result, result_nan]) self.assertEqual(result_v[0], False) self.assertEqual(result_nan_v[0], False) - x = np.array([10000., 1e-08]).astype("float32") - y = np.array([10000.1, 1e-09]).astype("float32") + x = np.array([10000., 1e-08]).astype(dtype) + y = np.array([10000.1, 1e-09]).astype(dtype) result_v, result_nan_v = exe.run(feed={'a': x, 'b': y}, fetch_list=[result, result_nan]) self.assertEqual(result_v[0], True) self.assertEqual(result_nan_v[0], True) - x = np.array([1.0, float('nan')]).astype("float32") - y = np.array([1.0, float('nan')]).astype("float32") + x = np.array([1.0, float('nan')]).astype(dtype) + y = np.array([1.0, float('nan')]).astype(dtype) result_v, result_nan_v = exe.run(feed={'a': x, 'b': y}, fetch_list=[result, result_nan]) self.assertEqual(result_v[0], False) self.assertEqual(result_nan_v[0], True) - def test_allclose_cpu(self): + # for corner case + x = np.array([10.1, 10.1]).astype(dtype) + y = np.array([10, 10]).astype(dtype) + result_c, = exe.run(feed={'a': x, 'b': y}, fetch_list=[result_corner]) + corner_res = (dtype == 'float64') + self.assertEqual(result_c[0], corner_res) + + def test_allclose_cpu_fp32(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.allclose_check(use_cuda=False, dtype='float32') + + def test_allclose_cpu_fp64(self): main = fluid.Program() startup = fluid.Program() with fluid.unique_name.guard(): with fluid.program_guard(main, startup): - self.allclose_check(use_cuda=False) + self.allclose_check(use_cuda=False, dtype='float64') + + def test_allclose_gpu_fp32(self): + if fluid.core.is_compiled_with_cuda(): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.allclose_check(use_cuda=True, dtype='float32') - def test_allclose_gpu(self): + def test_allclose_gpu_fp64(self): if fluid.core.is_compiled_with_cuda(): main = fluid.Program() startup = fluid.Program() with fluid.unique_name.guard(): with fluid.program_guard(main, startup): - self.allclose_check(use_cuda=True) + self.allclose_check(use_cuda=True, dtype='float64') def test_dygraph_mode(self): x_1 = np.array([10000., 1e-07]).astype("float32") @@ -78,10 +102,14 @@ class TestAllcloseLayer(unittest.TestCase): y_2 = np.array([10000.1, 1e-09]).astype("float32") x_3 = np.array([1.0, float('nan')]).astype("float32") y_3 = np.array([1.0, float('nan')]).astype("float32") + x_4 = np.array([10.1]).astype("float32") + y_4 = np.array([10]).astype("float32") + x_5 = np.array([10.1]).astype("float64") + y_5 = np.array([10]).astype("float64") with fluid.dygraph.guard(): - x_v_1 = fluid.dygraph.to_variable(x_1) - y_v_1 = fluid.dygraph.to_variable(y_1) + x_v_1 = paddle.to_tensor(x_1) + y_v_1 = paddle.to_tensor(y_1) ret_1 = paddle.allclose( x_v_1, y_v_1, @@ -98,8 +126,8 @@ class TestAllcloseLayer(unittest.TestCase): equal_nan=True, name='test_2') self.assertEqual(ret_1.numpy()[0], False) - x_v_2 = fluid.dygraph.to_variable(x_2) - y_v_2 = fluid.dygraph.to_variable(y_2) + x_v_2 = paddle.to_tensor(x_2) + y_v_2 = paddle.to_tensor(y_2) ret_2 = paddle.allclose( x_v_2, y_v_2, @@ -116,8 +144,8 @@ class TestAllcloseLayer(unittest.TestCase): equal_nan=True, name='test_4') self.assertEqual(ret_2.numpy()[0], True) - x_v_3 = fluid.dygraph.to_variable(x_3) - y_v_3 = fluid.dygraph.to_variable(y_3) + x_v_3 = paddle.to_tensor(x_3) + y_v_3 = paddle.to_tensor(y_3) ret_3 = paddle.allclose( x_v_3, y_v_3, @@ -134,6 +162,17 @@ class TestAllcloseLayer(unittest.TestCase): equal_nan=True, name='test_6') self.assertEqual(ret_3.numpy()[0], True) + # for corner case + x_v_4 = paddle.to_tensor(x_4) + y_v_4 = paddle.to_tensor(y_4) + ret_4 = paddle.allclose( + x_v_4, y_v_4, rtol=0.01, atol=0.0, name='test_7') + self.assertEqual(ret_4.numpy()[0], False) + x_v_5 = paddle.to_tensor(x_5) + y_v_5 = paddle.to_tensor(y_5) + ret_5 = paddle.allclose( + x_v_5, y_v_5, rtol=0.01, atol=0.0, name='test_8') + self.assertEqual(ret_5.numpy()[0], True) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_allclose_op.py b/python/paddle/fluid/tests/unittests/test_allclose_op.py index 6441a789f1d680d6f06c8e93fec0dbcf49648916..e96bf951240e788a1158c01ed6cb0574edd75dc5 100644 --- a/python/paddle/fluid/tests/unittests/test_allclose_op.py +++ b/python/paddle/fluid/tests/unittests/test_allclose_op.py @@ -51,6 +51,37 @@ class TestAllcloseOp(OpTest): self.check_output() +class TestAllcloseOpException(TestAllcloseOp): + def test_check_output(self): + def test_rtol_num(): + self.inputs['Rtol'] = np.array([1e-05, 1e-05]).astype("float64") + self.inputs['Atol'] = np.array([1e-08]).astype("float64") + self.check_output() + + self.assertRaises(ValueError, test_rtol_num) + + def test_rtol_type(): + self.inputs['Rtol'] = np.array([5]).astype("int32") + self.inputs['Atol'] = np.array([1e-08]).astype("float64") + self.check_output() + + self.assertRaises(ValueError, test_rtol_type) + + def test_atol_num(): + self.inputs['Rtol'] = np.array([1e-05]).astype("float64") + self.inputs['Atol'] = np.array([1e-08, 1e-08]).astype("float64") + self.check_output() + + self.assertRaises(ValueError, test_atol_num) + + def test_atol_type(): + self.inputs['Rtol'] = np.array([1e-05]).astype("float64") + self.inputs['Atol'] = np.array([8]).astype("int32") + self.check_output() + + self.assertRaises(ValueError, test_atol_type) + + class TestAllcloseOpSmallNum(TestAllcloseOp): def set_args(self): self.input = np.array([10000., 1e-08]).astype("float32") diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 210c69114772c46b259aaa744402a165be84aac6..d5989a1b10c6a4a15a8c1bcbb9ce56f55c057bf7 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import to_tensor from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_type, check_variable_and_dtype from ..fluid.layers.layer_function_generator import templatedoc @@ -137,10 +136,9 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): """ if in_dygraph_mode(): - rtol_tensor = to_tensor(rtol, dtype='float64') - atol_tensor = to_tensor(atol, dtype='float64') - return core.ops.allclose(x, y, rtol_tensor, atol_tensor, 'equal_nan', - equal_nan) + return core.ops.allclose(x, y, 'rtol', + str(rtol), 'atol', + str(atol), 'equal_nan', equal_nan) check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose') check_variable_and_dtype(y, "input", ['float32', 'float64'], 'allclose') @@ -149,26 +147,11 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): check_type(equal_nan, 'equal_nan', bool, 'allclose') helper = LayerHelper("allclose", **locals()) - rtol_var = helper.create_global_variable( - name=fluid.unique_name.generate('rtol'), - persistable=True, - dtype='float64', - shape=[1]) - helper.set_variable_initializer( - rtol_var, initializer=fluid.initializer.ConstantInitializer(rtol)) - atol_var = helper.create_variable( - name=fluid.unique_name.generate('atol'), - persistable=True, - dtype='float64', - shape=[1]) - helper.set_variable_initializer( - atol_var, initializer=fluid.initializer.ConstantInitializer(atol)) - out = helper.create_variable_for_type_inference(dtype='bool') - inputs = {'Input': x, 'Other': y, 'Rtol': rtol_var, 'Atol': atol_var} + inputs = {'Input': x, 'Other': y} outputs = {'Out': out} - attrs = {'equal_nan': equal_nan} + attrs = {'rtol': str(rtol), 'atol': str(atol), 'equal_nan': equal_nan} helper.append_op( type='allclose', inputs=inputs, outputs=outputs, attrs=attrs)