“ba3b2eb3a5c288bd898d057a77682cecf043836c”上不存在“doc/design/ops/sequence_decoder.html”
未验证 提交 0f79444e 编写于 作者: 张春乔 提交者: GitHub

[phi] add register of accuracy (#51308)

* add REGISTER of float32 in accuracy

* fix something
上级 cc511f24
...@@ -51,7 +51,6 @@ using VariableIdMap = std::map<std::string, std::vector<int>>; ...@@ -51,7 +51,6 @@ using VariableIdMap = std::map<std::string, std::vector<int>>;
// These Op needs set output dtype when register phi kernel, but they didn't // These Op needs set output dtype when register phi kernel, but they didn't
static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"abs", "abs",
"accuracy",
"adam", "adam",
"adamw", "adamw",
"all_close", "all_close",
......
...@@ -96,4 +96,7 @@ PD_REGISTER_KERNEL( ...@@ -96,4 +96,7 @@ PD_REGISTER_KERNEL(
accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) { accuracy, CPU, ALL_LAYOUT, phi::AccuracyRawKernel, float, double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64); kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64); kernel->InputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
} }
...@@ -140,4 +140,6 @@ PD_REGISTER_KERNEL(accuracy, ...@@ -140,4 +140,6 @@ PD_REGISTER_KERNEL(accuracy,
double) { double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64); kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64); kernel->InputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册