未验证 提交 b647c2f0 编写于 作者: P PuQing 提交者: GitHub

[PHI] Add multinomial output defs (#51357)

* add multinomial output defs

* fix register on gpu
上级 189b086b
...@@ -70,7 +70,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -70,7 +70,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"merged_adam", "merged_adam",
"momentum", "momentum",
"multiclass_nms3", "multiclass_nms3",
"multinomial",
"nanmedian", "nanmedian",
"sync_batch_norm_grad", "sync_batch_norm_grad",
"unique", "unique",
......
...@@ -45,4 +45,6 @@ void MultinomialKernel(const Context& dev_ctx, ...@@ -45,4 +45,6 @@ void MultinomialKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
multinomial, CPU, ALL_LAYOUT, phi::MultinomialKernel, float, double) {} multinomial, CPU, ALL_LAYOUT, phi::MultinomialKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
...@@ -296,6 +296,8 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only ...@@ -296,6 +296,8 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
ALL_LAYOUT, ALL_LAYOUT,
phi::MultinomialKernel, phi::MultinomialKernel,
float, float,
double) {} double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册