From d0da4cfee7d95d37bdf41a53081996c31ec26d2f Mon Sep 17 00:00:00 2001 From: Brendan Zabarauskas Date: Fri, 2 May 2014 11:04:46 -0700 Subject: [PATCH] Implement comparison operators for int and uint SIMD vectors --- src/librustc/middle/trans/base.rs | 38 +++++++++++++++++++++++++ src/librustc/middle/trans/expr.rs | 27 +++++++++--------- src/librustc/middle/typeck/check/mod.rs | 23 +++++++++++++-- src/test/compile-fail/simd-binop.rs | 37 ++++++++++++++++++++++++ src/test/run-pass/simd-binop.rs | 30 ++++++++++++++++--- 5 files changed, 135 insertions(+), 20 deletions(-) create mode 100644 src/test/compile-fail/simd-binop.rs diff --git a/src/librustc/middle/trans/base.rs b/src/librustc/middle/trans/base.rs index 0999da60ad3..7e38ae29d30 100644 --- a/src/librustc/middle/trans/base.rs +++ b/src/librustc/middle/trans/base.rs @@ -619,6 +619,44 @@ fn die(cx: &Block) -> ! { } } +pub fn compare_simd_types( + cx: &Block, + lhs: ValueRef, + rhs: ValueRef, + t: ty::t, + size: uint, + op: ast::BinOp) + -> ValueRef { + match ty::get(t).sty { + ty::ty_float(_) => { + // The comparison operators for floating point vectors are challenging. + // LLVM outputs a `< size x i1 >`, but if we perform a sign extension + // then bitcast to a floating point vector, the result will be `-NaN` + // for each truth value. Because of this they are unsupported. + cx.sess().bug("compare_simd_types: comparison operators \ + not supported for floating point SIMD types") + }, + ty::ty_uint(_) | ty::ty_int(_) => { + let cmp = match op { + ast::BiEq => lib::llvm::IntEQ, + ast::BiNe => lib::llvm::IntNE, + ast::BiLt => lib::llvm::IntSLT, + ast::BiLe => lib::llvm::IntSLE, + ast::BiGt => lib::llvm::IntSGT, + ast::BiGe => lib::llvm::IntSGE, + _ => cx.sess().bug("compare_simd_types: must be a comparison operator"), + }; + let return_ty = Type::vector(&type_of(cx.ccx(), t), size as u64); + // LLVM outputs an `< size x i1 >`, so we need to perform a sign extension + // to get the correctly sized type. This will compile to a single instruction + // once the IR is converted to assembly if the SIMD instruction is supported + // by the target architecture. + SExt(cx, ICmp(cx, cmp, lhs, rhs), return_ty) + }, + _ => cx.sess().bug("compare_simd_types: invalid SIMD type"), + } +} + pub type val_and_ty_fn<'r,'b> = |&'b Block<'b>, ValueRef, ty::t|: 'r -> &'b Block<'b>; diff --git a/src/librustc/middle/trans/expr.rs b/src/librustc/middle/trans/expr.rs index 3fa5a9e085a..2e8c60c5dc5 100644 --- a/src/librustc/middle/trans/expr.rs +++ b/src/librustc/middle/trans/expr.rs @@ -1259,16 +1259,15 @@ fn trans_eager_binop<'a>( -> DatumBlock<'a, Expr> { let _icx = push_ctxt("trans_eager_binop"); - let mut intype = { + let tcx = bcx.tcx(); + let is_simd = ty::type_is_simd(tcx, lhs_t); + let intype = { if ty::type_is_bot(lhs_t) { rhs_t } + else if is_simd { ty::simd_type(tcx, lhs_t) } else { lhs_t } }; - let tcx = bcx.tcx(); - if ty::type_is_simd(tcx, intype) { - intype = ty::simd_type(tcx, intype); - } let is_float = ty::type_is_fp(intype); - let signed = ty::type_is_signed(intype); + let is_signed = ty::type_is_signed(intype); let rhs = base::cast_shift_expr_rhs(bcx, op, lhs, rhs); @@ -1293,7 +1292,7 @@ fn trans_eager_binop<'a>( // Only zero-check integers; fp /0 is NaN bcx = base::fail_if_zero(bcx, binop_expr.span, op, rhs, rhs_t); - if signed { + if is_signed { SDiv(bcx, lhs, rhs) } else { UDiv(bcx, lhs, rhs) @@ -1307,7 +1306,7 @@ fn trans_eager_binop<'a>( // Only zero-check integers; fp %0 is NaN bcx = base::fail_if_zero(bcx, binop_expr.span, op, rhs, rhs_t); - if signed { + if is_signed { SRem(bcx, lhs, rhs) } else { URem(bcx, lhs, rhs) @@ -1319,21 +1318,21 @@ fn trans_eager_binop<'a>( ast::BiBitXor => Xor(bcx, lhs, rhs), ast::BiShl => Shl(bcx, lhs, rhs), ast::BiShr => { - if signed { + if is_signed { AShr(bcx, lhs, rhs) } else { LShr(bcx, lhs, rhs) } } ast::BiEq | ast::BiNe | ast::BiLt | ast::BiGe | ast::BiLe | ast::BiGt => { if ty::type_is_bot(rhs_t) { C_bool(bcx.ccx(), false) - } else { - if !ty::type_is_scalar(rhs_t) { - bcx.tcx().sess.span_bug(binop_expr.span, - "non-scalar comparison"); - } + } else if ty::type_is_scalar(rhs_t) { let cmpr = base::compare_scalar_types(bcx, lhs, rhs, rhs_t, op); bcx = cmpr.bcx; ZExt(bcx, cmpr.val, Type::i8(bcx.ccx())) + } else if is_simd { + base::compare_simd_types(bcx, lhs, rhs, intype, ty::simd_size(tcx, lhs_t), op) + } else { + bcx.tcx().sess.span_bug(binop_expr.span, "comparison operator unsupported for type") } } _ => { diff --git a/src/librustc/middle/typeck/check/mod.rs b/src/librustc/middle/typeck/check/mod.rs index a0847baaea2..e7d1a85957f 100644 --- a/src/librustc/middle/typeck/check/mod.rs +++ b/src/librustc/middle/typeck/check/mod.rs @@ -2102,8 +2102,27 @@ fn check_binop(fcx: &FnCtxt, let result_t = match op { ast::BiEq | ast::BiNe | ast::BiLt | ast::BiLe | ast::BiGe | - ast::BiGt => ty::mk_bool(), - _ => lhs_t + ast::BiGt => { + if ty::type_is_simd(tcx, lhs_t) { + if ty::type_is_fp(ty::simd_type(tcx, lhs_t)) { + fcx.type_error_message(expr.span, + |actual| { + format!("binary comparison operation `{}` not supported \ + for floating point SIMD vector `{}`", + ast_util::binop_to_str(op), actual) + }, + lhs_t, + None + ); + ty::mk_err() + } else { + lhs_t + } + } else { + ty::mk_bool() + } + }, + _ => lhs_t, }; fcx.write_ty(expr.id, result_t); diff --git a/src/test/compile-fail/simd-binop.rs b/src/test/compile-fail/simd-binop.rs new file mode 100644 index 00000000000..281e879592d --- /dev/null +++ b/src/test/compile-fail/simd-binop.rs @@ -0,0 +1,37 @@ +// Copyright 2014 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// ignore-tidy-linelength + +#![allow(experimental)] + +use std::unstable::simd::f32x4; + +fn main() { + + let _ = f32x4(0.0, 0.0, 0.0, 0.0) == f32x4(0.0, 0.0, 0.0, 0.0); + //~^ ERROR binary comparison operation `==` not supported for floating point SIMD vector `std::unstable::simd::f32x4` + + let _ = f32x4(0.0, 0.0, 0.0, 0.0) != f32x4(0.0, 0.0, 0.0, 0.0); + //~^ ERROR binary comparison operation `!=` not supported for floating point SIMD vector `std::unstable::simd::f32x4` + + let _ = f32x4(0.0, 0.0, 0.0, 0.0) < f32x4(0.0, 0.0, 0.0, 0.0); + //~^ ERROR binary comparison operation `<` not supported for floating point SIMD vector `std::unstable::simd::f32x4` + + let _ = f32x4(0.0, 0.0, 0.0, 0.0) <= f32x4(0.0, 0.0, 0.0, 0.0); + //~^ ERROR binary comparison operation `<=` not supported for floating point SIMD vector `std::unstable::simd::f32x4` + + let _ = f32x4(0.0, 0.0, 0.0, 0.0) >= f32x4(0.0, 0.0, 0.0, 0.0); + //~^ ERROR binary comparison operation `>=` not supported for floating point SIMD vector `std::unstable::simd::f32x4` + + let _ = f32x4(0.0, 0.0, 0.0, 0.0) > f32x4(0.0, 0.0, 0.0, 0.0); + //~^ ERROR binary comparison operation `>` not supported for floating point SIMD vector `std::unstable::simd::f32x4` + +} diff --git a/src/test/run-pass/simd-binop.rs b/src/test/run-pass/simd-binop.rs index 30eda1296d1..efcd99a04ce 100644 --- a/src/test/run-pass/simd-binop.rs +++ b/src/test/run-pass/simd-binop.rs @@ -25,6 +25,8 @@ fn eq_i32x4(i32x4(x0, x1, x2, x3): i32x4, i32x4(y0, y1, y2, y3): i32x4) -> bool } pub fn main() { + // arithmetic operators + assert!(eq_u32x4(u32x4(1, 2, 3, 4) + u32x4(4, 3, 2, 1), u32x4(5, 5, 5, 5))); assert!(eq_u32x4(u32x4(4, 5, 6, 7) - u32x4(4, 3, 2, 1), u32x4(0, 2, 4, 6))); assert!(eq_u32x4(u32x4(1, 2, 3, 4) * u32x4(4, 3, 2, 1), u32x4(4, 6, 6, 4))); @@ -43,8 +45,28 @@ pub fn main() { assert!(eq_i32x4(i32x4(1, 2, 3, 4) << i32x4(4, 3, 2, 1), i32x4(16, 16, 12, 8))); assert!(eq_i32x4(i32x4(1, 2, 3, 4) >> i32x4(4, 3, 2, 1), i32x4(0, 0, 0, 2))); - assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) + f32x4(4.0, 3.0, 2.0, 1.0), f32x4(5.0, 5.0, 5.0, 5.0))); - assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) - f32x4(4.0, 3.0, 2.0, 1.0), f32x4(-3.0, -1.0, 1.0, 3.0))); - assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) * f32x4(4.0, 3.0, 2.0, 1.0), f32x4(4.0, 6.0, 6.0, 4.0))); - assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) / f32x4(4.0, 4.0, 2.0, 1.0), f32x4(0.25, 0.5, 1.5, 4.0))); + assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) + f32x4(4.0, 3.0, 2.0, 1.0), + f32x4(5.0, 5.0, 5.0, 5.0))); + assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) - f32x4(4.0, 3.0, 2.0, 1.0), + f32x4(-3.0, -1.0, 1.0, 3.0))); + assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) * f32x4(4.0, 3.0, 2.0, 1.0), + f32x4(4.0, 6.0, 6.0, 4.0))); + assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) / f32x4(4.0, 4.0, 2.0, 1.0), + f32x4(0.25, 0.5, 1.5, 4.0))); + + // comparison operators + + assert!(eq_u32x4(u32x4(1, 2, 3, 4) == u32x4(3, 2, 1, 0), u32x4(0, !0, 0, 0))); + assert!(eq_u32x4(u32x4(1, 2, 3, 4) != u32x4(3, 2, 1, 0), u32x4(!0, 0, !0, !0))); + assert!(eq_u32x4(u32x4(1, 2, 3, 4) < u32x4(3, 2, 1, 0), u32x4(!0, 0, 0, 0))); + assert!(eq_u32x4(u32x4(1, 2, 3, 4) <= u32x4(3, 2, 1, 0), u32x4(!0, !0, 0, 0))); + assert!(eq_u32x4(u32x4(1, 2, 3, 4) >= u32x4(3, 2, 1, 0), u32x4(0, !0, !0, !0))); + assert!(eq_u32x4(u32x4(1, 2, 3, 4) > u32x4(3, 2, 1, 0), u32x4(0, 0, !0, !0))); + + assert!(eq_i32x4(i32x4(1, 2, 3, 4) == i32x4(3, 2, 1, 0), i32x4(0, !0, 0, 0))); + assert!(eq_i32x4(i32x4(1, 2, 3, 4) != i32x4(3, 2, 1, 0), i32x4(!0, 0, !0, !0))); + assert!(eq_i32x4(i32x4(1, 2, 3, 4) < i32x4(3, 2, 1, 0), i32x4(!0, 0, 0, 0))); + assert!(eq_i32x4(i32x4(1, 2, 3, 4) <= i32x4(3, 2, 1, 0), i32x4(!0, !0, 0, 0))); + assert!(eq_i32x4(i32x4(1, 2, 3, 4) >= i32x4(3, 2, 1, 0), i32x4(0, !0, !0, !0))); + assert!(eq_i32x4(i32x4(1, 2, 3, 4) > i32x4(3, 2, 1, 0), i32x4(0, 0, !0, !0))); } -- GitLab