From 389f8c5e011c8e748f829c45b2f4495ca86a3fcb Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 6 Jul 2021 10:22:50 +0800 Subject: [PATCH] [OP] fix histogram op when input tensor is empty, test=develop (#33970) --- paddle/fluid/operators/histogram_op.cu | 12 +++++++----- paddle/fluid/operators/histogram_op.h | 12 +++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/histogram_op.cu b/paddle/fluid/operators/histogram_op.cu index 5f86f8d72c..6a9183a8b4 100644 --- a/paddle/fluid/operators/histogram_op.cu +++ b/paddle/fluid/operators/histogram_op.cu @@ -81,6 +81,13 @@ class HistogramCUDAKernel : public framework::OpKernel { const T* input_data = input->data(); const int input_numel = input->numel(); + int64_t* out_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + + if (input_data == nullptr) return; + T output_min = static_cast(minval); T output_max = static_cast(maxval); @@ -126,11 +133,6 @@ class HistogramCUDAKernel : public framework::OpKernel { "But received max is %d, min is %d", maxval, minval)); - int64_t* out_data = output->mutable_data(context.GetPlace()); - math::SetConstant()( - context.template device_context(), output, - static_cast(0)); - auto stream = context.template device_context().stream(); KernelHistogram< diff --git a/paddle/fluid/operators/histogram_op.h b/paddle/fluid/operators/histogram_op.h index 6e48c86d02..a6f4448cbc 100644 --- a/paddle/fluid/operators/histogram_op.h +++ b/paddle/fluid/operators/histogram_op.h @@ -38,6 +38,13 @@ class HistogramKernel : public framework::OpKernel { const T* input_data = input->data(); auto input_numel = input->numel(); + int64_t* out_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + + if (input_data == nullptr) return; + T output_min = static_cast(minval); T output_max = static_cast(maxval); if (output_min == output_max) { @@ -63,11 +70,6 @@ class HistogramKernel : public framework::OpKernel { "But received max is %d, min is %d", maxval, minval)); - int64_t* out_data = output->mutable_data(context.GetPlace()); - math::SetConstant()( - context.template device_context(), output, - static_cast(0)); - for (int64_t i = 0; i < input_numel; i++) { if (input_data[i] >= output_min && input_data[i] <= output_max) { const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins / -- GitLab