未验证 提交 2876f6f8 编写于 作者: I Infinity_lee 提交者: GitHub

add output defs for histogram kernel (#51317)

上级 376dbb82
...@@ -65,7 +65,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -65,7 +65,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"generate_proposals", "generate_proposals",
"graph_sample_neighbors", "graph_sample_neighbors",
"group_norm", "group_norm",
"histogram",
"instance_norm", "instance_norm",
"lamb", "lamb",
"layer_norm", "layer_norm",
......
...@@ -1626,6 +1626,7 @@ void HistogramInferMeta( ...@@ -1626,6 +1626,7 @@ void HistogramInferMeta(
out->set_dims({bins}); out->set_dims({bins});
out->share_lod(input); out->share_lod(input);
out->set_dtype(DataType::INT64);
} }
void IdentityLossInferMeta(const MetaTensor& x, void IdentityLossInferMeta(const MetaTensor& x,
......
...@@ -85,4 +85,6 @@ PD_REGISTER_KERNEL(histogram, ...@@ -85,4 +85,6 @@ PD_REGISTER_KERNEL(histogram,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(paddle::DataType::INT64);
}
...@@ -154,4 +154,6 @@ PD_REGISTER_KERNEL(histogram, ...@@ -154,4 +154,6 @@ PD_REGISTER_KERNEL(histogram,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(paddle::DataType::INT64);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册