Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions compiler/rustc_mir_transform/src/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use tracing::debug;
/// once with `apply`. This is useful for MIR transformation passes.
pub(crate) struct MirPatch<'tcx> {
term_patch_map: FxHashMap<BasicBlock, TerminatorKind<'tcx>>,
/// Set of statements that should be replaced by `Nop`.
nop_statements: Vec<Location>,
new_blocks: Vec<BasicBlockData<'tcx>>,
new_statements: Vec<(Location, StatementKind<'tcx>)>,
new_locals: Vec<LocalDecl<'tcx>>,
Expand All @@ -33,6 +35,7 @@ impl<'tcx> MirPatch<'tcx> {
pub(crate) fn new(body: &Body<'tcx>) -> Self {
let mut result = MirPatch {
term_patch_map: Default::default(),
nop_statements: vec![],
new_blocks: vec![],
new_statements: vec![],
new_locals: vec![],
Expand Down Expand Up @@ -212,6 +215,15 @@ impl<'tcx> MirPatch<'tcx> {
self.term_patch_map.insert(block, new);
}

/// Mark given statement to be replaced by a `Nop`.
///
/// This method only works on statements from the initial body, and cannot be used to remove
/// statements from `add_statement` or `add_assign`.
#[tracing::instrument(level = "debug", skip(self))]
pub(crate) fn nop_statement(&mut self, loc: Location) {
self.nop_statements.push(loc);
}

/// Queues the insertion of a statement at a given location. The statement
/// currently at that location, and all statements that follow, are shifted
/// down. If multiple statements are queued for addition at the same
Expand Down Expand Up @@ -257,11 +269,8 @@ impl<'tcx> MirPatch<'tcx> {
bbs.extend(self.new_blocks);
body.local_decls.extend(self.new_locals);

// The order in which we patch terminators does not change the result.
#[allow(rustc::potential_query_instability)]
for (src, patch) in self.term_patch_map {
debug!("MirPatch: patching block {:?}", src);
bbs[src].terminator_mut().kind = patch;
for loc in self.nop_statements {
bbs[loc.block].statements[loc.statement_index].make_nop();
}

let mut new_statements = self.new_statements;
Expand All @@ -285,6 +294,17 @@ impl<'tcx> MirPatch<'tcx> {
.insert(loc.statement_index, Statement::new(source_info, stmt));
delta += 1;
}

// The order in which we patch terminators does not change the result.
#[allow(rustc::potential_query_instability)]
for (src, patch) in self.term_patch_map {
debug!("MirPatch: patching block {:?}", src);
let bb = &mut bbs[src];
if let TerminatorKind::Unreachable = patch {
bb.statements.clear();
}
bb.terminator_mut().kind = patch;
}
}

fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
Expand Down
19 changes: 12 additions & 7 deletions compiler/rustc_mir_transform/src/simplify_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;
use tracing::trace;

use crate::patch::MirPatch;

pub(super) enum SimplifyConstCondition {
AfterConstProp,
Final,
Expand All @@ -19,26 +21,27 @@ impl<'tcx> crate::MirPass<'tcx> for SimplifyConstCondition {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("Running SimplifyConstCondition on {:?}", body.source);
let typing_env = body.typing_env(tcx);
'blocks: for block in body.basic_blocks_mut() {
for stmt in block.statements.iter_mut() {
let mut patch = MirPatch::new(body);

'blocks: for (bb, block) in body.basic_blocks.iter_enumerated() {
for (statement_index, stmt) in block.statements.iter().enumerate() {
// Simplify `assume` of a known value: either a NOP or unreachable.
if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind
&& let NonDivergingIntrinsic::Assume(discr) = intrinsic
&& let Operand::Constant(c) = discr
&& let Some(constant) = c.const_.try_eval_bool(tcx, typing_env)
{
if constant {
stmt.make_nop();
patch.nop_statement(Location { block: bb, statement_index });
} else {
block.statements.clear();
block.terminator_mut().kind = TerminatorKind::Unreachable;
patch.patch_terminator(bb, TerminatorKind::Unreachable);
continue 'blocks;
}
}
}

let terminator = block.terminator_mut();
terminator.kind = match terminator.kind {
let terminator = block.terminator();
let terminator = match terminator.kind {
TerminatorKind::SwitchInt {
discr: Operand::Constant(ref c), ref targets, ..
} => {
Expand All @@ -58,7 +61,9 @@ impl<'tcx> crate::MirPass<'tcx> for SimplifyConstCondition {
},
_ => continue,
};
patch.patch_terminator(bb, terminator);
}
patch.apply(body);
}

fn is_required(&self) -> bool {
Expand Down
Loading