diff --git a/paddle/fluid/operators/histogram_op.cu b/paddle/fluid/operators/histogram_op.cu index 5f86f8d72c079dd554482685403a74d14934336e..6a9183a8b465b7526f956b84b23b3d2be6c0f141 100644 --- a/paddle/fluid/operators/histogram_op.cu +++ b/paddle/fluid/operators/histogram_op.cu @@ -81,6 +81,13 @@ class HistogramCUDAKernel : public framework::OpKernel { const T* input_data = input->data(); const int input_numel = input->numel(); + int64_t* out_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + + if (input_data == nullptr) return; + T output_min = static_cast(minval); T output_max = static_cast(maxval); @@ -126,11 +133,6 @@ class HistogramCUDAKernel : public framework::OpKernel { "But received max is %d, min is %d", maxval, minval)); - int64_t* out_data = output->mutable_data(context.GetPlace()); - math::SetConstant()( - context.template device_context(), output, - static_cast(0)); - auto stream = context.template device_context().stream(); KernelHistogram< diff --git a/paddle/fluid/operators/histogram_op.h b/paddle/fluid/operators/histogram_op.h index 6e48c86d022bda78c5f24a53679b6437c38f0e92..a6f4448cbcb17e7b596514a967da9c7c748c69a6 100644 --- a/paddle/fluid/operators/histogram_op.h +++ b/paddle/fluid/operators/histogram_op.h @@ -38,6 +38,13 @@ class HistogramKernel : public framework::OpKernel { const T* input_data = input->data(); auto input_numel = input->numel(); + int64_t* out_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + + if (input_data == nullptr) return; + T output_min = static_cast(minval); T output_max = static_cast(maxval); if (output_min == output_max) { @@ -63,11 +70,6 @@ class HistogramKernel : public framework::OpKernel { "But received max is %d, min is %d", maxval, minval)); - int64_t* out_data = output->mutable_data(context.GetPlace()); - math::SetConstant()( - context.template device_context(), output, - static_cast(0)); - for (int64_t i = 0; i < input_numel; i++) { if (input_data[i] >= output_min && input_data[i] <= output_max) { const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins /