diff --git a/utils/databake/derive/src/lib.rs b/utils/databake/derive/src/lib.rs index 74a01339e0c..eb0c8b372a4 100644 --- a/utils/databake/derive/src/lib.rs +++ b/utils/databake/derive/src/lib.rs @@ -6,6 +6,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; +use quote::format_ident; use quote::quote; use syn::{ parse::{Parse, ParseStream}, @@ -31,6 +32,40 @@ use synstructure::{AddBounds, Structure}; /// pub age: u32, /// } /// ``` +/// +/// # Custom baked type +/// +/// To bake to a different type than this, use `custom_bake` +/// and implement `CustomBake`. +/// +/// ```rust +/// use databake::Bake; +/// use databake::CustomBake; +/// +/// #[derive(Bake)] +/// #[databake(path = bar::module)] +/// #[databake(path = custom_bake)] +/// pub struct Message<'a> { +/// pub message: &'a str, +/// } +/// +/// // Bake to a string: +/// impl CustomBake for Message<'_> { +/// type BakedType<'a> = &'a str where Self: 'a; +/// fn to_custom_bake(&self) -> Self::BakedType<'_> { +/// &self.message +/// } +/// } +/// +/// impl<'a> Message<'a> { +/// pub fn from_custom_bake(message: &'a str) -> Self { +/// Self { message } +/// } +/// } +/// ``` +/// +/// If the constructor is unsafe, use `custom_bake_unsafe` +/// and implement `CustomBakeUnsafe`. #[proc_macro_derive(Bake, attributes(databake))] pub fn bake_derive(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -40,44 +75,89 @@ pub fn bake_derive(input: TokenStream) -> TokenStream { fn bake_derive_impl(input: &DeriveInput) -> TokenStream2 { let mut structure = Structure::new(input); - struct PathAttr(Punctuated); + enum DatabakeAttr { + Path(Punctuated), + CustomBake, + CustomBakeUnsafe, + } - impl Parse for PathAttr { + impl Parse for DatabakeAttr { fn parse(input: ParseStream<'_>) -> syn::parse::Result { let i: Ident = input.parse()?; - if i != "path" { - return Err(input.error(format!("expected token \"path\", found {i:?}"))); + if i == "path" { + input.parse::()?; + Ok(Self::Path(input.parse::()?.segments)) + } else if i == "custom_bake" { + Ok(Self::CustomBake) + } else if i == "custom_bake_unsafe" { + Ok(Self::CustomBakeUnsafe) + } else { + Err(input.error(format!("expected token \"path\", found {i:?}"))) } - input.parse::()?; - Ok(Self(input.parse::()?.segments)) } } - let path = input + let attrs = input .attrs .iter() - .find(|a| a.path().is_ident("databake")) - .expect("missing databake(path = ...) attribute") - .parse_args::() - .unwrap() - .0; + .filter(|a| a.path().is_ident("databake")) + .map(|a| a.parse_args::().unwrap()) + .collect::>(); - let bake_body = structure.each_variant(|vi| { - let recursive_calls = vi.bindings().iter().map(|b| { - let ident = b.binding.clone(); - quote! { let #ident = #ident.bake(env); } - }); + let path = attrs + .iter() + .filter_map(|a| match a { + DatabakeAttr::Path(path) => Some(path), + _ => None, + }) + .next() + .expect("missing databake(path = ...) attribute"); - let constructor = vi.construct(|_, i| { - let ident = &vi.bindings()[i].binding; - quote! { # #ident } - }); + let is_custom_bake = attrs.iter().any(|a| matches!(a, DatabakeAttr::CustomBake)); - quote! { - #(#recursive_calls)* - databake::quote! { #path::#constructor } + let is_custom_bake_unsafe = attrs + .iter() + .any(|a| matches!(a, DatabakeAttr::CustomBakeUnsafe)); + + let bake_body = if is_custom_bake || is_custom_bake_unsafe { + let type_ident = &structure.ast().ident; + let baked_ident = format_ident!("baked"); + if is_custom_bake_unsafe { + quote! { + x => { + let baked = databake::CustomBake::to_custom_bake(x).bake(env); + databake::quote! { + // Safety: the bake is generated from `CustomBakeUnsafe::to_custom_bake` + unsafe { #path::#type_ident::from_custom_bake(##baked_ident) } + } + } + } + } else { + quote! { + x => { + let baked = databake::CustomBake::to_custom_bake(x).bake(env); + databake::quote! { #path::#type_ident::from_custom_bake(##baked_ident) } + } + } } - }); + } else { + structure.each_variant(|vi| { + let recursive_calls = vi.bindings().iter().map(|b| { + let ident = b.binding.clone(); + quote! { let #ident = #ident.bake(env); } + }); + + let constructor = vi.construct(|_, i| { + let ident = &vi.bindings()[i].binding; + quote! { # #ident } + }); + + quote! { + #(#recursive_calls)* + databake::quote! { #path::#constructor } + } + }) + }; let borrows_size_body = structure.each_variant(|vi| { let recursive_calls = vi.bindings().iter().map(|b| { diff --git a/utils/databake/derive/tests/derive.rs b/utils/databake/derive/tests/derive.rs index fb9be2bdac0..8365fea2424 100644 --- a/utils/databake/derive/tests/derive.rs +++ b/utils/databake/derive/tests/derive.rs @@ -60,3 +60,90 @@ fn test_cow_example() { test ); } + +#[derive(Bake)] +#[databake(path = test)] +#[databake(custom_bake)] +pub struct CustomBakeExample<'a> { + x: usize, + y: alloc::borrow::Cow<'a, str>, +} + +impl CustomBake for CustomBakeExample<'_> { + type BakedType<'a> + = (usize, &'a str) + where + Self: 'a; + fn to_custom_bake(&self) -> Self::BakedType<'_> { + (self.x, &*self.y) + } +} + +impl<'a> CustomBakeExample<'a> { + pub const fn from_custom_bake(baked: ::BakedType<'a>) -> Self { + Self { + x: baked.0, + y: alloc::borrow::Cow::Borrowed(baked.1), + } + } +} + +#[test] +fn test_custom_bake_example() { + test_bake!( + CustomBakeExample<'static>, + const, + crate::CustomBakeExample { + x: 51423usize, + y: alloc::borrow::Cow::Borrowed("bar"), + }, + crate::CustomBakeExample::from_custom_bake((51423usize, "bar")), + test + ); +} + +#[derive(Bake)] +#[databake(path = test)] +#[databake(custom_bake_unsafe)] +pub struct CustomBakeUnsafeExample<'a> { + x: usize, + y: alloc::borrow::Cow<'a, str>, +} + +impl CustomBake for CustomBakeUnsafeExample<'_> { + type BakedType<'a> + = (usize, &'a str) + where + Self: 'a; + fn to_custom_bake(&self) -> Self::BakedType<'_> { + (self.x, &*self.y) + } +} + +// Safety: The type has a `from_custom_bake` fn of the required signature. +unsafe impl CustomBakeUnsafe for CustomBakeExample<'_> {} + +impl<'a> CustomBakeUnsafeExample<'a> { + /// # Safety + /// The argument MUST have been returned from [`Self::to_custom_bake`]. + pub const unsafe fn from_custom_bake(baked: ::BakedType<'a>) -> Self { + Self { + x: baked.0, + y: alloc::borrow::Cow::Borrowed(baked.1), + } + } +} + +#[test] +fn test_custom_bake_unsafe_example() { + test_bake!( + CustomBakeUnsafeExample<'static>, + const, + crate::CustomBakeUnsafeExample { + x: 51423usize, + y: alloc::borrow::Cow::Borrowed("bar"), + }, + unsafe { crate::CustomBakeUnsafeExample::from_custom_bake((51423usize, "bar")) }, + test + ); +} diff --git a/utils/databake/src/custom_bake.rs b/utils/databake/src/custom_bake.rs new file mode 100644 index 00000000000..5582a90e493 --- /dev/null +++ b/utils/databake/src/custom_bake.rs @@ -0,0 +1,39 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use crate::Bake; + +/// A trait for an item that can bake to something other than itself. +/// +/// For an unsafe version of this trait, see [`CustomBakeUnsafe`]. +/// +/// The type implementing this trait should have an associated function +/// with the following signature: +/// +/// ```ignore +/// /// The argument should have been returned from [`Self::to_custom_bake`]. +/// pub fn from_custom_bake(baked: CustomBake::BakedType) -> Self +/// ``` +pub trait CustomBake { + /// The type of the custom bake. + type BakedType<'a>: Bake + where + Self: 'a; + /// Returns `self` as the custom bake type. + fn to_custom_bake(&self) -> Self::BakedType<'_>; +} + +/// Same as [`CustomBake`] but allows for the constructor to be `unsafe`. +/// +/// # Safety +/// +/// The type implementing this trait MUST have an associated unsafe function +/// with the following signature: +/// +/// ```ignore +/// /// # Safety +/// /// The argument MUST have been returned from [`Self::to_custom_bake`]. +/// pub unsafe fn from_custom_bake(baked: CustomBakeUnsafe::BakedType) -> Self +/// ``` +pub unsafe trait CustomBakeUnsafe: CustomBake {} diff --git a/utils/databake/src/lib.rs b/utils/databake/src/lib.rs index ce55b7c7372..8d8afd3d5bd 100644 --- a/utils/databake/src/lib.rs +++ b/utils/databake/src/lib.rs @@ -74,6 +74,7 @@ mod alloc; pub mod converter; +mod custom_bake; mod primitives; #[doc(no_inline)] @@ -85,6 +86,8 @@ pub use quote::quote; #[cfg(feature = "derive")] pub use databake_derive::Bake; +pub use custom_bake::*; + use std::collections::HashSet; use std::sync::Mutex; @@ -145,6 +148,27 @@ pub trait BakeSize: Sized + Bake { /// test_bake!(usize, const, 18usize); /// ``` /// +/// ## Custom baked type +/// +/// If the baked type is different than the input type, pass both as arguments: +/// +/// ```no_run +/// # use databake::*; +/// # struct MyStruct(usize); +/// # impl Bake for MyStruct { +/// # fn bake(&self, _: &CrateEnv) -> TokenStream { unimplemented!() } +/// # } +/// # // We need an explicit main to put the struct at the crate root +/// # fn main() { +/// test_bake!( +/// MyStruct, +/// crate::MyStruct(42usize), +/// crate::MyStruct::from_custom_bake(42usize), +/// my_crate, +/// ); +/// # } +/// ``` +/// /// ## Crates and imports /// /// As most output will need to reference its crate, and its not possible to name a crate from @@ -172,16 +196,24 @@ pub trait BakeSize: Sized + Bake { #[macro_export] macro_rules! test_bake { ($type:ty, const, $expr:expr $(, $krate:ident)? $(, [$($env_crate:ident),+])? $(,)?) => { - const _: &$type = &$expr; - $crate::test_bake!($type, $expr $(, $krate)? $(, [$($env_crate),+])?); + $crate::test_bake!($type, const, $expr, $expr $(, $krate)? $(, [$($env_crate),+])?); }; ($type:ty, $expr:expr $(, $krate:ident)? $(, [$($env_crate:ident),+])? $(,)?) => { + $crate::test_bake!($type, $expr, $expr $(, $krate)? $(, [$($env_crate),+])?); + }; + + ($type:ty, const, $init_expr:expr, $baked_expr:expr $(, $krate:ident)? $(, [$($env_crate:ident),+])? $(,)?) => { + const _: &$type = &$baked_expr; + $crate::test_bake!($type, $init_expr, $baked_expr $(, $krate)? $(, [$($env_crate),+])?); + }; + + ($type:ty, $init_expr:expr, $baked_expr:expr $(, $krate:ident)? $(, [$($env_crate:ident),+])? $(,)?) => { let env = Default::default(); - let expr: &$type = &$expr; + let expr: &$type = &$init_expr; let bake = $crate::Bake::bake(expr, &env).to_string(); // For some reason `TokenStream` behaves differently in this line - let expected_bake = $crate::quote!($expr).to_string().replace("::<", ":: <").replace(">::", "> ::"); + let expected_bake = $crate::quote!($baked_expr).to_string().replace("::<", ":: <").replace(">::", "> ::"); // Trailing commas are a mess as well let bake = bake.replace(" ,)", ")").replace(" ,]", "]").replace(" , }", " }").replace(" , >", " >"); let expected_bake = expected_bake.replace(" ,)", ")").replace(" ,]", "]").replace(" , }", " }").replace(" , >", " >");