diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index f85a994a3aa593016a82fa0ab6a84a7e6a4f2428..a8daf566670f66709fd46d02e26f682ad03c8ab6 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -232,16 +232,18 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchNormForward) { - mgb_assert(wrt_idx < 5); - if (wrt_idx < 3) { - SymbolVarArray grad = BatchNormBackward::make( + mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING, + "batch norm could only take grad in training mode"); + mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx); + VarNodeArray ret(opr.input().size(), nullptr); + SymbolVarArray grad = BatchNormBackward::make( opr.input(0), out_grad[4], opr.output(2), opr.output(3), opr.input(1), opr.param()); - return grad[(wrt_idx + 2) % 3].node(); - } else { - return nullptr; + for (size_t i = 0; i < 3; ++ i) { + ret[i] = grad[(i + 2) % 3].node(); } + return ret; } #endif