未验证 提交 3094d475 编写于 作者: P PuQing 提交者: GitHub

[PHI] Add rnn and searchsorted output defs (#51360)

* add rnn and searchsorted output defs

* add gpu kernel
上级 907433a7
......@@ -72,8 +72,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"multiclass_nms3",
"multinomial",
"nanmedian",
"rnn",
"search_sort",
"sync_batch_norm_grad",
"unique",
"unique_consecutive_flattened_tensor",
......
......@@ -952,4 +952,6 @@ void RnnKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(rnn, CPU, ALL_LAYOUT, phi::RnnKernel, float, double) {}
PD_REGISTER_KERNEL(rnn, CPU, ALL_LAYOUT, phi::RnnKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
......@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(searchsorted,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -400,7 +400,11 @@ void RnnKernel(const Context &dev_ctx,
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float) {}
PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
#else
PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float, double) {}
PD_REGISTER_KERNEL(rnn, GPU, ALL_LAYOUT, phi::RnnKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
#endif
......@@ -25,4 +25,6 @@ PD_REGISTER_KERNEL(searchsorted,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -224,4 +224,6 @@ void RnnKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(rnn, XPU, ALL_LAYOUT, phi::RnnKernel, float) {}
PD_REGISTER_KERNEL(rnn, XPU, ALL_LAYOUT, phi::RnnKernel, float) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册