feat(builtins): macro

This commit is contained in:
2025-08-05 23:54:10 +08:00
parent 64f650b695
commit 32c602f21c
12 changed files with 426 additions and 132 deletions

View File

@@ -0,0 +1,207 @@
use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote, ToTokens};
use syn::{
parse_macro_input, FnArg, Item, ItemConst, ItemFn, ItemMod, Pat, PatType, Type,
Visibility,
};
pub fn builtins_impl(input: TokenStream) -> TokenStream {
let item_mod = parse_macro_input!(input as ItemMod);
let (_brace, items) = match item_mod.content.clone() {
Some(content) => content,
None => {
return syn::Error::new_spanned(
item_mod,
"`#[builtins]` macro can only be used on an inline module: `mod name { ... }`",
)
.to_compile_error()
.into();
}
};
let mut pub_item_mod: Vec<proc_macro2::TokenStream> = Vec::new();
let mut const_inserters = Vec::new();
let mut global_inserters = Vec::new();
let mut scoped_inserters = Vec::new();
let mut wrappers = Vec::new();
for item in &items {
match item {
Item::Const(item_const) => {
let inserter = generate_const_inserter(item_const);
const_inserters.push(inserter);
pub_item_mod.push(quote! {
pub #item_const
}.into());
}
Item::Fn(item_fn) => {
let (inserter, wrapper) = match generate_fn_wrapper(item_fn) {
Ok(result) => result,
Err(e) => return e.to_compile_error().into(),
};
if matches!(item_fn.vis, Visibility::Public(_)) {
global_inserters.push(inserter);
pub_item_mod.push(quote! { #item_fn }.into());
} else {
scoped_inserters.push(inserter);
pub_item_mod.push(quote! {
pub #item_fn
}.into());
}
wrappers.push(wrapper);
}
item => pub_item_mod.push(item.to_token_stream())
}
}
let output = quote! {
mod builtins {
#(#pub_item_mod)*
#(#wrappers)*
}
pub struct Builtins<Ctx: BuiltinsContext> {
pub consts: ::std::vec::Vec<(String, ::nixjit_value::Const)>,
pub global: ::std::vec::Vec<(String, fn(&mut Ctx, Vec<::nixjit_eval::Value<Ctx>>) -> ::nixjit_error::Result<::nixjit_eval::Value<Ctx>>)>,
pub scoped: ::std::vec::Vec<(String, fn(&mut Ctx, Vec<::nixjit_eval::Value<Ctx>>) -> ::nixjit_error::Result<::nixjit_eval::Value<Ctx>>)>,
}
impl<Ctx: BuiltinsContext> Builtins<Ctx> {
pub fn new() -> Self {
let mut consts = ::std::vec::Vec::new();
let mut global = ::std::vec::Vec::new();
let mut scoped = ::std::vec::Vec::new();
#(#const_inserters)*
#(#global_inserters)*
#(#scoped_inserters)*
Self { consts, global, scoped }
}
}
};
output.into()
}
fn generate_const_inserter(
item_const: &ItemConst,
) -> proc_macro2::TokenStream {
let name_str = item_const.ident.to_string().from_case(Case::UpperSnake).to_case(Case::Camel);
let const_name = &item_const.ident;
quote! {
consts.push((#name_str.to_string(), builtins::#const_name));
}
}
fn generate_fn_wrapper(
item_fn: &ItemFn,
) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
let fn_name = &item_fn.sig.ident;
let name_str = fn_name.to_string().from_case(Case::Snake).to_case(Case::Camel);
let wrapper_name = format_ident!("wrapper_{}", fn_name);
let mod_name = format_ident!("builtins");
let is_pub = matches!(item_fn.vis, Visibility::Public(_));
let mut user_args = item_fn.sig.inputs.iter().peekable();
let has_ctx = if let Some(FnArg::Typed(first_arg)) = user_args.peek() {
if let Type::Reference(_) = *first_arg.ty {
user_args.next();
true
} else {
false
}
} else {
return Err(syn::Error::new_spanned(
fn_name,
"A builtin function must not have a receiver argument",
));
};
let arg_pats: Vec<_> = user_args.rev().collect();
let arg_count = arg_pats.len();
let arg_unpacks = arg_pats.iter().enumerate().map(|(i, arg)| {
let arg_name = match &arg {
FnArg::Typed(PatType { pat, .. }) => {
if let Pat::Ident(pat_ident) = &**pat {
pat_ident.ident.clone()
} else {
format_ident!("arg{}", i, span = Span::call_site())
}
}
_ => format_ident!("arg{}", i, span = Span::call_site()),
};
let arg_ty = match &arg {
FnArg::Typed(PatType { ty, .. }) => ty,
_ => unreachable!(),
};
quote! {
let #arg_name: #arg_ty = args.pop().ok_or_else(|| ::nixjit_error::Error::EvalError("Not enough arguments provided".to_string()))?
.try_into().map_err(|e| ::nixjit_error::Error::EvalError(format!("Argument type conversion failed: {}", e)))?;
}
});
let arg_names: Vec<_> = arg_pats
.iter()
.enumerate()
.map(|(i, arg)| match &arg {
FnArg::Typed(PatType { pat, .. }) => {
if let Pat::Ident(pat_ident) = &**pat {
pat_ident.ident.clone()
} else {
format_ident!("arg{}", i, span = Span::call_site())
}
}
_ => unreachable!(),
})
.rev()
.collect();
let mut call_args = quote! { #(#arg_names),* };
if has_ctx {
call_args = quote! { ctx, #(#arg_names),* };
}
let returns_result = match &item_fn.sig.output {
syn::ReturnType::Type(_, ty) => {
if let Type::Path(type_path) = &**ty {
type_path.path.segments.iter().any(|s| s.ident == "Result")
} else {
false
}
}
_ => false,
};
let call_expr = if returns_result {
quote! { #fn_name(#call_args) }
} else {
quote! { Ok(#fn_name(#call_args).into()) }
};
let fn_type = quote! { fn(&mut Ctx, Vec<::nixjit_eval::Value<Ctx>>) -> ::nixjit_error::Result<::nixjit_eval::Value<Ctx>> };
let inserter = if is_pub {
quote! { global.push((#name_str.to_string(), #mod_name::#wrapper_name as #fn_type)); }
} else {
quote! { scoped.push((#name_str.to_string(), #mod_name::#wrapper_name as #fn_type)); }
};
let wrapper = quote! {
pub fn #wrapper_name<Ctx: BuiltinsContext>(ctx: &mut Ctx, mut args: Vec<::nixjit_eval::Value<Ctx>>) -> ::nixjit_error::Result<::nixjit_eval::Value<Ctx>> {
if args.len() != #arg_count {
return Err(::nixjit_error::Error::EvalError(format!("Function '{}' expects {} arguments, but received {}", #name_str, #arg_count, args.len())));
}
#(#arg_unpacks)*
#call_expr
}
};
Ok((inserter, wrapper))
}