From c8bb66314173e68aec897f8e4a3f988ad227adc0 Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 27 Nov 2017 14:21:34 +0800 Subject: [PATCH] Refine roi_pool_op to avoid warning --- paddle/operators/roi_pool_op.h | 49 +++++++++++++++------------------- 1 file changed, 21 insertions(+), 28 deletions(-) mode change 100755 => 100644 paddle/operators/roi_pool_op.h diff --git a/paddle/operators/roi_pool_op.h b/paddle/operators/roi_pool_op.h old mode 100755 new mode 100644 index bd7736d6312..3812c66c654 --- a/paddle/operators/roi_pool_op.h +++ b/paddle/operators/roi_pool_op.h @@ -133,54 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel { auto* in = ctx.Input("X"); auto* rois = ctx.Input("ROIs"); auto* argmax = ctx.Input("Argmax"); - auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto* x_grad = - ctx.Output(framework::GradVarName("X")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); - if (x_grad) { - int channels = in->dims()[1]; - auto in_stride = framework::stride(in->dims()); - auto roi_stride = framework::stride(rois->dims()); - + if (in_grad) { const int64_t* rois_data = rois->data(); - int rois_num = rois->dims()[0]; - - T* x_grad_data = x_grad->mutable_data(ctx.GetPlace()); + const T* out_grad_data = out_grad->data(); + const int64_t* argmax_data = argmax->data(); + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; - set_zero(ctx.device_context(), x_grad, static_cast(0)); + set_zero(ctx.device_context(), in_grad, static_cast(0)); - size_t roi_offset = roi_stride[0]; - size_t batch_offset = in_stride[0]; - size_t channel_offset = in_stride[1]; + auto in_stride = framework::stride(in->dims()); + auto argmax_stride = framework::stride(argmax->dims()); + auto roi_stride = framework::stride(rois->dims()); + auto out_stride = framework::stride(out_grad->dims()); - const T* out_grad_data = out_grad->data(); - size_t pool_channel_offset = pooled_height * pooled_width; - const int64_t* argmax_data = argmax->data(); + int rois_num = rois->dims()[0]; + int channels = in->dims()[1]; - for (size_t n = 0; n < rois_num; ++n) { - size_t roi_batch_idx = rois_data[0]; - T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx; + for (int n = 0; n < rois_num; ++n) { + int roi_batch_idx = rois_data[0]; + T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0]; for (int c = 0; c < channels; ++c) { for (int ph = 0; ph < pooled_height; ++ph) { for (int pw = 0; pw < pooled_width; ++pw) { - size_t pool_index = ph * pooled_width + pw; - + int pool_index = ph * pooled_width + pw; if (argmax_data[pool_index] >= 0) { - size_t index = static_cast(argmax_data[pool_index]); + auto index = argmax_data[pool_index]; batch_grad_data[index] += out_grad_data[pool_index]; } } } - batch_grad_data += channel_offset; - out_grad_data += pool_channel_offset; - argmax_data += pool_channel_offset; + batch_grad_data += in_stride[1]; + out_grad_data += out_stride[1]; + argmax_data += argmax_stride[1]; } - rois_data += roi_offset; + rois_data += roi_stride[0]; } } } -- GitLab