未验证 提交 24258c27 编写于 作者: 陈沧夜 提交者: GitHub

No.54:为 Paddle allclose、isclose 算子实现 float16 数据类型支持 (#51168)

上级 07d8770f
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -31,14 +33,16 @@ __global__ void AllcloseCUDAKernel(const T* in_data, ...@@ -31,14 +33,16 @@ __global__ void AllcloseCUDAKernel(const T* in_data,
bool* out_data) { bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val; bool val;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) { for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const T a = in_data[i], b = other_data[i]; const MPType a = static_cast<MPType>(in_data[i]);
const MPType b = static_cast<MPType>(other_data[i]);
if (isnan(a) || isnan(b)) { if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b); val = equal_nan && isnan(a) == isnan(b);
} else { } else {
T left = (a > b ? a - b : b - a); MPType left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b); MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left); MPType diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15; val = a == b || left <= right || diff <= 1e-15;
} }
if (!val) *out_data = false; if (!val) *out_data = false;
...@@ -92,7 +96,12 @@ void AllCloseKernel(const Context& dev_ctx, ...@@ -92,7 +96,12 @@ void AllCloseKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(allclose,
allclose, GPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) { GPU,
ALL_LAYOUT,
phi::AllCloseKernel,
float,
double,
phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
...@@ -15,8 +15,14 @@ ...@@ -15,8 +15,14 @@
#include "paddle/phi/kernels/isclose_kernel.h" #include "paddle/phi/kernels/isclose_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h" #include "paddle/phi/kernels/impl/isclose_kernel_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(isclose,
isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::IscloseKernel,
float,
double,
phi::dtype::float16) {}
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -109,14 +110,16 @@ __global__ void IscloseCUDAKernel(const T* in_data, ...@@ -109,14 +110,16 @@ __global__ void IscloseCUDAKernel(const T* in_data,
bool* out_data) { bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val; bool val;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) { for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const T a = in_data[i], b = other_data[i]; const MPType a = static_cast<MPType>(in_data[i]);
const MPType b = static_cast<MPType>(other_data[i]);
if (isnan(a) || isnan(b)) { if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b); val = equal_nan && isnan(a) == isnan(b);
} else { } else {
T left = (a > b ? a - b : b - a); MPType left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b); MPType right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left); MPType diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15; val = a == b || left <= right || diff <= 1e-15;
} }
out_data[i] = val; out_data[i] = val;
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid.core as core
class TestAllcloseOp(OpTest): class TestAllcloseOp(OpTest):
...@@ -134,7 +135,7 @@ class TestAllcloseError(unittest.TestCase): ...@@ -134,7 +135,7 @@ class TestAllcloseError(unittest.TestCase):
with paddle.static.program_guard( with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program() paddle.static.Program(), paddle.static.Program()
): ):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
result = paddle.allclose(x, y) result = paddle.allclose(x, y)
...@@ -170,6 +171,36 @@ class TestAllcloseError(unittest.TestCase): ...@@ -170,6 +171,36 @@ class TestAllcloseError(unittest.TestCase):
self.assertRaises(TypeError, test_equal_nan) self.assertRaises(TypeError, test_equal_nan)
class TestAllcloseOpFp16(unittest.TestCase):
def test_fp16(self):
x_data = np.random.rand(10, 10).astype('float16')
y_data = np.random.rand(10, 10).astype('float16')
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 10], name='x', dtype='float16')
y = paddle.static.data(shape=[10, 10], name='x', dtype='float16')
out = paddle.allclose(x, y, rtol=1e-05, atol=1e-08)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out])
class TestAllcloseOpFloat16(TestAllcloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float16")
self.other = np.array([10]).astype("float16")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, check_eager=True)
class TestAllcloseOpFloat32(TestAllcloseOp): class TestAllcloseOpFloat32(TestAllcloseOp):
def set_args(self): def set_args(self):
self.input = np.array([10.1]).astype("float32") self.input = np.array([10.1]).astype("float32")
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid.core as core
class TestIscloseOp(OpTest): class TestIscloseOp(OpTest):
...@@ -166,7 +167,7 @@ class TestIscloseError(unittest.TestCase): ...@@ -166,7 +167,7 @@ class TestIscloseError(unittest.TestCase):
with paddle.static.program_guard( with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program() paddle.static.Program(), paddle.static.Program()
): ):
x = paddle.fluid.data(name='x', shape=[10, 10], dtype='float16') x = paddle.fluid.data(name='x', shape=[10, 10], dtype='int32')
y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64') y = paddle.fluid.data(name='y', shape=[10, 10], dtype='float64')
result = paddle.isclose(x, y) result = paddle.isclose(x, y)
...@@ -203,6 +204,36 @@ class TestIscloseError(unittest.TestCase): ...@@ -203,6 +204,36 @@ class TestIscloseError(unittest.TestCase):
self.assertRaises(TypeError, test_equal_nan) self.assertRaises(TypeError, test_equal_nan)
class TestIscloseOpFp16(unittest.TestCase):
def test_fp16(self):
x_data = np.random.rand(10, 10).astype('float16')
y_data = np.random.rand(10, 10).astype('float16')
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=[10, 10], name='x', dtype='float16')
y = paddle.static.data(shape=[10, 10], name='x', dtype='float16')
out = paddle.isclose(x, y, rtol=1e-05, atol=1e-08)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_data, 'y': y_data}, fetch_list=[out])
class TestIscloseOpFloat16(TestIscloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float16")
self.other = np.array([10]).astype("float16")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, check_eager=True)
class TestIscloseOpFloat32(TestIscloseOp): class TestIscloseOpFloat32(TestIscloseOp):
def set_args(self): def set_args(self):
self.input = np.array([10.1]).astype("float32") self.input = np.array([10.1]).astype("float32")
......
...@@ -361,8 +361,8 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -361,8 +361,8 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
two tensors are elementwise equal within a tolerance. two tensors are elementwise equal within a tolerance.
Args: Args:
x(Tensor): The input tensor, it's data type should be float32, float64.. x(Tensor): The input tensor, it's data type should be float16, float32, float64..
y(Tensor): The input tensor, it's data type should be float32, float64.. y(Tensor): The input tensor, it's data type should be float16, float32, float64..
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): ${equal_nan_comment}. equal_nan(equalnantype, optional): ${equal_nan_comment}.
...@@ -401,8 +401,12 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -401,8 +401,12 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.allclose(x, y, rtol, atol, equal_nan) return _C_ops.allclose(x, y, rtol, atol, equal_nan)
else: else:
check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose') check_variable_and_dtype(
check_variable_and_dtype(y, "input", ['float32', 'float64'], 'allclose') x, "input", ['float16', 'float32', 'float64'], 'allclose'
)
check_variable_and_dtype(
y, "input", ['float16', 'float32', 'float64'], 'allclose'
)
check_type(rtol, 'rtol', float, 'allclose') check_type(rtol, 'rtol', float, 'allclose')
check_type(atol, 'atol', float, 'allclose') check_type(atol, 'atol', float, 'allclose')
check_type(equal_nan, 'equal_nan', bool, 'allclose') check_type(equal_nan, 'equal_nan', bool, 'allclose')
...@@ -989,8 +993,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -989,8 +993,8 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
two tensors are elementwise equal within a tolerance. two tensors are elementwise equal within a tolerance.
Args: Args:
x(Tensor): The input tensor, it's data type should be float32, float64. x(Tensor): The input tensor, it's data type should be float16, float32, float64.
y(Tensor): The input tensor, it's data type should be float32, float64. y(Tensor): The input tensor, it's data type should be float16, float32, float64.
rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` . equal_nan(equalnantype, optional): If :math:`True` , then two :math:`NaNs` will be compared as equal. Default: :math:`False` .
...@@ -1027,8 +1031,12 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -1027,8 +1031,12 @@ def isclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.isclose(x, y, rtol, atol, equal_nan) return _C_ops.isclose(x, y, rtol, atol, equal_nan)
else: else:
check_variable_and_dtype(x, "input", ['float32', 'float64'], 'isclose') check_variable_and_dtype(
check_variable_and_dtype(y, "input", ['float32', 'float64'], 'isclose') x, "input", ['float16', 'float32', 'float64'], 'isclose'
)
check_variable_and_dtype(
y, "input", ['float16', 'float32', 'float64'], 'isclose'
)
check_type(rtol, 'rtol', float, 'isclose') check_type(rtol, 'rtol', float, 'isclose')
check_type(atol, 'atol', float, 'isclose') check_type(atol, 'atol', float, 'isclose')
check_type(equal_nan, 'equal_nan', bool, 'isclose') check_type(equal_nan, 'equal_nan', bool, 'isclose')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册