diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index e2df3265f6f7d..566877f4a1ec5 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -3,8 +3,9 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; -use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; +use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv}; use rustc_middle::{bug, ty}; +use rustc_target::callconv::PassMode; use tracing::debug; use crate::builder::{Builder, PlaceRef, UNNAMED}; @@ -16,9 +17,12 @@ use crate::value::Value; pub(crate) fn adjust_activity_to_abi<'tcx>( tcx: TyCtxt<'tcx>, - fn_ty: Ty<'tcx>, + instance: Instance<'tcx>, + typing_env: TypingEnv<'tcx>, da: &mut Vec, ) { + let fn_ty = instance.ty(tcx, typing_env); + if !matches!(fn_ty.kind(), ty::FnDef(..)) { bug!("expected fn def for autodiff, got {:?}", fn_ty); } @@ -27,8 +31,16 @@ pub(crate) fn adjust_activity_to_abi<'tcx>( // All we do is decide how to handle the arguments. let sig = fn_ty.fn_sig(tcx).skip_binder(); + // FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions + let Ok(fn_abi) = + tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty()))) + else { + bug!("failed to get fn_abi of instance with empty varargs"); + }; + let mut new_activities = vec![]; let mut new_positions = vec![]; + let mut del_activities = 0; for (i, ty) in sig.inputs().iter().enumerate() { if let Some(inner_ty) = ty.builtin_deref(true) { if inner_ty.is_slice() { @@ -80,6 +92,34 @@ pub(crate) fn adjust_activity_to_abi<'tcx>( continue; } } + + let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty }; + + let layout = match tcx.layout_of(pci) { + Ok(layout) => layout.layout, + Err(_) => { + bug!("failed to compute layout for type {:?}", ty); + } + }; + + let pass_mode = &fn_abi.args[i].mode; + + // For ZST, just ignore and don't add its activity, as this arg won't be present + // in the LLVM passed to Enzyme. + // Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg + // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const` + if *pass_mode == PassMode::Ignore { + del_activities += 1; + da.remove(i); + } + + // If the argument is lowered as a `ScalarPair`, we need to duplicate its activity. + // Otherwise, the number of activities won't match the number of LLVM arguments and + // this will lead to errors when verifying the Enzyme call. + if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() { + new_activities.push(da[i].clone()); + new_positions.push(i + 1 - del_activities); + } } // now add the extra activities coming from slices // Reverse order to not invalidate the indices diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 06c3d8ed6bc2d..9e6e760649120 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1198,7 +1198,8 @@ fn codegen_autodiff<'ll, 'tcx>( adjust_activity_to_abi( tcx, - fn_source.ty(tcx, TypingEnv::fully_monomorphized()), + fn_source, + TypingEnv::fully_monomorphized(), &mut diff_attrs.input_activity, ); diff --git a/tests/codegen-llvm/autodiff/abi_handling.rs b/tests/codegen-llvm/autodiff/abi_handling.rs new file mode 100644 index 0000000000000..454ec698b917c --- /dev/null +++ b/tests/codegen-llvm/autodiff/abi_handling.rs @@ -0,0 +1,210 @@ +//@ revisions: debug release + +//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat +//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This test checks that Rust types are lowered to LLVM-IR types in a way +// we expect and Enzyme can handle. We explicitly check release mode to +// ensure that LLVM's O3 pipeline doesn't rewrite function signatures +// into forms that Enzyme can't process correctly. + +#![feature(autodiff)] + +use std::autodiff::{autodiff_forward, autodiff_reverse}; + +#[derive(Copy, Clone)] +struct Input { + x: f32, + y: f32, +} + +#[derive(Copy, Clone)] +struct Wrapper { + z: f32, +} + +#[derive(Copy, Clone)] +struct NestedInput { + x: f32, + y: Wrapper, +} + +fn square(x: f32) -> f32 { + x * x +} + +// CHECK-LABEL: ; abi_handling::df1 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal { float, float } +// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0) +// release-NEXT: define internal fastcc float +// release-SAME: (float %x.0.val, float %x.4.val) + +// CHECK-LABEL: ; abi_handling::f1 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal float +// debug-SAME: (ptr align 4 %x) +// release-NEXT: define internal fastcc noundef float +// release-SAME: (float %x.0.val, float %x.4.val) +#[autodiff_forward(df1, Dual, Dual)] +#[inline(never)] +fn f1(x: &[f32; 2]) -> f32 { + x[0] + x[1] +} + +// CHECK-LABEL: ; abi_handling::df2 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal { float, float } +// debug-SAME: (ptr %f, float %x, float %dret) +// release-NEXT: define internal fastcc float +// release-SAME: (float noundef %x) + +// CHECK-LABEL: ; abi_handling::f2 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal float +// debug-SAME: (ptr %f, float %x) +// release-NEXT: define internal fastcc noundef float +// release-SAME: (float noundef %x) +#[autodiff_reverse(df2, Const, Active, Active)] +#[inline(never)] +fn f2(f: fn(f32) -> f32, x: f32) -> f32 { + f(x) +} + +// CHECK-LABEL: ; abi_handling::df3 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal { float, float } +// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0, ptr align 4 %y, ptr align 4 %by_0) +// release-NEXT: define internal fastcc { float, float } +// release-SAME: (float %x.0.val) + +// CHECK-LABEL: ; abi_handling::f3 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal float +// debug-SAME: (ptr align 4 %x, ptr align 4 %y) +// release-NEXT: define internal fastcc noundef float +// release-SAME: (float %x.0.val) +#[autodiff_forward(df3, Dual, Dual, Dual)] +#[inline(never)] +fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 { + *x * *y +} + +// CHECK-LABEL: ; abi_handling::df4 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal { float, float } +// debug-SAME: (float %x.0, float %x.1, float %bx_0.0, float %bx_0.1) +// release-NEXT: define internal fastcc { float, float } +// release-SAME: (float noundef %x.0, float noundef %x.1) + +// CHECK-LABEL: ; abi_handling::f4 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal float +// debug-SAME: (float %x.0, float %x.1) +// release-NEXT: define internal fastcc noundef float +// release-SAME: (float noundef %x.0, float noundef %x.1) +#[autodiff_forward(df4, Dual, Dual)] +#[inline(never)] +fn f4(x: (f32, f32)) -> f32 { + x.0 * x.1 +} + +// CHECK-LABEL: ; abi_handling::df5 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal { float, float } +// debug-SAME: (float %i.0, float %i.1, float %bi_0.0, float %bi_0.1) +// release-NEXT: define internal fastcc { float, float } +// release-SAME: (float noundef %i.0, float noundef %i.1) + +// CHECK-LABEL: ; abi_handling::f5 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal float +// debug-SAME: (float %i.0, float %i.1) +// release-NEXT: define internal fastcc noundef float +// release-SAME: (float noundef %i.0, float noundef %i.1) +#[autodiff_forward(df5, Dual, Dual)] +#[inline(never)] +fn f5(i: Input) -> f32 { + i.x + i.y +} + +// CHECK-LABEL: ; abi_handling::df6 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal { float, float } +// debug-SAME: (float %i.0, float %i.1, float %bi_0.0, float %bi_0.1) +// release-NEXT: define internal fastcc { float, float } +// release-SAME: float noundef %i.0, float noundef %i.1 +// release-SAME: float noundef %bi_0.0, float noundef %bi_0.1 + +// CHECK-LABEL: ; abi_handling::f6 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal float +// debug-SAME: (float %i.0, float %i.1) +// release-NEXT: define internal fastcc noundef float +// release-SAME: (float noundef %i.0, float noundef %i.1) +#[autodiff_forward(df6, Dual, Dual)] +#[inline(never)] +fn f6(i: NestedInput) -> f32 { + i.x + i.y.z * i.y.z +} + +// CHECK-LABEL: ; abi_handling::df7 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal { float, float } +// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1, ptr align 4 %bx_0.0, ptr align 4 %bx_0.1) +// release-NEXT: define internal fastcc { float, float } +// release-SAME: (float %x.0.0.val, float %x.1.0.val) + +// CHECK-LABEL: ; abi_handling::f7 +// CHECK-NEXT: Function Attrs +// debug-NEXT: define internal float +// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1) +// release-NEXT: define internal fastcc noundef float +// release-SAME: (float %x.0.0.val, float %x.1.0.val) +#[autodiff_forward(df7, Dual, Dual)] +#[inline(never)] +fn f7(x: (&f32, &f32)) -> f32 { + x.0 * x.1 +} + +fn main() { + let x = std::hint::black_box(2.0); + let y = std::hint::black_box(3.0); + let z = std::hint::black_box(4.0); + static Y: f32 = std::hint::black_box(3.2); + + let in_f1 = [x, y]; + dbg!(f1(&in_f1)); + let res_f1 = df1(&in_f1, &[1.0, 0.0]); + dbg!(res_f1); + + dbg!(f2(square, x)); + let res_f2 = df2(square, x, 1.0); + dbg!(res_f2); + + dbg!(f3(&x, &Y)); + let res_f3 = df3(&x, &Y, &1.0, &0.0); + dbg!(res_f3); + + let in_f4 = (x, y); + dbg!(f4(in_f4)); + let res_f4 = df4(in_f4, (1.0, 0.0)); + dbg!(res_f4); + + let in_f5 = Input { x, y }; + dbg!(f5(in_f5)); + let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 }); + dbg!(res_f5); + + let in_f6 = NestedInput { x, y: Wrapper { z: y } }; + dbg!(f6(in_f6)); + let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } }); + dbg!(res_f6); + + let in_f7 = (&x, &y); + dbg!(f7(in_f7)); + let res_f7 = df7(in_f7, (&1.0, &0.0)); + dbg!(res_f7); +} diff --git a/tests/ui/autodiff/zst.rs b/tests/ui/autodiff/zst.rs new file mode 100644 index 0000000000000..7b9b5f5f20bdc --- /dev/null +++ b/tests/ui/autodiff/zst.rs @@ -0,0 +1,17 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme +//@ build-pass + +// Check that differentiating functions with ZST args does not break + +#![feature(autodiff)] + +#[core::autodiff::autodiff_forward(fd_inner, Const, Dual)] +fn f(_zst: (), _x: &mut f64) {} + +fn fd(x: &mut f64, xd: &mut f64) { + fd_inner((), x, xd); +} + +fn main() {}