未验证 提交 d1e2ba8a 编写于 作者: S sneaxiy 提交者: GitHub

Register exp/expm1/logit bf16 activation op kernels (#48702)

* register more bf16 ops

* update to register coresponding backward ops
上级 5c9bbe89
......@@ -371,7 +371,8 @@ PD_REGISTER_KERNEL(exp_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
......@@ -386,7 +387,8 @@ PD_REGISTER_KERNEL(expm1_grad,
phi::Expm1GradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(logit_grad,
GPU,
......@@ -394,7 +396,8 @@ PD_REGISTER_KERNEL(logit_grad,
phi::LogitGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square_grad,
GPU,
......
......@@ -215,21 +215,24 @@ PD_REGISTER_KERNEL(exp,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(expm1,
GPU,
ALL_LAYOUT,
phi::Expm1Kernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(logit,
GPU,
ALL_LAYOUT,
phi::LogitKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square,
GPU,
ALL_LAYOUT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册