未验证 提交 2db79f0a 编写于 作者: Z Zhen Wang 提交者: GitHub

[Cherry-pick]Fix the accuracy problem of allclose op when using float64 data...

[Cherry-pick]Fix the accuracy problem of allclose op when using float64 data type in static mode.(#29890) (#30313)

* Fix the accuracy problem of allclose op when using float64 data type in static mode.

* Format the code style.
上级 7346edc2
......@@ -14,7 +14,9 @@
#include "paddle/fluid/operators/allclose_op.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -63,9 +65,15 @@ class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
"The input tensor, it's data type should be float32, float64.");
AddInput("Other",
"The input tensor, it's data type should be float32, float64.");
AddInput("Rtol", "The relative tolerance.");
AddInput("Atol", "The absolute tolerance.");
AddInput("Rtol", "The relative tolerance.").AsDispensable();
AddInput("Atol", "The absolute tolerance.").AsDispensable();
AddOutput("Out", "The output tensor, it's data type is bool.");
AddAttr<std::string>("rtol",
"The relative tolerance. Default: :math:`1e-5` .")
.SetDefault("1e-5");
AddAttr<std::string>("atol",
"The absolute tolerance. Default: :math:`1e-8` .")
.SetDefault("1e-8");
AddAttr<bool>("equal_nan",
"If :math:`True` , then two :math:`NaNs` will be "
"compared as equal. Default: :math:`False` .")
......@@ -91,8 +99,6 @@ class AllcloseOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Allclose");
OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Allclose");
OP_INOUT_CHECK(ctx->HasInput("Rtol"), "Input", "Rtol", "Allclose");
OP_INOUT_CHECK(ctx->HasInput("Atol"), "Input", "Atol", "Allclose");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Allclose");
auto input_dim = ctx->GetInputDim("Input");
......@@ -153,3 +159,16 @@ REGISTER_OPERATOR(
ops::AllcloseOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel<CPU, float>,
ops::AllcloseKernel<CPU, double>);
REGISTER_OP_VERSION(allclose)
.AddCheckpoint(
R"ROC(
Upgrade allclose add 2 attributes [atol, rtol].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("rtol",
"(string) The relative tolerance. Default: :math:`1e-5` .",
std::string("1e-5"))
.NewAttr("atol",
"(string) The absolute tolerance. Default: :math:`1e-8` .",
std::string("1e-8")));
......@@ -14,6 +14,8 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
......@@ -44,14 +46,38 @@ class AllcloseKernel : public framework::OpKernel<T> {
// get input/output
const auto* input = ctx.Input<Tensor>("Input");
const auto* other = ctx.Input<Tensor>("Other");
const auto* rtol = ctx.Input<Tensor>("Rtol");
const auto* atol = ctx.Input<Tensor>("Atol");
auto* out = ctx.Output<Tensor>("Out");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
double rtol_v = std::stod(ctx.Attr<std::string>("rtol"));
double atol_v = std::stod(ctx.Attr<std::string>("atol"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
GetTensorValue<DeviceContext, double> get_tensor_value;
double rtol_v = get_tensor_value(dev_ctx, *rtol);
double atol_v = get_tensor_value(dev_ctx, *atol);
if (ctx.HasInput("Rtol")) {
const auto* rtol = ctx.Input<Tensor>("Rtol");
PADDLE_ENFORCE_EQ(
rtol->numel(), 1,
platform::errors::InvalidArgument(
"Input(Rtol) size must be 1, but get %d.", rtol->numel()));
PADDLE_ENFORCE_EQ(rtol->type(), framework::proto::VarType::FP64,
platform::errors::InvalidArgument(
"Input(Rtol) type must be double, but get %s.",
framework::DataTypeToString(rtol->type())));
rtol_v = get_tensor_value(dev_ctx, *rtol);
}
if (ctx.HasInput("Atol")) {
const auto* atol = ctx.Input<Tensor>("Atol");
PADDLE_ENFORCE_EQ(
atol->numel(), 1,
platform::errors::InvalidArgument(
"Input(Atol) size must be 1, but get %d", atol->numel()));
PADDLE_ENFORCE_EQ(atol->type(), framework::proto::VarType::FP64,
platform::errors::InvalidArgument(
"Input(Atol) type must be double, but get %s",
framework::DataTypeToString(atol->type())));
atol_v = get_tensor_value(dev_ctx, *atol);
}
AllcloseFunctor<DeviceContext, T>()(dev_ctx, *input, *other, rtol_v, atol_v,
equal_nan, out);
}
......
......@@ -19,57 +19,81 @@ import numpy as np
class TestAllcloseLayer(unittest.TestCase):
def allclose_check(self, use_cuda):
a = fluid.data(name="a", shape=[2], dtype='float32')
b = fluid.data(name="b", shape=[2], dtype='float32')
def allclose_check(self, use_cuda, dtype='float32'):
a = fluid.data(name="a", shape=[2], dtype=dtype)
b = fluid.data(name="b", shape=[2], dtype=dtype)
result = paddle.allclose(
a, b, rtol=1e-05, atol=1e-08, equal_nan=False, name="ignore_nan")
result_nan = paddle.allclose(
a, b, rtol=1e-05, atol=1e-08, equal_nan=True, name="equal_nan")
result_corner = paddle.allclose(
a, b, rtol=0.01, atol=0.0, name="corner_case")
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
x = np.array([10000., 1e-07]).astype("float32")
y = np.array([10000.1, 1e-08]).astype("float32")
x = np.array([10000., 1e-07]).astype(dtype)
y = np.array([10000.1, 1e-08]).astype(dtype)
result_v, result_nan_v = exe.run(feed={'a': x,
'b': y},
fetch_list=[result, result_nan])
self.assertEqual(result_v[0], False)
self.assertEqual(result_nan_v[0], False)
x = np.array([10000., 1e-08]).astype("float32")
y = np.array([10000.1, 1e-09]).astype("float32")
x = np.array([10000., 1e-08]).astype(dtype)
y = np.array([10000.1, 1e-09]).astype(dtype)
result_v, result_nan_v = exe.run(feed={'a': x,
'b': y},
fetch_list=[result, result_nan])
self.assertEqual(result_v[0], True)
self.assertEqual(result_nan_v[0], True)
x = np.array([1.0, float('nan')]).astype("float32")
y = np.array([1.0, float('nan')]).astype("float32")
x = np.array([1.0, float('nan')]).astype(dtype)
y = np.array([1.0, float('nan')]).astype(dtype)
result_v, result_nan_v = exe.run(feed={'a': x,
'b': y},
fetch_list=[result, result_nan])
self.assertEqual(result_v[0], False)
self.assertEqual(result_nan_v[0], True)
def test_allclose_cpu(self):
# for corner case
x = np.array([10.1, 10.1]).astype(dtype)
y = np.array([10, 10]).astype(dtype)
result_c, = exe.run(feed={'a': x, 'b': y}, fetch_list=[result_corner])
corner_res = (dtype == 'float64')
self.assertEqual(result_c[0], corner_res)
def test_allclose_cpu_fp32(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.allclose_check(use_cuda=False, dtype='float32')
def test_allclose_cpu_fp64(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.allclose_check(use_cuda=False)
self.allclose_check(use_cuda=False, dtype='float64')
def test_allclose_gpu_fp32(self):
if fluid.core.is_compiled_with_cuda():
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.allclose_check(use_cuda=True, dtype='float32')
def test_allclose_gpu(self):
def test_allclose_gpu_fp64(self):
if fluid.core.is_compiled_with_cuda():
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.allclose_check(use_cuda=True)
self.allclose_check(use_cuda=True, dtype='float64')
def test_dygraph_mode(self):
x_1 = np.array([10000., 1e-07]).astype("float32")
......@@ -78,10 +102,14 @@ class TestAllcloseLayer(unittest.TestCase):
y_2 = np.array([10000.1, 1e-09]).astype("float32")
x_3 = np.array([1.0, float('nan')]).astype("float32")
y_3 = np.array([1.0, float('nan')]).astype("float32")
x_4 = np.array([10.1]).astype("float32")
y_4 = np.array([10]).astype("float32")
x_5 = np.array([10.1]).astype("float64")
y_5 = np.array([10]).astype("float64")
with fluid.dygraph.guard():
x_v_1 = fluid.dygraph.to_variable(x_1)
y_v_1 = fluid.dygraph.to_variable(y_1)
x_v_1 = paddle.to_tensor(x_1)
y_v_1 = paddle.to_tensor(y_1)
ret_1 = paddle.allclose(
x_v_1,
y_v_1,
......@@ -98,8 +126,8 @@ class TestAllcloseLayer(unittest.TestCase):
equal_nan=True,
name='test_2')
self.assertEqual(ret_1.numpy()[0], False)
x_v_2 = fluid.dygraph.to_variable(x_2)
y_v_2 = fluid.dygraph.to_variable(y_2)
x_v_2 = paddle.to_tensor(x_2)
y_v_2 = paddle.to_tensor(y_2)
ret_2 = paddle.allclose(
x_v_2,
y_v_2,
......@@ -116,8 +144,8 @@ class TestAllcloseLayer(unittest.TestCase):
equal_nan=True,
name='test_4')
self.assertEqual(ret_2.numpy()[0], True)
x_v_3 = fluid.dygraph.to_variable(x_3)
y_v_3 = fluid.dygraph.to_variable(y_3)
x_v_3 = paddle.to_tensor(x_3)
y_v_3 = paddle.to_tensor(y_3)
ret_3 = paddle.allclose(
x_v_3,
y_v_3,
......@@ -134,6 +162,17 @@ class TestAllcloseLayer(unittest.TestCase):
equal_nan=True,
name='test_6')
self.assertEqual(ret_3.numpy()[0], True)
# for corner case
x_v_4 = paddle.to_tensor(x_4)
y_v_4 = paddle.to_tensor(y_4)
ret_4 = paddle.allclose(
x_v_4, y_v_4, rtol=0.01, atol=0.0, name='test_7')
self.assertEqual(ret_4.numpy()[0], False)
x_v_5 = paddle.to_tensor(x_5)
y_v_5 = paddle.to_tensor(y_5)
ret_5 = paddle.allclose(
x_v_5, y_v_5, rtol=0.01, atol=0.0, name='test_8')
self.assertEqual(ret_5.numpy()[0], True)
if __name__ == "__main__":
......
......@@ -51,6 +51,37 @@ class TestAllcloseOp(OpTest):
self.check_output()
class TestAllcloseOpException(TestAllcloseOp):
def test_check_output(self):
def test_rtol_num():
self.inputs['Rtol'] = np.array([1e-05, 1e-05]).astype("float64")
self.inputs['Atol'] = np.array([1e-08]).astype("float64")
self.check_output()
self.assertRaises(ValueError, test_rtol_num)
def test_rtol_type():
self.inputs['Rtol'] = np.array([5]).astype("int32")
self.inputs['Atol'] = np.array([1e-08]).astype("float64")
self.check_output()
self.assertRaises(ValueError, test_rtol_type)
def test_atol_num():
self.inputs['Rtol'] = np.array([1e-05]).astype("float64")
self.inputs['Atol'] = np.array([1e-08, 1e-08]).astype("float64")
self.check_output()
self.assertRaises(ValueError, test_atol_num)
def test_atol_type():
self.inputs['Rtol'] = np.array([1e-05]).astype("float64")
self.inputs['Atol'] = np.array([8]).astype("int32")
self.check_output()
self.assertRaises(ValueError, test_atol_type)
class TestAllcloseOpSmallNum(TestAllcloseOp):
def set_args(self):
self.input = np.array([10000., 1e-08]).astype("float32")
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import to_tensor
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_type, check_variable_and_dtype
from ..fluid.layers.layer_function_generator import templatedoc
......@@ -137,10 +136,9 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
"""
if in_dygraph_mode():
rtol_tensor = to_tensor(rtol, dtype='float64')
atol_tensor = to_tensor(atol, dtype='float64')
return core.ops.allclose(x, y, rtol_tensor, atol_tensor, 'equal_nan',
equal_nan)
return core.ops.allclose(x, y, 'rtol',
str(rtol), 'atol',
str(atol), 'equal_nan', equal_nan)
check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose')
check_variable_and_dtype(y, "input", ['float32', 'float64'], 'allclose')
......@@ -149,26 +147,11 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
check_type(equal_nan, 'equal_nan', bool, 'allclose')
helper = LayerHelper("allclose", **locals())
rtol_var = helper.create_global_variable(
name=fluid.unique_name.generate('rtol'),
persistable=True,
dtype='float64',
shape=[1])
helper.set_variable_initializer(
rtol_var, initializer=fluid.initializer.ConstantInitializer(rtol))
atol_var = helper.create_variable(
name=fluid.unique_name.generate('atol'),
persistable=True,
dtype='float64',
shape=[1])
helper.set_variable_initializer(
atol_var, initializer=fluid.initializer.ConstantInitializer(atol))
out = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'Input': x, 'Other': y, 'Rtol': rtol_var, 'Atol': atol_var}
inputs = {'Input': x, 'Other': y}
outputs = {'Out': out}
attrs = {'equal_nan': equal_nan}
attrs = {'rtol': str(rtol), 'atol': str(atol), 'equal_nan': equal_nan}
helper.append_op(
type='allclose', inputs=inputs, outputs=outputs, attrs=attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册