提交 0ad377c7 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): add error message when input dtype is not equal to param dtype in BN2Elemwise pass

GitOrigin-RevId: 3d09a2a12eaca7cc39c51a5004230a466272e936
上级 93f70a95
...@@ -1686,6 +1686,13 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { ...@@ -1686,6 +1686,13 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
SymbolVar invsqrt_variance = opr::PowC::make(variance SymbolVar invsqrt_variance = opr::PowC::make(variance
+ variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5}); + variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5});
auto res = scale * (x - mean) * invsqrt_variance + bias; auto res = scale * (x - mean) * invsqrt_variance + bias;
if (x.dtype() != res.dtype()) {
mgb_throw(MegBrainError,
"BN's input dtype %s is not compatible with "
"param dtype %s when fusing BN. You may need to "
"dump FP32 model.",
x.dtype().name(), res.dtype().name());
}
rewriter.replace_var( rewriter.replace_var(
opr->output(4), res.node(), opr->output(4), res.node(),
mgb_cstr_log( mgb_cstr_log(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册