diff --git a/dnn/src/naive/deformable_ps_roi_pooling/opr_impl.cpp b/dnn/src/naive/deformable_ps_roi_pooling/opr_impl.cpp index 643e066689140f9ff3233cddf230ba9f42b4dec3..e8c04fa24eda7fe9c25d8276d47f9fbdf37fe655 100644 --- a/dnn/src/naive/deformable_ps_roi_pooling/opr_impl.cpp +++ b/dnn/src/naive/deformable_ps_roi_pooling/opr_impl.cpp @@ -293,7 +293,7 @@ void Fwd::exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, float trans_std = param.trans_std, scale = param.spatial_scale; size_t nr_bbox = rois.layout[0]; - size_t nr_cls = no_trans ? 1 : trans.layout[0]; + size_t nr_cls = no_trans ? 1 : trans.layout[1] / 2; size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; const float* data_ptr = data.ptr(); @@ -339,7 +339,7 @@ void Bwd::exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, float trans_std = param.trans_std, scale = param.spatial_scale; size_t nr_bbox = rois.layout[0]; - size_t nr_cls = no_trans ? 1 : trans.layout[0]; + size_t nr_cls = no_trans ? 1 : trans.layout[1] / 2; size_t IC = data.layout[1], IH = data.layout[2], IW = data.layout[3]; const float* data_ptr = data.ptr();