未验证 提交 3b90a7f3 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Register half datatype for Roll Kernel (#49192)

* register half datatype

* register roll grad fp16 kernel
上级 eaf90003
...@@ -80,6 +80,7 @@ PD_REGISTER_KERNEL(roll_grad, ...@@ -80,6 +80,7 @@ PD_REGISTER_KERNEL(roll_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::RollGradKernel, phi::RollGradKernel,
phi::dtype::float16,
float, float,
double, double,
int, int,
......
...@@ -82,6 +82,7 @@ PD_REGISTER_KERNEL(roll, ...@@ -82,6 +82,7 @@ PD_REGISTER_KERNEL(roll,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::RollKernel, phi::RollKernel,
phi::dtype::float16,
float, float,
double, double,
int, int,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册