Skip to content

Add a compiler intrinsic to back bigint_helper_methods #133663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Override carrying_mul_add in cg_llvm
  • Loading branch information
scottmcm committed Dec 27, 2024
commit 4669c0d756ddfbd3df0ee1d5c7a4b1cdaabf5945
31 changes: 31 additions & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,37 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
self.const_i32(cache_type),
])
}
sym::carrying_mul_add => {
let (size, signed) = fn_args.type_at(0).int_size_and_signed(self.tcx);

let wide_llty = self.type_ix(size.bits() * 2);
let args = args.as_array().unwrap();
let [a, b, c, d] = args.map(|a| self.intcast(a.immediate(), wide_llty, signed));

let wide = if signed {
let prod = self.unchecked_smul(a, b);
let acc = self.unchecked_sadd(prod, c);
self.unchecked_sadd(acc, d)
} else {
let prod = self.unchecked_umul(a, b);
let acc = self.unchecked_uadd(prod, c);
self.unchecked_uadd(acc, d)
};

let narrow_llty = self.type_ix(size.bits());
let low = self.trunc(wide, narrow_llty);
let bits_const = self.const_uint(wide_llty, size.bits());
// No need for ashr when signed; LLVM changes it to lshr anyway.
let high = self.lshr(wide, bits_const);
// FIXME: could be `trunc nuw`, even for signed.
let high = self.trunc(high, narrow_llty);

let pair_llty = self.type_struct(&[narrow_llty, narrow_llty], false);
let pair = self.const_poison(pair_llty);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just comparing to my version, this doesn't really matter that much, but const_undef feels slightly more accurate here just semantically. Not actually sure if this affects codegen though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poison is always preferred where possible, because it doesn't have the same complications as undef. (undef << 1) & 1 is guaranteed to be 0, but (poison << 1) & 1 is still poison. You'll also see that both what we emit for tuples and what the optimizer does for tuples both use poison for this: https://rust.godbolt.org/z/WTnW4GPE6

let pair = self.insert_value(pair, low, 0);
let pair = self.insert_value(pair, high, 1);
pair
}
sym::ctlz
| sym::ctlz_nonzero
| sym::cttz
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#![feature(iter_intersperse)]
#![feature(let_chains)]
#![feature(rustdoc_internals)]
#![feature(slice_as_array)]
#![feature(try_blocks)]
#![warn(unreachable_pub)]
// tidy-alphabetical-end
Expand Down
4 changes: 2 additions & 2 deletions library/core/src/intrinsics/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ impl const CarryingMulAdd for i128 {
fn carrying_mul_add(self, b: i128, c: i128, d: i128) -> (u128, i128) {
let (low, high) = wide_mul_u128(self as u128, b as u128);
let mut high = high as i128;
high = high.wrapping_add((self >> 127) * b);
high = high.wrapping_add(self * (b >> 127));
high = high.wrapping_add(i128::wrapping_mul(self >> 127, b));
high = high.wrapping_add(i128::wrapping_mul(self, b >> 127));
let (low, carry) = u128::overflowing_add(low, c as u128);
high = high.wrapping_add((carry as i128) + (c >> 127));
let (low, carry) = u128::overflowing_add(low, d as u128);
Expand Down
10 changes: 10 additions & 0 deletions library/core/tests/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ fn carrying_mul_add_fallback_i32() {

#[test]
fn carrying_mul_add_fallback_u128() {
assert_eq!(fallback_cma::<u128>(u128::MAX, u128::MAX, 0, 0), (1, u128::MAX - 1));
assert_eq!(fallback_cma::<u128>(1, 1, 1, 1), (3, 0));
assert_eq!(fallback_cma::<u128>(0, 0, u128::MAX, u128::MAX), (u128::MAX - 1, 1));
assert_eq!(
Expand All @@ -178,8 +179,17 @@ fn carrying_mul_add_fallback_u128() {

#[test]
fn carrying_mul_add_fallback_i128() {
assert_eq!(fallback_cma::<i128>(-1, -1, 0, 0), (1, 0));
let r = fallback_cma::<i128>(-1, -1, -1, -1);
assert_eq!(r, (u128::MAX, -1));
let r = fallback_cma::<i128>(1, -1, 1, 1);
assert_eq!(r, (1, 0));
assert_eq!(
fallback_cma::<i128>(i128::MAX, i128::MAX, i128::MAX, i128::MAX),
(u128::MAX, i128::MAX / 2),
);
assert_eq!(
fallback_cma::<i128>(i128::MIN, i128::MIN, i128::MAX, i128::MAX),
(u128::MAX - 1, -(i128::MIN / 2)),
);
}
137 changes: 137 additions & 0 deletions tests/codegen/intrinsics/carrying_mul_add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
//@ revisions: RAW OPT
//@ compile-flags: -C opt-level=1
//@[RAW] compile-flags: -C no-prepopulate-passes
//@[OPT] min-llvm-version: 19

#![crate_type = "lib"]
#![feature(core_intrinsics)]
#![feature(core_intrinsics_fallbacks)]

// Note that LLVM seems to sometimes permute the order of arguments to mul and add,
// so these tests don't check the arguments in the optimized revision.

use std::intrinsics::{carrying_mul_add, fallback};

// The fallbacks are emitted even when they're never used, but optimize out.

// RAW: wide_mul_u128
// OPT-NOT: wide_mul_u128

// CHECK-LABEL: @cma_u8
#[no_mangle]
pub unsafe fn cma_u8(a: u8, b: u8, c: u8, d: u8) -> (u8, u8) {
// CHECK: [[A:%.+]] = zext i8 %a to i16
// CHECK: [[B:%.+]] = zext i8 %b to i16
// CHECK: [[C:%.+]] = zext i8 %c to i16
// CHECK: [[D:%.+]] = zext i8 %d to i16
// CHECK: [[AB:%.+]] = mul nuw i16
// RAW-SAME: [[A]], [[B]]
// CHECK: [[ABC:%.+]] = add nuw i16
// RAW-SAME: [[AB]], [[C]]
// CHECK: [[ABCD:%.+]] = add nuw i16
// RAW-SAME: [[ABC]], [[D]]
// CHECK: [[LOW:%.+]] = trunc i16 [[ABCD]] to i8
// CHECK: [[HIGHW:%.+]] = lshr i16 [[ABCD]], 8
// RAW: [[HIGH:%.+]] = trunc i16 [[HIGHW]] to i8
// OPT: [[HIGH:%.+]] = trunc nuw i16 [[HIGHW]] to i8
// CHECK: [[PAIR0:%.+]] = insertvalue { i8, i8 } poison, i8 [[LOW]], 0
// CHECK: [[PAIR1:%.+]] = insertvalue { i8, i8 } [[PAIR0]], i8 [[HIGH]], 1
// OPT: ret { i8, i8 } [[PAIR1]]
carrying_mul_add(a, b, c, d)
}

// CHECK-LABEL: @cma_u32
#[no_mangle]
pub unsafe fn cma_u32(a: u32, b: u32, c: u32, d: u32) -> (u32, u32) {
// CHECK: [[A:%.+]] = zext i32 %a to i64
// CHECK: [[B:%.+]] = zext i32 %b to i64
// CHECK: [[C:%.+]] = zext i32 %c to i64
// CHECK: [[D:%.+]] = zext i32 %d to i64
// CHECK: [[AB:%.+]] = mul nuw i64
// RAW-SAME: [[A]], [[B]]
// CHECK: [[ABC:%.+]] = add nuw i64
// RAW-SAME: [[AB]], [[C]]
// CHECK: [[ABCD:%.+]] = add nuw i64
// RAW-SAME: [[ABC]], [[D]]
// CHECK: [[LOW:%.+]] = trunc i64 [[ABCD]] to i32
// CHECK: [[HIGHW:%.+]] = lshr i64 [[ABCD]], 32
// RAW: [[HIGH:%.+]] = trunc i64 [[HIGHW]] to i32
// OPT: [[HIGH:%.+]] = trunc nuw i64 [[HIGHW]] to i32
// CHECK: [[PAIR0:%.+]] = insertvalue { i32, i32 } poison, i32 [[LOW]], 0
// CHECK: [[PAIR1:%.+]] = insertvalue { i32, i32 } [[PAIR0]], i32 [[HIGH]], 1
// OPT: ret { i32, i32 } [[PAIR1]]
carrying_mul_add(a, b, c, d)
}

// CHECK-LABEL: @cma_u128
// CHECK-SAME: sret{{.+}}dereferenceable(32){{.+}}%_0,{{.+}}%a,{{.+}}%b,{{.+}}%c,{{.+}}%d
#[no_mangle]
pub unsafe fn cma_u128(a: u128, b: u128, c: u128, d: u128) -> (u128, u128) {
// CHECK: [[A:%.+]] = zext i128 %a to i256
// CHECK: [[B:%.+]] = zext i128 %b to i256
// CHECK: [[C:%.+]] = zext i128 %c to i256
// CHECK: [[D:%.+]] = zext i128 %d to i256
// CHECK: [[AB:%.+]] = mul nuw i256
// RAW-SAME: [[A]], [[B]]
// CHECK: [[ABC:%.+]] = add nuw i256
// RAW-SAME: [[AB]], [[C]]
// CHECK: [[ABCD:%.+]] = add nuw i256
// RAW-SAME: [[ABC]], [[D]]
// CHECK: [[LOW:%.+]] = trunc i256 [[ABCD]] to i128
// CHECK: [[HIGHW:%.+]] = lshr i256 [[ABCD]], 128
// RAW: [[HIGH:%.+]] = trunc i256 [[HIGHW]] to i128
// OPT: [[HIGH:%.+]] = trunc nuw i256 [[HIGHW]] to i128
// RAW: [[PAIR0:%.+]] = insertvalue { i128, i128 } poison, i128 [[LOW]], 0
// RAW: [[PAIR1:%.+]] = insertvalue { i128, i128 } [[PAIR0]], i128 [[HIGH]], 1
// OPT: store i128 [[LOW]], ptr %_0
// OPT: [[P1:%.+]] = getelementptr inbounds i8, ptr %_0, {{i32|i64}} 16
// OPT: store i128 [[HIGH]], ptr [[P1]]
// CHECK: ret void
carrying_mul_add(a, b, c, d)
}

// CHECK-LABEL: @cma_i128
// CHECK-SAME: sret{{.+}}dereferenceable(32){{.+}}%_0,{{.+}}%a,{{.+}}%b,{{.+}}%c,{{.+}}%d
#[no_mangle]
pub unsafe fn cma_i128(a: i128, b: i128, c: i128, d: i128) -> (u128, i128) {
// CHECK: [[A:%.+]] = sext i128 %a to i256
// CHECK: [[B:%.+]] = sext i128 %b to i256
// CHECK: [[C:%.+]] = sext i128 %c to i256
// CHECK: [[D:%.+]] = sext i128 %d to i256
// CHECK: [[AB:%.+]] = mul nsw i256
// RAW-SAME: [[A]], [[B]]
// CHECK: [[ABC:%.+]] = add nsw i256
// RAW-SAME: [[AB]], [[C]]
// CHECK: [[ABCD:%.+]] = add nsw i256
// RAW-SAME: [[ABC]], [[D]]
// CHECK: [[LOW:%.+]] = trunc i256 [[ABCD]] to i128
// CHECK: [[HIGHW:%.+]] = lshr i256 [[ABCD]], 128
// RAW: [[HIGH:%.+]] = trunc i256 [[HIGHW]] to i128
// OPT: [[HIGH:%.+]] = trunc nuw i256 [[HIGHW]] to i128
// RAW: [[PAIR0:%.+]] = insertvalue { i128, i128 } poison, i128 [[LOW]], 0
// RAW: [[PAIR1:%.+]] = insertvalue { i128, i128 } [[PAIR0]], i128 [[HIGH]], 1
// OPT: store i128 [[LOW]], ptr %_0
// OPT: [[P1:%.+]] = getelementptr inbounds i8, ptr %_0, {{i32|i64}} 16
// OPT: store i128 [[HIGH]], ptr [[P1]]
// CHECK: ret void
carrying_mul_add(a, b, c, d)
}

// CHECK-LABEL: @fallback_cma_u32
#[no_mangle]
pub unsafe fn fallback_cma_u32(a: u32, b: u32, c: u32, d: u32) -> (u32, u32) {
// OPT-DAG: [[A:%.+]] = zext i32 %a to i64
// OPT-DAG: [[B:%.+]] = zext i32 %b to i64
// OPT-DAG: [[AB:%.+]] = mul nuw i64
// OPT-DAG: [[C:%.+]] = zext i32 %c to i64
// OPT-DAG: [[ABC:%.+]] = add nuw i64{{.+}}[[C]]
// OPT-DAG: [[D:%.+]] = zext i32 %d to i64
// OPT-DAG: [[ABCD:%.+]] = add nuw i64{{.+}}[[D]]
// OPT-DAG: [[LOW:%.+]] = trunc i64 [[ABCD]] to i32
// OPT-DAG: [[HIGHW:%.+]] = lshr i64 [[ABCD]], 32
// OPT-DAG: [[HIGH:%.+]] = trunc nuw i64 [[HIGHW]] to i32
// OPT-DAG: [[PAIR0:%.+]] = insertvalue { i32, i32 } poison, i32 [[LOW]], 0
// OPT-DAG: [[PAIR1:%.+]] = insertvalue { i32, i32 } [[PAIR0]], i32 [[HIGH]], 1
// OPT-DAG: ret { i32, i32 } [[PAIR1]]
fallback::CarryingMulAdd::carrying_mul_add(a, b, c, d)
}
Loading