未验证 提交 0f79444e 编写于 作者: 张春乔 提交者: GitHub

[phi] add register of accuracy (#51308)

* add REGISTER of float32 in accuracy

* fix something
上级 cc511f24
......@@ -51,7 +51,6 @@ using VariableIdMap = std::map<std::string, std::vector<int>>;
// These Op needs set output dtype when register phi kernel, but they didn't
static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"abs",
"accuracy",
"adam",
"adamw",
"all_close",
......
......@@ -96,4 +96,7 @@ PD_REGISTER_KERNEL(
accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
}
......@@ -140,4 +140,6 @@ PD_REGISTER_KERNEL(accuracy,
double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册