未验证 提交 3429e65a 编写于 作者: Q qingqing01 提交者: GitHub

Fix CPU implementation of roi_align_op backward (#18728)

上级 70b03760
...@@ -154,6 +154,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -154,6 +154,8 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int width = in_dims[3]; int width = in_dims[3];
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
if (rois_num == 0) return;
auto in_stride = framework::stride(in_dims); auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims()); auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims()); auto out_stride = framework::stride(out->dims());
...@@ -278,6 +280,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> { ...@@ -278,6 +280,10 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace()); T* in_grad_data = in_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, in_grad, static_cast<T>(0));
auto in_stride = framework::stride(in->dims()); auto in_stride = framework::stride(in->dims());
auto roi_stride = framework::stride(rois->dims()); auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out_grad->dims()); auto out_stride = framework::stride(out_grad->dims());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册