未验证 提交 2876f6f8 编写于 作者: I Infinity_lee 提交者: GitHub

add output defs for histogram kernel (#51317)

上级 376dbb82
......@@ -65,7 +65,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"generate_proposals",
"graph_sample_neighbors",
"group_norm",
"histogram",
"instance_norm",
"lamb",
"layer_norm",
......
......@@ -1626,6 +1626,7 @@ void HistogramInferMeta(
out->set_dims({bins});
out->share_lod(input);
out->set_dtype(DataType::INT64);
}
void IdentityLossInferMeta(const MetaTensor& x,
......
......@@ -85,4 +85,6 @@ PD_REGISTER_KERNEL(histogram,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(paddle::DataType::INT64);
}
......@@ -154,4 +154,6 @@ PD_REGISTER_KERNEL(histogram,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(paddle::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册