未验证 提交 0ce9bf77 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #5931 from guoshengCS/fix-ROIPoolOP-warn

Refine roi_pool_op to avoid warning
...@@ -133,53 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> { ...@@ -133,53 +133,47 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::Tensor>("ROIs"); auto* rois = ctx.Input<framework::Tensor>("ROIs");
auto* argmax = ctx.Input<framework::Tensor>("Argmax"); auto* argmax = ctx.Input<framework::Tensor>("Argmax");
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* in_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto pooled_height = ctx.Attr<int>("pooled_height"); auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width"); auto pooled_width = ctx.Attr<int>("pooled_width");
if (x_grad) { if (in_grad) {
int channels = in->dims()[1];
auto in_stride = framework::stride(in->dims());
auto roi_stride = framework::stride(rois->dims());
const int64_t* rois_data = rois->data<int64_t>(); const int64_t* rois_data = rois->data<int64_t>();
int rois_num = rois->dims()[0]; const T* out_grad_data = out_grad->data<T>();
const int64_t* argmax_data = argmax->data<int64_t>();
T* x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace()); T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
set_zero(ctx.device_context(), x_grad, static_cast<T>(0)); set_zero(ctx.device_context(), in_grad, static_cast<T>(0));
size_t roi_offset = roi_stride[0]; auto in_stride = framework::stride(in->dims());
size_t batch_offset = in_stride[0]; auto argmax_stride = framework::stride(argmax->dims());
size_t channel_offset = in_stride[1]; auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out_grad->dims());
const T* out_grad_data = out_grad->data<T>(); int rois_num = rois->dims()[0];
size_t pool_channel_offset = pooled_height * pooled_width; int channels = in->dims()[1];
const int64_t* argmax_data = argmax->data<int64_t>();
for (size_t n = 0; n < rois_num; ++n) { for (int n = 0; n < rois_num; ++n) {
size_t roi_batch_idx = rois_data[0]; int roi_batch_idx = rois_data[0];
T* batch_grad_data = x_grad_data + batch_offset * roi_batch_idx; T* batch_grad_data = in_grad_data + roi_batch_idx * in_stride[0];
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
for (int ph = 0; ph < pooled_height; ++ph) { for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) { for (int pw = 0; pw < pooled_width; ++pw) {
size_t pool_index = ph * pooled_width + pw; int pool_index = ph * pooled_width + pw;
if (argmax_data[pool_index] >= 0) { if (argmax_data[pool_index] >= 0) {
size_t index = static_cast<size_t>(argmax_data[pool_index]); auto index = argmax_data[pool_index];
batch_grad_data[index] += out_grad_data[pool_index]; batch_grad_data[index] += out_grad_data[pool_index];
} }
} }
} }
batch_grad_data += channel_offset; batch_grad_data += in_stride[1];
out_grad_data += pool_channel_offset; out_grad_data += out_stride[1];
argmax_data += pool_channel_offset; argmax_data += argmax_stride[1];
} }
rois_data += roi_offset; rois_data += roi_stride[0];
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册