未验证 提交 abc2cc57 编写于 作者: R Roc 提交者: GitHub

fix transform data (#44266)

* fix transform data

* fix dropout kernel

* Revert "fix transform data"

This reverts commit ada75ecd169ea194ce43f7ed75dcc968f5ed2fb9.
上级 469d5ab4
...@@ -84,7 +84,9 @@ PD_REGISTER_KERNEL(dropout, ...@@ -84,7 +84,9 @@ PD_REGISTER_KERNEL(dropout,
float, float,
double, double,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(dropout_nd, PD_REGISTER_KERNEL(dropout_nd,
GPU, GPU,
...@@ -93,4 +95,6 @@ PD_REGISTER_KERNEL(dropout_nd, ...@@ -93,4 +95,6 @@ PD_REGISTER_KERNEL(dropout_nd,
float, float,
double, double,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册