未验证 提交 b647c2f0 编写于 作者: P PuQing 提交者: GitHub

[PHI] Add multinomial output defs (#51357)

* add multinomial output defs

* fix register on gpu
上级 189b086b
......@@ -70,7 +70,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"merged_adam",
"momentum",
"multiclass_nms3",
"multinomial",
"nanmedian",
"sync_batch_norm_grad",
"unique",
......
......@@ -45,4 +45,6 @@ void MultinomialKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
multinomial, CPU, ALL_LAYOUT, phi::MultinomialKernel, float, double) {}
multinomial, CPU, ALL_LAYOUT, phi::MultinomialKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -296,6 +296,8 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
ALL_LAYOUT,
phi::MultinomialKernel,
float,
double) {}
double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册