From cbe2f6095a3a15318aa54362beae3535a7b049a2 Mon Sep 17 00:00:00 2001 From: Wesley Wiser Date: Sun, 20 Oct 2019 23:48:31 -0400 Subject: [PATCH] Implement pass to remove branches on uninhabited variants --- src/librustc_mir/lib.rs | 1 + src/librustc_mir/transform/mod.rs | 3 + .../transform/uninhabited_enum_branching.rs | 126 ++++++++++ .../mir-opt/uninhabited_enum_branching.rs | 224 ++++++++++++++++++ 4 files changed, 354 insertions(+) create mode 100644 src/librustc_mir/transform/uninhabited_enum_branching.rs create mode 100644 src/test/mir-opt/uninhabited_enum_branching.rs diff --git a/src/librustc_mir/lib.rs b/src/librustc_mir/lib.rs index 4d604cb025c..be3bbf46f1c 100644 --- a/src/librustc_mir/lib.rs +++ b/src/librustc_mir/lib.rs @@ -26,6 +26,7 @@ #![feature(associated_type_bounds)] #![feature(range_is_empty)] #![feature(stmt_expr_attributes)] +#![feature(bool_to_option)] #![recursion_limit="256"] diff --git a/src/librustc_mir/transform/mod.rs b/src/librustc_mir/transform/mod.rs index dbe6c784592..e51dd719ae2 100644 --- a/src/librustc_mir/transform/mod.rs +++ b/src/librustc_mir/transform/mod.rs @@ -36,6 +36,7 @@ pub mod generator; pub mod inline; pub mod uniform_array_move_out; +pub mod uninhabited_enum_branching; pub(crate) fn provide(providers: &mut Providers<'_>) { self::qualify_consts::provide(providers); @@ -257,6 +258,8 @@ fn run_optimization_passes<'tcx>( // Optimizations begin. + &uninhabited_enum_branching::UninhabitedEnumBranching, + &simplify::SimplifyCfg::new("after-uninhabited-enum-branching"), &uniform_array_move_out::RestoreSubsliceArrayMoveOut::new(tcx), &inline::Inline, diff --git a/src/librustc_mir/transform/uninhabited_enum_branching.rs b/src/librustc_mir/transform/uninhabited_enum_branching.rs new file mode 100644 index 00000000000..a6c18aee6a8 --- /dev/null +++ b/src/librustc_mir/transform/uninhabited_enum_branching.rs @@ -0,0 +1,126 @@ +//! A pass that eliminates branches on uninhabited enum variants. + +use crate::transform::{MirPass, MirSource}; +use rustc::mir::{ + BasicBlock, BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, TerminatorKind, +}; +use rustc::ty::layout::{Abi, TyLayout, Variants}; +use rustc::ty::{Ty, TyCtxt}; + +pub struct UninhabitedEnumBranching; + +fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option { + if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator { + p.as_local() + } else { + None + } +} + +/// If the basic block terminates by switching on a discriminant, this returns the `Ty` the +/// discriminant is read from. Otherwise, returns None. +fn get_switched_on_type<'tcx>( + block_data: &BasicBlockData<'tcx>, + body: &Body<'tcx>, +) -> Option> { + let terminator = block_data.terminator(); + + // Only bother checking blocks which terminate by switching on a local. + if let Some(local) = get_discriminant_local(&terminator.kind) { + let stmt_before_term = (block_data.statements.len() > 0) + .then_with(|| &block_data.statements[block_data.statements.len() - 1].kind); + + if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term + { + if l.as_local() == Some(local) { + if let Some(r_local) = place.as_local() { + let ty = body.local_decls[r_local].ty; + + if ty.is_enum() { + return Some(ty); + } + } + } + } + } + + None +} + +fn variant_discriminants<'tcx>( + layout: &TyLayout<'tcx>, + ty: Ty<'tcx>, + tcx: TyCtxt<'tcx>, +) -> Vec { + match &layout.details.variants { + Variants::Single { index } => vec![index.as_u32() as u128], + Variants::Multiple { variants, .. } => variants + .iter_enumerated() + .filter_map(|(idx, layout)| { + (layout.abi != Abi::Uninhabited) + .then_with(|| ty.discriminant_for_variant(tcx, idx).unwrap().val) + }) + .collect(), + } +} + +impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { + fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) { + if source.promoted.is_some() { + return; + } + + trace!("UninhabitedEnumBranching starting for {:?}", source); + + let basic_block_count = body.basic_blocks().len(); + + for bb in 0..basic_block_count { + let bb = BasicBlock::from_usize(bb); + trace!("processing block {:?}", bb); + + let discriminant_ty = + if let Some(ty) = get_switched_on_type(&body.basic_blocks()[bb], body) { + ty + } else { + continue; + }; + + let layout = tcx.layout_of(tcx.param_env(source.def_id()).and(discriminant_ty)); + + let allowed_variants = if let Ok(layout) = layout { + variant_discriminants(&layout, discriminant_ty, tcx) + } else { + continue; + }; + + trace!("allowed_variants = {:?}", allowed_variants); + + if let TerminatorKind::SwitchInt { values, targets, .. } = + &mut body.basic_blocks_mut()[bb].terminator_mut().kind + { + let vals = &*values; + let zipped = vals.iter().zip(targets.into_iter()); + + let mut matched_values = Vec::with_capacity(allowed_variants.len()); + let mut matched_targets = Vec::with_capacity(allowed_variants.len() + 1); + + for (val, target) in zipped { + if allowed_variants.contains(val) { + matched_values.push(*val); + matched_targets.push(*target); + } else { + trace!("eliminating {:?} -> {:?}", val, target); + } + } + + // handle the "otherwise" branch + matched_targets.push(targets.pop().unwrap()); + + *values = matched_values.into(); + *targets = matched_targets; + } else { + unreachable!() + } + } + } +} diff --git a/src/test/mir-opt/uninhabited_enum_branching.rs b/src/test/mir-opt/uninhabited_enum_branching.rs new file mode 100644 index 00000000000..1f37ff1498d --- /dev/null +++ b/src/test/mir-opt/uninhabited_enum_branching.rs @@ -0,0 +1,224 @@ +enum Empty { } + +// test matching an enum with uninhabited variants +enum Test1 { + A(Empty), + B(Empty), + C +} + +// test an enum where the discriminants don't match the variant indexes +// (the optimization should do nothing here) +enum Test2 { + D = 4, + E = 5, +} + +fn main() { + match Test1::C { + Test1::A(_) => "A(Empty)", + Test1::B(_) => "B(Empty)", + Test1::C => "C", + }; + + match Test2::D { + Test2::D => "D", + Test2::E => "E", + }; +} + +// END RUST SOURCE +// +// START rustc.main.UninhabitedEnumBranching.before.mir +// let mut _0: (); +// let _1: &str; +// let mut _2: Test1; +// let mut _3: isize; +// let mut _4: &str; +// let mut _5: &str; +// let _6: &str; +// let mut _7: Test2; +// let mut _8: isize; +// let mut _9: &str; +// bb0: { +// StorageLive(_1); +// StorageLive(_2); +// _2 = Test1::C; +// _3 = discriminant(_2); +// switchInt(move _3) -> [0isize: bb3, 1isize: bb4, 2isize: bb1, otherwise: bb2]; +// } +// bb1: { +// StorageLive(_5); +// _5 = const "C"; +// _1 = &(*_5); +// StorageDead(_5); +// goto -> bb5; +// } +// bb2: { +// unreachable; +// } +// bb3: { +// _1 = const "A(Empty)"; +// goto -> bb5; +// } +// bb4: { +// StorageLive(_4); +// _4 = const "B(Empty)"; +// _1 = &(*_4); +// StorageDead(_4); +// goto -> bb5; +// } +// bb5: { +// StorageDead(_2); +// StorageDead(_1); +// StorageLive(_6); +// StorageLive(_7); +// _7 = Test2::D; +// _8 = discriminant(_7); +// switchInt(move _8) -> [4isize: bb8, 5isize: bb6, otherwise: bb7]; +// } +// bb6: { +// StorageLive(_9); +// _9 = const "E"; +// _6 = &(*_9); +// StorageDead(_9); +// goto -> bb9; +// } +// bb7: { +// unreachable; +// } +// bb8: { +// _6 = const "D"; +// goto -> bb9; +// } +// bb9: { +// StorageDead(_7); +// StorageDead(_6); +// _0 = (); +// return; +// } +// END rustc.main.UninhabitedEnumBranching.before.mir +// START rustc.main.UninhabitedEnumBranching.after.mir +// let mut _0: (); +// let _1: &str; +// let mut _2: Test1; +// let mut _3: isize; +// let mut _4: &str; +// let mut _5: &str; +// let _6: &str; +// let mut _7: Test2; +// let mut _8: isize; +// let mut _9: &str; +// bb0: { +// StorageLive(_1); +// StorageLive(_2); +// _2 = Test1::C; +// _3 = discriminant(_2); +// switchInt(move _3) -> [2isize: bb1, otherwise: bb2]; +// } +// bb1: { +// StorageLive(_5); +// _5 = const "C"; +// _1 = &(*_5); +// StorageDead(_5); +// goto -> bb5; +// } +// bb2: { +// unreachable; +// } +// bb3: { +// _1 = const "A(Empty)"; +// goto -> bb5; +// } +// bb4: { +// StorageLive(_4); +// _4 = const "B(Empty)"; +// _1 = &(*_4); +// StorageDead(_4); +// goto -> bb5; +// } +// bb5: { +// StorageDead(_2); +// StorageDead(_1); +// StorageLive(_6); +// StorageLive(_7); +// _7 = Test2::D; +// _8 = discriminant(_7); +// switchInt(move _8) -> [4isize: bb8, 5isize: bb6, otherwise: bb7]; +// } +// bb6: { +// StorageLive(_9); +// _9 = const "E"; +// _6 = &(*_9); +// StorageDead(_9); +// goto -> bb9; +// } +// bb7: { +// unreachable; +// } +// bb8: { +// _6 = const "D"; +// goto -> bb9; +// } +// bb9: { +// StorageDead(_7); +// StorageDead(_6); +// _0 = (); +// return; +// } +// END rustc.main.UninhabitedEnumBranching.after.mir +// START rustc.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir +// let mut _0: (); +// let _1: &str; +// let mut _2: Test1; +// let mut _3: isize; +// let mut _4: &str; +// let mut _5: &str; +// let _6: &str; +// let mut _7: Test2; +// let mut _8: isize; +// let mut _9: &str; +// bb0: { +// StorageLive(_1); +// StorageLive(_2); +// _2 = Test1::C; +// _3 = discriminant(_2); +// switchInt(move _3) -> [2isize: bb1, otherwise: bb2]; +// } +// bb1: { +// StorageLive(_5); +// _5 = const "C"; +// _1 = &(*_5); +// StorageDead(_5); +// StorageDead(_2); +// StorageDead(_1); +// StorageLive(_6); +// StorageLive(_7); +// _7 = Test2::D; +// _8 = discriminant(_7); +// switchInt(move _8) -> [4isize: bb5, 5isize: bb3, otherwise: bb4]; +// } +// bb2: { +// unreachable; +// } +// bb3: { +// StorageLive(_9); +// _9 = const "E"; +// _6 = &(*_9); +// StorageDead(_9); +// goto -> bb6; +// } +// bb4: { +// unreachable; +// } +// bb5: { +// _6 = const "D"; +// goto -> bb6; +// } +// bb6: { +// StorageDead(_7); +// StorageDead(_6); +// _0 = (); +// return; +// } +// END rustc.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir -- GitLab