diff --git a/.github/workflows/ci_windows.yml b/.github/workflows/ci_windows.yml index c166f025..8917729a 100644 --- a/.github/workflows/ci_windows.yml +++ b/.github/workflows/ci_windows.yml @@ -79,7 +79,7 @@ jobs: run: cargo build --all-features -p cust_raw - name: Build - run: cargo build --workspace --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "vecadd*" --exclude "gemm*" --exclude "ex*" --exclude "cudnn*" --exclude "sha2*" + run: cargo build --workspace --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "vecadd*" --exclude "gemm*" --exclude "ex*" --exclude "cudnn*" --exclude "sha2*" --exclude "i128*" # Don't currently test because many tests rely on the system having a CUDA GPU # - name: Test @@ -88,7 +88,7 @@ jobs: - name: Check documentation env: RUSTDOCFLAGS: -Dwarnings - run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "vecadd*" --exclude "gemm*" --exclude "ex*" --exclude "cudnn*" --exclude "sha2*" --exclude "cust_raw" + run: cargo doc --workspace --all-features --document-private-items --no-deps --exclude "optix*" --exclude "path-tracer" --exclude "denoiser" --exclude "vecadd*" --exclude "gemm*" --exclude "ex*" --exclude "cudnn*" --exclude "sha2*" --exclude "i128*"--exclude "cust_raw" # Disabled due to dll issues, someone with Windows knowledge needed # - name: Compiletest # run: cargo run -p compiletests --release --no-default-features -- --target-arch compute_61,compute_70,compute_90 diff --git a/Cargo.toml b/Cargo.toml index 18931ef0..56592723 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ members = [ "examples/cuda/path_tracer/kernels", "examples/cuda/sha2_crates_io", "examples/cuda/sha2_crates_io/kernels", + "examples/cuda/i128_demo", + "examples/cuda/i128_demo/kernels", "examples/optix/*", "tests/compiletests", diff --git a/crates/rustc_codegen_nvvm/libintrinsics.bc b/crates/rustc_codegen_nvvm/libintrinsics.bc index c22e92db..ec9a7642 100644 Binary files a/crates/rustc_codegen_nvvm/libintrinsics.bc and b/crates/rustc_codegen_nvvm/libintrinsics.bc differ diff --git a/crates/rustc_codegen_nvvm/libintrinsics.ll b/crates/rustc_codegen_nvvm/libintrinsics.ll index d9cb5e2d..5da51b20 100644 --- a/crates/rustc_codegen_nvvm/libintrinsics.ll +++ b/crates/rustc_codegen_nvvm/libintrinsics.ll @@ -239,113 +239,6 @@ start: } declare {i16, i1} @llvm.umul.with.overflow.i16(i16, i16) #0 -; This is a bit weird, we need to use functions defined in rust crates (compiler_builtins) -; as intrinsics in the codegen, but we can't directly use their name, otherwise we will have -; really odd and incorrect behavior in the crate theyre defined in. So we need to make a wrapper for them that is opaque -; to the codegen, which is what this is doing. - -define {<2 x i64>, i1} @__nvvm_i128_addo(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call {<2 x i64>, i1} @__rust_i128_addo(<2 x i64> %0, <2 x i64> %1) - ret {<2 x i64>, i1} %2 -} -declare {<2 x i64>, i1} @__rust_i128_addo(<2 x i64>, <2 x i64>) #0 - -define {<2 x i64>, i1} @__nvvm_u128_addo(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call {<2 x i64>, i1} @__rust_u128_addo(<2 x i64> %0, <2 x i64> %1) - ret {<2 x i64>, i1} %2 -} -declare {<2 x i64>, i1} @__rust_u128_addo(<2 x i64>, <2 x i64>) #0 - -define {<2 x i64>, i1} @__nvvm_i128_subo(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call {<2 x i64>, i1} @__rust_i128_subo(<2 x i64> %0, <2 x i64> %1) - ret {<2 x i64>, i1} %2 -} -declare {<2 x i64>, i1} @__rust_i128_subo(<2 x i64>, <2 x i64>) #0 - -define {<2 x i64>, i1} @__nvvm_u128_subo(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call {<2 x i64>, i1} @__rust_u128_subo(<2 x i64> %0, <2 x i64> %1) - ret {<2 x i64>, i1} %2 -} -declare {<2 x i64>, i1} @__rust_u128_subo(<2 x i64>, <2 x i64>) #0 - -define {<2 x i64>, i1} @__nvvm_i128_mulo(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call {<2 x i64>, i1} @__rust_i128_mulo(<2 x i64> %0, <2 x i64> %1) - ret {<2 x i64>, i1} %2 -} -declare {<2 x i64>, i1} @__rust_i128_mulo(<2 x i64>, <2 x i64>) #0 - -define {<2 x i64>, i1} @__nvvm_u128_mulo(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call {<2 x i64>, i1} @__rust_u128_mulo(<2 x i64> %0, <2 x i64> %1) - ret {<2 x i64>, i1} %2 -} -declare {<2 x i64>, i1} @__rust_u128_mulo(<2 x i64>, <2 x i64>) #0 - -; Division operations from compiler-builtins -define <2 x i64> @__nvvm_divti3(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call <2 x i64> @__divti3(<2 x i64> %0, <2 x i64> %1) - ret <2 x i64> %2 -} -declare <2 x i64> @__divti3(<2 x i64>, <2 x i64>) #0 - -define <2 x i64> @__nvvm_udivti3(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call <2 x i64> @__udivti3(<2 x i64> %0, <2 x i64> %1) - ret <2 x i64> %2 -} -declare <2 x i64> @__udivti3(<2 x i64>, <2 x i64>) #0 - -; Remainder operations from compiler-builtins -define <2 x i64> @__nvvm_modti3(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call <2 x i64> @__modti3(<2 x i64> %0, <2 x i64> %1) - ret <2 x i64> %2 -} -declare <2 x i64> @__modti3(<2 x i64>, <2 x i64>) #0 - -define <2 x i64> @__nvvm_umodti3(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call <2 x i64> @__umodti3(<2 x i64> %0, <2 x i64> %1) - ret <2 x i64> %2 -} -declare <2 x i64> @__umodti3(<2 x i64>, <2 x i64>) #0 - -; Multiplication from compiler-builtins -define <2 x i64> @__nvvm_multi3(<2 x i64>, <2 x i64>) #0 { -start: - %2 = call <2 x i64> @__multi3(<2 x i64> %0, <2 x i64> %1) - ret <2 x i64> %2 -} -declare <2 x i64> @__multi3(<2 x i64>, <2 x i64>) #0 - -; Shift operations from compiler-builtins -define <2 x i64> @__nvvm_ashlti3(<2 x i64>, i32) #0 { -start: - %2 = call <2 x i64> @__ashlti3(<2 x i64> %0, i32 %1) - ret <2 x i64> %2 -} -declare <2 x i64> @__ashlti3(<2 x i64>, i32) #0 - -define <2 x i64> @__nvvm_ashrti3(<2 x i64>, i32) #0 { -start: - %2 = call <2 x i64> @__ashrti3(<2 x i64> %0, i32 %1) - ret <2 x i64> %2 -} -declare <2 x i64> @__ashrti3(<2 x i64>, i32) #0 - -define <2 x i64> @__nvvm_lshrti3(<2 x i64>, i32) #0 { -start: - %2 = call <2 x i64> @__lshrti3(<2 x i64> %0, i32 %1) - ret <2 x i64> %2 -} -declare <2 x i64> @__lshrti3(<2 x i64>, i32) #0 - ; Required because we need to explicitly generate { i32, i1 } for the following intrinsics ; except rustc will not generate them (it will make { i32, i8 }) which libnvvm rejects. diff --git a/crates/rustc_codegen_nvvm/src/builder.rs b/crates/rustc_codegen_nvvm/src/builder.rs index bee8d182..05f879ac 100644 --- a/crates/rustc_codegen_nvvm/src/builder.rs +++ b/crates/rustc_codegen_nvvm/src/builder.rs @@ -32,6 +32,8 @@ use crate::int_replace::{get_transformed_type, transmute_llval}; use crate::llvm::{self, BasicBlock, Type, Value}; use crate::ty::LayoutLlvmExt; +mod emulate_i128; + pub(crate) enum CountZerosKind { Leading, Trailing, @@ -133,7 +135,7 @@ impl<'ll, 'tcx> Deref for Builder<'_, 'll, 'tcx> { macro_rules! imath_builder_methods { ($($self_:ident.$name:ident($($arg:ident),*) => $llvm_capi:ident => $op:block)+) => { $(fn $name(&mut $self_, $($arg: &'ll Value),*) -> &'ll Value { - // Dispatch to i128 emulation or `compiler_builtins`-based intrinsic + // Dispatch to i128 emulation when any operand is 128 bits wide. if $($self_.is_i128($arg))||* $op else { @@ -319,40 +321,30 @@ impl<'ll, 'tcx, 'a> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { self.unchecked_usub(a, b) => LLVMBuildNUWSub => { self.emulate_i128_sub(a, b) } self.unchecked_ssub(a, b) => LLVMBuildNSWSub => { self.emulate_i128_sub(a, b) } - self.mul(a, b) => LLVMBuildMul => { self.call_intrinsic("__nvvm_multi3", &[a, b]) } - self.unchecked_umul(a, b) => LLVMBuildNUWMul => { - self.call_intrinsic("__nvvm_multi3", &[a, b]) - } - self.unchecked_smul(a, b) => LLVMBuildNSWMul => { - self.call_intrinsic("__nvvm_multi3", &[a, b]) - } + self.mul(a, b) => LLVMBuildMul => { self.emulate_i128_mul(a, b) } + self.unchecked_umul(a, b) => LLVMBuildNUWMul => { self.emulate_i128_mul(a, b) } + self.unchecked_smul(a, b) => LLVMBuildNSWMul => { self.emulate_i128_mul(a, b) } - self.udiv(a, b) => LLVMBuildUDiv => { self.call_intrinsic("__nvvm_udivti3", &[a, b]) } - self.exactudiv(a, b) => LLVMBuildExactUDiv => { - self.call_intrinsic("__nvvm_udivti3", &[a, b]) - } - self.sdiv(a, b) => LLVMBuildSDiv => { self.call_intrinsic("__nvvm_divti3", &[a, b]) } - self.exactsdiv(a, b) => LLVMBuildExactSDiv => { - self.call_intrinsic("__nvvm_divti3", &[a, b]) - } - self.urem(a, b) => LLVMBuildURem => { self.call_intrinsic("__nvvm_umodti3", &[a, b]) } - self.srem(a, b) => LLVMBuildSRem => { self.call_intrinsic("__nvvm_modti3", &[a, b]) } + self.udiv(a, b) => LLVMBuildUDiv => { self.emulate_i128_udiv(a, b) } + self.exactudiv(a, b) => LLVMBuildExactUDiv => { self.emulate_i128_udiv(a, b) } + self.sdiv(a, b) => LLVMBuildSDiv => { self.emulate_i128_sdiv(a, b) } + self.exactsdiv(a, b) => LLVMBuildExactSDiv => { self.emulate_i128_sdiv(a, b) } + self.urem(a, b) => LLVMBuildURem => { self.emulate_i128_urem(a, b) } + self.srem(a, b) => LLVMBuildSRem => { self.emulate_i128_srem(a, b) } self.shl(a, b) => LLVMBuildShl => { - // Convert shift amount to i32 for compiler-builtins. let b = self.trunc(b, self.type_i32()); - self.call_intrinsic("__nvvm_ashlti3", &[a, b]) + self.emulate_i128_shl(a, b) } self.lshr(a, b) => LLVMBuildLShr => { - // Convert shift amount to i32 for compiler-builtins. let b = self.trunc(b, self.type_i32()); - self.call_intrinsic("__nvvm_lshrti3", &[a, b]) + self.emulate_i128_lshr(a, b) } self.ashr(a, b) => LLVMBuildAShr => { - // Convert shift amount to i32 for compiler-builtins. let b = self.trunc(b, self.type_i32()); - self.call_intrinsic("__nvvm_ashrti3", &[a, b]) + self.emulate_i128_ashr(a, b) } + self.and(a, b) => LLVMBuildAnd => { self.emulate_i128_and(a, b) } self.or(a, b) => LLVMBuildOr => { self.emulate_i128_or(a, b) } self.xor(a, b) => LLVMBuildXor => { self.emulate_i128_xor(a, b) } @@ -404,19 +396,39 @@ impl<'ll, 'tcx, 'a> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { _ => panic!("tried to get overflow intrinsic for op applied to non-int type"), }; + match (oop, new_kind) { + (OverflowOp::Add, Int(I128)) => { + return self.emulate_i128_add_with_overflow(lhs, rhs, true); + } + (OverflowOp::Add, Uint(U128)) => { + return self.emulate_i128_add_with_overflow(lhs, rhs, false); + } + (OverflowOp::Sub, Int(I128)) => { + return self.emulate_i128_sub_with_overflow(lhs, rhs, true); + } + (OverflowOp::Sub, Uint(U128)) => { + return self.emulate_i128_sub_with_overflow(lhs, rhs, false); + } + (OverflowOp::Mul, Int(I128)) => { + return self.emulate_i128_mul_with_overflow(lhs, rhs, true); + } + (OverflowOp::Mul, Uint(U128)) => { + return self.emulate_i128_mul_with_overflow(lhs, rhs, false); + } + _ => {} + } + let name = match oop { OverflowOp::Add => match new_kind { Int(I8) => "__nvvm_i8_addo", Int(I16) => "llvm.sadd.with.overflow.i16", Int(I32) => "llvm.sadd.with.overflow.i32", Int(I64) => "llvm.sadd.with.overflow.i64", - Int(I128) => "__nvvm_i128_addo", Uint(U8) => "__nvvm_u8_addo", Uint(U16) => "llvm.uadd.with.overflow.i16", Uint(U32) => "llvm.uadd.with.overflow.i32", Uint(U64) => "llvm.uadd.with.overflow.i64", - Uint(U128) => "__nvvm_u128_addo", _ => unreachable!(), }, OverflowOp::Sub => match new_kind { @@ -424,13 +436,11 @@ impl<'ll, 'tcx, 'a> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { Int(I16) => "llvm.ssub.with.overflow.i16", Int(I32) => "llvm.ssub.with.overflow.i32", Int(I64) => "llvm.ssub.with.overflow.i64", - Int(I128) => "__nvvm_i128_subo", Uint(U8) => "__nvvm_u8_subo", Uint(U16) => "llvm.usub.with.overflow.i16", Uint(U32) => "llvm.usub.with.overflow.i32", Uint(U64) => "llvm.usub.with.overflow.i64", - Uint(U128) => "__nvvm_u128_subo", _ => unreachable!(), }, @@ -439,13 +449,11 @@ impl<'ll, 'tcx, 'a> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { Int(I16) => "llvm.smul.with.overflow.i16", Int(I32) => "llvm.smul.with.overflow.i32", Int(I64) => "llvm.smul.with.overflow.i64", - Int(I128) => "__nvvm_i128_mulo", Uint(U8) => "__nvvm_u8_mulo", Uint(U16) => "llvm.umul.with.overflow.i16", Uint(U32) => "llvm.umul.with.overflow.i32", Uint(U64) => "llvm.umul.with.overflow.i64", - Uint(U128) => "__nvvm_u128_mulo", _ => unreachable!(), }, @@ -1333,244 +1341,6 @@ impl<'ll> StaticBuilderMethods for Builder<'_, 'll, '_> { } impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> { - // Helper function to check if a value is 128-bit integer - fn is_i128(&self, val: &'ll Value) -> bool { - let ty = self.val_ty(val); - if unsafe { llvm::LLVMRustGetTypeKind(ty) == llvm::TypeKind::Integer } { - unsafe { llvm::LLVMGetIntTypeWidth(ty) == 128 } - } else { - false - } - } - - // Helper to split i128 into low and high u64 parts - fn split_i128(&mut self, val: &'ll Value) -> (&'ll Value, &'ll Value) { - let i64_ty = self.type_i64(); - let const_64 = self.const_u128(64); - - let lo = self.trunc(val, i64_ty); - let shifted = unsafe { llvm::LLVMBuildLShr(self.llbuilder, val, const_64, UNNAMED) }; - let hi = self.trunc(shifted, i64_ty); - - (lo, hi) - } - - // Helper to combine two u64 values into i128 - fn combine_i128(&mut self, lo: &'ll Value, hi: &'ll Value) -> &'ll Value { - let i128_ty = self.type_i128(); - let const_64 = self.const_u128(64); - - let lo_ext = self.zext(lo, i128_ty); - let hi_ext = self.zext(hi, i128_ty); - let hi_shifted = unsafe { llvm::LLVMBuildShl(self.llbuilder, hi_ext, const_64, UNNAMED) }; - unsafe { llvm::LLVMBuildOr(self.llbuilder, lo_ext, hi_shifted, UNNAMED) } - } - - // Emulate 128-bit addition using two 64-bit additions with carry - fn emulate_i128_add(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { - let i64_ty = self.type_i64(); - let (lhs_lo, lhs_hi) = self.split_i128(lhs); - let (rhs_lo, rhs_hi) = self.split_i128(rhs); - - // Add low parts - let sum_lo = unsafe { llvm::LLVMBuildAdd(self.llbuilder, lhs_lo, rhs_lo, UNNAMED) }; - - // Check for carry from low addition - let carry = self.icmp(IntPredicate::IntULT, sum_lo, lhs_lo); - let carry_ext = self.zext(carry, i64_ty); - - // Add high parts with carry - let sum_hi_temp = unsafe { llvm::LLVMBuildAdd(self.llbuilder, lhs_hi, rhs_hi, UNNAMED) }; - let sum_hi = unsafe { llvm::LLVMBuildAdd(self.llbuilder, sum_hi_temp, carry_ext, UNNAMED) }; - - self.combine_i128(sum_lo, sum_hi) - } - - // Emulate 128-bit subtraction - fn emulate_i128_sub(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { - let i64_ty = self.type_i64(); - let (lhs_lo, lhs_hi) = self.split_i128(lhs); - let (rhs_lo, rhs_hi) = self.split_i128(rhs); - - // Subtract low parts - let diff_lo = unsafe { llvm::LLVMBuildSub(self.llbuilder, lhs_lo, rhs_lo, UNNAMED) }; - - // Check for borrow - let borrow = self.icmp(IntPredicate::IntUGT, rhs_lo, lhs_lo); - let borrow_ext = self.zext(borrow, i64_ty); - - // Subtract high parts with borrow - let diff_hi_temp = unsafe { llvm::LLVMBuildSub(self.llbuilder, lhs_hi, rhs_hi, UNNAMED) }; - let diff_hi = - unsafe { llvm::LLVMBuildSub(self.llbuilder, diff_hi_temp, borrow_ext, UNNAMED) }; - - self.combine_i128(diff_lo, diff_hi) - } - - // Emulate 128-bit bitwise AND - fn emulate_i128_and(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { - let (lhs_lo, lhs_hi) = self.split_i128(lhs); - let (rhs_lo, rhs_hi) = self.split_i128(rhs); - - let and_lo = unsafe { llvm::LLVMBuildAnd(self.llbuilder, lhs_lo, rhs_lo, UNNAMED) }; - let and_hi = unsafe { llvm::LLVMBuildAnd(self.llbuilder, lhs_hi, rhs_hi, UNNAMED) }; - - self.combine_i128(and_lo, and_hi) - } - - // Emulate 128-bit bitwise OR - fn emulate_i128_or(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { - let (lhs_lo, lhs_hi) = self.split_i128(lhs); - let (rhs_lo, rhs_hi) = self.split_i128(rhs); - - let or_lo = unsafe { llvm::LLVMBuildOr(self.llbuilder, lhs_lo, rhs_lo, UNNAMED) }; - let or_hi = unsafe { llvm::LLVMBuildOr(self.llbuilder, lhs_hi, rhs_hi, UNNAMED) }; - - self.combine_i128(or_lo, or_hi) - } - - // Emulate 128-bit bitwise XOR - fn emulate_i128_xor(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { - let (lhs_lo, lhs_hi) = self.split_i128(lhs); - let (rhs_lo, rhs_hi) = self.split_i128(rhs); - - let xor_lo = unsafe { llvm::LLVMBuildXor(self.llbuilder, lhs_lo, rhs_lo, UNNAMED) }; - let xor_hi = unsafe { llvm::LLVMBuildXor(self.llbuilder, lhs_hi, rhs_hi, UNNAMED) }; - - self.combine_i128(xor_lo, xor_hi) - } - - // Emulate 128-bit bitwise NOT - fn emulate_i128_not(&mut self, val: &'ll Value) -> &'ll Value { - let (lo, hi) = self.split_i128(val); - - let not_lo = unsafe { llvm::LLVMBuildNot(self.llbuilder, lo, UNNAMED) }; - let not_hi = unsafe { llvm::LLVMBuildNot(self.llbuilder, hi, UNNAMED) }; - - self.combine_i128(not_lo, not_hi) - } - - // Emulate 128-bit negation (two's complement) - fn emulate_i128_neg(&mut self, val: &'ll Value) -> &'ll Value { - // Two's complement: ~val + 1 - let not_val = self.emulate_i128_not(val); - let one = self.const_u128(1); - self.emulate_i128_add(not_val, one) - } - - pub(crate) fn emulate_i128_bswap(&mut self, val: &'ll Value) -> &'ll Value { - // Split the 128-bit value into two 64-bit halves - let (lo, hi) = self.split_i128(val); - - // Byte-swap each 64-bit half using the LLVM intrinsic (which exists in LLVM 7.1) - let swapped_lo = self.call_intrinsic("llvm.bswap.i64", &[lo]); - let swapped_hi = self.call_intrinsic("llvm.bswap.i64", &[hi]); - - // Swap the halves: the high part becomes low and vice versa - self.combine_i128(swapped_hi, swapped_lo) - } - - pub(crate) fn emulate_i128_count_zeros( - &mut self, - val: &'ll Value, - kind: CountZerosKind, - is_nonzero: bool, - ) -> &'ll Value { - // Split the 128-bit value into two 64-bit halves - let (lo, hi) = self.split_i128(val); - - match kind { - CountZerosKind::Leading => { - // Count leading zeros: check high part first - let hi_is_zero = self.icmp(IntPredicate::IntEQ, hi, self.const_u64(0)); - let hi_ctlz = - self.call_intrinsic("llvm.ctlz.i64", &[hi, self.const_bool(is_nonzero)]); - let lo_ctlz = - self.call_intrinsic("llvm.ctlz.i64", &[lo, self.const_bool(is_nonzero)]); - - // If high part is zero, result is 64 + ctlz(lo), otherwise ctlz(hi) - let lo_ctlz_plus_64 = self.add(lo_ctlz, self.const_u64(64)); - let result_64 = self.select(hi_is_zero, lo_ctlz_plus_64, hi_ctlz); - - // Zero-extend to i128 - self.zext(result_64, self.type_i128()) - } - CountZerosKind::Trailing => { - // Count trailing zeros: check low part first - let lo_is_zero = self.icmp(IntPredicate::IntEQ, lo, self.const_u64(0)); - let lo_cttz = - self.call_intrinsic("llvm.cttz.i64", &[lo, self.const_bool(is_nonzero)]); - let hi_cttz = - self.call_intrinsic("llvm.cttz.i64", &[hi, self.const_bool(is_nonzero)]); - - // If low part is zero, result is 64 + cttz(hi), otherwise cttz(lo) - let hi_cttz_plus_64 = self.add(hi_cttz, self.const_u64(64)); - let result_64 = self.select(lo_is_zero, hi_cttz_plus_64, lo_cttz); - - // Zero-extend to i128 - self.zext(result_64, self.type_i128()) - } - } - } - - pub(crate) fn emulate_i128_ctpop(&mut self, val: &'ll Value) -> &'ll Value { - // Split the 128-bit value into two 64-bit halves - let (lo, hi) = self.split_i128(val); - - // Count population (number of 1 bits) in each half - let lo_popcount = self.call_intrinsic("llvm.ctpop.i64", &[lo]); - let hi_popcount = self.call_intrinsic("llvm.ctpop.i64", &[hi]); - - // Add the two counts - let total_64 = self.add(lo_popcount, hi_popcount); - - // Zero-extend to i128 - self.zext(total_64, self.type_i128()) - } - - pub(crate) fn emulate_i128_rotate( - &mut self, - val: &'ll Value, - shift: &'ll Value, - is_left: bool, - ) -> &'ll Value { - // Rotate is implemented as: (val << shift) | (val >> (128 - shift)) - // For rotate right: (val >> shift) | (val << (128 - shift)) - - // Ensure shift is i128 - let shift_128 = if self.val_ty(shift) == self.type_i128() { - shift - } else { - self.zext(shift, self.type_i128()) - }; - - // Calculate 128 - shift for the complementary shift - let bits_128 = self.const_u128(128); - let shift_complement = self.sub(bits_128, shift_128); - - // Perform the two shifts - let (first_shift, second_shift) = if is_left { - (self.shl(val, shift_128), self.lshr(val, shift_complement)) - } else { - (self.lshr(val, shift_128), self.shl(val, shift_complement)) - }; - - // Combine with OR - self.or(first_shift, second_shift) - } - - pub(crate) fn emulate_i128_bitreverse(&mut self, val: &'ll Value) -> &'ll Value { - // Split the 128-bit value into two 64-bit halves - let (lo, hi) = self.split_i128(val); - - // Reverse bits in each half using the 64-bit intrinsic - let reversed_lo = self.call_intrinsic("llvm.bitreverse.i64", &[lo]); - let reversed_hi = self.call_intrinsic("llvm.bitreverse.i64", &[hi]); - - // Swap the halves: reversed high becomes low and vice versa - self.combine_i128(reversed_hi, reversed_lo) - } - fn with_cx(cx: &'a CodegenCx<'ll, 'tcx>) -> Self { // Create a fresh builder from the crate context. let llbuilder = unsafe { llvm::LLVMCreateBuilderInContext(cx.llcx) }; diff --git a/crates/rustc_codegen_nvvm/src/builder/emulate_i128.rs b/crates/rustc_codegen_nvvm/src/builder/emulate_i128.rs new file mode 100644 index 00000000..a836c50a --- /dev/null +++ b/crates/rustc_codegen_nvvm/src/builder/emulate_i128.rs @@ -0,0 +1,557 @@ +use rustc_abi::{Align, Size}; +use rustc_codegen_ssa::common::IntPredicate; +use rustc_codegen_ssa::traits::*; + +use crate::llvm::{self, Type, Value}; + +use super::{Builder, CountZerosKind, UNNAMED}; + +impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> { + pub(super) fn is_i128(&self, val: &'ll Value) -> bool { + let ty = self.val_ty(val); + if ty == self.type_i128() { + return true; + } + if unsafe { llvm::LLVMRustGetTypeKind(ty) == llvm::TypeKind::Integer } { + unsafe { llvm::LLVMGetIntTypeWidth(ty) == 128 } + } else { + false + } + } + + // Helper to split i128 into low and high u64 parts + fn split_i128(&mut self, val: &'ll Value) -> (&'ll Value, &'ll Value) { + let vec_ty = self.type_vector(self.type_i64(), 2); + let bitcast = unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, vec_ty, UNNAMED) }; + let lo = unsafe { + llvm::LLVMBuildExtractElement(self.llbuilder, bitcast, self.cx.const_i32(0), UNNAMED) + }; + let hi = unsafe { + llvm::LLVMBuildExtractElement(self.llbuilder, bitcast, self.cx.const_i32(1), UNNAMED) + }; + + (lo, hi) + } + + fn ensure_i128(&mut self, val: &'ll Value) -> &'ll Value { + if self.val_ty(val) == self.type_i128() { + val + } else { + unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, self.type_i128(), UNNAMED) } + } + } + + fn call_compiler_builtin( + &mut self, + name: &str, + ret_ty: &'ll Type, + args: &[&'ll Value], + ) -> &'ll Value { + let arg_tys: Vec<_> = args.iter().map(|&arg| self.val_ty(arg)).collect(); + let fn_ty = self.type_func(&arg_tys, ret_ty); + let llfn = self.cx.declare_fn(name, fn_ty, None); + self.call(fn_ty, None, None, llfn, args, None, None) + } + + fn trap_if(&mut self, cond: &'ll Value, label: &str) { + let trap_label = format!("{label}_trap"); + let cont_label = format!("{label}_cont"); + let trap_bb = self.append_sibling_block(&trap_label); + let cont_bb = self.append_sibling_block(&cont_label); + + self.cond_br(cond, trap_bb, cont_bb); + + let mut trap_bx = Self::build(self.cx, trap_bb); + trap_bx.call_intrinsic("llvm.trap", &[]); + trap_bx.unreachable(); + + let cont_bx = Self::build(self.cx, cont_bb); + *self = cont_bx; + } + + fn uadd_with_overflow_i64( + &mut self, + lhs: &'ll Value, + rhs: &'ll Value, + ) -> (&'ll Value, &'ll Value) { + let call = self.call_intrinsic("llvm.uadd.with.overflow.i64", &[lhs, rhs]); + (self.extract_value(call, 0), self.extract_value(call, 1)) + } + + fn usub_with_overflow_i64( + &mut self, + lhs: &'ll Value, + rhs: &'ll Value, + ) -> (&'ll Value, &'ll Value) { + let call = self.call_intrinsic("llvm.usub.with.overflow.i64", &[lhs, rhs]); + (self.extract_value(call, 0), self.extract_value(call, 1)) + } + + // Helper to combine two u64 values into i128 + fn combine_i128(&mut self, lo: &'ll Value, hi: &'ll Value) -> &'ll Value { + let vec_ty = self.type_vector(self.type_i64(), 2); + let mut vec = self.const_undef(vec_ty); + vec = unsafe { + llvm::LLVMBuildInsertElement(self.llbuilder, vec, lo, self.cx.const_i32(0), UNNAMED) + }; + vec = unsafe { + llvm::LLVMBuildInsertElement(self.llbuilder, vec, hi, self.cx.const_i32(1), UNNAMED) + }; + unsafe { llvm::LLVMBuildBitCast(self.llbuilder, vec, self.type_i128(), UNNAMED) } + } + + // Multiply two u64 values and return the full 128-bit product as (lo, hi) + fn mul_u64_to_u128(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> (&'ll Value, &'ll Value) { + let i64_ty = self.type_i64(); + let i32_ty = self.type_i32(); + let shift_32 = self.const_u64(32); + let mask_32 = self.const_u64(0xFFFF_FFFF); + + // Split the operands into 32-bit halves + let lhs_lo32 = self.trunc(lhs, i32_ty); + let lhs_shifted = self.lshr(lhs, shift_32); + let lhs_hi32 = self.trunc(lhs_shifted, i32_ty); + let rhs_lo32 = self.trunc(rhs, i32_ty); + let rhs_shifted = self.lshr(rhs, shift_32); + let rhs_hi32 = self.trunc(rhs_shifted, i32_ty); + + // Extend halves to 64 bits for the partial products + let lhs_lo64 = self.zext(lhs_lo32, i64_ty); + let lhs_hi64 = self.zext(lhs_hi32, i64_ty); + let rhs_lo64 = self.zext(rhs_lo32, i64_ty); + let rhs_hi64 = self.zext(rhs_hi32, i64_ty); + + // Compute partial products (32-bit x 32-bit -> 64-bit) + let p0 = unsafe { llvm::LLVMBuildMul(self.llbuilder, lhs_lo64, rhs_lo64, UNNAMED) }; + let p1 = unsafe { llvm::LLVMBuildMul(self.llbuilder, lhs_lo64, rhs_hi64, UNNAMED) }; + let p2 = unsafe { llvm::LLVMBuildMul(self.llbuilder, lhs_hi64, rhs_lo64, UNNAMED) }; + let p3 = unsafe { llvm::LLVMBuildMul(self.llbuilder, lhs_hi64, rhs_hi64, UNNAMED) }; + + // Sum cross terms and track the carry that escapes the low 64 bits + let (cross, cross_carry_bit) = self.uadd_with_overflow_i64(p1, p2); + + let cross_low = self.and(cross, mask_32); + let cross_low_shifted = self.shl(cross_low, shift_32); + let (lo, lo_carry_bit) = self.uadd_with_overflow_i64(p0, cross_low_shifted); + + let cross_high = self.lshr(cross, shift_32); + let cross_carry_ext = self.zext(cross_carry_bit, i64_ty); + let cross_carry_high = self.shl(cross_carry_ext, shift_32); + let cross_total_high = self.add(cross_high, cross_carry_high); + + let (hi_temp, _) = self.uadd_with_overflow_i64(p3, cross_total_high); + let lo_carry = self.zext(lo_carry_bit, i64_ty); + let (hi, _) = self.uadd_with_overflow_i64(hi_temp, lo_carry); + + (lo, hi) + } + + // Emulate 128-bit addition using compiler-builtins + pub(super) fn emulate_i128_add(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + let lhs = self.ensure_i128(lhs); + let rhs = self.ensure_i128(rhs); + let args = [lhs, rhs]; + self.call_compiler_builtin("__rust_i128_add", self.type_i128(), &args) + } + + // Emulate 128-bit subtraction using compiler-builtins + pub(super) fn emulate_i128_sub(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + let lhs = self.ensure_i128(lhs); + let rhs = self.ensure_i128(rhs); + let args = [lhs, rhs]; + self.call_compiler_builtin("__rust_i128_sub", self.type_i128(), &args) + } + + // Emulate 128-bit bitwise AND + pub(super) fn emulate_i128_and(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + let (lhs_lo, lhs_hi) = self.split_i128(lhs); + let (rhs_lo, rhs_hi) = self.split_i128(rhs); + + let and_lo = unsafe { llvm::LLVMBuildAnd(self.llbuilder, lhs_lo, rhs_lo, UNNAMED) }; + let and_hi = unsafe { llvm::LLVMBuildAnd(self.llbuilder, lhs_hi, rhs_hi, UNNAMED) }; + + self.combine_i128(and_lo, and_hi) + } + + // Emulate 128-bit bitwise OR + pub(super) fn emulate_i128_or(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + let (lhs_lo, lhs_hi) = self.split_i128(lhs); + let (rhs_lo, rhs_hi) = self.split_i128(rhs); + + let or_lo = unsafe { llvm::LLVMBuildOr(self.llbuilder, lhs_lo, rhs_lo, UNNAMED) }; + let or_hi = unsafe { llvm::LLVMBuildOr(self.llbuilder, lhs_hi, rhs_hi, UNNAMED) }; + + self.combine_i128(or_lo, or_hi) + } + + // Emulate 128-bit bitwise XOR + pub(super) fn emulate_i128_xor(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + let (lhs_lo, lhs_hi) = self.split_i128(lhs); + let (rhs_lo, rhs_hi) = self.split_i128(rhs); + + let xor_lo = unsafe { llvm::LLVMBuildXor(self.llbuilder, lhs_lo, rhs_lo, UNNAMED) }; + let xor_hi = unsafe { llvm::LLVMBuildXor(self.llbuilder, lhs_hi, rhs_hi, UNNAMED) }; + + self.combine_i128(xor_lo, xor_hi) + } + + // Emulate 128-bit multiplication using compiler-builtins + pub(super) fn emulate_i128_mul(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + let lhs = self.ensure_i128(lhs); + let rhs = self.ensure_i128(rhs); + let args = [lhs, rhs]; + self.call_compiler_builtin("__multi3", self.type_i128(), &args) + } + + fn emulate_i128_udivrem( + &mut self, + num: &'ll Value, + den: &'ll Value, + ) -> (&'ll Value, &'ll Value) { + let num = self.ensure_i128(num); + let den = self.ensure_i128(den); + + let zero = self.const_u128(0); + let denom_is_zero = self.icmp(IntPredicate::IntEQ, den, zero); + self.trap_if(denom_is_zero, "i128_udiv_zero"); + + let i128_align = Align::from_bytes(16).expect("align 16"); + let rem_slot = self.alloca(Size::from_bits(128), i128_align); + let rem_ptr = self.pointercast(rem_slot, self.cx.type_ptr_to(self.type_i128())); + + let args = [num, den, rem_ptr]; + let quot = self.call_compiler_builtin("__udivmodti4", self.type_i128(), &args); + let rem = self.load(self.type_i128(), rem_ptr, i128_align); + + (quot, rem) + } + + pub(super) fn emulate_i128_udiv(&mut self, num: &'ll Value, den: &'ll Value) -> &'ll Value { + self.emulate_i128_udivrem(num, den).0 + } + + pub(super) fn emulate_i128_urem(&mut self, num: &'ll Value, den: &'ll Value) -> &'ll Value { + self.emulate_i128_udivrem(num, den).1 + } + + fn emulate_i128_sdivrem( + &mut self, + num: &'ll Value, + den: &'ll Value, + ) -> (&'ll Value, &'ll Value) { + let num = self.ensure_i128(num); + let den = self.ensure_i128(den); + + let zero = self.const_u128(0); + let denom_is_zero = self.icmp(IntPredicate::IntEQ, den, zero); + self.trap_if(denom_is_zero, "i128_sdiv_zero"); + + let min_i128 = self.const_u128(1u128 << 127); + let neg_one_i128 = self.const_u128(u128::MAX); + let num_is_min = self.icmp(IntPredicate::IntEQ, num, min_i128); + let den_is_neg_one = self.icmp(IntPredicate::IntEQ, den, neg_one_i128); + let overflow_case = self.and(num_is_min, den_is_neg_one); + self.trap_if(overflow_case, "i128_sdiv_overflow"); + + let args = [num, den]; + let quot = self.call_compiler_builtin("__divti3", self.type_i128(), &args); + let rem = self.call_compiler_builtin("__modti3", self.type_i128(), &args); + + (quot, rem) + } + + pub(super) fn emulate_i128_sdiv(&mut self, num: &'ll Value, den: &'ll Value) -> &'ll Value { + self.emulate_i128_sdivrem(num, den).0 + } + + pub(super) fn emulate_i128_srem(&mut self, num: &'ll Value, den: &'ll Value) -> &'ll Value { + self.emulate_i128_sdivrem(num, den).1 + } + + pub(super) fn emulate_i128_add_with_overflow( + &mut self, + lhs: &'ll Value, + rhs: &'ll Value, + signed: bool, + ) -> (&'ll Value, &'ll Value) { + let i64_ty = self.type_i64(); + let (lhs_lo, lhs_hi) = self.split_i128(lhs); + let (rhs_lo, rhs_hi) = self.split_i128(rhs); + + let (sum_lo, carry_lo) = self.uadd_with_overflow_i64(lhs_lo, rhs_lo); + let carry_lo_ext = self.zext(carry_lo, i64_ty); + + let (sum_hi_temp, carry_hi1) = self.uadd_with_overflow_i64(lhs_hi, rhs_hi); + let (sum_hi, carry_hi2) = self.uadd_with_overflow_i64(sum_hi_temp, carry_lo_ext); + + let unsigned_overflow = self.or(carry_hi1, carry_hi2); + + let overflow_flag = if signed { + let zero = self.const_u64(0); + let lhs_neg = self.icmp(IntPredicate::IntSLT, lhs_hi, zero); + let rhs_neg = self.icmp(IntPredicate::IntSLT, rhs_hi, zero); + let res_neg = self.icmp(IntPredicate::IntSLT, sum_hi, zero); + let same_sign = self.icmp(IntPredicate::IntEQ, lhs_neg, rhs_neg); + let sign_diff = self.icmp(IntPredicate::IntNE, lhs_neg, res_neg); + self.and(same_sign, sign_diff) + } else { + unsigned_overflow + }; + + let result = self.combine_i128(sum_lo, sum_hi); + (result, overflow_flag) + } + + pub(super) fn emulate_i128_sub_with_overflow( + &mut self, + lhs: &'ll Value, + rhs: &'ll Value, + signed: bool, + ) -> (&'ll Value, &'ll Value) { + let i64_ty = self.type_i64(); + let (lhs_lo, lhs_hi) = self.split_i128(lhs); + let (rhs_lo, rhs_hi) = self.split_i128(rhs); + + let (diff_lo, borrow_lo) = self.usub_with_overflow_i64(lhs_lo, rhs_lo); + let borrow_lo_ext = self.zext(borrow_lo, i64_ty); + + let (diff_hi_temp, borrow_hi1) = self.usub_with_overflow_i64(lhs_hi, rhs_hi); + let (diff_hi, borrow_hi2) = self.usub_with_overflow_i64(diff_hi_temp, borrow_lo_ext); + + let unsigned_overflow = self.or(borrow_hi1, borrow_hi2); + + let overflow_flag = if signed { + let zero = self.const_u64(0); + let lhs_neg = self.icmp(IntPredicate::IntSLT, lhs_hi, zero); + let rhs_neg = self.icmp(IntPredicate::IntSLT, rhs_hi, zero); + let res_neg = self.icmp(IntPredicate::IntSLT, diff_hi, zero); + let lhs_diff_rhs = self.icmp(IntPredicate::IntNE, lhs_neg, rhs_neg); + let res_diff_lhs = self.icmp(IntPredicate::IntNE, res_neg, lhs_neg); + self.and(lhs_diff_rhs, res_diff_lhs) + } else { + unsigned_overflow + }; + + let result = self.combine_i128(diff_lo, diff_hi); + (result, overflow_flag) + } + + pub(super) fn emulate_i128_mul_with_overflow( + &mut self, + lhs: &'ll Value, + rhs: &'ll Value, + signed: bool, + ) -> (&'ll Value, &'ll Value) { + let i64_ty = self.type_i64(); + let (lhs_lo, lhs_hi) = self.split_i128(lhs); + let (rhs_lo, rhs_hi) = self.split_i128(rhs); + + let (ll_lo, ll_hi) = self.mul_u64_to_u128(lhs_lo, rhs_lo); + let (lh_lo, lh_hi) = self.mul_u64_to_u128(lhs_lo, rhs_hi); + let (hl_lo, hl_hi) = self.mul_u64_to_u128(lhs_hi, rhs_lo); + let (hh_lo, hh_hi) = self.mul_u64_to_u128(lhs_hi, rhs_hi); + + let (mid_temp, carry_mid1) = self.uadd_with_overflow_i64(ll_hi, lh_lo); + let (mid_sum, carry_mid2) = self.uadd_with_overflow_i64(mid_temp, hl_lo); + let carry_mid1_ext = self.zext(carry_mid1, i64_ty); + let carry_mid2_ext = self.zext(carry_mid2, i64_ty); + let mid_carry_sum = self.add(carry_mid1_ext, carry_mid2_ext); + + let result = self.combine_i128(ll_lo, mid_sum); + + let (upper_temp0, carry_upper1) = self.uadd_with_overflow_i64(lh_hi, hl_hi); + let (upper_temp1, carry_upper2) = self.uadd_with_overflow_i64(upper_temp0, hh_lo); + let (upper_low, carry_upper3) = self.uadd_with_overflow_i64(upper_temp1, mid_carry_sum); + + let carry_upper1_ext = self.zext(carry_upper1, i64_ty); + let carry_upper2_ext = self.zext(carry_upper2, i64_ty); + let carry_upper3_ext = self.zext(carry_upper3, i64_ty); + let carries_sum = self.add(carry_upper1_ext, carry_upper2_ext); + let carries_sum = self.add(carries_sum, carry_upper3_ext); + + let (upper_high, carry_upper4) = self.uadd_with_overflow_i64(hh_hi, carries_sum); + + let zero = self.const_u64(0); + let upper_low_nonzero = self.icmp(IntPredicate::IntNE, upper_low, zero); + let upper_high_nonzero = self.icmp(IntPredicate::IntNE, upper_high, zero); + let upper_nonzero = self.or(upper_low_nonzero, upper_high_nonzero); + let unsigned_overflow = self.or(upper_nonzero, carry_upper4); + + let overflow_flag = if signed { + let max = self.const_u64(u64::MAX); + let (_, res_hi) = self.split_i128(result); + let res_neg = self.icmp(IntPredicate::IntSLT, res_hi, zero); + + let upper_low_zero = self.icmp(IntPredicate::IntEQ, upper_low, zero); + let upper_high_zero = self.icmp(IntPredicate::IntEQ, upper_high, zero); + let high_all_zero = self.and(upper_low_zero, upper_high_zero); + + let upper_low_ones = self.icmp(IntPredicate::IntEQ, upper_low, max); + let upper_high_ones = self.icmp(IntPredicate::IntEQ, upper_high, max); + let high_all_ones = self.and(upper_low_ones, upper_high_ones); + + let not_high_all_zero = self.not(high_all_zero); + let overflow_pos = self.or(not_high_all_zero, carry_upper4); + let not_high_all_ones = self.not(high_all_ones); + let overflow_neg = self.or(not_high_all_ones, carry_upper4); + + self.select(res_neg, overflow_neg, overflow_pos) + } else { + unsigned_overflow + }; + + (result, overflow_flag) + } + + pub(super) fn emulate_i128_shl(&mut self, val: &'ll Value, shift: &'ll Value) -> &'ll Value { + let val = self.ensure_i128(val); + let shift = self.intcast(shift, self.cx.type_i32(), false); + let args = [val, shift]; + self.call_compiler_builtin("__ashlti3", self.type_i128(), &args) + } + + pub(super) fn emulate_i128_lshr(&mut self, val: &'ll Value, shift: &'ll Value) -> &'ll Value { + let val = self.ensure_i128(val); + let shift = self.intcast(shift, self.cx.type_i32(), false); + let args = [val, shift]; + self.call_compiler_builtin("__lshrti3", self.type_i128(), &args) + } + + pub(super) fn emulate_i128_ashr(&mut self, val: &'ll Value, shift: &'ll Value) -> &'ll Value { + let val = self.ensure_i128(val); + let shift = self.intcast(shift, self.cx.type_i32(), false); + let args = [val, shift]; + self.call_compiler_builtin("__ashrti3", self.type_i128(), &args) + } + + // Emulate 128-bit bitwise NOT + pub(super) fn emulate_i128_not(&mut self, val: &'ll Value) -> &'ll Value { + let (lo, hi) = self.split_i128(val); + + let not_lo = unsafe { llvm::LLVMBuildNot(self.llbuilder, lo, UNNAMED) }; + let not_hi = unsafe { llvm::LLVMBuildNot(self.llbuilder, hi, UNNAMED) }; + + self.combine_i128(not_lo, not_hi) + } + + // Emulate 128-bit negation (two's complement) + pub(super) fn emulate_i128_neg(&mut self, val: &'ll Value) -> &'ll Value { + // Two's complement: ~val + 1 + let not_val = self.emulate_i128_not(val); + let one = self.const_u128(1); + self.emulate_i128_add(not_val, one) + } + + pub(crate) fn emulate_i128_bswap(&mut self, val: &'ll Value) -> &'ll Value { + // Split the 128-bit value into two 64-bit halves + let (lo, hi) = self.split_i128(val); + + // Byte-swap each 64-bit half using the LLVM intrinsic (which exists in LLVM 7.1) + let swapped_lo = self.call_intrinsic("llvm.bswap.i64", &[lo]); + let swapped_hi = self.call_intrinsic("llvm.bswap.i64", &[hi]); + + // Swap the halves: the high part becomes low and vice versa + self.combine_i128(swapped_hi, swapped_lo) + } + + pub(crate) fn emulate_i128_count_zeros( + &mut self, + val: &'ll Value, + kind: CountZerosKind, + is_nonzero: bool, + ) -> &'ll Value { + // Split the 128-bit value into two 64-bit halves + let (lo, hi) = self.split_i128(val); + + match kind { + CountZerosKind::Leading => { + // Count leading zeros: check high part first + let hi_is_zero = self.icmp(IntPredicate::IntEQ, hi, self.const_u64(0)); + let hi_ctlz = + self.call_intrinsic("llvm.ctlz.i64", &[hi, self.const_bool(is_nonzero)]); + let lo_ctlz = + self.call_intrinsic("llvm.ctlz.i64", &[lo, self.const_bool(is_nonzero)]); + + // If high part is zero, result is 64 + ctlz(lo), otherwise ctlz(hi) + let lo_ctlz_plus_64 = self.add(lo_ctlz, self.const_u64(64)); + let result_64 = self.select(hi_is_zero, lo_ctlz_plus_64, hi_ctlz); + + // Zero-extend to i128 + self.zext(result_64, self.type_i128()) + } + CountZerosKind::Trailing => { + // Count trailing zeros: check low part first + let lo_is_zero = self.icmp(IntPredicate::IntEQ, lo, self.const_u64(0)); + let lo_cttz = + self.call_intrinsic("llvm.cttz.i64", &[lo, self.const_bool(is_nonzero)]); + let hi_cttz = + self.call_intrinsic("llvm.cttz.i64", &[hi, self.const_bool(is_nonzero)]); + + // If low part is zero, result is 64 + cttz(hi), otherwise cttz(lo) + let hi_cttz_plus_64 = self.add(hi_cttz, self.const_u64(64)); + let result_64 = self.select(lo_is_zero, hi_cttz_plus_64, lo_cttz); + + // Zero-extend to i128 + self.zext(result_64, self.type_i128()) + } + } + } + + pub(crate) fn emulate_i128_ctpop(&mut self, val: &'ll Value) -> &'ll Value { + // Split the 128-bit value into two 64-bit halves + let (lo, hi) = self.split_i128(val); + + // Count population (number of 1 bits) in each half + let lo_popcount = self.call_intrinsic("llvm.ctpop.i64", &[lo]); + let hi_popcount = self.call_intrinsic("llvm.ctpop.i64", &[hi]); + + // Add the two counts + let total_64 = self.add(lo_popcount, hi_popcount); + + // Zero-extend to i128 + self.zext(total_64, self.type_i128()) + } + + pub(crate) fn emulate_i128_rotate( + &mut self, + val: &'ll Value, + shift: &'ll Value, + is_left: bool, + ) -> &'ll Value { + // Rotate is implemented as: (val << shift) | (val >> (128 - shift)) + // For rotate right: (val >> shift) | (val << (128 - shift)) + + // Ensure shift is i128 + let shift_128 = if self.val_ty(shift) == self.type_i128() { + shift + } else { + self.zext(shift, self.type_i128()) + }; + + // Calculate 128 - shift for the complementary shift + let bits_128 = self.const_u128(128); + let shift_complement = self.sub(bits_128, shift_128); + + // Perform the two shifts + let (first_shift, second_shift) = if is_left { + (self.shl(val, shift_128), self.lshr(val, shift_complement)) + } else { + (self.lshr(val, shift_128), self.shl(val, shift_complement)) + }; + + // Combine with OR + self.or(first_shift, second_shift) + } + + pub(crate) fn emulate_i128_bitreverse(&mut self, val: &'ll Value) -> &'ll Value { + // Split the 128-bit value into two 64-bit halves + let (lo, hi) = self.split_i128(val); + + // Reverse bits in each half using the 64-bit intrinsic + let reversed_lo = self.call_intrinsic("llvm.bitreverse.i64", &[lo]); + let reversed_hi = self.call_intrinsic("llvm.bitreverse.i64", &[hi]); + + // Swap the halves: reversed high becomes low and vice versa + self.combine_i128(reversed_hi, reversed_lo) + } +} diff --git a/crates/rustc_codegen_nvvm/src/ctx_intrinsics.rs b/crates/rustc_codegen_nvvm/src/ctx_intrinsics.rs index 56e45053..d0464212 100644 --- a/crates/rustc_codegen_nvvm/src/ctx_intrinsics.rs +++ b/crates/rustc_codegen_nvvm/src/ctx_intrinsics.rs @@ -14,7 +14,6 @@ impl<'ll> CodegenCx<'ll, '_> { #[rustfmt::skip] // stop rustfmt from making this 2k lines pub(crate) fn build_intrinsics_map(&mut self) { let mut map = self.intrinsics_map.borrow_mut(); - let mut remapped = self.remapped_integer_args.borrow_mut(); macro_rules! ifn { ($map:expr, $($name:literal)|*, fn($($arg:expr),*) -> $ret:expr) => { @@ -24,9 +23,6 @@ impl<'ll> CodegenCx<'ll, '_> { }; } - let real_t_i128 = self.type_i128(); - let real_t_i128_i1 = self.type_struct(&[real_t_i128, self.type_i1()], false); - let i8p = self.type_i8p(); let void = self.type_void(); let i1 = self.type_i1(); @@ -34,7 +30,6 @@ impl<'ll> CodegenCx<'ll, '_> { let t_i16 = self.type_i16(); let t_i32 = self.type_i32(); let t_i64 = self.type_i64(); - let t_i128 = self.type_vector(t_i64, 2); let t_f32 = self.type_f32(); let t_f64 = self.type_f64(); let t_isize = self.type_isize(); @@ -43,7 +38,6 @@ impl<'ll> CodegenCx<'ll, '_> { let t_i16_i1 = self.type_struct(&[t_i16, i1], false); let t_i32_i1 = self.type_struct(&[t_i32, i1], false); let t_i64_i1 = self.type_struct(&[t_i64, i1], false); - let t_i128_i1 = self.type_struct(&[t_i128, i1], false); let voidp = self.voidp(); @@ -75,34 +69,6 @@ impl<'ll> CodegenCx<'ll, '_> { ifn!(map, "llvm.umul.with.overflow.i32", fn(t_i32, t_i32) -> t_i32_i1); ifn!(map, "llvm.umul.with.overflow.i64", fn(t_i64, t_i64) -> t_i64_i1); - let i128_checked_binops = [ - "__nvvm_i128_addo", - "__nvvm_u128_addo", - "__nvvm_i128_subo", - "__nvvm_u128_subo", - "__nvvm_i128_mulo", - "__nvvm_u128_mulo" - ]; - - for binop in i128_checked_binops { - map.insert(binop, (vec![t_i128, t_i128], t_i128_i1)); - let llfn_ty = self.type_func(&[t_i128, t_i128], t_i128_i1); - remapped.insert(llfn_ty, (Some(real_t_i128_i1), vec![(0, real_t_i128), (1, real_t_i128)])); - } - - let i128_saturating_ops = [ - "llvm.sadd.sat.i128", - "llvm.uadd.sat.i128", - "llvm.ssub.sat.i128", - "llvm.usub.sat.i128", - ]; - - for binop in i128_saturating_ops { - map.insert(binop, (vec![t_i128, t_i128], t_i128)); - let llfn_ty = self.type_func(&[t_i128, t_i128], t_i128); - remapped.insert(llfn_ty, (Some(real_t_i128), vec![(0, real_t_i128), (1, real_t_i128)])); - } - // for some very strange reason, they arent supported for i8 either, but that case // is easy to handle and we declare our own functions for that which just // zext to i16, use the i16 intrinsic, then trunc back to i8 @@ -115,34 +81,6 @@ impl<'ll> CodegenCx<'ll, '_> { ifn!(map, "__nvvm_i8_mulo", fn(t_i8, t_i8) -> t_i8_i1); ifn!(map, "__nvvm_u8_mulo", fn(t_i8, t_i8) -> t_i8_i1); - // i128 arithmetic operations from compiler-builtins - // Division and remainder - ifn!(map, "__nvvm_divti3", fn(t_i128, t_i128) -> t_i128); - ifn!(map, "__nvvm_udivti3", fn(t_i128, t_i128) -> t_i128); - ifn!(map, "__nvvm_modti3", fn(t_i128, t_i128) -> t_i128); - ifn!(map, "__nvvm_umodti3", fn(t_i128, t_i128) -> t_i128); - - // Multiplication - ifn!(map, "__nvvm_multi3", fn(t_i128, t_i128) -> t_i128); - - // Shift operations - ifn!(map, "__nvvm_ashlti3", fn(t_i128, t_i32) -> t_i128); - ifn!(map, "__nvvm_ashrti3", fn(t_i128, t_i32) -> t_i128); - ifn!(map, "__nvvm_lshrti3", fn(t_i128, t_i32) -> t_i128); - - // Add remapping for i128 binary operations (division, remainder, multiplication) - // All have the same signature: (i128, i128) -> i128 - let i128_binary_llfn_ty = self.type_func(&[t_i128, t_i128], t_i128); - remapped.insert(i128_binary_llfn_ty, (Some(real_t_i128), vec![(0, real_t_i128), (1, real_t_i128)])); - - // Add remapping for i128 shift operations - // All have the same signature: (i128, i32) -> i128 - let i128_shift_llfn_ty = self.type_func(&[t_i128, t_i32], t_i128); - remapped.insert(i128_shift_llfn_ty, (Some(real_t_i128), vec![(0, real_t_i128), (1, t_i32)])); - - // see comment in libintrinsics.ll - // ifn!(map, "__nvvm_i128_trap", fn(t_i128, t_i128) -> t_i128); - ifn!(map, "llvm.sadd.sat.i8", fn(t_i8, t_i8) -> t_i8); ifn!(map, "llvm.sadd.sat.i16", fn(t_i16, t_i16) -> t_i16); ifn!(map, "llvm.sadd.sat.i32", fn(t_i32, t_i32) -> t_i32); diff --git a/crates/rustc_codegen_nvvm/src/intrinsic.rs b/crates/rustc_codegen_nvvm/src/intrinsic.rs index 0a805b89..7718d072 100644 --- a/crates/rustc_codegen_nvvm/src/intrinsic.rs +++ b/crates/rustc_codegen_nvvm/src/intrinsic.rs @@ -428,6 +428,15 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { let pair = self.insert_value(pair, low, 0); self.insert_value(pair, high, 1) } + sym::unchecked_shl => self.shl(args[0].immediate(), args[1].immediate()), + sym::unchecked_shr => { + let lhs = args[0].immediate(); + let rhs = args[1].immediate(); + match arg_tys[0].kind() { + ty::Int(_) => self.ashr(lhs, rhs), + _ => self.lshr(lhs, rhs), + } + } sym::ctlz | sym::ctlz_nonzero | sym::cttz diff --git a/examples/cuda/i128_demo/Cargo.toml b/examples/cuda/i128_demo/Cargo.toml new file mode 100644 index 00000000..4e4f373b --- /dev/null +++ b/examples/cuda/i128_demo/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "i128_demo" +version = "0.1.0" +edition = "2024" + +[dependencies] +cust = { path = "../../../crates/cust" } + +[build-dependencies] +cuda_builder = { workspace = true, default-features = false } diff --git a/examples/cuda/i128_demo/build.rs b/examples/cuda/i128_demo/build.rs new file mode 100644 index 00000000..c87b560c --- /dev/null +++ b/examples/cuda/i128_demo/build.rs @@ -0,0 +1,18 @@ +use std::env; +use std::path::PathBuf; + +use cuda_builder::CudaBuilder; + +fn main() { + println!("cargo::rerun-if-changed=build.rs"); + println!("cargo::rerun-if-changed=kernels"); + + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + + CudaBuilder::new(manifest_dir.join("kernels")) + .copy_to(out_path.join("kernels.ptx")) + .final_module_path(out_path.join("final_module.ll")) + .build() + .unwrap(); +} diff --git a/examples/cuda/i128_demo/kernels/Cargo.toml b/examples/cuda/i128_demo/kernels/Cargo.toml new file mode 100644 index 00000000..58edca9d --- /dev/null +++ b/examples/cuda/i128_demo/kernels/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "i128-demo-kernels" +version = "0.1.0" +edition = "2024" + +[dependencies] +cuda_std = { path = "../../../../crates/cuda_std" } + +[lib] +crate-type = ["cdylib", "rlib"] diff --git a/examples/cuda/i128_demo/kernels/src/lib.rs b/examples/cuda/i128_demo/kernels/src/lib.rs new file mode 100644 index 00000000..027b649b --- /dev/null +++ b/examples/cuda/i128_demo/kernels/src/lib.rs @@ -0,0 +1,48 @@ +#![no_std] + +use cuda_std::prelude::*; + +#[kernel] +#[allow(improper_ctypes_definitions, clippy::missing_safety_doc)] +pub unsafe fn i128_ops( + a: &[u128], + b: &[u128], + add_out: *mut u128, + sub_out: *mut u128, + mul_out: *mut u128, + and_out: *mut u128, + xor_out: *mut u128, + shl_out: *mut u128, + lshr_out: *mut u128, + ashr_out: *mut u128, + udiv_out: *mut u128, + sdiv_out: *mut u128, + urem_out: *mut u128, + srem_out: *mut u128, +) { + let idx = thread::index_1d() as usize; + if idx >= a.len() || idx >= b.len() { + return; + } + + let av = a[idx]; + let bv = b[idx]; + let shift = (bv & 127) as u32; + let signed = av as i128; + let signed_b = bv as i128; + + unsafe { + *add_out.add(idx) = av.wrapping_add(bv); + *sub_out.add(idx) = av.wrapping_sub(bv); + *mul_out.add(idx) = av.wrapping_mul(bv); + *and_out.add(idx) = av & bv; + *xor_out.add(idx) = av ^ bv; + *shl_out.add(idx) = av.wrapping_shl(shift); + *lshr_out.add(idx) = av.wrapping_shr(shift); + *ashr_out.add(idx) = (signed.wrapping_shr(shift)) as u128; + *udiv_out.add(idx) = av / bv; + *sdiv_out.add(idx) = (signed / signed_b) as u128; + *urem_out.add(idx) = av % bv; + *srem_out.add(idx) = (signed % signed_b) as u128; + } +} diff --git a/examples/cuda/i128_demo/src/main.rs b/examples/cuda/i128_demo/src/main.rs new file mode 100644 index 00000000..3ed9f92d --- /dev/null +++ b/examples/cuda/i128_demo/src/main.rs @@ -0,0 +1,293 @@ +use cust::prelude::*; +use std::error::Error; + +static PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/kernels.ptx")); + +const INPUT_PAIRS: [(u128, u128); 23] = [ + // Basic non-zero divisor sanity check. + (0, 3), + // Simple add/sub with small + patterned mask divisor. + (1, 0x1111_1111_1111_1111_1111_1111_1111_1111), + // Max magnitude dividend exercising unsigned wraparound with patterned divisor. + (u128::MAX, 0x2222_2222_2222_2222_2222_2222_2222_2222), + // Near max positive signed plus odd offset to stress signed division. + ( + (1u128 << 127) + 123_456_789, + 0x3333_3333_3333_3333_3333_3333_3333_3333, + ), + // Mixed hex patterns to ensure bitwise ops and shifts propagate carries between halves. + ( + 0x0123_4567_89ab_cdef_0123_4567_89ab_cdef, + 0x4444_4444_4444_4444_4444_4444_4444_4444, + ), + // Alternating pattern stressing xor/and/or combinations across words. + ( + 0xfedc_ba98_7654_3210_fedc_ba98_7654_3210, + 0x5555_5555_5555_5555_5555_5555_5555_5555, + ), + // Low-half mask to hit low→high shift carries. + ( + 0x0000_ffff_0000_ffff_0000_ffff_0000_ffff, + 0x6666_6666_6666_6666_6666_6666_6666_6666, + ), + // Random-looking pattern to detect misplaced limb ordering. + ( + 0xabcd_ef12_3456_789a_bcde_f012_3456_789a, + 0x7777_7777_7777_7777_7777_7777_7777_7777, + ), + // Pure power-of-two vs small divisor for shifts and div edge cases. + (1u128 << 127, 5), + // Signed overflow boundary vs max unsigned divisor. + (i128::MAX as u128, u128::MAX), + // Distinct power-of-two limbs to check cross-term multiplications. + (1u128 << 64, (1u128 << 63) + 1), + // Odd division with high-bit divisor to stress udiv/sdiv paths. + (u128::MAX / 3, 0x8000_0000_0000_0000_0000_0000_0000_0001), + // Near-overflow positive dividends paired with non power-of-two divisors. + (0x7fff_ffff_ffff_ffff_0000_0000_0000_0001, u128::MAX / 5), + // Signed negative boundary mixed with patterned divisor. + ( + 0x8000_0000_0000_0000_8000_0000_0000_0000, + 0xffff_ffff_0000_0000_ffff_ffff_0000_0001, + ), + // Arbitrary large magnitudes to sanity check arithmetic stability. + ( + 123_456_789_012_345_678_901_234_567_890u128, + (1u128 << 127) - 3, + ), + // Values near u128::MAX to ensure carry/borrow propagation. + (u128::MAX - 999_999_999, u128::MAX - 2), + // Exercises the lower-half cross-carry path in the mul emulation. + ( + 0x0000_0000_0000_0000_ffff_ffff_ffff_ffff, + 0x0000_0000_0000_0000_ffff_ffff_ffff_ffff, + ), + // Check emulated mul path edgecase. + (u128::MAX, u128::MAX), + // Shift exactly 64 bits with positive divisor. + ( + 0x0123_4567_89ab_cdef_fedc_ba98_7654_3210, + (1u128 << 120) | 64, + ), + // Shift exactly 64 bits with negative divisor to stress signed paths. + ( + 0xfedc_ba98_7654_3210_0123_4567_89ab_cdef, + (1u128 << 127) | 64, + ), + // Shift just below the limb boundary. + ( + 0xaaaa_aaaa_5555_5555_ffff_ffff_0000_0000, + (1u128 << 96) | 63, + ), + // Shift just above the limb boundary. + ( + 0x0001_0203_0405_0607_0809_0a0b_0c0d_0e0f, + (1u128 << 80) | 65, + ), + // Maximum masked shift amount with high-bit divisor. + ( + 0xffff_0000_0000_0000_ffff_0000_0000_0001, + (1u128 << 127) | 127, + ), +]; + +fn main() -> Result<(), Box> { + let _ctx = cust::quick_init()?; + + let module = Module::from_ptx(PTX, &[])?; + let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?; + let kernel = module.get_function("i128_ops")?; + + let (host_a, host_b): (Vec, Vec) = INPUT_PAIRS.iter().copied().unzip(); + + let len = host_a.len(); + assert_eq!(len, host_b.len()); + + let a_gpu = DeviceBuffer::from_slice(&host_a)?; + let b_gpu = DeviceBuffer::from_slice(&host_b)?; + + let add_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let sub_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let mul_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let and_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let xor_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let shl_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let lshr_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let ashr_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let udiv_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let sdiv_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let urem_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + let srem_gpu = DeviceBuffer::from_slice(&vec![0u128; len])?; + + let block_size = 128u32; + let grid_size = (len as u32).div_ceil(block_size); + + unsafe { + launch!( + kernel<<>>( + a_gpu.as_device_ptr(), + a_gpu.len(), + b_gpu.as_device_ptr(), + b_gpu.len(), + add_gpu.as_device_ptr(), + sub_gpu.as_device_ptr(), + mul_gpu.as_device_ptr(), + and_gpu.as_device_ptr(), + xor_gpu.as_device_ptr(), + shl_gpu.as_device_ptr(), + lshr_gpu.as_device_ptr(), + ashr_gpu.as_device_ptr(), + udiv_gpu.as_device_ptr(), + sdiv_gpu.as_device_ptr(), + urem_gpu.as_device_ptr(), + srem_gpu.as_device_ptr() + ) + )?; + } + + stream.synchronize()?; + + let mut gpu_add = vec![0u128; len]; + let mut gpu_sub = vec![0u128; len]; + let mut gpu_mul = vec![0u128; len]; + let mut gpu_and = vec![0u128; len]; + let mut gpu_xor = vec![0u128; len]; + let mut gpu_shl = vec![0u128; len]; + let mut gpu_lshr = vec![0u128; len]; + let mut gpu_ashr = vec![0u128; len]; + let mut gpu_udiv = vec![0u128; len]; + let mut gpu_sdiv = vec![0u128; len]; + let mut gpu_urem = vec![0u128; len]; + let mut gpu_srem = vec![0u128; len]; + + add_gpu.copy_to(&mut gpu_add)?; + sub_gpu.copy_to(&mut gpu_sub)?; + mul_gpu.copy_to(&mut gpu_mul)?; + and_gpu.copy_to(&mut gpu_and)?; + xor_gpu.copy_to(&mut gpu_xor)?; + shl_gpu.copy_to(&mut gpu_shl)?; + lshr_gpu.copy_to(&mut gpu_lshr)?; + ashr_gpu.copy_to(&mut gpu_ashr)?; + udiv_gpu.copy_to(&mut gpu_udiv)?; + sdiv_gpu.copy_to(&mut gpu_sdiv)?; + urem_gpu.copy_to(&mut gpu_urem)?; + srem_gpu.copy_to(&mut gpu_srem)?; + + let mut cpu_add = vec![0u128; len]; + let mut cpu_sub = vec![0u128; len]; + let mut cpu_mul = vec![0u128; len]; + let mut cpu_and = vec![0u128; len]; + let mut cpu_xor = vec![0u128; len]; + let mut cpu_shl = vec![0u128; len]; + let mut cpu_lshr = vec![0u128; len]; + let mut cpu_ashr = vec![0u128; len]; + let mut cpu_udiv = vec![0u128; len]; + let mut cpu_sdiv = vec![0u128; len]; + let mut cpu_urem = vec![0u128; len]; + let mut cpu_srem = vec![0u128; len]; + + for (i, (&av, &bv)) in host_a.iter().zip(host_b.iter()).enumerate() { + let shift = (bv & 127) as u32; + let signed = av as i128; + let signed_b = bv as i128; + cpu_add[i] = av.wrapping_add(bv); + cpu_sub[i] = av.wrapping_sub(bv); + cpu_mul[i] = av.wrapping_mul(bv); + cpu_and[i] = av & bv; + cpu_xor[i] = av ^ bv; + cpu_shl[i] = av.wrapping_shl(shift); + cpu_lshr[i] = av.wrapping_shr(shift); + cpu_ashr[i] = (signed.wrapping_shr(shift)) as u128; + cpu_udiv[i] = av / bv; + cpu_sdiv[i] = (signed / signed_b) as u128; + cpu_urem[i] = av % bv; + cpu_srem[i] = (signed % signed_b) as u128; + } + + let mut all_ok = true; + all_ok &= compare_results("add", &gpu_add, &cpu_add); + all_ok &= compare_results("sub", &gpu_sub, &cpu_sub); + all_ok &= compare_results("mul", &gpu_mul, &cpu_mul); + all_ok &= compare_results("and", &gpu_and, &cpu_and); + all_ok &= compare_results("xor", &gpu_xor, &cpu_xor); + all_ok &= compare_results("shl", &gpu_shl, &cpu_shl); + all_ok &= compare_results("lshr", &gpu_lshr, &cpu_lshr); + all_ok &= compare_results("ashr", &gpu_ashr, &cpu_ashr); + all_ok &= compare_results("udiv", &gpu_udiv, &cpu_udiv); + all_ok &= compare_results("sdiv", &gpu_sdiv, &cpu_sdiv); + all_ok &= compare_results("urem", &gpu_urem, &cpu_urem); + all_ok &= compare_results("srem", &gpu_srem, &cpu_srem); + + if !all_ok { + return Err("Mismatch between GPU and CPU i128 results".into()); + } + + // Ensure signed overflow (`i128::MIN / -1`) traps on the device. + let trap_stream = Stream::new(StreamFlags::NON_BLOCKING, None)?; + let trap_a = DeviceBuffer::from_slice(&[i128::MIN as u128])?; + let trap_b = DeviceBuffer::from_slice(&[u128::MAX])?; + let trap_add = DeviceBuffer::from_slice(&[0u128])?; + let trap_sub = DeviceBuffer::from_slice(&[0u128])?; + let trap_mul = DeviceBuffer::from_slice(&[0u128])?; + let trap_and = DeviceBuffer::from_slice(&[0u128])?; + let trap_xor = DeviceBuffer::from_slice(&[0u128])?; + let trap_shl = DeviceBuffer::from_slice(&[0u128])?; + let trap_lshr = DeviceBuffer::from_slice(&[0u128])?; + let trap_ashr = DeviceBuffer::from_slice(&[0u128])?; + let trap_udiv = DeviceBuffer::from_slice(&[0u128])?; + let trap_sdiv = DeviceBuffer::from_slice(&[0u128])?; + let trap_urem = DeviceBuffer::from_slice(&[0u128])?; + let trap_srem = DeviceBuffer::from_slice(&[0u128])?; + + let trap_launch = unsafe { + launch!( + kernel<<<1u32, 1u32, 0, trap_stream>>>( + trap_a.as_device_ptr(), + trap_a.len(), + trap_b.as_device_ptr(), + trap_b.len(), + trap_add.as_device_ptr(), + trap_sub.as_device_ptr(), + trap_mul.as_device_ptr(), + trap_and.as_device_ptr(), + trap_xor.as_device_ptr(), + trap_shl.as_device_ptr(), + trap_lshr.as_device_ptr(), + trap_ashr.as_device_ptr(), + trap_udiv.as_device_ptr(), + trap_sdiv.as_device_ptr(), + trap_urem.as_device_ptr(), + trap_srem.as_device_ptr() + ) + ) + }; + + let trap_result = match trap_launch { + Ok(()) => trap_stream.synchronize(), + Err(e) => Err(e), + }; + + match trap_result { + Err(e) => println!("Correctly got expected trap for i128::MIN / -1: {e}"), + Ok(()) => return Err("Expected trap for i128::MIN / -1 not triggered".into()), + } + + println!("All i128 GPU results match CPU computations."); + Ok(()) +} + +fn compare_results(name: &str, gpu: &[u128], cpu: &[u128]) -> bool { + let mut ok = true; + for (idx, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() { + if g != c { + println!("[{name}] mismatch at index {idx}: gpu={g:#034x}, cpu={c:#034x}"); + ok = false; + } + } + + if ok { + println!("[{name}] results match"); + } + + ok +}