未验证 提交 389f8c5e 编写于 作者: Q Qi Li 提交者: GitHub

[OP] fix histogram op when input tensor is empty, test=develop (#33970)

上级 7a476608
......@@ -81,6 +81,13 @@ class HistogramCUDAKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>();
const int input_numel = input->numel();
int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, int64_t>()(
context.template device_context<platform::CUDADeviceContext>(), output,
static_cast<int64_t>(0));
if (input_data == nullptr) return;
T output_min = static_cast<T>(minval);
T output_max = static_cast<T>(maxval);
......@@ -126,11 +133,6 @@ class HistogramCUDAKernel : public framework::OpKernel<T> {
"But received max is %d, min is %d",
maxval, minval));
int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, int64_t>()(
context.template device_context<platform::CUDADeviceContext>(), output,
static_cast<int64_t>(0));
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
KernelHistogram<
......
......@@ -38,6 +38,13 @@ class HistogramKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>();
auto input_numel = input->numel();
int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
math::SetConstant<DeviceContext, int64_t>()(
context.template device_context<DeviceContext>(), output,
static_cast<int64_t>(0));
if (input_data == nullptr) return;
T output_min = static_cast<T>(minval);
T output_max = static_cast<T>(maxval);
if (output_min == output_max) {
......@@ -63,11 +70,6 @@ class HistogramKernel : public framework::OpKernel<T> {
"But received max is %d, min is %d",
maxval, minval));
int64_t* out_data = output->mutable_data<int64_t>(context.GetPlace());
math::SetConstant<DeviceContext, int64_t>()(
context.template device_context<DeviceContext>(), output,
static_cast<int64_t>(0));
for (int64_t i = 0; i < input_numel; i++) {
if (input_data[i] >= output_min && input_data[i] <= output_max) {
const int64_t bin = (int64_t)((input_data[i] - output_min) * nbins /
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册