diff --git a/src/opr/impl/dnn/roi_align.cpp b/src/opr/impl/dnn/roi_align.cpp index 3c29417ff72b1f29c55434f498f3bf9b4c6dc77e..e89ec8dcdd65f1c50d71b3bc54072f86ecc6cd23 100644 --- a/src/opr/impl/dnn/roi_align.cpp +++ b/src/opr/impl/dnn/roi_align.cpp @@ -42,9 +42,6 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois, #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIAlignForward) { - if (out_grad[1]) { - return InvalidGrad::make(opr, wrt_idx); - } if (wrt_idx == 0) { // wrt src SymbolVar grad = diff --git a/src/opr/impl/dnn/roi_pooling.cpp b/src/opr/impl/dnn/roi_pooling.cpp index ab2801d55d054e108df4c789145673fbcae5707b..7c2d3df9e25e46717de9608d83780c8817821ad7 100644 --- a/src/opr/impl/dnn/roi_pooling.cpp +++ b/src/opr/impl/dnn/roi_pooling.cpp @@ -86,7 +86,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes( #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(ROIPoolingForward) { - if (out_grad[1] || wrt_idx == 2) { + if (wrt_idx == 2) { return InvalidGrad::make(opr, wrt_idx); } if (wrt_idx == 0) {