250 lines
9.5 KiB
Rust
250 lines
9.5 KiB
Rust
//! Implements the `#[builtins]` procedural macro attribute.
|
|
//!
|
|
//! This macro simplifies the process of defining built-in functions (primops)
|
|
//! for the Nix interpreter. It inspects the functions inside a `mod` block
|
|
//! and generates the necessary boilerplate to make them callable from Nix code.
|
|
//!
|
|
//! Specifically, it generates:
|
|
//! 1. A `Builtins` struct containing arrays of constant values and function pointers.
|
|
//! 2. A wrapper function for each user-defined function. This wrapper handles:
|
|
//! - Arity (argument count) checking.
|
|
//! - Type conversion from the generic `nixjit_eval::Value` into the
|
|
//! specific types expected by the user's function.
|
|
//! - Calling the user's function with the converted arguments.
|
|
//! - Wrapping the return value back into a `Result<nixjit_eval::Value>`.
|
|
|
|
use convert_case::{Case, Casing};
|
|
use proc_macro::TokenStream;
|
|
use proc_macro2::Span;
|
|
use quote::{ToTokens, format_ident, quote};
|
|
use syn::{FnArg, Item, ItemFn, ItemMod, Pat, PatType, Type, Visibility, parse_macro_input};
|
|
|
|
/// The implementation of the `#[builtins]` macro.
|
|
pub fn builtins_impl(input: TokenStream) -> TokenStream {
|
|
let item_mod = parse_macro_input!(input as ItemMod);
|
|
let mod_name = &item_mod.ident;
|
|
let visibility = &item_mod.vis;
|
|
|
|
let (_brace, items) = match item_mod.content {
|
|
Some(content) => content,
|
|
None => {
|
|
return syn::Error::new_spanned(
|
|
item_mod,
|
|
"`#[builtins]` macro can only be used on an inline module: `mod name { ... }`",
|
|
)
|
|
.to_compile_error()
|
|
.into();
|
|
}
|
|
};
|
|
|
|
let mut pub_item_mod: Vec<proc_macro2::TokenStream> = Vec::new();
|
|
let mut consts = Vec::new();
|
|
let mut global = Vec::new();
|
|
let mut scoped = Vec::new();
|
|
let mut wrappers = Vec::new();
|
|
|
|
// Iterate over the items (functions, consts) in the user's module.
|
|
for item in &items {
|
|
match item {
|
|
Item::Const(item_const) => {
|
|
// Handle `const` definitions. These are exposed as constants in Nix.
|
|
let name_str = item_const
|
|
.ident
|
|
.to_string()
|
|
.from_case(Case::UpperSnake)
|
|
.to_case(Case::Camel);
|
|
let const_name = &item_const.ident;
|
|
consts.push(quote! { (#name_str, builtins::#const_name) });
|
|
pub_item_mod.push(
|
|
quote! {
|
|
pub #item_const
|
|
}
|
|
.into(),
|
|
);
|
|
}
|
|
Item::Fn(item_fn) => {
|
|
// Handle function definitions. These become primops.
|
|
let (primop, wrapper) = match generate_primop_wrapper(item_fn) {
|
|
Ok(result) => result,
|
|
Err(e) => return e.to_compile_error().into(),
|
|
};
|
|
// Public functions are added to the global scope, private ones to a scoped set.
|
|
if matches!(item_fn.vis, Visibility::Public(_)) {
|
|
global.push(primop);
|
|
pub_item_mod.push(quote! { #item_fn }.into());
|
|
} else {
|
|
scoped.push(primop);
|
|
pub_item_mod.push(
|
|
quote! {
|
|
pub #item_fn
|
|
}
|
|
.into(),
|
|
);
|
|
}
|
|
wrappers.push(wrapper);
|
|
}
|
|
// Other items are passed through unchanged.
|
|
item => pub_item_mod.push(item.to_token_stream()),
|
|
}
|
|
}
|
|
|
|
let consts_len = consts.len();
|
|
let global_len = global.len();
|
|
let scoped_len = scoped.len();
|
|
|
|
// Assemble the final generated code.
|
|
let output = quote! {
|
|
// Re-create the user's module, now with generated wrappers.
|
|
#visibility mod #mod_name {
|
|
#(#pub_item_mod)*
|
|
#(#wrappers)*
|
|
pub const CONSTS_LEN: usize = #consts_len;
|
|
pub const GLOBAL_LEN: usize = #global_len;
|
|
pub const SCOPED_LEN: usize = #scoped_len;
|
|
}
|
|
|
|
/// A struct containing all the built-in constants and functions.
|
|
pub struct Builtins<Ctx: BuiltinsContext> {
|
|
/// Constant values available in the global scope.
|
|
pub consts: [(&'static str, ::nixjit_value::Const); #mod_name::CONSTS_LEN],
|
|
/// Global functions available in the global scope.
|
|
pub global: [(&'static str, usize, fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::GLOBAL_LEN],
|
|
/// Scoped functions, typically available under the `builtins` attribute set.
|
|
pub scoped: [(&'static str, usize, fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::SCOPED_LEN],
|
|
}
|
|
|
|
impl<Ctx: BuiltinsContext> Builtins<Ctx> {
|
|
/// Creates a new instance of the `Builtins` struct.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
consts: [#(#consts,)*],
|
|
global: [#(#global,)*],
|
|
scoped: [#(#scoped,)*],
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
output.into()
|
|
}
|
|
|
|
/// Generates the primop metadata and the wrapper function for a single user-defined function.
|
|
fn generate_primop_wrapper(
|
|
item_fn: &ItemFn,
|
|
) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> {
|
|
let fn_name = &item_fn.sig.ident;
|
|
let name_str = fn_name
|
|
.to_string()
|
|
.from_case(Case::Snake)
|
|
.to_case(Case::Camel);
|
|
let wrapper_name = format_ident!("wrapper_{}", fn_name);
|
|
let mod_name = format_ident!("builtins");
|
|
|
|
let mut user_args = item_fn.sig.inputs.iter().peekable();
|
|
|
|
// Check if the first argument is a context `&mut Ctx`.
|
|
let has_ctx = if let Some(FnArg::Typed(first_arg)) = user_args.peek() {
|
|
if let Type::Reference(_) = *first_arg.ty {
|
|
user_args.next(); // Consume the context argument
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
} else {
|
|
return Err(syn::Error::new_spanned(
|
|
fn_name,
|
|
"A builtin function must not have a receiver argument",
|
|
));
|
|
};
|
|
|
|
// Collect the remaining arguments.
|
|
let arg_pats: Vec<_> = user_args.rev().collect();
|
|
let arg_count = arg_pats.len();
|
|
|
|
// Generate code to unpack and convert arguments from the `Vec<Value>`.
|
|
let arg_unpacks = arg_pats.iter().enumerate().map(|(i, arg)| {
|
|
let arg_name = match &arg {
|
|
FnArg::Typed(PatType { pat, .. }) => {
|
|
if let Pat::Ident(pat_ident) = &**pat {
|
|
pat_ident.ident.clone()
|
|
} else {
|
|
// Create a placeholder name if the pattern is not a simple ident.
|
|
format_ident!("arg{}", i, span = Span::call_site())
|
|
}
|
|
}
|
|
_ => format_ident!("arg{}", i, span = Span::call_site()),
|
|
};
|
|
let arg_ty = match &arg {
|
|
FnArg::Typed(PatType { ty, .. }) => ty,
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
quote! {
|
|
let #arg_name: #arg_ty = args.pop().ok_or_else(|| ::nixjit_error::Error::EvalError("Not enough arguments provided".to_string()))?
|
|
.try_into().map_err(|e| ::nixjit_error::Error::EvalError(format!("Argument type conversion failed: {}", e)))?;
|
|
}
|
|
});
|
|
|
|
// Get the names of the arguments to pass to the user's function.
|
|
let arg_names: Vec<_> = arg_pats
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, arg)| match &arg {
|
|
FnArg::Typed(PatType { pat, .. }) => {
|
|
if let Pat::Ident(pat_ident) = &**pat {
|
|
pat_ident.ident.clone()
|
|
} else {
|
|
format_ident!("arg{}", i, span = Span::call_site())
|
|
}
|
|
}
|
|
_ => unreachable!(),
|
|
})
|
|
.rev()
|
|
.collect();
|
|
|
|
// Construct the argument list for the final call.
|
|
let mut call_args = quote! { #(#arg_names),* };
|
|
if has_ctx {
|
|
call_args = quote! { ctx, #(#arg_names),* };
|
|
}
|
|
|
|
// Check if the user's function already returns a `Result`.
|
|
let returns_result = match &item_fn.sig.output {
|
|
syn::ReturnType::Type(_, ty) => {
|
|
if let Type::Path(type_path) = &**ty {
|
|
type_path.path.segments.iter().any(|s| s.ident == "Result")
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
_ => false,
|
|
};
|
|
|
|
// Wrap the call expression in `Ok(...)` if it doesn't return a `Result`.
|
|
let call_expr = if returns_result {
|
|
quote! { #fn_name(#call_args) }
|
|
} else {
|
|
quote! { Ok(#fn_name(#call_args).into()) }
|
|
};
|
|
|
|
let arity = arg_names.len();
|
|
let fn_type = quote! { fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value> };
|
|
|
|
// The primop metadata tuple: (name, arity, wrapper_function_pointer)
|
|
let primop = quote! { (#name_str, #arity, #mod_name::#wrapper_name as #fn_type) };
|
|
|
|
// The generated wrapper function.
|
|
let wrapper = quote! {
|
|
pub fn #wrapper_name<Ctx: BuiltinsContext>(ctx: &mut Ctx, mut args: Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value> {
|
|
if args.len() != #arg_count {
|
|
return Err(::nixjit_error::Error::EvalError(format!("Function '{}' expects {} arguments, but received {}", #name_str, #arg_count, args.len())));
|
|
}
|
|
#(#arg_unpacks)*
|
|
|
|
#call_expr
|
|
}
|
|
};
|
|
|
|
Ok((primop, wrapper))
|
|
}
|