未验证 提交 3094d475 编写于 作者: P PuQing 提交者: GitHub

[PHI] Add rnn and searchsorted output defs (#51360)

* add rnn and searchsorted output defs

* add gpu kernel
上级 907433a7
...@@ -72,8 +72,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -72,8 +72,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"multiclass_nms3", "multiclass_nms3",
"multinomial", "multinomial",
"nanmedian", "nanmedian",
"rnn",
"search_sort",
"sync_batch_norm_grad", "sync_batch_norm_grad",
"unique", "unique",
"unique_consecutive_flattened_tensor", "unique_consecutive_flattened_tensor",
......
...@@ -952,4 +952,6 @@ void RnnKernel(const Context& dev_ctx, ...@@ -952,4 +952,6 @@ void RnnKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(rnn, CPU, ALL_LAYOUT, phi::RnnKernel, float, double) {} PD_REGISTER_KERNEL(rnn, CPU, ALL_LAYOUT, phi::RnnKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
...@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(searchsorted, ...@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(searchsorted,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -400,7 +400,11 @@ void RnnKernel(const Context &dev_ctx, ...@@ -400,7 +400,11 @@ void RnnKernel(const Context &dev_ctx,
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// MIOPEN do not support double // MIOPEN do not support double
PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float) {} PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
#else #else
PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float, double) {} PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
#endif #endif
...@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(searchsorted, ...@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(searchsorted,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -224,4 +224,6 @@ void RnnKernel(const Context& dev_ctx, ...@@ -224,4 +224,6 @@ void RnnKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(rnn, XPU, ALL_LAYOUT, phi::RnnKernel, float) {} PD_REGISTER_KERNEL(rnn, XPU, ALL_LAYOUT, phi::RnnKernel, float) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册