Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ either = "1.8.0"
indexmap = "2.6.0"
rspirv = "0.12"
rustc_codegen_spirv-types.workspace = true
spirv-std-types = { workspace = true, features = ["std"] }
rustc-demangle = "0.1.21"
sanitize-filename = "0.6.0"
smallvec = { version = "1.6.1", features = ["const_generics", "const_new", "union"] }
Expand Down
339 changes: 336 additions & 3 deletions crates/rustc_codegen_spirv/src/attr.rs

Large diffs are not rendered by default.

348 changes: 14 additions & 334 deletions crates/rustc_codegen_spirv/src/symbols.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion crates/spirv-std/macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ workspace = true
proc-macro = true

[dependencies]
spirv-std-types.workspace = true
spirv-std-types = { workspace = true, features = ["std"] }
proc-macro2 = "1.0.24"
quote = "1.0.8"
syn = { version = "2.0.90", features = ["full", "visit-mut"] }
97 changes: 77 additions & 20 deletions crates/spirv-std/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,12 @@
mod image;

use proc_macro::TokenStream;
use proc_macro2::{Delimiter, Group, Span, TokenTree};
use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};

use syn::{ImplItemFn, visit_mut::VisitMut};

use quote::{ToTokens, quote};
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
use std::fmt::Write;

/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
Expand Down Expand Up @@ -144,45 +145,101 @@ pub fn Image(item: TokenStream) -> TokenStream {
/// `#[cfg_attr(target_arch="spirv", rust_gpu::spirv(..))]`.
#[proc_macro_attribute]
pub fn spirv(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut tokens: Vec<TokenTree> = Vec::new();
let spirv = format_ident!("{}", &spirv_attr_with_version());

// prepend with #[rust_gpu::spirv(..)]
let attr: proc_macro2::TokenStream = attr.into();
tokens.extend(quote! { #[cfg_attr(target_arch="spirv", rust_gpu::spirv(#attr))] });
let mut tokens = quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] };

let item: proc_macro2::TokenStream = item.into();
for tt in item {
match tt {
TokenTree::Group(group) if group.delimiter() == Delimiter::Parenthesis => {
let mut sub_tokens = Vec::new();
let mut group_tokens = proc_macro2::TokenStream::new();
let mut last_token_hashtag = false;
for tt in group.stream() {
let is_token_hashtag =
matches!(&tt, TokenTree::Punct(punct) if punct.as_char() == '#');
match tt {
TokenTree::Group(group)
if group.delimiter() == Delimiter::Bracket
&& matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv")
&& matches!(sub_tokens.last(), Some(TokenTree::Punct(p)) if p.as_char() == '#') =>
&& last_token_hashtag
&& matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
{
// group matches [spirv ...]
let inner = group.stream(); // group stream doesn't include the brackets
sub_tokens.extend(
quote! { [cfg_attr(target_arch="spirv", rust_gpu::#inner)] },
// group stream doesn't include the brackets
let inner = group
.stream()
.into_iter()
.skip(1)
.collect::<proc_macro2::TokenStream>();
group_tokens.extend(
quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] },
);
}
_ => sub_tokens.push(tt),
_ => group_tokens.append(tt),
}
last_token_hashtag = is_token_hashtag;
}
tokens.push(TokenTree::from(Group::new(
Delimiter::Parenthesis,
sub_tokens.into_iter().collect(),
)));
let mut out = Group::new(Delimiter::Parenthesis, group_tokens);
out.set_span(group.span());
tokens.append(out);
}
_ => tokens.push(tt),
_ => tokens.append(tt),
}
}
tokens
.into_iter()
.collect::<proc_macro2::TokenStream>()
.into()
tokens.into()
}

/// For testing only! Is not reexported in `spirv-std`, but reachable via
/// `spirv_std::macros::spirv_recursive_for_testing`.
///
/// May be more expensive than plain `spirv`, since we're checking a lot more symbols. So I've opted to
/// have this be a separate macro, instead of modifying the standard `spirv` one.
#[proc_macro_attribute]
pub fn spirv_recursive_for_testing(attr: TokenStream, item: TokenStream) -> TokenStream {
fn recurse(spirv: &Ident, stream: proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let mut last_token_hashtag = false;
stream.into_iter().map(|tt| {
let mut is_token_hashtag = false;
let out = match tt {
TokenTree::Group(group)
if group.delimiter() == Delimiter::Bracket
&& last_token_hashtag
&& matches!(group.stream().into_iter().next(), Some(TokenTree::Ident(ident)) if ident == "spirv") =>
{
// group matches [spirv ...]
// group stream doesn't include the brackets
let inner = group
.stream()
.into_iter()
.skip(1)
.collect::<proc_macro2::TokenStream>();
quote! { [cfg_attr(target_arch="spirv", rust_gpu::#spirv #inner)] }
},
TokenTree::Group(group) => {
let mut out = Group::new(group.delimiter(), recurse(spirv, group.stream()));
out.set_span(group.span());
TokenTree::Group(out).into()
},
TokenTree::Punct(punct) => {
is_token_hashtag = punct.as_char() == '#';
TokenTree::Punct(punct).into()
}
tt => tt.into(),
};
last_token_hashtag = is_token_hashtag;
out
}).collect()
}

let attr: proc_macro2::TokenStream = attr.into();
let item: proc_macro2::TokenStream = item.into();

// prepend with #[rust_gpu::spirv(..)]
let spirv = format_ident!("{}", &spirv_attr_with_version());
let inner = recurse(&spirv, item);
quote! { #[cfg_attr(target_arch="spirv", rust_gpu::#spirv(#attr))] #inner }.into()
}

/// Marks a function as runnable only on the GPU, and will panic on
Expand Down
5 changes: 4 additions & 1 deletion crates/spirv-std/shared/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "spirv-std-types"
description = "SPIR-V types shared between spirv-std and spirv-std-macros"
description = "SPIR-V types shared between spirv-std, spirv-std-macros and rustc_codegen_spirv"
version.workspace = true
authors.workspace = true
edition.workspace = true
Expand All @@ -9,3 +9,6 @@ repository.workspace = true

[lints]
workspace = true

[features]
std = []
4 changes: 3 additions & 1 deletion crates/spirv-std/shared/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# `spirv-std-types`

Small shared crate, to share definitions between [`spirv-std`](https://docs.rs/spirv-std/) and [`spirv-std-macros`](https://docs.rs/spirv-std-macros/). Please refer to [`spirv-std`](https://docs.rs/spirv-std/) for more information.
Small shared crate, to share definitions between [`spirv-std`](https://docs.rs/spirv-std/), [`spirv-std-macros`](https://docs.rs/spirv-std-macros/) and [`rustc_codegen_spirv`](https://docs.rs/rustc_codegen_spirv/). Should only contain symbols that compile in a `no_std` context.

Please refer to [`spirv-std`](https://docs.rs/spirv-std/) for more information.
4 changes: 3 additions & 1 deletion crates/spirv-std/shared/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#![doc = include_str!("../README.md")]
#![no_std]
#![cfg_attr(not(feature = "std"), no_std)]

pub mod image_params;
#[cfg(feature = "std")]
pub mod spirv_attr_version;
18 changes: 18 additions & 0 deletions crates/spirv-std/shared/src/spirv_attr_version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/// Returns the `spirv` attribute name with version tag.
///
/// The `#[spirv()]` attribute in `spirv_std` expands to this spirv attribute name,
/// including the version of `spirv_std`. `rustc_codegen_spirv` verifies that the
/// version matches with the codegen backend, to prevent accidental mismatches.
///
/// ```no_run
/// # use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
/// let spirv = spirv_attr_with_version();
/// let attr = format!("#[rust_gpu::{spirv}(vertex)]");
/// // version here may be out-of-date
/// assert_eq!("#[rust_gpu::spirv_v0_9(vertex)]", attr);
/// ```
pub fn spirv_attr_with_version() -> String {
let major: u32 = env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap();
let minor: u32 = env!("CARGO_PKG_VERSION_MINOR").parse().unwrap();
format!("spirv_v{major}_{minor}")
}
Loading