crux_cli/codegen/
formatter.rs

1#![allow(clippy::no_effect_underscore_binding)]
2
3use std::{collections::BTreeMap, convert::Into};
4
5use ascent::ascent;
6use rustdoc_types::{GenericArg, GenericArgs, Item, ItemEnum, Type};
7
8use crate::codegen::collect;
9
10use super::{
11    indexed::Indexed,
12    item::{
13        is_plain_variant, is_struct_plain, is_struct_tuple, is_struct_unit, is_struct_variant,
14        is_tuple_variant,
15    },
16    node::ItemNode,
17    serde::case::RenameRule,
18    serde_generate::format::{ContainerFormat, Format, Named, VariantFormat},
19};
20
21ascent! {
22    #![measure_rule_times]
23    pub struct Formatter;
24
25    // ------- facts ------------------
26    relation edge(ItemNode, ItemNode);
27
28    // ------- rules ------------------
29
30    relation struct_unit(ItemNode);
31    struct_unit(s) <-- edge(s, _), if is_struct_unit(&s.item);
32
33    relation struct_plain(ItemNode);
34    struct_plain(s) <-- edge(s, _), if is_struct_plain(&s.item);
35
36    relation struct_tuple(ItemNode);
37    struct_tuple(s) <-- edge(s, _), if is_struct_tuple(&s.item);
38
39    relation field(ItemNode, ItemNode);
40    field(x, f) <-- edge(x, f), if x.has_field(f);
41
42    relation fields(ItemNode, Vec<ItemNode>);
43    fields(x, fields) <--
44        field(x, f),
45        agg fs = collect(f) in field(x, f),
46        let fields = x.fields(&fs);
47
48    relation variant(ItemNode, ItemNode);
49    variant(e, v) <-- edge(e, v), if e.has_variant(v);
50
51    relation variants(ItemNode, Vec<ItemNode>);
52    variants(e, variants) <--
53        variant(e, v),
54        agg vs = collect(v) in variant(e, v),
55        let variants = e.variants(&vs);
56
57    relation variant_plain(ItemNode, ItemNode);
58    variant_plain(e, v) <-- variant(e, v), if is_plain_variant(&v.item);
59
60    relation variant_tuple(ItemNode, ItemNode);
61    variant_tuple(e, v) <-- variant(e, v), if is_tuple_variant(&v.item);
62
63    relation variant_struct(ItemNode, ItemNode);
64    variant_struct(e, v) <-- variant(e, v), if is_struct_variant(&v.item);
65
66    relation format(ItemNode, Indexed<Format>);
67    format(x, format) <--
68        field(x, field),
69        fields(x, fields),
70        if let Some(format) = make_format(field, fields);
71
72    relation format_named(ItemNode, Indexed<Named<Format>>);
73    format_named(x, format) <--
74        field(x, field),
75        fields(x, fields),
76        if let Some(format) = make_named_format(field, fields, x);
77
78    relation format_plain_variant(ItemNode, Indexed<Named<VariantFormat>>);
79    format_plain_variant(e, format) <--
80        variant_plain(e, v),
81        variants(e, variants),
82        if let Some(format) = make_plain_variant_format(v, variants, e);
83
84    relation format_tuple_variant(ItemNode, Indexed<Named<VariantFormat>>);
85    format_tuple_variant(e, format) <--
86        variant_tuple(e, v),
87        variants(e, variants),
88        agg formats = collect(format) in format(v, format),
89        if let Some(format) = make_tuple_variant_format(v, &formats, variants, e);
90
91    relation format_struct_variant(ItemNode, Indexed<Named<VariantFormat>>);
92    format_struct_variant(e, format) <--
93        variant_struct(e, v),
94        variants(e, variants),
95        agg formats = collect(format) in format_named(v, format),
96        if let Some(format) = make_struct_variant_format(v, &formats, variants, e);
97
98    relation format_variant(ItemNode, Indexed<Named<VariantFormat>>);
99    format_variant(e, format) <-- format_plain_variant(e, format);
100    format_variant(e, format) <-- format_tuple_variant(e, format);
101    format_variant(e, format) <-- format_struct_variant(e, format);
102
103    relation container(String, ContainerFormat);
104    container(name, container) <--
105        struct_plain(s),
106        if let Some(name) = s.name(),
107        agg field_formats = collect(format) in format_named(s, format),
108        let container = make_struct_plain(&field_formats);
109    container(name, container) <--
110        struct_unit(s),
111        if let Some(name) = s.name(),
112        let container = make_struct_unit();
113    container(name, container) <--
114        struct_tuple(s),
115        if let Some(name) = s.name(),
116        agg field_formats = collect(format) in format(s, format),
117        let container = make_struct_tuple(&field_formats);
118    container(name, container) <--
119        variant(e, _),
120        if let Some(name) = e.name(),
121        agg variant_formats = collect(format) in format_variant(e, format),
122        let container = make_enum(&variant_formats);
123    container("Range".to_string(), container) <--
124        field(_, f), if f.is_range(),
125        if let Some(container) = make_range(f);
126    container("Request".to_string(), container) <--
127        let container = make_request();
128}
129
130fn make_format(field: &ItemNode, all_fields: &[ItemNode]) -> Option<Indexed<Format>> {
131    let index = all_fields.iter().position(|f| f == field)?;
132    match &field.item.inner {
133        ItemEnum::StructField(type_) => Some(Indexed {
134            index,
135            value: {
136                if let Some((_whole, serde_with)) = field.item.attrs.iter().find_map(|attr| {
137                    lazy_regex::regex_captures!(r#"\[serde\(with\s*=\s*"(\w+)"\)\]"#, attr)
138                }) {
139                    match serde_with {
140                        "serde_bytes" => Format::Bytes, // e.g. HttpRequest.body, HttpResponse.body
141                        _ => todo!(),
142                    }
143                } else {
144                    type_.into()
145                }
146            },
147        }),
148        _ => None,
149    }
150}
151
152fn make_named_format(
153    field: &ItemNode,
154    all_fields: &[ItemNode],
155    struct_: &ItemNode,
156) -> Option<Indexed<Named<Format>>> {
157    match field.name() {
158        Some(name) => match make_format(field, all_fields) {
159            Some(Indexed { index, value }) => Some(Indexed {
160                index,
161                value: Named {
162                    name: field_name(name, &field.item.attrs, &struct_.item.attrs),
163                    value,
164                },
165            }),
166            _ => None,
167        },
168        _ => None,
169    }
170}
171
172fn make_plain_variant_format(
173    variant: &ItemNode,
174    all_variants: &[ItemNode],
175    enum_: &ItemNode,
176) -> Option<Indexed<Named<VariantFormat>>> {
177    let index = all_variants.iter().position(|f| f == variant)?;
178    match &variant.item {
179        Item {
180            name: Some(name),
181            inner: ItemEnum::Variant(_),
182            ..
183        } => Some(Indexed {
184            index,
185            value: Named {
186                name: variant_name(name, &variant.item.attrs, &enum_.item.attrs),
187                value: VariantFormat::Unit,
188            },
189        }),
190        _ => None,
191    }
192}
193
194fn make_struct_variant_format(
195    variant: &ItemNode,
196    fields: &[(&Indexed<Named<Format>>,)],
197    all_variants: &[ItemNode],
198    enum_: &ItemNode,
199) -> Option<Indexed<Named<VariantFormat>>> {
200    let index = all_variants.iter().position(|f| f == variant)?;
201    match &variant.item {
202        Item {
203            name: Some(name),
204            inner: ItemEnum::Variant(_),
205            ..
206        } => {
207            let mut fields = fields.to_owned();
208            fields.sort();
209            let fields = fields.iter().map(|(f,)| f.inner()).collect::<Vec<_>>();
210            Some(Indexed {
211                index,
212                value: Named {
213                    name: variant_name(name, &variant.item.attrs, &enum_.item.attrs),
214                    value: VariantFormat::Struct(fields),
215                },
216            })
217        }
218        _ => None,
219    }
220}
221
222fn make_tuple_variant_format(
223    variant: &ItemNode,
224    fields: &[(&Indexed<Format>,)],
225    all_variants: &[ItemNode],
226    enum_: &ItemNode,
227) -> Option<Indexed<Named<VariantFormat>>> {
228    let index = all_variants.iter().position(|v| v == variant)?;
229    match &variant.item {
230        Item {
231            name: Some(name),
232            inner: ItemEnum::Variant(_),
233            ..
234        } => {
235            let mut fields = fields.to_owned();
236            fields.sort();
237            let fields = fields.iter().map(|(f,)| f.inner()).collect::<Vec<_>>();
238            let value = match fields.len() {
239                0 => VariantFormat::Unit,
240                1 => VariantFormat::NewType(Box::new(fields[0].clone())),
241                _ => VariantFormat::Tuple(fields),
242            };
243            Some(Indexed {
244                index,
245                value: Named {
246                    name: variant_name(name, &variant.item.attrs, &enum_.item.attrs),
247                    value,
248                },
249            })
250        }
251        _ => None,
252    }
253}
254
255fn make_struct_unit() -> ContainerFormat {
256    ContainerFormat::UnitStruct
257}
258
259fn make_struct_plain(fields: &[(&Indexed<Named<Format>>,)]) -> ContainerFormat {
260    let mut fields = fields.to_owned();
261    fields.sort();
262    let fields = fields.iter().map(|(f,)| f.inner()).collect::<Vec<_>>();
263    match fields.len() {
264        0 => ContainerFormat::UnitStruct,
265        _ => ContainerFormat::Struct(fields),
266    }
267}
268
269fn make_struct_tuple(fields: &[(&Indexed<Format>,)]) -> ContainerFormat {
270    let mut fields = fields.to_owned();
271    fields.sort();
272    let fields = fields.iter().map(|(f,)| f.inner()).collect::<Vec<_>>();
273    match fields.len() {
274        0 => ContainerFormat::UnitStruct,
275        1 => ContainerFormat::NewTypeStruct(Box::new(fields[0].clone())),
276        _ => ContainerFormat::TupleStruct(fields),
277    }
278}
279
280fn make_enum(formats: &[(&Indexed<Named<VariantFormat>>,)]) -> ContainerFormat {
281    let mut map = BTreeMap::default();
282    for (Indexed { index, value },) in formats {
283        map.insert(*index, value.clone());
284    }
285    ContainerFormat::Enum(map)
286}
287
288fn make_range(field: &ItemNode) -> Option<ContainerFormat> {
289    match &field.item.inner {
290        ItemEnum::StructField(range_type) => {
291            let field_format: Option<Format> = match range_type {
292                Type::ResolvedPath(path) => match &path.args {
293                    Some(args) => match args.as_ref() {
294                        GenericArgs::AngleBracketed { args, .. } => {
295                            let type_ = args.iter().next()?;
296                            match type_ {
297                                GenericArg::Type(ref type_) => Some(type_.into()),
298                                _ => None,
299                            }
300                        }
301                        GenericArgs::Parenthesized { .. } => None,
302                    },
303                    _ => None,
304                },
305                _ => None,
306            };
307            field_format.map(|f| {
308                ContainerFormat::Struct(vec![
309                    Named {
310                        name: "start".to_string(),
311                        value: f.clone(),
312                    },
313                    Named {
314                        name: "end".to_string(),
315                        value: f.clone(),
316                    },
317                ])
318            })
319        }
320        _ => None,
321    }
322}
323
324fn make_request() -> ContainerFormat {
325    ContainerFormat::Struct(vec![
326        Named {
327            name: "id".to_string(),
328            value: Format::U32,
329        },
330        Named {
331            name: "effect".to_string(),
332            value: Format::TypeName("Effect".to_string()),
333        },
334    ])
335}
336
337impl From<&Type> for Format {
338    fn from(type_: &Type) -> Self {
339        match type_ {
340            Type::ResolvedPath(path) => {
341                let name = path_to_string(path);
342                if let Some(args) = &path.args {
343                    match args.as_ref() {
344                        GenericArgs::AngleBracketed {
345                            args,
346                            constraints: _,
347                        } => match name.as_str() {
348                            "Option" => {
349                                let format = match args.first() {
350                                    Some(GenericArg::Type(ref type_)) => type_.into(),
351                                    _ => todo!(),
352                                };
353                                Format::Option(Box::new(format))
354                            }
355                            "String" => Format::Str,
356                            "Vec" => {
357                                let format = match args.first() {
358                                    Some(GenericArg::Type(ref type_)) => type_.into(),
359                                    _ => todo!(),
360                                };
361                                Format::Seq(Box::new(format))
362                            }
363                            _ => Format::TypeName(name),
364                        },
365                        GenericArgs::Parenthesized {
366                            inputs: _,
367                            output: _,
368                        } => todo!(),
369                    }
370                } else {
371                    Format::TypeName(name)
372                }
373            }
374            Type::DynTrait(_dyn_trait) => todo!(),
375            Type::Generic(_param_name) => todo!(),
376            Type::Primitive(s) => match s.as_ref() {
377                "bool" => Format::Bool,
378                "char" => Format::Char,
379                "isize" => match std::mem::size_of::<isize>() {
380                    4 => Format::I32,
381                    8 => Format::I64,
382                    _ => panic!("unsupported isize size"),
383                },
384                "i8" => Format::I8,
385                "i16" => Format::I16,
386                "i32" => Format::I32,
387                "i64" => Format::I64,
388                "i128" => Format::I128,
389                "usize" => match std::mem::size_of::<usize>() {
390                    4 => Format::U32,
391                    8 => Format::U64,
392                    _ => panic!("unsupported usize size"),
393                },
394                "u8" => Format::U8,
395                "u16" => Format::U16,
396                "u32" => Format::U32,
397                "u64" => Format::U64,
398                "u128" => Format::U128,
399                s => panic!("need to implement primitive {s}"),
400            },
401            Type::FunctionPointer(_function_pointer) => todo!(),
402            Type::Tuple(vec) => Format::Tuple(vec.iter().map(Into::into).collect()),
403            Type::Slice(_) => todo!(),
404            Type::Array { type_: _, len: _ } => todo!(),
405            Type::Pat {
406                type_: _,
407                __pat_unstable_do_not_use,
408            } => todo!(),
409            Type::ImplTrait(_vec) => todo!(),
410            Type::Infer => todo!(),
411            Type::RawPointer {
412                is_mutable: _,
413                type_: _,
414            } => todo!(),
415            Type::BorrowedRef {
416                lifetime: _,
417                is_mutable: _,
418                type_: _,
419            } => todo!(),
420            Type::QualifiedPath {
421                name,
422                args: _,
423                self_type: _,
424                trait_: _,
425            } => Format::TypeName(name.to_string()),
426        }
427    }
428}
429
430fn path_to_string(path: &rustdoc_types::Path) -> String {
431    if let Some((_mod, name)) = path.path.rsplit_once("::") {
432        name.to_string()
433    } else {
434        path.path.clone()
435    }
436}
437
438fn variant_name<T>(name: &str, variant_attrs: &[T], enum_attrs: &[T]) -> String
439where
440    T: AsRef<str>,
441{
442    if let Some((_whole, rename)) = variant_attrs.iter().find_map(|attr| {
443        lazy_regex::regex_captures!(r#"\[serde\(rename\s*=\s*"(\w+)"\)\]"#, attr.as_ref())
444    }) {
445        return rename.to_string();
446    }
447
448    if let Some((_whole, rename_all)) = enum_attrs.iter().find_map(|attr| {
449        lazy_regex::regex_captures!(r#"\[serde\(rename_all\s*=\s*"(\w+)"\)\]"#, attr.as_ref())
450    }) {
451        return RenameRule::from_str(rename_all)
452            .unwrap_or(RenameRule::None)
453            .apply_to_variant(name);
454    }
455
456    name.to_string()
457}
458
459fn field_name<T>(name: &str, field_attrs: &[T], struct_attrs: &[T]) -> String
460where
461    T: AsRef<str>,
462{
463    if let Some((_whole, rename)) = field_attrs.iter().find_map(|attr| {
464        lazy_regex::regex_captures!(r#"\[serde\(rename\s*=\s*"(\w+)"\)\]"#, attr.as_ref())
465    }) {
466        return rename.to_string();
467    }
468
469    if let Some((_whole, rename_all)) = struct_attrs.iter().find_map(|attr| {
470        lazy_regex::regex_captures!(r#"\[serde\(rename_all\s*=\s*"(\w+)"\)\]"#, attr.as_ref())
471    }) {
472        return RenameRule::from_str(rename_all)
473            .unwrap_or(RenameRule::None)
474            .apply_to_field(name);
475    }
476
477    name.to_string()
478}
479
480#[cfg(test)]
481mod tests {
482    use rstest::rstest;
483
484    use super::*;
485
486    #[rstest]
487    #[case("foo", &[""], &[], "foo")]
488    #[case("foo", &["#[serde(rename = \"bar\")]"], &[], "bar")]
489    #[case("FooBar", &[], &["#[serde(rename_all = \"camelCase\")]"], "fooBar")]
490    #[case("FooBar", &[""], &["#[serde(with = \"something\")]",
491        "#[serde(rename_all = \"snake_case\")]"], "foo_bar")]
492    #[case("FooBar", &["#[serde(rename = \"bar\")]"], &["#[serde(with = \"something\")]",
493        "#[serde(rename_all = \"snake_case\")]"], "bar")]
494    fn variant_renaming<T: AsRef<str>>(
495        #[case] name: &str,
496        #[case] variant_attrs: &[T],
497        #[case] enum_attrs: &[T],
498        #[case] expected: String,
499    ) {
500        assert_eq!(variant_name(name, variant_attrs, enum_attrs), expected);
501    }
502
503    #[rstest]
504    #[case("foo", &[""], &[], "foo")]
505    #[case("foo", &["#[serde(rename = \"bar\")]"], &[], "bar")]
506    #[case("foo_bar", &[], &["#[serde(rename_all = \"camelCase\")]"], "fooBar")]
507    #[case("foo_bar", &[""], &["#[serde(with = \"something\")]",
508        "#[serde(rename_all = \"PascalCase\")]"], "FooBar")]
509    #[case("foo_bar", &["#[serde(rename = \"bar\")]"], &["#[serde(with = \"something\")]",
510        "#[serde(rename_all = \"PascalCase\")]"], "bar")]
511    fn field_renaming<T: AsRef<str>>(
512        #[case] name: &str,
513        #[case] field_attrs: &[T],
514        #[case] struct_attrs: &[T],
515        #[case] expected: String,
516    ) {
517        assert_eq!(field_name(name, field_attrs, struct_attrs), expected);
518    }
519}