From fade97d4ef9c1c709b97643b2e965dc116898bf9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 16 Sep 2020 10:50:09 +0800 Subject: [PATCH] fix(mgb/gopt): fix convert batchnorm to elemwise pass issue GitOrigin-RevId: eda7f1ab95f9f15448cdcafa2197fa1edd3e7946 --- src/gopt/impl/inference.cpp | 3 ++- src/gopt/test/inference.cpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index a7739d588..9e40def94 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1592,7 +1592,8 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { SymbolVar bias = {rewriter.get_var(bn->input(2))}; SymbolVar mean = {rewriter.get_var(bn->input(3))}; SymbolVar variance = {rewriter.get_var(bn->input(4))}; - SymbolVar invsqrt_variance = opr::PowC::make(variance, {-0.5}); + SymbolVar invsqrt_variance = opr::PowC::make(variance + + variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5}); auto res = scale * (x - mean) * invsqrt_variance + bias; rewriter.replace_var( opr->output(4), res.node(), diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index c6a7d7950..44989430f 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1404,7 +1404,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) { auto func = graph->compile({make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); func->execute(); - MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); } TEST(TestGoptInference, ConvBiasNonlinearityFusePass) { -- GitLab