diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 4fca3ccc3af4d93622a46417bfc6070974ce51fb..11fb551b20d685b9a353d28932720dc992b298d3 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2472,6 +2472,58 @@ struct SquareGradGradFunctor : public BaseActivationFunctor { }; #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) + +template +struct CudaLogitFunctor : public BaseActivationFunctor { + using MT = typename phi::dtype::MPTypeTrait::Type; + + MT zero = static_cast(0.0f); + MT one = static_cast(1.0f); + float eps; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"eps", &eps}}; + } + + // logit(x) = ln(x/(1-x)) + __device__ __forceinline__ T operator()(const T arg_x) const { + MT x = static_cast(arg_x); + MT y = min(x, (one - static_cast(eps))); + y = max(y, static_cast(eps)); + + if (!eps) { + y = x < zero || x > one ? static_cast(NAN) : log(y / (one - y)); + } else { + y = log(y / (one - y)); + } + return static_cast(y); + } +}; + +template +struct CudaLogitGradFunctor : public BaseActivationFunctor { + using MT = typename phi::dtype::MPTypeTrait::Type; + + float eps; + MT zero = static_cast(0.0f); + MT one = static_cast(1.0f); + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"eps", &eps}}; + } + // logit(x)' = 1/(x*(1-x)) + __device__ __forceinline__ T operator()(const T dout, const T arg_x) const { + MT x = static_cast(arg_x); + MT dx = (x < static_cast(eps) || x > one - static_cast(eps)) + ? zero + : (static_cast(dout) / (x * (one - x))); + return static_cast(dx); + } + static constexpr ActBwdOpFwdDeps FwdDeps() { + return ActBwdOpFwdDeps::kDepOut; + } +}; + template struct CudaReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 617fbd45f05c5c2e4e683dd694f4d2c110bdfedf..5573d666776b7ffee1d6aa72726bcb5017afc083 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -228,6 +228,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, CudaRelu6GradFunctor, threshold); +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(LogitCUDA, + CudaLogitGradFunctor, + eps); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, CudaHardTanhGradFunctor, @@ -382,6 +385,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(logit_grad, LogitCUDAGradKernel) PD_REGISTER_KERNEL(expm1_grad, GPU, @@ -392,15 +396,6 @@ PD_REGISTER_KERNEL(expm1_grad, phi::dtype::float16, phi::dtype::bfloat16) {} -PD_REGISTER_KERNEL(logit_grad, - GPU, - ALL_LAYOUT, - phi::LogitGradKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} - PD_REGISTER_KERNEL(square_grad, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index c60a93725504d10a86fafb7744d920715b2e2710..cf3c66f53de2d56663c75e9b5ecbc17798716dfb 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -109,6 +109,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) +DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, CudaThresholdedReluFunctor, threshold) @@ -225,14 +226,6 @@ PD_REGISTER_KERNEL(expm1, double, phi::dtype::float16, phi::dtype::bfloat16) {} -PD_REGISTER_KERNEL(logit, - GPU, - ALL_LAYOUT, - phi::LogitKernel, - float, - double, - phi::dtype::float16, - phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(square, GPU, ALL_LAYOUT, @@ -263,6 +256,8 @@ PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel) +PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel) + PD_REGISTER_KERNEL(pow, GPU, ALL_LAYOUT, diff --git a/python/paddle/fluid/tests/unittests/test_logit_op.py b/python/paddle/fluid/tests/unittests/test_logit_op.py index 597f8fe197f23c0ba6af2a089e64b830e842c78e..8623304548def3d67003bdac37fc795e323b4dda 100644 --- a/python/paddle/fluid/tests/unittests/test_logit_op.py +++ b/python/paddle/fluid/tests/unittests/test_logit_op.py @@ -18,6 +18,8 @@ import numpy as np from op_test import OpTest import paddle +from paddle.fluid import core +from paddle.fluid.tests.unittests.op_test import convert_float_to_uint16 np.random.seed(10) @@ -43,9 +45,6 @@ class TestLogitOp(OpTest): def setUp(self): self.op_type = 'logit' self.python_api = paddle.logit - self.dtype = np.float64 - self.shape = [120] - self.eps = 1e-8 self.set_attrs() x = np.random.uniform(-1.0, 1.0, self.shape).astype(self.dtype) out = logit(x, self.eps) @@ -55,7 +54,39 @@ class TestLogitOp(OpTest): self.attrs = {'eps': self.eps} def set_attrs(self): - pass + self.dtype = np.float64 + self.shape = [120] + self.eps = 1e-8 + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_eager=True + ) + + +class TestLogitOpFp32(TestLogitOp): + def set_attrs(self): + self.dtype = np.float32 + self.shape = [120] + self.eps = 1e-8 + + def test_check_output(self): + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad( + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_eager=True + ) + + +class TestLogitOpFp16(TestLogitOp): + def set_attrs(self): + self.dtype = np.float16 + self.shape = [120] + self.eps = 1e-8 def test_check_output(self): self.check_output(check_eager=True) @@ -66,13 +97,56 @@ class TestLogitOp(OpTest): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestLogitOpBf16(OpTest): + def setUp(self): + self.op_type = 'logit' + self.python_api = paddle.logit + self.set_attrs() + x = np.random.uniform(-0.5, 0.5, self.shape).astype(np.float32) + out = logit(x, self.eps) + self.x_grad = logit_grad(x, self.eps) + self.inputs = {'X': convert_float_to_uint16(x)} + self.outputs = {'Out': convert_float_to_uint16(out)} + self.attrs = {'eps': self.eps} + + def set_attrs(self): + self.dtype = np.uint16 + self.shape = [120] + self.eps = 1e-8 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=True) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X'], + ['Out'], + user_defined_grads=[self.x_grad], + check_eager=True, + ) + + class TestLogitShape(TestLogitOp): def set_attrs(self): + self.dtype = np.float64 self.shape = [2, 60] + self.eps = 1e-8 class TestLogitEps(TestLogitOp): def set_attrs(self): + self.dtype = np.float32 + self.shape = [120] self.eps = 1e-8 diff --git a/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py index e4f2356faf4f4a73322ac205e2848c956235e190..ced30722cf2792259265c79f9982a6ff3ac0fc8e 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py @@ -38,6 +38,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [ 'increment', 'l1_norm', 'log_loss', + 'logit', 'lrn', 'margin_rank_loss', 'match_matrix_tensor',