Skip to content

Commit 575aa54

Browse files
committed
Add PTX calling conventions to generated LLVM IR
Fixes a TODO. Shouldn't affect anything but why not.
1 parent ac2674f commit 575aa54

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

crates/rustc_codegen_nvvm/src/context.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
317317
}
318318
}
319319

320-
/// Declare a function. All functions use the default ABI, NVVM ignores any calling convention markers.
321-
/// All functions calls are generated according to the PTX calling convention.
320+
/// Declare a function with appropriate PTX calling conventions.
322321
/// <https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#calling-conventions>
323322
pub fn declare_fn(
324323
&self,
@@ -332,8 +331,12 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
332331

333332
trace!("Declaring function `{}` with ty `{:?}`", name, ty);
334333

335-
// TODO(RDambrosio016): we should probably still generate accurate calling conv for functions
336-
// just to make it easier to debug IR and/or make it more compatible with compiling using llvm
334+
// Set PTX device calling convention for all functions declared here.
335+
// Kernel functions will have their calling convention overridden in mono_item.rs
336+
unsafe {
337+
llvm::LLVMSetFunctionCallConv(llfn, llvm::PtxCallConv::Device as u32);
338+
}
339+
337340
llvm::SetUnnamedAddress(llfn, llvm::UnnamedAddr::Global);
338341
if let Some(abi) = fn_abi {
339342
abi.apply_attrs_llfn(self, llfn);

crates/rustc_codegen_nvvm/src/llvm.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ pub(crate) enum Visibility {
206206
Protected = 2,
207207
}
208208

209+
/// PTX/NVPTX calling conventions from LLVM
210+
/// See: https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/CallingConv.h
211+
///
212+
/// While NVVM doesn't strictly require these calling conventions to be set
213+
/// (it generates PTX according to its own rules), we set them anyway to
214+
/// make the generated LLVM IR more accurate and easier to debug.
215+
#[repr(u32)]
216+
pub(crate) enum PtxCallConv {
217+
/// PTX kernel calling convention
218+
Kernel = 71,
219+
/// PTX device calling convention
220+
Device = 72,
221+
}
222+
209223
/// LLVMUnnamedAddr
210224
#[repr(C)]
211225
pub(crate) enum UnnamedAddr {

crates/rustc_codegen_nvvm/src/mono_item.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> {
9898
// to nvvm.annotations per the nvvm ir docs.
9999
if nvvm_attrs.kernel {
100100
trace!("Marking function `{:?}` as a kernel", symbol_name);
101+
llvm::LLVMSetFunctionCallConv(lldecl, llvm::PtxCallConv::Kernel as u32);
102+
103+
// Add kernel metadata for NVVM
101104
let kernel = llvm::LLVMMDStringInContext(self.llcx, "kernel".as_ptr().cast(), 6);
102105
let mdvals = &[lldecl, kernel, self.const_i32(1)];
103106
let node =

0 commit comments

Comments
 (0)