未验证 提交 44e6de98 编写于 作者: T thunder95 提交者: GitHub

【PaddlePaddle Hackathon 4 No.49】:为 Paddle bce_loss 支持 float16 数据类型 (#50930)

* untracked files

* bce_loss_fp16

* remove unused files

* back max_rel_erro still big

* simplify code

* upd

* fix max_relative_error

* restart ci

* Update test_bce_loss.py

* Update test_bce_loss.py

* Update test_bce_loss.py

* Update test_bce_loss.py

* try to pass test

* restore file

* remove error value

* fix bug

---------
Co-authored-by: NZhang Ting <Douyaer2020@qq.com>
上级 ebc58548
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include <vector> #include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
...@@ -26,17 +28,15 @@ namespace phi { ...@@ -26,17 +28,15 @@ namespace phi {
template <typename T> template <typename T>
struct BCELossGradFunctor { struct BCELossGradFunctor {
T one; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
T eps; MT one = static_cast<MT>(1.0f);
MT eps = static_cast<MT>(1e-12);
HOSTDEVICE inline BCELossGradFunctor() {
one = static_cast<T>(1.0f);
eps = static_cast<T>(1e-12);
}
HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const { HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const {
T term1 = max((one - x) * x, eps); MT x_mt = static_cast<MT>(x);
return (dout * (x - label) / term1); MT term1 = max((one - x_mt) * x_mt, eps);
return static_cast<T>(static_cast<MT>(dout) *
(x_mt - static_cast<MT>(label)) / term1);
} }
}; };
...@@ -55,5 +55,10 @@ void BCELossGradKernel(const Context& dev_ctx, ...@@ -55,5 +55,10 @@ void BCELossGradKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(bce_loss_grad,
bce_loss_grad, GPU, ALL_LAYOUT, phi::BCELossGradKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::BCELossGradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include <vector> #include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
...@@ -27,22 +29,23 @@ namespace phi { ...@@ -27,22 +29,23 @@ namespace phi {
template <typename T> template <typename T>
struct BCELossFunctor { struct BCELossFunctor {
T one; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
T neg_100; MT zero = static_cast<MT>(0);
MT one = static_cast<MT>(1.0f);
HOSTDEVICE inline BCELossFunctor() { MT neg_100 = static_cast<MT>(-100.);
one = static_cast<T>(1.0f);
neg_100 = static_cast<T>(-100.);
}
HOSTDEVICE inline T operator()(const T x, const T label) const { HOSTDEVICE inline T operator()(const T x, const T label) const {
MT x_mt = static_cast<MT>(x);
MT label_mt = static_cast<MT>(label);
PADDLE_ENFORCE( PADDLE_ENFORCE(
(x >= static_cast<T>(0)) && (x <= one), (x_mt >= zero) && (x_mt <= one),
"Input is expected to be within the interval [0, 1], but received %f.", "Input is expected to be within the interval [0, 1], but received %f.",
x); x_mt);
T term1 = max(phi::kps::details::Log(x), neg_100);
T term2 = max(phi::kps::details::Log(one - x), neg_100); MT term1 = max(phi::kps::details::Log(x_mt), neg_100);
return (((label - one) * term2) - (label * term1)); MT term2 = max(phi::kps::details::Log(one - x_mt), neg_100);
return static_cast<T>((label_mt - one) * term2 - label_mt * term1);
} }
}; };
...@@ -60,5 +63,10 @@ void BCELossKernel(const Context& dev_ctx, ...@@ -60,5 +63,10 @@ void BCELossKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(bce_loss,
bce_loss, GPU, ALL_LAYOUT, phi::BCELossKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::BCELossKernel,
float,
double,
phi::dtype::float16) {}
...@@ -19,6 +19,7 @@ from eager_op_test import OpTest ...@@ -19,6 +19,7 @@ from eager_op_test import OpTest
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core
def test_static_layer( def test_static_layer(
...@@ -249,11 +250,12 @@ def bce_wrapper(x, label): ...@@ -249,11 +250,12 @@ def bce_wrapper(x, label):
class TestBceLossOp(OpTest): class TestBceLossOp(OpTest):
def setUp(self): def setUp(self):
self.init_test_dtype()
self.init_test_case() self.init_test_case()
self.op_type = "bce_loss" self.op_type = "bce_loss"
self.python_api = bce_wrapper self.python_api = bce_wrapper
input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64") input_np = np.random.uniform(0.1, 0.8, self.shape).astype(self.dtype)
label_np = np.random.randint(0, 2, self.shape).astype("float64") label_np = np.random.randint(0, 2, self.shape).astype(self.dtype)
output_np = bce_loss(input_np, label_np) output_np = bce_loss(input_np, label_np)
self.inputs = {'X': input_np, 'Label': label_np} self.inputs = {'X': input_np, 'Label': label_np}
...@@ -268,6 +270,9 @@ class TestBceLossOp(OpTest): ...@@ -268,6 +270,9 @@ class TestBceLossOp(OpTest):
def init_test_case(self): def init_test_case(self):
self.shape = [10, 10] self.shape = [10, 10]
def init_test_dtype(self):
self.dtype = "float64"
class TestBceLossOpCase1(OpTest): class TestBceLossOpCase1(OpTest):
def init_test_cast(self): def init_test_cast(self):
...@@ -279,6 +284,39 @@ class TestBceLossOpCase2(OpTest): ...@@ -279,6 +284,39 @@ class TestBceLossOpCase2(OpTest):
self.shape = [2, 3, 20] self.shape = [2, 3, 20]
class TestBceLossOpFP16(TestBceLossOp):
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_dtype(self):
self.dtype = np.float16
class TestBceLossOpStaticFP16(unittest.TestCase):
def test_fp16(self):
paddle.enable_static()
shape = [2, 3, 20]
x_data = np.random.uniform(0.1, 0.8, shape).astype("float16")
y_data = np.random.randint(0, 2, shape).astype("float16")
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(shape=shape, name='x', dtype='float16')
y = paddle.static.data(shape=shape, name='y', dtype='float16')
out = paddle.nn.functional.binary_cross_entropy(
x, y, reduction="none"
)
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
output_pd = exe.run(
feed={'x': x_data, 'y': y_data}, fetch_list=[out]
)[0]
paddle.disable_static()
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -641,10 +641,10 @@ def binary_cross_entropy( ...@@ -641,10 +641,10 @@ def binary_cross_entropy(
Parameters: Parameters:
input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *],
N is batch_size, `*` means number of additional dimensions. The ``input`` N is batch_size, `*` means number of additional dimensions. The ``input``
should always be the output of sigmod. Available dtype is float32, float64. should always be the output of sigmod. Available dtype is float16, float32, float64.
label (Tensor): The target labels tensor. 2-D tensor with the same shape as label (Tensor): The target labels tensor. 2-D tensor with the same shape as
``input``. The target labels which values should be numbers between 0 and 1. ``input``. The target labels which values should be numbers between 0 and 1.
Available dtype is float32, float64. Available dtype is float16, float32, float64.
weight (Tensor, optional): A manual rescaling weight given to the loss of each weight (Tensor, optional): A manual rescaling weight given to the loss of each
batch element. If given, has to be a Tensor of size nbatch and the data type batch element. If given, has to be a Tensor of size nbatch and the data type
is float32, float64. Default is ``'None'``. is float32, float64. Default is ``'None'``.
...@@ -694,10 +694,16 @@ def binary_cross_entropy( ...@@ -694,10 +694,16 @@ def binary_cross_entropy(
return out return out
else: else:
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'binary_cross_entropy' input,
'input',
['float16', 'float32', 'float64'],
'binary_cross_entropy',
) )
check_variable_and_dtype( check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'binary_cross_entropy' label,
'label',
['float16', 'float32', 'float64'],
'binary_cross_entropy',
) )
sub_name = name if weight is None and reduction == 'none' else None sub_name = name if weight is None and reduction == 'none' else None
......
...@@ -730,8 +730,8 @@ class BCELoss(Layer): ...@@ -730,8 +730,8 @@ class BCELoss(Layer):
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
Shape: Shape:
- input (Tensor): 2-D tensor with shape: ``[N, *]``, N is batch_size, `*` means number of additional dimensions. The input ``input`` should always be the output of sigmod. Available dtype is float32, float64. - input (Tensor): 2-D tensor with shape: ``[N, *]``, N is batch_size, `*` means number of additional dimensions. The input ``input`` should always be the output of sigmod. Available dtype is float16, float32, float64.
- label (Tensor): 2-D tensor with the same shape as ``input``. The target labels which values should be numbers between 0 and 1. Available dtype is float32, float64. - label (Tensor): 2-D tensor with the same shape as ``input``. The target labels which values should be numbers between 0 and 1. Available dtype is float16, float32, float64.
- output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar. - output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar.
Returns: Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册