From 2876f6f8a9f48a65c3ec4a358707e8de1d781bca Mon Sep 17 00:00:00 2001 From: Infinity_lee Date: Tue, 14 Mar 2023 14:16:03 +0800 Subject: [PATCH] add output defs for histogram kernel (#51317) --- .../framework/new_executor/interpreter/interpreter_util.cc | 1 - paddle/phi/infermeta/unary.cc | 1 + paddle/phi/kernels/cpu/histogram_kernel.cc | 4 +++- paddle/phi/kernels/gpu/histogram_kernel.cu | 4 +++- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index b837c24c4fe..b2ef8c2f7dc 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -65,7 +65,6 @@ static std::set OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { "generate_proposals", "graph_sample_neighbors", "group_norm", - "histogram", "instance_norm", "lamb", "layer_norm", diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 68d2231b2fa..b2edfd5a2ae 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1626,6 +1626,7 @@ void HistogramInferMeta( out->set_dims({bins}); out->share_lod(input); + out->set_dtype(DataType::INT64); } void IdentityLossInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/cpu/histogram_kernel.cc b/paddle/phi/kernels/cpu/histogram_kernel.cc index 4c04566b8b0..030dee9908b 100644 --- a/paddle/phi/kernels/cpu/histogram_kernel.cc +++ b/paddle/phi/kernels/cpu/histogram_kernel.cc @@ -85,4 +85,6 @@ PD_REGISTER_KERNEL(histogram, float, double, int, - int64_t) {} + int64_t) { + kernel->OutputAt(0).SetDataType(paddle::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/histogram_kernel.cu b/paddle/phi/kernels/gpu/histogram_kernel.cu index cdcd70363dc..111b13f11dd 100644 --- a/paddle/phi/kernels/gpu/histogram_kernel.cu +++ b/paddle/phi/kernels/gpu/histogram_kernel.cu @@ -154,4 +154,6 @@ PD_REGISTER_KERNEL(histogram, float, double, int, - int64_t) {} + int64_t) { + kernel->OutputAt(0).SetDataType(paddle::DataType::INT64); +} -- GitLab