diff --git a/paddle/operators/roi_pool_op.h b/paddle/operators/roi_pool_op.h old mode 100755 new mode 100644 index bd7736d63125f1be57c8af5141208f66d0592adb..3812c66c65457b9d1337690d1a82759aab9a9732 --- a/paddle/operators/roi_pool_op.h +++ b/paddle/operators/roi_pool_op.h @@ -133,54 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel { auto* in = ctx.Input("X"); auto* rois = ctx.Input("ROIs"); auto* argmax = ctx.Input("Argmax"); - auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto* x_grad = - ctx.Output(framework::GradVarName("X")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); auto pooled_height = ctx.Attr("pooled_height"); auto pooled_width = ctx.Attr("pooled_width"); - if (x_grad) { - int channels = in->dims()[1]; - auto in_stride = framework::stride(in->dims()); - auto roi_stride = framework::stride(rois->dims()); - + if (in_grad) { const int64_t* rois_data = rois->data(); - int rois_num = rois->dims()[0]; - - T* x_grad_data = x_grad->mutable_data(ctx.GetPlace()); + const T* out_grad_data = out_grad->data(); + const int64_t* argmax_data = argmax->data(); + T* in_grad_data = in_grad->mutable_data(ctx.GetPlace()); math::SetConstant set_zero; - set_zero(ctx.device_context(), x_grad, static_cast(0)); + set_zero(ctx.device_context(), in_grad, static_cast(0)); - size_t roi_offset = roi_stride[0]; - size_t batch_offset = in_stride[0]; - size_t channel_offset = in_stride[1]; + auto in_stride = framework::stride(in->dims()); + auto argmax_stride = framework::stride(argmax->dims()); + auto roi_stride = framework::stride(rois->dims()); + auto out_stride = framework::stride(out_grad->dims()); - const T* out_grad_data = out_grad->data(); - size_t pool_channel_offset = pooled_height * pooled_width; - const int64_t* argmax_data = argmax->data(); + int rois_num = rois->dims()[0]; + int channels = in->dims()[1]; - for (size_t n = 0; n < rois_num; ++n) { - size_t roi_batch_idx = rois_data[0]; - T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx; + for (int n = 0; n < rois_num; ++n) { + int roi_batch_idx = rois_data[0]; + T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0]; for (int c = 0; c < channels; ++c) { for (int ph = 0; ph < pooled_height; ++ph) { for (int pw = 0; pw < pooled_width; ++pw) { - size_t pool_index = ph * pooled_width + pw; - + int pool_index = ph * pooled_width + pw; if (argmax_data[pool_index] >= 0) { - size_t index = static_cast(argmax_data[pool_index]); + auto index = argmax_data[pool_index]; batch_grad_data[index] += out_grad_data[pool_index]; } } } - batch_grad_data += channel_offset; - out_grad_data += pool_channel_offset; - argmax_data += pool_channel_offset; + batch_grad_data += in_stride[1]; + out_grad_data += out_stride[1]; + argmax_data += argmax_stride[1]; } - rois_data += roi_offset; + rois_data += roi_stride[0]; } } }