提交 fade97d4 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix convert batchnorm to elemwise pass issue

GitOrigin-RevId: eda7f1ab95f9f15448cdcafa2197fa1edd3e7946
上级 81b6a733
...@@ -1592,7 +1592,8 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { ...@@ -1592,7 +1592,8 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
SymbolVar bias = {rewriter.get_var(bn->input(2))}; SymbolVar bias = {rewriter.get_var(bn->input(2))};
SymbolVar mean = {rewriter.get_var(bn->input(3))}; SymbolVar mean = {rewriter.get_var(bn->input(3))};
SymbolVar variance = {rewriter.get_var(bn->input(4))}; SymbolVar variance = {rewriter.get_var(bn->input(4))};
SymbolVar invsqrt_variance = opr::PowC::make(variance, {-0.5}); SymbolVar invsqrt_variance = opr::PowC::make(variance
+ variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5});
auto res = scale * (x - mean) * invsqrt_variance + bias; auto res = scale * (x - mean) * invsqrt_variance + bias;
rewriter.replace_var( rewriter.replace_var(
opr->output(4), res.node(), opr->output(4), res.node(),
......
...@@ -1404,7 +1404,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) { ...@@ -1404,7 +1404,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) {
auto func = graph->compile({make_callback_copy(y, host_y), auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)}); make_callback_copy(y_opt, host_y_opt)});
func->execute(); func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
} }
TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { TEST(TestGoptInference, ConvBiasNonlinearityFusePass) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册