rusteron_code_gen/
parser.rs

1use crate::generator::{CBinding, CWrapper, Method};
2use crate::{Arg, ArgProcessing, CHandler};
3use itertools::Itertools;
4use quote::ToTokens;
5use std::collections::{BTreeMap, BTreeSet};
6use std::fs;
7use std::path::PathBuf;
8use syn::{Attribute, Item, ItemForeignMod, ItemStruct, ItemType, Lit, Meta, MetaNameValue};
9
10pub fn parse_bindings(out: &PathBuf) -> CBinding {
11    let file_content = fs::read_to_string(out.clone()).expect("Unable to read file");
12    let syntax_tree = syn::parse_file(&file_content).expect("Unable to parse file");
13    let mut wrappers = BTreeMap::new();
14    let mut methods = Vec::new();
15    let mut handlers = Vec::new();
16
17    // Iterate through the items in the file
18    for item in syntax_tree.items {
19        match item {
20            Item::Struct(s) => {
21                process_struct(&mut wrappers, &s);
22            }
23            Item::Type(ty) => {
24                process_type(&mut wrappers, &mut handlers, &ty);
25            }
26            Item::ForeignMod(fm) => {
27                process_c_method(&mut wrappers, &mut methods, fm);
28            }
29            _ => {}
30        }
31    }
32
33    /*    // need to filter out args which don't match
34        for wrapper in wrappers.values_mut() {
35            for method in wrapper.methods.iter_mut() {
36                let method_debug = format!("{:?}", method);
37                for arg in method.arguments.iter_mut() {
38                    if let ArgProcessing::Handler(args) = &arg.processing {
39                        let handler = args.get(0).unwrap();
40                        if !handlers.iter().any(|h| h.type_name == handler.c_type) {
41                            log::info!("replacing {} back to default", method_debug);
42                            // arg.processing = ArgProcessing::Default;
43                        }
44                    }
45                }
46            }
47        }
48    */
49    let bindings = CBinding {
50        wrappers: wrappers
51            .into_iter()
52            .filter(|(_, wrapper)| {
53                // these are from media driver and do not follow convention
54                ![
55                    "aeron_thread",
56                    "aeron_command",
57                    "aeron_executor",
58                    "aeron_name_resolver",
59                    "aeron_udp_channel_transport", // this one I have issues with handlers
60                    "aeron_udp_transport",         // this one I have issues with handlers
61                ]
62                .iter()
63                .any(|&filter| wrapper.type_name.starts_with(filter))
64            })
65            .collect(),
66        methods,
67        handlers: handlers
68            .into_iter()
69            .filter(|h| {
70                !["aeron_udp_channel", "aeron_udp_transport"]
71                    .iter()
72                    .any(|&filter| h.type_name.starts_with(filter))
73            })
74            .collect(),
75    };
76
77    let mismatched_types = bindings
78        .wrappers
79        .iter()
80        .filter(|(key, w)| key.as_str() != w.type_name)
81        .map(|(a, b)| (a.clone(), b.clone()))
82        .collect_vec();
83    assert_eq!(Vec::<(String, CWrapper)>::new(), mismatched_types);
84    bindings
85}
86
87fn process_c_method(
88    wrappers: &mut BTreeMap<String, CWrapper>,
89    methods: &mut Vec<Method>,
90    fm: ItemForeignMod,
91) {
92    // Extract functions inside extern "C" blocks
93    if fm.abi.name.is_some() && fm.abi.name.as_ref().unwrap().value() == "C" {
94        for foreign_item in fm.items {
95            if let syn::ForeignItem::Fn(f) = foreign_item {
96                let docs = get_doc_comments(&f.attrs);
97                let fn_name = f.sig.ident.to_string();
98
99                // Get function arguments and return type as Rust code
100                let args = extract_function_arguments(&f.sig.inputs);
101                let ret = extract_return_type(&f.sig.output);
102
103                let option = if let Some(arg) = args
104                    .iter()
105                    .skip_while(|a| a.is_mut_pointer() && a.is_primitive())
106                    .next()
107                {
108                    let ty = &arg.c_type;
109                    let ty = ty.split(' ').last().map(|t| t.to_string()).unwrap();
110                    if wrappers.contains_key(&ty) {
111                        Some(ty)
112                    } else {
113                        find_closest_wrapper_from_method_name(wrappers, &fn_name)
114                    }
115                } else {
116                    find_closest_wrapper_from_method_name(wrappers, &fn_name)
117                };
118
119                match option {
120                    Some(key) => {
121                        let wrapper = wrappers.get_mut(&key).unwrap();
122                        wrapper.methods.push(Method {
123                            fn_name: fn_name.clone(),
124                            struct_method_name: fn_name
125                                .replace(&wrapper.type_name[..wrapper.type_name.len() - 1], "")
126                                .to_string(),
127                            return_type: Arg {
128                                name: "".to_string(),
129                                c_type: ret.clone(),
130                                processing: ArgProcessing::Default,
131                            },
132                            arguments: process_types(args.clone()),
133                            docs: docs.clone(),
134                        });
135                    }
136                    None => methods.push(Method {
137                        fn_name: fn_name.clone(),
138                        struct_method_name: "".to_string(),
139                        return_type: Arg {
140                            name: "".to_string(),
141                            c_type: ret.clone(),
142                            processing: ArgProcessing::Default,
143                        },
144                        arguments: process_types(args.clone()),
145                        docs: docs.clone(),
146                    }),
147                }
148            }
149        }
150    }
151}
152
153fn find_closest_wrapper_from_method_name(
154    wrappers: &mut BTreeMap<String, CWrapper>,
155    fn_name: &String,
156) -> Option<String> {
157    let type_names = get_possible_wrappers(&fn_name);
158
159    let mut value = None;
160    for ty in type_names {
161        if wrappers.contains_key(&ty) {
162            value = Some(ty);
163            break;
164        }
165    }
166
167    value
168}
169
170pub fn get_possible_wrappers(fn_name: &str) -> Vec<String> {
171    fn_name
172        .char_indices()
173        .filter(|(_, c)| *c == '_')
174        .map(|(i, _)| format!("{}_t", &fn_name[..i]))
175        .rev()
176        .collect_vec()
177}
178
179fn process_type(
180    wrappers: &mut BTreeMap<String, CWrapper>,
181    handlers: &mut Vec<CHandler>,
182    ty: &ItemType,
183) {
184    // Handle type definitions and get docs
185    let docs = get_doc_comments(&ty.attrs);
186
187    let type_name = ty.ident.to_string();
188    let class_name = snake_to_pascal_case(&type_name);
189
190    if ty.to_token_stream().to_string().contains("_stct") {
191        wrappers
192            .entry(type_name.clone())
193            .or_insert(CWrapper {
194                class_name,
195                without_name: type_name[..type_name.len() - 2].to_string(),
196                type_name,
197                ..Default::default()
198            })
199            .docs
200            .extend(docs);
201    } else {
202        // Parse the function pointer type -> it is typically used for handlers/callbacks
203        if let syn::Type::Path(type_path) = &*ty.ty {
204            if let Some(segment) = type_path.path.segments.last() {
205                if segment.ident.to_string() == "Option" {
206                    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
207                        if let Some(syn::GenericArgument::Type(syn::Type::BareFn(bare_fn))) =
208                            args.args.first()
209                        {
210                            let args: Vec<Arg> = bare_fn
211                                .inputs
212                                .iter()
213                                .map(|arg| {
214                                    let arg_name = match &arg.name {
215                                        Some((ident, _)) => ident.to_string(),
216                                        None => "".to_string(),
217                                    };
218                                    let arg_type = arg.ty.to_token_stream().to_string();
219                                    (arg_name, arg_type)
220                                })
221                                .map(|(field_name, field_type)| Arg {
222                                    name: field_name,
223                                    c_type: field_type,
224                                    processing: ArgProcessing::Default,
225                                })
226                                .collect();
227                            let string = bare_fn.output.to_token_stream().to_string();
228                            let mut return_type = string.trim();
229
230                            if return_type.starts_with("-> ") {
231                                return_type = &return_type[3..];
232                            }
233
234                            if return_type.is_empty() {
235                                return_type = "()";
236                            }
237
238                            if args.iter().filter(|a| a.is_c_void()).count() == 1 {
239                                let value = CHandler {
240                                    type_name: ty.ident.to_string(),
241                                    args: process_types(args),
242                                    return_type: Arg {
243                                        name: "".to_string(),
244                                        c_type: return_type.to_string(),
245                                        processing: ArgProcessing::Default,
246                                    },
247                                    docs: docs.clone(),
248                                    fn_mut_signature: Default::default(),
249                                    closure_type_name: Default::default(),
250                                };
251                                handlers.push(value);
252                            }
253                        }
254                    }
255                }
256            }
257        }
258    }
259}
260
261fn process_struct(wrappers: &mut BTreeMap<String, CWrapper>, s: &ItemStruct) {
262    // Print the struct name and its doc comments
263    let docs = get_doc_comments(&s.attrs);
264    let type_name = s.ident.to_string().replace("_stct", "_t");
265    let class_name = snake_to_pascal_case(&type_name);
266
267    let fields: Vec<Arg> = s
268        .fields
269        .iter()
270        .map(|f| {
271            let field_name = f.ident.as_ref().unwrap().to_string();
272            let field_type = f.ty.to_token_stream().to_string();
273            (field_name, field_type)
274        })
275        .map(|(field_name, field_type)| Arg {
276            name: field_name,
277            c_type: field_type,
278            processing: ArgProcessing::Default,
279        })
280        .collect();
281
282    let w = wrappers.entry(type_name.to_string()).or_insert(CWrapper {
283        class_name,
284        without_name: type_name[..type_name.len() - 2].to_string(),
285        type_name,
286        ..Default::default()
287    });
288    w.docs.extend(docs);
289    w.fields = process_types(fields);
290}
291
292fn process_types(mut name_and_type: Vec<Arg>) -> Vec<Arg> {
293    // now mark arguments which can be reduced
294    for i in 1..name_and_type.len() {
295        let param1 = &name_and_type[i - 1];
296        let param2 = &name_and_type[i];
297
298        let is_int = param2.c_type == "usize" || param2.c_type == "i32";
299        let length_field = param2.name == "length"
300            || param2.name == "len"
301            || (param2.name.ends_with("_length") && param2.name.starts_with(&param1.name));
302        if param2.is_c_void() && !param1.is_mut_pointer() && param1.c_type.ends_with("_t") {
303            // closures
304            //         handler: aeron_on_available_counter_t,
305            //         clientd: *mut ::std::os::raw::c_void,
306            let processing = ArgProcessing::Handler(vec![param1.clone(), param2.clone()]);
307            name_and_type[i - 1].processing = processing.clone();
308            name_and_type[i].processing = processing.clone();
309        } else if param1.is_c_string_any() && !param1.is_mut_pointer() && is_int && length_field {
310            //     pub stripped_channel: *mut ::std::os::raw::c_char,
311            //     pub stripped_channel_length: usize,
312            let processing = ArgProcessing::StringWithLength(vec![param1.clone(), param2.clone()]);
313            name_and_type[i - 1].processing = processing.clone();
314            name_and_type[i].processing = processing.clone();
315        } else if param1.is_byte_array()
316            // && !param1.is_mut_pointer()
317            && is_int
318            && length_field
319        {
320            //         key_buffer: *const u8,
321            //         key_buffer_length: usize,
322            let processing =
323                ArgProcessing::ByteArrayWithLength(vec![param1.clone(), param2.clone()]);
324            name_and_type[i - 1].processing = processing.clone();
325            name_and_type[i].processing = processing.clone();
326        }
327
328        //
329    }
330
331    name_and_type
332}
333
334// Helper function to extract doc comments
335fn get_doc_comments(attrs: &[Attribute]) -> BTreeSet<String> {
336    attrs
337        .iter()
338        .filter_map(|attr| {
339            // Parse the attribute meta to check if it is a `Meta::NameValue`
340            if let Meta::NameValue(MetaNameValue {
341                path,
342                value: syn::Expr::Lit(expr_lit),
343                ..
344            }) = &attr.meta
345            {
346                // Check if the path is "doc"
347                if path.is_ident("doc") {
348                    // Check if the literal is a string and return its value
349                    if let Lit::Str(lit_str) = &expr_lit.lit {
350                        return Some(lit_str.value().trim().to_string());
351                    }
352                }
353            }
354            None
355        })
356        .collect()
357}
358
359pub fn snake_to_pascal_case(mut snake: &str) -> String {
360    if snake.ends_with("_t") {
361        snake = &snake[..snake.len() - 2];
362    }
363    snake
364        .split('_')
365        .filter(|x| *x != "on") // Split the string by underscores
366        .map(|word| {
367            let mut chars = word.chars();
368            // Capitalize the first letter and collect the rest of the letters
369            match chars.next() {
370                Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
371                None => String::new(),
372            }
373        })
374        .collect()
375}
376
377// Helper function to extract function arguments as Rust code
378fn extract_function_arguments(
379    inputs: &syn::punctuated::Punctuated<syn::FnArg, syn::token::Comma>,
380) -> Vec<Arg> {
381    inputs
382        .iter()
383        .map(|arg| match arg {
384            syn::FnArg::Receiver(_) => "self".to_string(), // Handle self receiver
385            syn::FnArg::Typed(pat_type) => pat_type.to_token_stream().to_string(), // Convert the pattern and type to Rust code
386        })
387        .map(|arg| {
388            arg.splitn(2, ':')
389                .map(|s| s.trim().to_string())
390                .collect_tuple()
391                .unwrap()
392        })
393        .map(|(name, ty)| Arg {
394            name,
395            c_type: ty,
396            processing: ArgProcessing::Default,
397        })
398        .collect_vec()
399}
400
401// Helper function to extract return type as Rust code
402fn extract_return_type(output: &syn::ReturnType) -> String {
403    match output {
404        syn::ReturnType::Default => "()".to_string(), // No return type, equivalent to ()
405        syn::ReturnType::Type(_, ty) => ty.to_token_stream().to_string(), // Convert the type to Rust code
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use crate::parser::parse_bindings;
412
413    #[test]
414    fn media_driver() {
415        let bindings = parse_bindings(&"../rusteron-code-gen/bindings/media-driver.rs".into());
416        assert_eq!(
417            "AeronImageFragmentAssembler",
418            bindings
419                .wrappers
420                .get("aeron_image_fragment_assembler_t")
421                .unwrap()
422                .class_name
423        );
424    }
425    #[test]
426    fn client() {
427        let bindings = parse_bindings(&"../rusteron-code-gen/bindings/client.rs".into());
428        assert_eq!(
429            "AeronImageFragmentAssembler",
430            bindings
431                .wrappers
432                .get("aeron_image_fragment_assembler_t")
433                .unwrap()
434                .class_name
435        );
436        assert!(bindings.handlers.len() > 1);
437    }
438}