未验证 提交 289677e2 编写于 作者: B Bo Zhang 提交者: GitHub

【AMP OP&Test】unit test for test_logit_op (#51051)

* test_logit_op

* add cudaKernel to replace eigen impl

* bf16 unit test CI
上级 de2166c0
......@@ -2472,6 +2472,58 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename T>
struct CudaLogitFunctor : public BaseActivationFunctor<T> {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT zero = static_cast<MT>(0.0f);
MT one = static_cast<MT>(1.0f);
float eps;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"eps", &eps}};
}
// logit(x) = ln(x/(1-x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MT x = static_cast<MT>(arg_x);
MT y = min(x, (one - static_cast<MT>(eps)));
y = max(y, static_cast<MT>(eps));
if (!eps) {
y = x < zero || x > one ? static_cast<T>(NAN) : log(y / (one - y));
} else {
y = log(y / (one - y));
}
return static_cast<T>(y);
}
};
template <typename T>
struct CudaLogitGradFunctor : public BaseActivationFunctor<T> {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
float eps;
MT zero = static_cast<MT>(0.0f);
MT one = static_cast<MT>(1.0f);
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"eps", &eps}};
}
// logit(x)' = 1/(x*(1-x))
__device__ __forceinline__ T operator()(const T dout, const T arg_x) const {
MT x = static_cast<MT>(arg_x);
MT dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps))
? zero
: (static_cast<MT>(dout) / (x * (one - x)));
return static_cast<T>(dx);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
......
......@@ -228,6 +228,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6,
CudaRelu6GradFunctor,
threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(LogitCUDA,
CudaLogitGradFunctor,
eps);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh,
CudaHardTanhGradFunctor,
......@@ -382,6 +385,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(logit_grad, LogitCUDAGradKernel)
PD_REGISTER_KERNEL(expm1_grad,
GPU,
......@@ -392,15 +396,6 @@ PD_REGISTER_KERNEL(expm1_grad,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(logit_grad,
GPU,
ALL_LAYOUT,
phi::LogitGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square_grad,
GPU,
ALL_LAYOUT,
......
......@@ -109,6 +109,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor,
threshold)
......@@ -225,14 +226,6 @@ PD_REGISTER_KERNEL(expm1,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(logit,
GPU,
ALL_LAYOUT,
phi::LogitKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square,
GPU,
ALL_LAYOUT,
......@@ -263,6 +256,8 @@ PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(logit, LogitCUDAKernel)
PD_REGISTER_KERNEL(pow,
GPU,
ALL_LAYOUT,
......
......@@ -18,6 +18,8 @@ import numpy as np
from op_test import OpTest
import paddle
from paddle.fluid import core
from paddle.fluid.tests.unittests.op_test import convert_float_to_uint16
np.random.seed(10)
......@@ -43,9 +45,6 @@ class TestLogitOp(OpTest):
def setUp(self):
self.op_type = 'logit'
self.python_api = paddle.logit
self.dtype = np.float64
self.shape = [120]
self.eps = 1e-8
self.set_attrs()
x = np.random.uniform(-1.0, 1.0, self.shape).astype(self.dtype)
out = logit(x, self.eps)
......@@ -55,7 +54,39 @@ class TestLogitOp(OpTest):
self.attrs = {'eps': self.eps}
def set_attrs(self):
pass
self.dtype = np.float64
self.shape = [120]
self.eps = 1e-8
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(
['X'], ['Out'], user_defined_grads=[self.x_grad], check_eager=True
)
class TestLogitOpFp32(TestLogitOp):
def set_attrs(self):
self.dtype = np.float32
self.shape = [120]
self.eps = 1e-8
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(
['X'], ['Out'], user_defined_grads=[self.x_grad], check_eager=True
)
class TestLogitOpFp16(TestLogitOp):
def set_attrs(self):
self.dtype = np.float16
self.shape = [120]
self.eps = 1e-8
def test_check_output(self):
self.check_output(check_eager=True)
......@@ -66,13 +97,56 @@ class TestLogitOp(OpTest):
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestLogitOpBf16(OpTest):
def setUp(self):
self.op_type = 'logit'
self.python_api = paddle.logit
self.set_attrs()
x = np.random.uniform(-0.5, 0.5, self.shape).astype(np.float32)
out = logit(x, self.eps)
self.x_grad = logit_grad(x, self.eps)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
self.attrs = {'eps': self.eps}
def set_attrs(self):
self.dtype = np.uint16
self.shape = [120]
self.eps = 1e-8
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=True)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
['Out'],
user_defined_grads=[self.x_grad],
check_eager=True,
)
class TestLogitShape(TestLogitOp):
def set_attrs(self):
self.dtype = np.float64
self.shape = [2, 60]
self.eps = 1e-8
class TestLogitEps(TestLogitOp):
def set_attrs(self):
self.dtype = np.float32
self.shape = [120]
self.eps = 1e-8
......
......@@ -38,6 +38,7 @@ NO_FP64_CHECK_GRAD_OP_LIST = [
'increment',
'l1_norm',
'log_loss',
'logit',
'lrn',
'margin_rank_loss',
'match_matrix_tensor',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册