提交 ccc95ad2 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mgb/gopt): fix remove redundant typecvt pass

GitOrigin-RevId: 6a7957e362931302ac883e05031b6be0ca42b0d2
上级 9b413219
......@@ -635,15 +635,19 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const {
auto on_opr = [&](OperatorNodeBase* opr) {
if (auto tc0 = try_cast_as_op<opr::TypeCvt>(opr)) {
if (auto tc1 = try_cast_as_op<opr::TypeCvt>(tc0->input(0))) {
auto inp0 = rewriter.get_var(tc0->input(0));
if (auto tc1 = try_cast_as_op<opr::TypeCvt>(inp0)) {
if (should_remove(tc0->param(), tc1->param())) {
auto inp1 = tc1->input(0);
mgb_assert(!rewriter.has_manual_replace(inp1));
// TypeCvt returns the input var if its dtype is already
// dest_type
auto fold = opr::TypeCvt::make(tc1->input(0), tc0->param());
auto fold = opr::TypeCvt::make(inp1, tc0->param());
rewriter.replace_var(
tc0->output(0), fold.node(),
mgb_cstr_log("cvt_b(cvt_a(x)) -> cvt_b(x)"));
}
return;
}
}
rewriter.auto_replace_outputs(opr);
......
......@@ -395,6 +395,12 @@ TEST_PASS(RemoveRedundantTypeCvtPass, Basic) {
check(x_fp16, x_fp16_fp32_fp16);
#endif
auto x_i32 = opr::TypeCvt::make(x, dtype::Int32());
auto x_i32_i16 = opr::TypeCvt::make(x_i32, dtype::Int16());
auto x_i32_i16_i8 = opr::TypeCvt::make(x_i32_i16, dtype::Int8());
auto x_i8 = opr::TypeCvt::make(x, dtype::Int8());
check(x_i8, x_i32_i16_i8);
auto x_q8 = opr::TypeCvt::make(x, dtype::QuantizedS8(0.1f));
auto x_q8_fp32 = opr::TypeCvt::make(x_q8, dtype::Float32());
auto x_q8_fp32_q8 = opr::TypeCvt::make(x_q8_fp32, dtype::QuantizedS8(0.1f));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册