From 44e6de9847d68a0d5fda84cbfcfb369351a3c8e6 Mon Sep 17 00:00:00 2001 From: thunder95 <290844930@qq.com> Date: Mon, 17 Apr 2023 14:27:28 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=204=20No.49?= =?UTF-8?q?=E3=80=91=EF=BC=9A=E4=B8=BA=20Paddle=20bce=5Floss=20=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20float16=20=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B=20(#?= =?UTF-8?q?50930)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * untracked files * bce_loss_fp16 * remove unused files * back max_rel_erro still big * simplify code * upd * fix max_relative_error * restart ci * Update test_bce_loss.py * Update test_bce_loss.py * Update test_bce_loss.py * Update test_bce_loss.py * try to pass test * restore file * remove error value * fix bug --------- Co-authored-by: Zhang Ting --- .../phi/kernels/gpu/bce_loss_grad_kernel.cu | 27 +++++++----- paddle/phi/kernels/gpu/bce_loss_kernel.cu | 36 +++++++++------- .../fluid/tests/unittests/test_bce_loss.py | 42 ++++++++++++++++++- python/paddle/nn/functional/loss.py | 14 +++++-- python/paddle/nn/layer/loss.py | 4 +- 5 files changed, 90 insertions(+), 33 deletions(-) diff --git a/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu index 94eabac4d13..b50fc130174 100644 --- a/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/bce_loss_grad_kernel.cu @@ -18,6 +18,8 @@ #include #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" @@ -26,17 +28,15 @@ namespace phi { template struct BCELossGradFunctor { - T one; - T eps; - - HOSTDEVICE inline BCELossGradFunctor() { - one = static_cast(1.0f); - eps = static_cast(1e-12); - } + using MT = typename phi::dtype::MPTypeTrait::Type; + MT one = static_cast(1.0f); + MT eps = static_cast(1e-12); HOSTDEVICE inline T operator()(const T x, const T label, const T dout) const { - T term1 = max((one - x) * x, eps); - return (dout * (x - label) / term1); + MT x_mt = static_cast(x); + MT term1 = max((one - x_mt) * x_mt, eps); + return static_cast(static_cast(dout) * + (x_mt - static_cast(label)) / term1); } }; @@ -55,5 +55,10 @@ void BCELossGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - bce_loss_grad, GPU, ALL_LAYOUT, phi::BCELossGradKernel, float, double) {} +PD_REGISTER_KERNEL(bce_loss_grad, + GPU, + ALL_LAYOUT, + phi::BCELossGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/bce_loss_kernel.cu b/paddle/phi/kernels/gpu/bce_loss_kernel.cu index b190bce4742..49191b3e354 100644 --- a/paddle/phi/kernels/gpu/bce_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/bce_loss_kernel.cu @@ -18,6 +18,8 @@ #include #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" @@ -27,22 +29,23 @@ namespace phi { template struct BCELossFunctor { - T one; - T neg_100; - - HOSTDEVICE inline BCELossFunctor() { - one = static_cast(1.0f); - neg_100 = static_cast(-100.); - } + using MT = typename phi::dtype::MPTypeTrait::Type; + MT zero = static_cast(0); + MT one = static_cast(1.0f); + MT neg_100 = static_cast(-100.); HOSTDEVICE inline T operator()(const T x, const T label) const { + MT x_mt = static_cast(x); + MT label_mt = static_cast(label); + PADDLE_ENFORCE( - (x >= static_cast(0)) && (x <= one), + (x_mt >= zero) && (x_mt <= one), "Input is expected to be within the interval [0, 1], but received %f.", - x); - T term1 = max(phi::kps::details::Log(x), neg_100); - T term2 = max(phi::kps::details::Log(one - x), neg_100); - return (((label - one) * term2) - (label * term1)); + x_mt); + + MT term1 = max(phi::kps::details::Log(x_mt), neg_100); + MT term2 = max(phi::kps::details::Log(one - x_mt), neg_100); + return static_cast((label_mt - one) * term2 - label_mt * term1); } }; @@ -60,5 +63,10 @@ void BCELossKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - bce_loss, GPU, ALL_LAYOUT, phi::BCELossKernel, float, double) {} +PD_REGISTER_KERNEL(bce_loss, + GPU, + ALL_LAYOUT, + phi::BCELossKernel, + float, + double, + phi::dtype::float16) {} diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py index 91dbec6d06b..38310e5a620 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -19,6 +19,7 @@ from eager_op_test import OpTest import paddle from paddle import fluid +from paddle.fluid import core def test_static_layer( @@ -249,11 +250,12 @@ def bce_wrapper(x, label): class TestBceLossOp(OpTest): def setUp(self): + self.init_test_dtype() self.init_test_case() self.op_type = "bce_loss" self.python_api = bce_wrapper - input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64") - label_np = np.random.randint(0, 2, self.shape).astype("float64") + input_np = np.random.uniform(0.1, 0.8, self.shape).astype(self.dtype) + label_np = np.random.randint(0, 2, self.shape).astype(self.dtype) output_np = bce_loss(input_np, label_np) self.inputs = {'X': input_np, 'Label': label_np} @@ -268,6 +270,9 @@ class TestBceLossOp(OpTest): def init_test_case(self): self.shape = [10, 10] + def init_test_dtype(self): + self.dtype = "float64" + class TestBceLossOpCase1(OpTest): def init_test_cast(self): @@ -279,6 +284,39 @@ class TestBceLossOpCase2(OpTest): self.shape = [2, 3, 20] +class TestBceLossOpFP16(TestBceLossOp): + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_dtype(self): + self.dtype = np.float16 + + +class TestBceLossOpStaticFP16(unittest.TestCase): + def test_fp16(self): + paddle.enable_static() + shape = [2, 3, 20] + x_data = np.random.uniform(0.1, 0.8, shape).astype("float16") + y_data = np.random.randint(0, 2, shape).astype("float16") + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=shape, name='x', dtype='float16') + y = paddle.static.data(shape=shape, name='y', dtype='float16') + out = paddle.nn.functional.binary_cross_entropy( + x, y, reduction="none" + ) + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + output_pd = exe.run( + feed={'x': x_data, 'y': y_data}, fetch_list=[out] + )[0] + paddle.disable_static() + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e0f7d874be6..8b82e03f90f 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -641,10 +641,10 @@ def binary_cross_entropy( Parameters: input (Tensor): The input predications tensor. 2-D tensor with shape: [N, *], N is batch_size, `*` means number of additional dimensions. The ``input`` - should always be the output of sigmod. Available dtype is float32, float64. + should always be the output of sigmod. Available dtype is float16, float32, float64. label (Tensor): The target labels tensor. 2-D tensor with the same shape as ``input``. The target labels which values should be numbers between 0 and 1. - Available dtype is float32, float64. + Available dtype is float16, float32, float64. weight (Tensor, optional): A manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch and the data type is float32, float64. Default is ``'None'``. @@ -694,10 +694,16 @@ def binary_cross_entropy( return out else: check_variable_and_dtype( - input, 'input', ['float32', 'float64'], 'binary_cross_entropy' + input, + 'input', + ['float16', 'float32', 'float64'], + 'binary_cross_entropy', ) check_variable_and_dtype( - label, 'label', ['float32', 'float64'], 'binary_cross_entropy' + label, + 'label', + ['float16', 'float32', 'float64'], + 'binary_cross_entropy', ) sub_name = name if weight is None and reduction == 'none' else None diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 6fd186c882b..967d490897f 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -730,8 +730,8 @@ class BCELoss(Layer): For more information, please refer to :ref:`api_guide_Name`. Shape: - - input (Tensor): 2-D tensor with shape: ``[N, *]``, N is batch_size, `*` means number of additional dimensions. The input ``input`` should always be the output of sigmod. Available dtype is float32, float64. - - label (Tensor): 2-D tensor with the same shape as ``input``. The target labels which values should be numbers between 0 and 1. Available dtype is float32, float64. + - input (Tensor): 2-D tensor with shape: ``[N, *]``, N is batch_size, `*` means number of additional dimensions. The input ``input`` should always be the output of sigmod. Available dtype is float16, float32, float64. + - label (Tensor): 2-D tensor with the same shape as ``input``. The target labels which values should be numbers between 0 and 1. Available dtype is float16, float32, float64. - output (Tensor): If ``reduction`` is ``'none'``, the shape of output is same as ``input`` , else the shape of output is scalar. Returns: -- GitLab