crux_cli/codegen/
formatter.rs

1use std::collections::BTreeMap;
2
3use ascent::ascent;
4use rustdoc_types::{GenericArg, GenericArgs, Item, ItemEnum, Type};
5
6use crate::codegen::collect;
7
8use super::{
9    indexed::Indexed,
10    item::*,
11    node::ItemNode,
12    serde::case::RenameRule,
13    serde_generate::format::{ContainerFormat, Format, Named, VariantFormat},
14};
15
16ascent! {
17    #![measure_rule_times]
18    pub struct Formatter;
19
20    // ------- facts ------------------
21    relation edge(ItemNode, ItemNode);
22
23    // ------- rules ------------------
24
25    relation struct_unit(ItemNode);
26    struct_unit(s) <-- edge(s, _), if is_struct_unit(&s.item);
27
28    relation struct_plain(ItemNode);
29    struct_plain(s) <-- edge(s, _), if is_struct_plain(&s.item);
30
31    relation struct_tuple(ItemNode);
32    struct_tuple(s) <-- edge(s, _), if is_struct_tuple(&s.item);
33
34    relation field(ItemNode, ItemNode);
35    field(x, f) <-- edge(x, f), if x.has_field(f);
36
37    relation fields(ItemNode, Vec<ItemNode>);
38    fields(x, fields) <--
39        field(x, f),
40        agg fs = collect(f) in field(x, f),
41        let fields = x.fields(fs);
42
43    relation variant(ItemNode, ItemNode);
44    variant(e, v) <-- edge(e, v), if e.has_variant(v);
45
46    relation variants(ItemNode, Vec<ItemNode>);
47    variants(e, variants) <--
48        variant(e, v),
49        agg vs = collect(v) in variant(e, v),
50        let variants = e.variants(vs);
51
52    relation variant_plain(ItemNode, ItemNode);
53    variant_plain(e, v) <-- variant(e, v), if is_plain_variant(&v.item);
54
55    relation variant_tuple(ItemNode, ItemNode);
56    variant_tuple(e, v) <-- variant(e, v), if is_tuple_variant(&v.item);
57
58    relation variant_struct(ItemNode, ItemNode);
59    variant_struct(e, v) <-- variant(e, v), if is_struct_variant(&v.item);
60
61    relation format(ItemNode, Indexed<Format>);
62    format(x, format) <--
63        field(x, field),
64        fields(x, fields),
65        if let Some(format) = make_format(field, fields);
66
67    relation format_named(ItemNode, Indexed<Named<Format>>);
68    format_named(x, format) <--
69        field(x, field),
70        fields(x, fields),
71        if let Some(format) = make_named_format(field, fields, x);
72
73    relation format_plain_variant(ItemNode, Indexed<Named<VariantFormat>>);
74    format_plain_variant(e, format) <--
75        variant_plain(e, v),
76        variants(e, variants),
77        if let Some(format) = make_plain_variant_format(v, variants, e);
78
79    relation format_tuple_variant(ItemNode, Indexed<Named<VariantFormat>>);
80    format_tuple_variant(e, format) <--
81        variant_tuple(e, v),
82        variants(e, variants),
83        agg formats = collect(format) in format(v, format),
84        if let Some(format) = make_tuple_variant_format(v, &formats, variants, e);
85
86    relation format_struct_variant(ItemNode, Indexed<Named<VariantFormat>>);
87    format_struct_variant(e, format) <--
88        variant_struct(e, v),
89        variants(e, variants),
90        agg formats = collect(format) in format_named(v, format),
91        if let Some(format) = make_struct_variant_format(v, &formats, variants, e);
92
93    relation format_variant(ItemNode, Indexed<Named<VariantFormat>>);
94    format_variant(e, format) <-- format_plain_variant(e, format);
95    format_variant(e, format) <-- format_tuple_variant(e, format);
96    format_variant(e, format) <-- format_struct_variant(e, format);
97
98    relation container(String, ContainerFormat);
99    container(name, container) <--
100        struct_plain(s),
101        if let Some(name) = s.name(),
102        agg field_formats = collect(format) in format_named(s, format),
103        let container = make_struct_plain(&field_formats);
104    container(name, container) <--
105        struct_unit(s),
106        if let Some(name) = s.name(),
107        let container = make_struct_unit();
108    container(name, container) <--
109        struct_tuple(s),
110        if let Some(name) = s.name(),
111        agg field_formats = collect(format) in format(s, format),
112        let container = make_struct_tuple(&field_formats);
113    container(name, container) <--
114        variant(e, _),
115        if let Some(name) = e.name(),
116        agg variant_formats = collect(format) in format_variant(e, format),
117        let container = make_enum(&variant_formats);
118    container("Range".to_string(), container) <--
119        field(_, f), if f.is_range(),
120        if let Some(container) = make_range(f);
121    container("Request".to_string(), container) <--
122        let container = make_request();
123}
124
125fn make_format(field: &ItemNode, all_fields: &[ItemNode]) -> Option<Indexed<Format>> {
126    let index = all_fields.iter().position(|f| f == field)?;
127    match &field.item.inner {
128        ItemEnum::StructField(type_) => Some(Indexed {
129            index: index as u32,
130            value: {
131                if let Some((_whole, serde_with)) = field.item.attrs.iter().find_map(|attr| {
132                    lazy_regex::regex_captures!(r#"\[serde\(with\s*=\s*"(\w+)"\)\]"#, attr)
133                }) {
134                    match serde_with {
135                        "serde_bytes" => Format::Bytes, // e.g. HttpRequest.body, HttpResponse.body
136                        _ => todo!(),
137                    }
138                } else {
139                    type_.into()
140                }
141            },
142        }),
143        _ => None,
144    }
145}
146
147fn make_named_format(
148    field: &ItemNode,
149    all_fields: &[ItemNode],
150    struct_: &ItemNode,
151) -> Option<Indexed<Named<Format>>> {
152    match field.name() {
153        Some(name) => match make_format(field, all_fields) {
154            Some(Indexed { index, value }) => Some(Indexed {
155                index,
156                value: Named {
157                    name: field_name(name, &field.item.attrs, &struct_.item.attrs),
158                    value,
159                },
160            }),
161            _ => None,
162        },
163        _ => None,
164    }
165}
166
167fn make_plain_variant_format(
168    variant: &ItemNode,
169    all_variants: &[ItemNode],
170    enum_: &ItemNode,
171) -> Option<Indexed<Named<VariantFormat>>> {
172    let index = all_variants.iter().position(|f| f == variant)?;
173    match &variant.item {
174        Item {
175            name: Some(name),
176            inner: ItemEnum::Variant(_),
177            ..
178        } => Some(Indexed {
179            index: index as u32,
180            value: Named {
181                name: variant_name(name, &variant.item.attrs, &enum_.item.attrs),
182                value: VariantFormat::Unit,
183            },
184        }),
185        _ => None,
186    }
187}
188
189fn make_struct_variant_format(
190    variant: &ItemNode,
191    fields: &[(&Indexed<Named<Format>>,)],
192    all_variants: &[ItemNode],
193    enum_: &ItemNode,
194) -> Option<Indexed<Named<VariantFormat>>> {
195    let index = all_variants.iter().position(|f| f == variant)?;
196    match &variant.item {
197        Item {
198            name: Some(name),
199            inner: ItemEnum::Variant(_),
200            ..
201        } => {
202            let mut fields = fields.to_owned();
203            fields.sort();
204            let fields = fields.iter().map(|(f,)| f.inner()).collect::<Vec<_>>();
205            Some(Indexed {
206                index: index as u32,
207                value: Named {
208                    name: variant_name(name, &variant.item.attrs, &enum_.item.attrs),
209                    value: VariantFormat::Struct(fields),
210                },
211            })
212        }
213        _ => None,
214    }
215}
216
217fn make_tuple_variant_format(
218    variant: &ItemNode,
219    fields: &[(&Indexed<Format>,)],
220    all_variants: &[ItemNode],
221    enum_: &ItemNode,
222) -> Option<Indexed<Named<VariantFormat>>> {
223    let index = all_variants.iter().position(|v| v == variant)?;
224    match &variant.item {
225        Item {
226            name: Some(name),
227            inner: ItemEnum::Variant(_),
228            ..
229        } => {
230            let mut fields = fields.to_owned();
231            fields.sort();
232            let fields = fields.iter().map(|(f,)| f.inner()).collect::<Vec<_>>();
233            let value = match fields.len() {
234                0 => VariantFormat::Unit,
235                1 => VariantFormat::NewType(Box::new(fields[0].clone())),
236                _ => VariantFormat::Tuple(fields),
237            };
238            Some(Indexed {
239                index: index as u32,
240                value: Named {
241                    name: variant_name(name, &variant.item.attrs, &enum_.item.attrs),
242                    value,
243                },
244            })
245        }
246        _ => None,
247    }
248}
249
250fn make_struct_unit() -> ContainerFormat {
251    ContainerFormat::UnitStruct
252}
253
254fn make_struct_plain(fields: &[(&Indexed<Named<Format>>,)]) -> ContainerFormat {
255    let mut fields = fields.to_owned();
256    fields.sort();
257    let fields = fields.iter().map(|(f,)| f.inner()).collect::<Vec<_>>();
258    match fields.len() {
259        0 => ContainerFormat::UnitStruct,
260        _ => ContainerFormat::Struct(fields),
261    }
262}
263
264fn make_struct_tuple(fields: &[(&Indexed<Format>,)]) -> ContainerFormat {
265    let mut fields = fields.to_owned();
266    fields.sort();
267    let fields = fields.iter().map(|(f,)| f.inner()).collect::<Vec<_>>();
268    match fields.len() {
269        0 => ContainerFormat::UnitStruct,
270        1 => ContainerFormat::NewTypeStruct(Box::new(fields[0].clone())),
271        _ => ContainerFormat::TupleStruct(fields),
272    }
273}
274
275fn make_enum(formats: &[(&Indexed<Named<VariantFormat>>,)]) -> ContainerFormat {
276    let mut map = BTreeMap::default();
277    for (Indexed { index, value },) in formats {
278        map.insert(*index, value.clone());
279    }
280    ContainerFormat::Enum(map)
281}
282
283fn make_range(field: &ItemNode) -> Option<ContainerFormat> {
284    match &field.item.inner {
285        ItemEnum::StructField(range_type) => {
286            let field_format: Option<Format> = match range_type {
287                Type::ResolvedPath(path) => match &path.args {
288                    Some(args) => match args.as_ref() {
289                        GenericArgs::AngleBracketed { args, .. } => {
290                            let type_ = args.iter().next()?;
291                            match type_ {
292                                GenericArg::Type(ref type_) => Some(type_.into()),
293                                _ => None,
294                            }
295                        }
296                        _ => None,
297                    },
298                    _ => None,
299                },
300                _ => None,
301            };
302            field_format.map(|f| {
303                ContainerFormat::Struct(vec![
304                    Named {
305                        name: "start".to_string(),
306                        value: f.clone(),
307                    },
308                    Named {
309                        name: "end".to_string(),
310                        value: f.clone(),
311                    },
312                ])
313            })
314        }
315        _ => None,
316    }
317}
318
319fn make_request() -> ContainerFormat {
320    ContainerFormat::Struct(vec![
321        Named {
322            name: "id".to_string(),
323            value: Format::U32,
324        },
325        Named {
326            name: "effect".to_string(),
327            value: Format::TypeName("Effect".to_string()),
328        },
329    ])
330}
331
332impl From<&Type> for Format {
333    fn from(type_: &Type) -> Self {
334        match type_ {
335            Type::ResolvedPath(path) => {
336                let name = path_to_string(path);
337                if let Some(args) = &path.args {
338                    match args.as_ref() {
339                        GenericArgs::AngleBracketed {
340                            args,
341                            constraints: _,
342                        } => match name.as_str() {
343                            "Option" => {
344                                let format = match args[0] {
345                                    GenericArg::Type(ref type_) => type_.into(),
346                                    _ => todo!(),
347                                };
348                                Format::Option(Box::new(format))
349                            }
350                            "String" => Format::Str,
351                            "Vec" => {
352                                let format = match args[0] {
353                                    GenericArg::Type(ref type_) => type_.into(),
354                                    _ => todo!(),
355                                };
356                                Format::Seq(Box::new(format))
357                            }
358                            _ => Format::TypeName(name),
359                        },
360                        GenericArgs::Parenthesized {
361                            inputs: _,
362                            output: _,
363                        } => todo!(),
364                        GenericArgs::ReturnTypeNotation => todo!(),
365                    }
366                } else {
367                    Format::TypeName(name)
368                }
369            }
370            Type::DynTrait(_dyn_trait) => todo!(),
371            Type::Generic(_param_name) => todo!(),
372            Type::Primitive(s) => match s.as_ref() {
373                "bool" => Format::Bool,
374                "char" => Format::Char,
375                "isize" => match std::mem::size_of::<isize>() {
376                    4 => Format::I32,
377                    8 => Format::I64,
378                    _ => panic!("unsupported isize size"),
379                },
380                "i8" => Format::I8,
381                "i16" => Format::I16,
382                "i32" => Format::I32,
383                "i64" => Format::I64,
384                "i128" => Format::I128,
385                "usize" => match std::mem::size_of::<usize>() {
386                    4 => Format::U32,
387                    8 => Format::U64,
388                    _ => panic!("unsupported usize size"),
389                },
390                "u8" => Format::U8,
391                "u16" => Format::U16,
392                "u32" => Format::U32,
393                "u64" => Format::U64,
394                "u128" => Format::U128,
395                s => panic!("need to implement primitive {s}"),
396            },
397            Type::FunctionPointer(_function_pointer) => todo!(),
398            Type::Tuple(vec) => Format::Tuple(vec.iter().map(|t| t.into()).collect()),
399            Type::Slice(_) => todo!(),
400            Type::Array { type_: _, len: _ } => todo!(),
401            Type::Pat {
402                type_: _,
403                __pat_unstable_do_not_use,
404            } => todo!(),
405            Type::ImplTrait(_vec) => todo!(),
406            Type::Infer => todo!(),
407            Type::RawPointer {
408                is_mutable: _,
409                type_: _,
410            } => todo!(),
411            Type::BorrowedRef {
412                lifetime: _,
413                is_mutable: _,
414                type_: _,
415            } => todo!(),
416            Type::QualifiedPath {
417                name,
418                args: _,
419                self_type: _,
420                trait_: _,
421            } => Format::TypeName(name.to_string()),
422        }
423    }
424}
425
426fn path_to_string(path: &rustdoc_types::Path) -> String {
427    if let Some((_mod, name)) = path.path.rsplit_once("::") {
428        name.to_string()
429    } else {
430        path.path.clone()
431    }
432}
433
434fn variant_name<T>(name: &str, variant_attrs: &[T], enum_attrs: &[T]) -> String
435where
436    T: AsRef<str>,
437{
438    if let Some((_whole, rename)) = variant_attrs.iter().find_map(|attr| {
439        lazy_regex::regex_captures!(r#"\[serde\(rename\s*=\s*"(\w+)"\)\]"#, attr.as_ref())
440    }) {
441        return rename.to_string();
442    }
443
444    if let Some((_whole, rename_all)) = enum_attrs.iter().find_map(|attr| {
445        lazy_regex::regex_captures!(r#"\[serde\(rename_all\s*=\s*"(\w+)"\)\]"#, attr.as_ref())
446    }) {
447        return RenameRule::from_str(rename_all)
448            .unwrap_or(RenameRule::None)
449            .apply_to_variant(name);
450    }
451
452    name.to_string()
453}
454
455fn field_name<T>(name: &str, field_attrs: &[T], struct_attrs: &[T]) -> String
456where
457    T: AsRef<str>,
458{
459    if let Some((_whole, rename)) = field_attrs.iter().find_map(|attr| {
460        lazy_regex::regex_captures!(r#"\[serde\(rename\s*=\s*"(\w+)"\)\]"#, attr.as_ref())
461    }) {
462        return rename.to_string();
463    }
464
465    if let Some((_whole, rename_all)) = struct_attrs.iter().find_map(|attr| {
466        lazy_regex::regex_captures!(r#"\[serde\(rename_all\s*=\s*"(\w+)"\)\]"#, attr.as_ref())
467    }) {
468        return RenameRule::from_str(rename_all)
469            .unwrap_or(RenameRule::None)
470            .apply_to_field(name);
471    }
472
473    name.to_string()
474}
475
476#[cfg(test)]
477mod tests {
478    use rstest::rstest;
479
480    use super::*;
481
482    #[rstest]
483    #[case("foo", &[""], &[], "foo")]
484    #[case("foo", &["#[serde(rename = \"bar\")]"], &[], "bar")]
485    #[case("FooBar", &[], &["#[serde(rename_all = \"camelCase\")]"], "fooBar")]
486    #[case("FooBar", &[""], &["#[serde(with = \"something\")]",
487        "#[serde(rename_all = \"snake_case\")]"], "foo_bar")]
488    #[case("FooBar", &["#[serde(rename = \"bar\")]"], &["#[serde(with = \"something\")]",
489        "#[serde(rename_all = \"snake_case\")]"], "bar")]
490    fn variant_renaming<T: AsRef<str>>(
491        #[case] name: &str,
492        #[case] variant_attrs: &[T],
493        #[case] enum_attrs: &[T],
494        #[case] expected: String,
495    ) {
496        assert_eq!(variant_name(name, variant_attrs, enum_attrs), expected);
497    }
498
499    #[rstest]
500    #[case("foo", &[""], &[], "foo")]
501    #[case("foo", &["#[serde(rename = \"bar\")]"], &[], "bar")]
502    #[case("foo_bar", &[], &["#[serde(rename_all = \"camelCase\")]"], "fooBar")]
503    #[case("foo_bar", &[""], &["#[serde(with = \"something\")]",
504        "#[serde(rename_all = \"PascalCase\")]"], "FooBar")]
505    #[case("foo_bar", &["#[serde(rename = \"bar\")]"], &["#[serde(with = \"something\")]",
506        "#[serde(rename_all = \"PascalCase\")]"], "bar")]
507    fn field_renaming<T: AsRef<str>>(
508        #[case] name: &str,
509        #[case] field_attrs: &[T],
510        #[case] struct_attrs: &[T],
511        #[case] expected: String,
512    ) {
513        assert_eq!(field_name(name, field_attrs, struct_attrs), expected);
514    }
515}