diff --git a/src/librustc_trans/trans/base.rs b/src/librustc_trans/trans/base.rs index f584de7c47f3b88a22befaa1328e41c4d56053bb..cea3adccda4a11a91f758dca5563b9963cf372fa 100644 --- a/src/librustc_trans/trans/base.rs +++ b/src/librustc_trans/trans/base.rs @@ -756,7 +756,7 @@ fn iter_variant<'blk, 'tcx, F>(cx: Block<'blk, 'tcx>, } pub fn cast_shift_expr_rhs(cx: Block, - op: ast::BinOp, + op: ast::BinOp_, lhs: ValueRef, rhs: ValueRef) -> ValueRef { @@ -765,24 +765,24 @@ pub fn cast_shift_expr_rhs(cx: Block, |a,b| ZExt(cx, a, b)) } -pub fn cast_shift_const_rhs(op: ast::BinOp, +pub fn cast_shift_const_rhs(op: ast::BinOp_, lhs: ValueRef, rhs: ValueRef) -> ValueRef { cast_shift_rhs(op, lhs, rhs, |a, b| unsafe { llvm::LLVMConstTrunc(a, b.to_ref()) }, |a, b| unsafe { llvm::LLVMConstZExt(a, b.to_ref()) }) } -pub fn cast_shift_rhs(op: ast::BinOp, - lhs: ValueRef, - rhs: ValueRef, - trunc: F, - zext: G) - -> ValueRef where +fn cast_shift_rhs(op: ast::BinOp_, + lhs: ValueRef, + rhs: ValueRef, + trunc: F, + zext: G) + -> ValueRef where F: FnOnce(ValueRef, Type) -> ValueRef, G: FnOnce(ValueRef, Type) -> ValueRef, { // Shifts may have any size int on the rhs - if ast_util::is_shift_binop(op.node) { + if ast_util::is_shift_binop(op) { let mut rhs_llty = val_ty(rhs); let mut lhs_llty = val_ty(lhs); if rhs_llty.kind() == Vector { rhs_llty = rhs_llty.element_type() } diff --git a/src/librustc_trans/trans/consts.rs b/src/librustc_trans/trans/consts.rs index 2a3fcd66195b340e456ac5301002f19058a9506f..c95b29f4e7afca2fad2ee6aeb9ca91551d46ec90 100644 --- a/src/librustc_trans/trans/consts.rs +++ b/src/librustc_trans/trans/consts.rs @@ -376,7 +376,7 @@ fn const_expr_unadjusted<'a, 'tcx>(cx: &CrateContext<'a, 'tcx>, let signed = ty::type_is_signed(intype); let (te2, _) = const_expr(cx, &**e2, param_substs); - let te2 = base::cast_shift_const_rhs(b, te1, te2); + let te2 = base::cast_shift_const_rhs(b.node, te1, te2); match b.node { ast::BiAdd => { diff --git a/src/librustc_trans/trans/expr.rs b/src/librustc_trans/trans/expr.rs index c316308c618bcb42a12ec9646e99ffd7d1663ab8..9dd3f60ec4fcbd85cd655aaf91cf705b53fadd35 100644 --- a/src/librustc_trans/trans/expr.rs +++ b/src/librustc_trans/trans/expr.rs @@ -1765,7 +1765,6 @@ fn trans_eager_binop<'blk, 'tcx>(bcx: Block<'blk, 'tcx>, }; let is_float = ty::type_is_fp(intype); let is_signed = ty::type_is_signed(intype); - let rhs = base::cast_shift_expr_rhs(bcx, op, lhs, rhs); let info = expr_info(binop_expr); let binop_debug_loc = binop_expr.debug_loc(); @@ -1838,13 +1837,17 @@ fn trans_eager_binop<'blk, 'tcx>(bcx: Block<'blk, 'tcx>, ast::BiBitOr => Or(bcx, lhs, rhs, binop_debug_loc), ast::BiBitAnd => And(bcx, lhs, rhs, binop_debug_loc), ast::BiBitXor => Xor(bcx, lhs, rhs, binop_debug_loc), - ast::BiShl => Shl(bcx, lhs, rhs, binop_debug_loc), + ast::BiShl => { + let (newbcx, res) = with_overflow_check( + bcx, OverflowOp::Shl, info, lhs_t, lhs, rhs, binop_debug_loc); + bcx = newbcx; + res + } ast::BiShr => { - if is_signed { - AShr(bcx, lhs, rhs, binop_debug_loc) - } else { - LShr(bcx, lhs, rhs, binop_debug_loc) - } + let (newbcx, res) = with_overflow_check( + bcx, OverflowOp::Shr, info, lhs_t, lhs, rhs, binop_debug_loc); + bcx = newbcx; + res } ast::BiEq | ast::BiNe | ast::BiLt | ast::BiGe | ast::BiLe | ast::BiGt => { if is_simd { @@ -2384,9 +2387,38 @@ enum OverflowOp { Add, Sub, Mul, + Shl, + Shr, } impl OverflowOp { + fn codegen_strategy(&self) -> OverflowCodegen { + use self::OverflowCodegen::{ViaIntrinsic, ViaInputCheck}; + match *self { + OverflowOp::Add => ViaIntrinsic(OverflowOpViaIntrinsic::Add), + OverflowOp::Sub => ViaIntrinsic(OverflowOpViaIntrinsic::Sub), + OverflowOp::Mul => ViaIntrinsic(OverflowOpViaIntrinsic::Mul), + + OverflowOp::Shl => ViaInputCheck(OverflowOpViaInputCheck::Shl), + OverflowOp::Shr => ViaInputCheck(OverflowOpViaInputCheck::Shr), + } + } +} + +enum OverflowCodegen { + ViaIntrinsic(OverflowOpViaIntrinsic), + ViaInputCheck(OverflowOpViaInputCheck), +} + +enum OverflowOpViaInputCheck { Shl, Shr, } + +enum OverflowOpViaIntrinsic { Add, Sub, Mul, } + +impl OverflowOpViaIntrinsic { + fn to_intrinsic<'blk, 'tcx>(&self, bcx: Block<'blk, 'tcx>, lhs_ty: Ty) -> ValueRef { + let name = self.to_intrinsic_name(bcx.tcx(), lhs_ty); + bcx.ccx().get_intrinsic(&name) + } fn to_intrinsic_name(&self, tcx: &ty::ctxt, ty: Ty) -> &'static str { use syntax::ast::IntTy::*; use syntax::ast::UintTy::*; @@ -2408,7 +2440,7 @@ fn to_intrinsic_name(&self, tcx: &ty::ctxt, ty: Ty) -> &'static str { }; match *self { - OverflowOp::Add => match new_sty { + OverflowOpViaIntrinsic::Add => match new_sty { ty_int(TyI8) => "llvm.sadd.with.overflow.i8", ty_int(TyI16) => "llvm.sadd.with.overflow.i16", ty_int(TyI32) => "llvm.sadd.with.overflow.i32", @@ -2421,7 +2453,7 @@ fn to_intrinsic_name(&self, tcx: &ty::ctxt, ty: Ty) -> &'static str { _ => unreachable!(), }, - OverflowOp::Sub => match new_sty { + OverflowOpViaIntrinsic::Sub => match new_sty { ty_int(TyI8) => "llvm.ssub.with.overflow.i8", ty_int(TyI16) => "llvm.ssub.with.overflow.i16", ty_int(TyI32) => "llvm.ssub.with.overflow.i32", @@ -2434,7 +2466,7 @@ fn to_intrinsic_name(&self, tcx: &ty::ctxt, ty: Ty) -> &'static str { _ => unreachable!(), }, - OverflowOp::Mul => match new_sty { + OverflowOpViaIntrinsic::Mul => match new_sty { ty_int(TyI8) => "llvm.smul.with.overflow.i8", ty_int(TyI16) => "llvm.smul.with.overflow.i16", ty_int(TyI32) => "llvm.smul.with.overflow.i32", @@ -2449,16 +2481,14 @@ fn to_intrinsic_name(&self, tcx: &ty::ctxt, ty: Ty) -> &'static str { }, } } -} - -fn with_overflow_check<'a, 'b>(bcx: Block<'a, 'b>, oop: OverflowOp, info: NodeIdAndSpan, - lhs_t: Ty, lhs: ValueRef, rhs: ValueRef, binop_debug_loc: DebugLoc) - -> (Block<'a, 'b>, ValueRef) { - if bcx.unreachable.get() { return (bcx, _Undef(lhs)); } - if bcx.ccx().check_overflow() { - let name = oop.to_intrinsic_name(bcx.tcx(), lhs_t); - let llfn = bcx.ccx().get_intrinsic(&name); + fn build_intrinsic_call<'blk, 'tcx>(&self, bcx: Block<'blk, 'tcx>, + info: NodeIdAndSpan, + lhs_t: Ty<'tcx>, lhs: ValueRef, + rhs: ValueRef, + binop_debug_loc: DebugLoc) + -> (Block<'blk, 'tcx>, ValueRef) { + let llfn = self.to_intrinsic(bcx, lhs_t); let val = Call(bcx, llfn, &[lhs, rhs], None, binop_debug_loc); let result = ExtractValue(bcx, val, 0); // iN operation result @@ -2477,11 +2507,118 @@ fn with_overflow_check<'a, 'b>(bcx: Block<'a, 'b>, oop: OverflowOp, info: NodeId InternedString::new("arithmetic operation overflowed"))); (bcx, result) + } +} + +impl OverflowOpViaInputCheck { + fn build_with_input_check<'blk, 'tcx>(&self, + bcx: Block<'blk, 'tcx>, + info: NodeIdAndSpan, + lhs_t: Ty<'tcx>, + lhs: ValueRef, + rhs: ValueRef, + binop_debug_loc: DebugLoc) + -> (Block<'blk, 'tcx>, ValueRef) + { + let lhs_llty = val_ty(lhs); + let rhs_llty = val_ty(rhs); + + // Panic if any bits are set outside of bits that we always + // mask in. + // + // Note that the mask's value is derived from the LHS type + // (since that is where the 32/64 distinction is relevant) but + // the mask's type must match the RHS type (since they will + // both be fed into a and-binop) + let invert_mask = !shift_mask_val(lhs_llty); + let invert_mask = C_integral(rhs_llty, invert_mask, true); + + let outer_bits = And(bcx, rhs, invert_mask, binop_debug_loc); + let cond = ICmp(bcx, llvm::IntNE, outer_bits, + C_integral(rhs_llty, 0, false), binop_debug_loc); + let result = match *self { + OverflowOpViaInputCheck::Shl => + build_unchecked_lshift(bcx, lhs, rhs, binop_debug_loc), + OverflowOpViaInputCheck::Shr => + build_unchecked_rshift(bcx, lhs_t, lhs, rhs, binop_debug_loc), + }; + let bcx = + base::with_cond(bcx, cond, |bcx| + controlflow::trans_fail(bcx, info, + InternedString::new("shift operation overflowed"))); + + (bcx, result) + } +} + +fn shift_mask_val(llty: Type) -> u64 { + // i8/u8 can shift by at most 7, i16/u16 by at most 15, etc. + llty.int_width() - 1 +} + +// To avoid UB from LLVM, these two functions mask RHS with an +// appropriate mask unconditionally (i.e. the fallback behavior for +// all shifts). For 32- and 64-bit types, this matches the semantics +// of Java. (See related discussion on #1877 and #10183.) + +fn build_unchecked_lshift<'blk, 'tcx>(bcx: Block<'blk, 'tcx>, + lhs: ValueRef, + rhs: ValueRef, + binop_debug_loc: DebugLoc) -> ValueRef { + let rhs = base::cast_shift_expr_rhs(bcx, ast::BinOp_::BiShl, lhs, rhs); + // #1877, #10183: Ensure that input is always valid + let rhs = shift_mask_rhs(bcx, rhs, binop_debug_loc); + Shl(bcx, lhs, rhs, binop_debug_loc) +} + +fn build_unchecked_rshift<'blk, 'tcx>(bcx: Block<'blk, 'tcx>, + lhs_t: Ty<'tcx>, + lhs: ValueRef, + rhs: ValueRef, + binop_debug_loc: DebugLoc) -> ValueRef { + let rhs = base::cast_shift_expr_rhs(bcx, ast::BinOp_::BiShr, lhs, rhs); + // #1877, #10183: Ensure that input is always valid + let rhs = shift_mask_rhs(bcx, rhs, binop_debug_loc); + let is_signed = ty::type_is_signed(lhs_t); + if is_signed { + AShr(bcx, lhs, rhs, binop_debug_loc) + } else { + LShr(bcx, lhs, rhs, binop_debug_loc) + } +} + +fn shift_mask_rhs<'blk, 'tcx>(bcx: Block<'blk, 'tcx>, + rhs: ValueRef, + debug_loc: DebugLoc) -> ValueRef { + let rhs_llty = val_ty(rhs); + let mask = shift_mask_val(rhs_llty); + And(bcx, rhs, C_integral(rhs_llty, mask, false), debug_loc) +} + +fn with_overflow_check<'blk, 'tcx>(bcx: Block<'blk, 'tcx>, oop: OverflowOp, info: NodeIdAndSpan, + lhs_t: Ty<'tcx>, lhs: ValueRef, + rhs: ValueRef, + binop_debug_loc: DebugLoc) + -> (Block<'blk, 'tcx>, ValueRef) { + if bcx.unreachable.get() { return (bcx, _Undef(lhs)); } + if bcx.ccx().check_overflow() { + + match oop.codegen_strategy() { + OverflowCodegen::ViaIntrinsic(oop) => + oop.build_intrinsic_call(bcx, info, lhs_t, lhs, rhs, binop_debug_loc), + OverflowCodegen::ViaInputCheck(oop) => + oop.build_with_input_check(bcx, info, lhs_t, lhs, rhs, binop_debug_loc), + } } else { let res = match oop { OverflowOp::Add => Add(bcx, lhs, rhs, binop_debug_loc), OverflowOp::Sub => Sub(bcx, lhs, rhs, binop_debug_loc), OverflowOp::Mul => Mul(bcx, lhs, rhs, binop_debug_loc), + + OverflowOp::Shl => + build_unchecked_lshift(bcx, lhs, rhs, binop_debug_loc), + OverflowOp::Shr => + build_unchecked_rshift(bcx, lhs_t, lhs, rhs, binop_debug_loc), }; (bcx, res) }