diff --git a/src/gopt/impl/misc.cpp b/src/gopt/impl/misc.cpp index fd9efe63688421e0ac860ea213852dfbcb46ccd5..6f066e2dd5aaf8ea1d132324bc6cc245fefeab39 100644 --- a/src/gopt/impl/misc.cpp +++ b/src/gopt/impl/misc.cpp @@ -635,15 +635,19 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { auto on_opr = [&](OperatorNodeBase* opr) { if (auto tc0 = try_cast_as_op(opr)) { - if (auto tc1 = try_cast_as_op(tc0->input(0))) { + auto inp0 = rewriter.get_var(tc0->input(0)); + if (auto tc1 = try_cast_as_op(inp0)) { if (should_remove(tc0->param(), tc1->param())) { + auto inp1 = tc1->input(0); + mgb_assert(!rewriter.has_manual_replace(inp1)); // TypeCvt returns the input var if its dtype is already // dest_type - auto fold = opr::TypeCvt::make(tc1->input(0), tc0->param()); + auto fold = opr::TypeCvt::make(inp1, tc0->param()); rewriter.replace_var( tc0->output(0), fold.node(), mgb_cstr_log("cvt_b(cvt_a(x)) -> cvt_b(x)")); } + return; } } rewriter.auto_replace_outputs(opr); diff --git a/src/gopt/test/misc.cpp b/src/gopt/test/misc.cpp index d42f6defdbb656c3c795df83c10471876b47268d..fcf7fb271ed5165fbf989235da0f71201c6284cf 100644 --- a/src/gopt/test/misc.cpp +++ b/src/gopt/test/misc.cpp @@ -395,6 +395,12 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) { check(x_fp16, x_fp16_fp32_fp16); #endif + auto x_i32 = opr::TypeCvt::make(x, dtype::Int32()); + auto x_i32_i16 = opr::TypeCvt::make(x_i32, dtype::Int16()); + auto x_i32_i16_i8 = opr::TypeCvt::make(x_i32_i16, dtype::Int8()); + auto x_i8 = opr::TypeCvt::make(x, dtype::Int8()); + check(x_i8, x_i32_i16_i8); + auto x_q8 = opr::TypeCvt::make(x, dtype::QuantizedS8(0.1f)); auto x_q8_fp32 = opr::TypeCvt::make(x_q8, dtype::Float32()); auto x_q8_fp32_q8 = opr::TypeCvt::make(x_q8_fp32, dtype::QuantizedS8(0.1f));