From 0ad377c7cfa22b1c41c65cbf15801cdb278e7cd7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Sep 2021 11:05:21 +0800 Subject: [PATCH] fix(mgb/gopt): add error message when input dtype is not equal to param dtype in BN2Elemwise pass GitOrigin-RevId: 3d09a2a12eaca7cc39c51a5004230a466272e936 --- src/gopt/impl/inference.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index a9ba68a4..e537a502 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1686,6 +1686,13 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { SymbolVar invsqrt_variance = opr::PowC::make(variance + variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5}); auto res = scale * (x - mean) * invsqrt_variance + bias; + if (x.dtype() != res.dtype()) { + mgb_throw(MegBrainError, + "BN's input dtype %s is not compatible with " + "param dtype %s when fusing BN. You may need to " + "dump FP32 model.", + x.dtype().name(), res.dtype().name()); + } rewriter.replace_var( opr->output(4), res.node(), mgb_cstr_log( -- GitLab