提交 cff5e2c1 编写于 作者: F FDInSky 提交者: wangguanzhong

fix roi_align_op cpu backward's bug (#18789)

* test=develop fix cpu roi_align_op backward bug
上级 9dbb62ee
...@@ -256,13 +256,15 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -256,13 +256,15 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
auto spatial_scale = ctx.Attr<float>("spatial_scale"); auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio"); auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto in_dims = in->dims(); auto in_dims = in->dims();
if (!in_grad) {
return;
}
int channels = in_dims[1]; int channels = in_dims[1];
int height = in_dims[2]; int height = in_dims[2];
int width = in_dims[3]; int width = in_dims[3];
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
if (!in_grad) {
return;
}
Tensor roi_batch_id_list; Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num}); roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data = int* roi_batch_id_data =
...@@ -275,15 +277,21 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -275,15 +277,21 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
roi_batch_id_data[i] = n; roi_batch_id_data[i] = n;
} }
} }
in_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, in_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
if ((!out_grad->IsInitialized()) || (output_grad_size <= 0)) {
return;
}
const T* rois_data = rois->data<T>(); const T* rois_data = rois->data<T>();
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace()); T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, in_grad, static_cast<T>(0));
auto in_stride = framework::stride(in->dims()); auto in_stride = framework::stride(in->dims());
auto roi_stride = framework::stride(rois->dims()); auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out_grad->dims()); auto out_stride = framework::stride(out_grad->dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册