Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions library/proc_macro/src/bridge/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use std::marker::PhantomData;
use std::sync::atomic::AtomicU32;

use super::*;
use crate::StandaloneLevel;
use crate::bridge::server::{Dispatcher, DispatcherTrait};
use crate::bridge::standalone::NoRustc;

macro_rules! define_client_handles {
(
Expand Down Expand Up @@ -141,7 +144,10 @@ macro_rules! define_client_side {
api_tags::Method::$name(api_tags::$name::$method).encode(&mut buf, &mut ());
$($arg.encode(&mut buf, &mut ());)*

buf = bridge.dispatch.call(buf);
buf = match &mut bridge.dispatch {
DispatchWay::Closure(f) => f.call(buf),
DispatchWay::Directly(disp) => disp.dispatch(buf),
};

let r = Result::<_, PanicMessage>::decode(&mut &buf[..], &mut ());

Expand All @@ -155,13 +161,18 @@ macro_rules! define_client_side {
}
with_api!(self, self, define_client_side);

enum DispatchWay<'a> {
Closure(closure::Closure<'a, Buffer, Buffer>),
Directly(Dispatcher<NoRustc>),
}

struct Bridge<'a> {
/// Reusable buffer (only `clear`-ed, never shrunk), primarily
/// used for making requests.
cached_buffer: Buffer,

/// Server-side function that the client uses to make requests.
dispatch: closure::Closure<'a, Buffer, Buffer>,
dispatch: DispatchWay<'a>,

/// Provided globals for this macro expansion.
globals: ExpnGlobals<Span>,
Expand All @@ -173,12 +184,33 @@ impl<'a> !Sync for Bridge<'a> {}
#[allow(unsafe_code)]
mod state {
use std::cell::{Cell, RefCell};
use std::marker::PhantomData;
use std::ptr;

use super::Bridge;
use crate::StandaloneLevel;
use crate::bridge::buffer::Buffer;
use crate::bridge::client::{COUNTERS, DispatchWay};
use crate::bridge::server::{Dispatcher, HandleStore, MarkedTypes};
use crate::bridge::{ExpnGlobals, Marked, standalone};

thread_local! {
static BRIDGE_STATE: Cell<*const ()> = const { Cell::new(ptr::null()) };
static STANDALONE: RefCell<Bridge<'static>> = RefCell::new(standalone_bridge());
pub(super) static USE_STANDALONE: Cell<StandaloneLevel> = const { Cell::new(StandaloneLevel::Never) };
}

fn standalone_bridge() -> Bridge<'static> {
let mut store = HandleStore::new(&COUNTERS);
let id = store.Span.alloc(Marked { value: standalone::Span::DUMMY, _marker: PhantomData });
let dummy = super::Span { handle: id };
let dispatcher =
Dispatcher { handle_store: store, server: MarkedTypes(standalone::NoRustc) };
Bridge {
cached_buffer: Buffer::new(),
dispatch: DispatchWay::Directly(dispatcher),
globals: ExpnGlobals { call_site: dummy, def_site: dummy, mixed_site: dummy },
}
}

pub(super) fn set<'bridge, R>(state: &RefCell<Bridge<'bridge>>, f: impl FnOnce() -> R) -> R {
Expand All @@ -199,16 +231,23 @@ mod state {
pub(super) fn with<R>(
f: impl for<'bridge> FnOnce(Option<&RefCell<Bridge<'bridge>>>) -> R,
) -> R {
let state = BRIDGE_STATE.get();
// SAFETY: the only place where the pointer is set is in `set`. It puts
// back the previous value after the inner call has returned, so we know
// that as long as the pointer is not null, it came from a reference to
// a `RefCell<Bridge>` that outlasts the call to this function. Since `f`
// works the same for any lifetime of the bridge, including the actual
// one, we can lie here and say that the lifetime is `'static` without
// anyone noticing.
let bridge = unsafe { state.cast::<RefCell<Bridge<'static>>>().as_ref() };
f(bridge)
let level = USE_STANDALONE.get();
if level == StandaloneLevel::Always
|| (level == StandaloneLevel::FallbackOnly && BRIDGE_STATE.get().is_null())
{
STANDALONE.with(|bridge| f(Some(bridge)))
} else {
let state = BRIDGE_STATE.get();
// SAFETY: the only place where the pointer is set is in `set`. It puts
// back the previous value after the inner call has returned, so we know
// that as long as the pointer is not null, it came from a reference to
// a `RefCell<Bridge>` that outlasts the call to this function. Since `f`
// works the same for any lifetime of the bridge, including the actual
// one, we can lie here and say that the lifetime is `'static` without
// anyone noticing.
let bridge = unsafe { state.cast::<RefCell<Bridge<'static>>>().as_ref() };
f(bridge)
}
}
}

Expand All @@ -228,6 +267,10 @@ pub(crate) fn is_available() -> bool {
state::with(|s| s.is_some())
}

pub(crate) fn enable_standalone(level: StandaloneLevel) {
state::USE_STANDALONE.set(level);
}

/// A client-side RPC entry-point, which may be using a different `proc_macro`
/// from the one used by the server, but can be invoked compatibly.
///
Expand Down Expand Up @@ -292,7 +335,11 @@ fn run_client<A: for<'a, 's> Decode<'a, 's, ()>, R: Encode<()>>(
let (globals, input) = <(ExpnGlobals<Span>, A)>::decode(reader, &mut ());

// Put the buffer we used for input back in the `Bridge` for requests.
let state = RefCell::new(Bridge { cached_buffer: buf.take(), dispatch, globals });
let state = RefCell::new(Bridge {
cached_buffer: buf.take(),
dispatch: DispatchWay::Closure(dispatch),
globals,
});

let output = state::set(&state, || f(input));

Expand Down
1 change: 1 addition & 0 deletions library/proc_macro/src/bridge/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ mod rpc;
mod selfless_reify;
#[forbid(unsafe_code)]
pub mod server;
pub(crate) mod standalone;
#[allow(unsafe_code)]
mod symbol;

Expand Down
16 changes: 8 additions & 8 deletions library/proc_macro/src/bridge/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ macro_rules! define_server_handles {
) => {
#[allow(non_snake_case)]
pub(super) struct HandleStore<S: Types> {
$($oty: handle::OwnedStore<S::$oty>,)*
$($ity: handle::InternedStore<S::$ity>,)*
$(pub(super) $oty: handle::OwnedStore<S::$oty>,)*
$(pub(super) $ity: handle::InternedStore<S::$ity>,)*
}

impl<S: Types> HandleStore<S> {
fn new(handle_counters: &'static client::HandleCounters) -> Self {
pub(super) fn new(handle_counters: &'static client::HandleCounters) -> Self {
HandleStore {
$($oty: handle::OwnedStore::new(&handle_counters.$oty),)*
$($ity: handle::InternedStore::new(&handle_counters.$ity),)*
Expand Down Expand Up @@ -119,7 +119,7 @@ macro_rules! declare_server_traits {
}
with_api!(Self, self_, declare_server_traits);

pub(super) struct MarkedTypes<S: Types>(S);
pub(super) struct MarkedTypes<S: Types>(pub(super) S);

impl<S: Server> Server for MarkedTypes<S> {
fn globals(&mut self) -> ExpnGlobals<Self::Span> {
Expand Down Expand Up @@ -150,9 +150,9 @@ macro_rules! define_mark_types_impls {
}
with_api!(Self, self_, define_mark_types_impls);

struct Dispatcher<S: Types> {
handle_store: HandleStore<S>,
server: S,
pub(super) struct Dispatcher<S: Types> {
pub(super) handle_store: HandleStore<MarkedTypes<S>>,
pub(super) server: MarkedTypes<S>,
}

macro_rules! define_dispatcher_impl {
Expand All @@ -167,7 +167,7 @@ macro_rules! define_dispatcher_impl {
fn dispatch(&mut self, buf: Buffer) -> Buffer;
}

impl<S: Server> DispatcherTrait for Dispatcher<MarkedTypes<S>> {
impl<S: Server> DispatcherTrait for Dispatcher<S> {
$(type $name = <MarkedTypes<S> as Types>::$name;)*

fn dispatch(&mut self, mut buf: Buffer) -> Buffer {
Expand Down
Loading
Loading