未验证 提交 39899d79 编写于 作者: L Little-chick 提交者: GitHub

add register of auc (#51451)

* Update interpreter_util.cc

* Update auc_kernel.cc

* Update auc_kernel.cu

* Update auc_kernel.cc

* Update auc_kernel.cu
上级 37662dd1
...@@ -56,7 +56,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -56,7 +56,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"any_raw", "any_raw",
"arg_sort", "arg_sort",
"atan2", "atan2",
"auc",
"clip_by_norm", "clip_by_norm",
"complex", "complex",
"conv3d_coo", "conv3d_coo",
......
...@@ -207,4 +207,8 @@ void AucKernel(const Context &dev_ctx, ...@@ -207,4 +207,8 @@ void AucKernel(const Context &dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(auc, CPU, ALL_LAYOUT, phi::AucKernel, float) {} PD_REGISTER_KERNEL(auc, CPU, ALL_LAYOUT, phi::AucKernel, float) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
}
...@@ -273,4 +273,8 @@ void AucKernel(const Context &dev_ctx, ...@@ -273,4 +273,8 @@ void AucKernel(const Context &dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(auc, GPU, ALL_LAYOUT, phi::AucKernel, float) {} PD_REGISTER_KERNEL(auc, GPU, ALL_LAYOUT, phi::AucKernel, float) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册