From 54d18115b6f59abc5639f978f47a7c27a49cd2cd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 4 Aug 2020 15:12:49 +0800 Subject: [PATCH] fix(imperative): fix grad of BatchNorm GitOrigin-RevId: 1e8d8afaf260bfdaa41ea170404ca6080e54405f --- src/opr/impl/dnn/batch_norm.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index f85a994a..a8daf566 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -232,16 +232,18 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(BatchNormForward) { - mgb_assert(wrt_idx < 5); - if (wrt_idx < 3) { - SymbolVarArray grad = BatchNormBackward::make( + mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING, + "batch norm could only take grad in training mode"); + mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx); + VarNodeArray ret(opr.input().size(), nullptr); + SymbolVarArray grad = BatchNormBackward::make( opr.input(0), out_grad[4], opr.output(2), opr.output(3), opr.input(1), opr.param()); - return grad[(wrt_idx + 2) % 3].node(); - } else { - return nullptr; + for (size_t i = 0; i < 3; ++ i) { + ret[i] = grad[(i + 2) % 3].node(); } + return ret; } #endif -- GitLab