diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index a9ba68a423cd0e80b138caa74e31ff7a0eedbf5b..e537a5029493ae7c1eb017e29ecbc13bfd2491a0 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1686,6 +1686,13 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { SymbolVar invsqrt_variance = opr::PowC::make(variance + variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5}); auto res = scale * (x - mean) * invsqrt_variance + bias; + if (x.dtype() != res.dtype()) { + mgb_throw(MegBrainError, + "BN's input dtype %s is not compatible with " + "param dtype %s when fusing BN. You may need to " + "dump FP32 model.", + x.dtype().name(), res.dtype().name()); + } rewriter.replace_var( opr->output(4), res.node(), mgb_cstr_log(