crux_macros/
capability.rs

1use darling::{ast, util, FromDeriveInput, FromField, ToTokens};
2use proc_macro2::TokenStream;
3use proc_macro_error::{abort, OptionExt};
4use quote::quote;
5use syn::{DeriveInput, GenericArgument, Ident, PathArguments, Type};
6
7#[derive(FromDeriveInput, Debug)]
8#[darling(supports(struct_named))]
9struct CapabilityStructReceiver {
10    ident: Ident,
11    data: ast::Data<util::Ignored, CapabilityFieldReceiver>,
12}
13
14#[derive(FromField, Debug)]
15pub struct CapabilityFieldReceiver {
16    ident: Option<Ident>,
17    ty: Type,
18}
19
20impl ToTokens for CapabilityStructReceiver {
21    fn to_tokens(&self, tokens: &mut TokenStream) {
22        let name = &self.ident;
23        let operation_type = self
24            .data
25            .as_ref()
26            .take_struct()
27            .expect_or_abort("should be a struct")
28            .fields
29            .iter()
30            .find(|f| f.ident.as_ref().unwrap() == "context")
31            .map(|f| first_generic_parameter(&f.ty))
32            .expect_or_abort("could not find a field named `context`");
33
34        tokens.extend(quote! {
35          impl<Ev> crux_core::capability::Capability<Ev> for #name<Ev> {
36            type Operation = #operation_type;
37            type MappedSelf<MappedEv> = #name<MappedEv>;
38
39            fn map_event<F, NewEv>(&self, f: F) -> Self::MappedSelf<NewEv>
40            where
41                F: Fn(NewEv) -> Ev + Send + Sync + 'static,
42                Ev: 'static,
43                NewEv: 'static + Send,
44            {
45              #name::new(self.context.map_event(f))
46            }
47          }
48        })
49    }
50}
51
52pub(crate) fn capability_impl(input: &DeriveInput) -> TokenStream {
53    let input = match CapabilityStructReceiver::from_derive_input(input) {
54        Ok(v) => v,
55        Err(e) => {
56            return e.write_errors();
57        }
58    };
59
60    quote!(#input)
61}
62
63fn first_generic_parameter(ty: &Type) -> Type {
64    let generic_param = match ty.clone() {
65        Type::Path(mut path) if path.qself.is_none() => {
66            // Get the last segment of the path where the generic parameters should be
67            let last = path.path.segments.last_mut().expect("type has no segments");
68            let type_params = std::mem::take(&mut last.arguments);
69
70            let first_type_parameter = match type_params {
71                PathArguments::AngleBracketed(params) => params.args.first().cloned(),
72                _ => None,
73            };
74
75            // This argument must be a type
76            match first_type_parameter {
77                Some(GenericArgument::Type(t2)) => Some(t2),
78                _ => None,
79            }
80        }
81        _ => None,
82    };
83    let Some(generic_param) = generic_param else {
84        abort!(ty, "context field type should have generic type parameters");
85    };
86    generic_param
87}
88
89#[cfg(test)]
90mod tests {
91    use darling::{FromDeriveInput, FromMeta};
92    use quote::quote;
93    use syn::{parse_str, Type};
94
95    use crate::capability::CapabilityStructReceiver;
96
97    use super::first_generic_parameter;
98
99    #[test]
100    fn test_derive() {
101        let input = r#"
102            #[derive(Capability)]
103            pub struct Render<Ev> {
104              context: CapabilityContext<RenderOperation, Ev>,
105            }
106        "#;
107        let input = parse_str(input).unwrap();
108        let input = CapabilityStructReceiver::from_derive_input(&input).unwrap();
109
110        let actual = quote!(#input);
111
112        insta::assert_snapshot!(pretty_print(&actual), @r###"
113        impl<Ev> crux_core::capability::Capability<Ev> for Render<Ev> {
114            type Operation = RenderOperation;
115            type MappedSelf<MappedEv> = Render<MappedEv>;
116            fn map_event<F, NewEv>(&self, f: F) -> Self::MappedSelf<NewEv>
117            where
118                F: Fn(NewEv) -> Ev + Send + Sync + 'static,
119                Ev: 'static,
120                NewEv: 'static + Send,
121            {
122                Render::new(self.context.map_event(f))
123            }
124        }
125        "###);
126    }
127
128    #[test]
129    fn test_first_generic_parameter() {
130        let ty = Type::from_string("CapabilityContext<my_mod::MyOperation, Ev>").unwrap();
131
132        let first_param = first_generic_parameter(&ty);
133
134        assert_eq!(
135            quote!(#first_param).to_string(),
136            quote!(my_mod::MyOperation).to_string()
137        );
138    }
139
140    fn pretty_print(ts: &proc_macro2::TokenStream) -> String {
141        let file = syn::parse_file(&ts.to_string()).unwrap();
142        prettyplease::unparse(&file)
143    }
144}