diff --git a/Cargo.lock b/Cargo.lock index 83e50db..5107cd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -396,15 +396,8 @@ dependencies = [ name = "nixjit" version = "0.1.0" dependencies = [ - "hashbrown 0.15.4", "nixjit_context", - "nixjit_error", - "nixjit_eval", - "nixjit_hir", - "nixjit_ir", - "nixjit_lir", "nixjit_value", - "rnix", ] [[package]] @@ -425,6 +418,7 @@ dependencies = [ "cranelift-jit", "cranelift-module", "cranelift-native", + "derive_more", "hashbrown 0.15.4", "itertools", "nixjit_builtins", @@ -436,6 +430,8 @@ dependencies = [ "nixjit_lir", "nixjit_value", "petgraph", + "replace_with", + "rnix", ] [[package]] diff --git a/evaluator/nixjit/Cargo.toml b/evaluator/nixjit/Cargo.toml index 4e0db2e..7b49985 100644 --- a/evaluator/nixjit/Cargo.toml +++ b/evaluator/nixjit/Cargo.toml @@ -4,13 +4,5 @@ version = "0.1.0" edition = "2024" [dependencies] -rnix = "0.12" -hashbrown = "0.15" - nixjit_context = { path = "../nixjit_context" } -nixjit_error = { path = "../nixjit_error" } -nixjit_eval = { path = "../nixjit_eval" } -nixjit_hir = { path = "../nixjit_hir" } -nixjit_ir = { path = "../nixjit_ir" } -nixjit_lir = { path = "../nixjit_lir" } nixjit_value = { path = "../nixjit_value" } diff --git a/evaluator/nixjit/src/lib.rs b/evaluator/nixjit/src/lib.rs index 37285d4..f16cea5 100644 --- a/evaluator/nixjit/src/lib.rs +++ b/evaluator/nixjit/src/lib.rs @@ -12,4 +12,3 @@ #[cfg(test)] mod test; - diff --git a/evaluator/nixjit/src/test.rs b/evaluator/nixjit/src/test.rs index eb8a851..5426ab0 100644 --- a/evaluator/nixjit/src/test.rs +++ b/evaluator/nixjit/src/test.rs @@ -1,22 +1,17 @@ #![allow(unused_macros)] - use std::collections::BTreeMap; -use hashbrown::HashSet; use nixjit_context::Context; -use nixjit_eval::EvalContext; -use nixjit_hir::Downgrade; -use nixjit_lir::ResolveContext; use nixjit_value::{AttrSet, Const, List, Symbol, Value}; #[inline] fn test_expr(expr: &str, expected: Value) { println!("{expr}"); - let mut ctx = Context::new(); - let expr = rnix::Root::parse(expr).tree().expr().unwrap().downgrade(&mut ctx).unwrap(); - ctx.resolve(expr).unwrap(); - assert_eq!(ctx.eval(expr).unwrap().to_public(&ctx, &mut HashSet::new()), expected); + assert_eq!( + Context::new().eval(expr).unwrap(), + expected + ); } macro_rules! map { @@ -59,7 +54,7 @@ macro_rules! string { macro_rules! symbol { ($e:expr) => { - Symbol::from($e.to_string()) + Symbol::from($e) }; } @@ -206,7 +201,10 @@ fn test_func() { "(inputs@{ x, y, ... }: x + inputs.y) { x = 1; y = 2; z = 3; }", int!(3), ); - test_expr("let fix = f: let x = f x; in x; in (fix (self: { x = 1; y = self.x + 1; })).y", int!(2)); + test_expr( + "let fix = f: let x = f x; in x; in (fix (self: { x = 1; y = self.x + 1; })).y", + int!(2), + ); } #[test] diff --git a/evaluator/nixjit_builtins/src/lib.rs b/evaluator/nixjit_builtins/src/lib.rs index 85fdad7..60a8ceb 100644 --- a/evaluator/nixjit_builtins/src/lib.rs +++ b/evaluator/nixjit_builtins/src/lib.rs @@ -1,10 +1,9 @@ use nixjit_macros::builtins; -use nixjit_eval::EvalContext; -pub trait BuiltinsContext: EvalContext {} +pub trait BuiltinsContext {} #[builtins] -mod builtins { +pub mod builtins { use nixjit_error::{Error, Result}; use nixjit_eval::Value; use nixjit_value::Const; @@ -15,7 +14,7 @@ mod builtins { const FALSE: Const = Const::Bool(false); const NULL: Const = Const::Null; - fn add(a: Value, b: Value) -> Result> { + fn add(a: Value, b: Value) -> Result { use Value::*; Ok(match (a, b) { (Int(a), Int(b)) => Int(a + b), @@ -28,11 +27,11 @@ mod builtins { }) } - pub fn import(ctx: &mut Ctx, path: Value) -> Result> { + pub fn import(ctx: &mut Ctx, path: Value) -> Result { todo!() } - fn elem_at(list: Value, idx: Value) -> Result> { + fn elem_at(list: Value, idx: Value) -> Result { todo!() } } diff --git a/evaluator/nixjit_context/Cargo.toml b/evaluator/nixjit_context/Cargo.toml index 210d788..033dcea 100644 --- a/evaluator/nixjit_context/Cargo.toml +++ b/evaluator/nixjit_context/Cargo.toml @@ -4,9 +4,12 @@ version = "0.1.0" edition = "2024" [dependencies] +derive_more = { version = "2.0", features = ["full"] } hashbrown = "0.15" itertools = "0.14" petgraph = "0.8" +replace_with = "0.1" +rnix = "0.12" cranelift = "0.122" cranelift-module = "0.122" diff --git a/evaluator/nixjit_context/src/lib.rs b/evaluator/nixjit_context/src/lib.rs index 196079c..a8b1c88 100644 --- a/evaluator/nixjit_context/src/lib.rs +++ b/evaluator/nixjit_context/src/lib.rs @@ -1,50 +1,162 @@ -use core::mem::MaybeUninit; use std::cell::{OnceCell, RefCell}; use std::rc::Rc; +use derive_more::Unwrap; use hashbrown::{HashMap, HashSet}; -use petgraph::algo::toposort; +use itertools::Itertools; use petgraph::graph::{DiGraph, NodeIndex}; +use nixjit_builtins::{ + Builtins, BuiltinsContext, + builtins::{CONSTS_LEN, GLOBAL_LEN, SCOPED_LEN}, +}; use nixjit_error::{Error, Result}; use nixjit_eval::{EvalContext, Evaluate, Value}; -use nixjit_hir::{DowngradeContext, Hir}; -use nixjit_ir::{ArgIdx, ExprId, Param}; +use nixjit_hir::{Downgrade, DowngradeContext, Hir}; +use nixjit_ir::{ArgIdx, Const, ExprId, Param, PrimOp, PrimOpId}; use nixjit_lir::{Lir, LookupResult, Resolve, ResolveContext}; use nixjit_jit::{JITCompiler, JITContext, JITFunc}; +use replace_with::replace_with_and_return; -#[derive(Debug)] -struct Frame { - values: Vec>, - left: usize, +enum Scope { + With, + Let(HashMap), + Arg(Option), +} + +#[derive(Debug, Unwrap)] +enum Ir { + Hir(Hir), + Lir(Lir), +} + +impl Ir { + unsafe fn unwrap_hir_ref_unchecked(&self) -> &Hir { + if let Self::Hir(hir) = self { + hir + } else { + unsafe { core::hint::unreachable_unchecked() } + } + } + + unsafe fn unwrap_hir_mut_unchecked(&mut self) -> &mut Hir { + #[cfg(debug_assertions)] + if let Self::Hir(hir) = self { + hir + } else { + unsafe { core::hint::unreachable_unchecked() } + } + #[cfg(not(debug_assertions))] + if let Self::Hir(hir) = self { + hir + } else { + unsafe { core::hint::unreachable_unchecked() } + } + } + + unsafe fn unwrap_hir_unchecked(self) -> Hir { + if cfg!(debug_assertions) { + self.unwrap_hir() + } else if let Self::Hir(hir) = self { + hir + } else { + unsafe { core::hint::unreachable_unchecked() } + } + } + + unsafe fn unwrap_lir_ref_unchecked(&self) -> &Lir { + #[cfg(debug_assertions)] + if let Self::Lir(lir) = self { + lir + } else { + unsafe { core::hint::unreachable_unchecked() } + } + #[cfg(not(debug_assertions))] + if let Self::Lir(lir) = self { + lir + } else { + panic!() + } + } } -#[derive(Default)] pub struct Context { - hirs: Vec>, - lirs: Vec>, + irs: Vec>, resolved: Vec, - scopes: Vec>, + scopes: Vec, + args_count: usize, + primops: Vec) -> Result>, funcs: HashMap, graph: DiGraph, nodes: Vec, - stack: Vec, - with_scopes: Vec>>>, + stack: Vec>, + with_scopes: Vec>>, jit: JITCompiler, compiled: Vec>>, } -impl Drop for Context { - fn drop(&mut self) { - for (i, lir) in self.lirs.iter_mut().enumerate() { - if self.resolved[i] { - unsafe { - lir.assume_init_drop(); - } - } +impl Default for Context { + fn default() -> Self { + let Builtins { + consts, + global, + scoped, + } = Builtins::new(); + let global_scope = Scope::Let( + consts + .iter() + .enumerate() + .map(|(id, (k, _))| (k.to_string(), unsafe { ExprId::from(id) })) + .chain(global.iter().enumerate().map(|(idx, (k, _, _))| { + (k.to_string(), unsafe { + ExprId::from(idx + CONSTS_LEN) + }) + })) + .chain(core::iter::once(("builtins".to_string(), unsafe { + ExprId::from(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN) + }))) + .collect(), + ); + let primops = global + .iter() + .map(|&(_, _, f)| f) + .chain(scoped.iter().map(|&(_, _, f)| f)) + .collect(); + let irs = consts + .into_iter() + .map(|(_, val)| Ir::Lir(Lir::Const(Const { val }))) + .chain( + global + .into_iter() + .enumerate() + .map(|(idx, (name, arity, _))| { + Ir::Lir(Lir::PrimOp(PrimOp { + name, + arity, + id: unsafe { PrimOpId::from(idx) }, + })) + }), + ) + .map(RefCell::new) + .collect(); + Self { + irs, + resolved: Vec::new(), + scopes: vec![global_scope], + args_count: 0, + primops, + funcs: HashMap::new(), + graph: DiGraph::new(), + nodes: Vec::new(), + + stack: Vec::new(), + with_scopes: Vec::new(), + + jit: JITCompiler::new(), + compiled: Vec::new(), } } } @@ -53,114 +165,166 @@ impl Context { pub fn new() -> Self { Self::default() } + + pub fn eval(mut self, expr: &str) -> Result { + let root = rnix::Root::parse(expr); + if !root.errors().is_empty() { + return Err(Error::ParseError( + root.errors().iter().map(|err| err.to_string()).join(";"), + )); + } + let root = root.tree().expr().unwrap().downgrade(&mut self)?; + self.resolve(&root)?; + Ok(EvalContext::eval(&mut self, &root)?.to_public(&mut HashSet::new())) + } } impl DowngradeContext for Context { fn new_expr(&mut self, expr: Hir) -> ExprId { - let id = ExprId::from(self.hirs.len()); - self.hirs.push(expr.into()); - self.lirs.push(MaybeUninit::uninit()); - self.nodes.push(self.graph.add_node(id)); + let id = unsafe { ExprId::from(self.irs.len()) }; + self.irs.push(Ir::Hir(expr).into()); + self.nodes.push(self.graph.add_node(unsafe { id.clone() })); self.resolved.push(false); self.compiled.push(OnceCell::new()); id } fn with_expr(&self, id: ExprId, f: impl FnOnce(&Hir, &Self) -> T) -> T { - let idx = usize::from(id); - f(&self.hirs[idx].borrow(), self) - } - fn with_expr_mut(&mut self, id: ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T { - let idx = usize::from(id); unsafe { + let idx = id.raw(); + f(&self.irs[idx].borrow().unwrap_hir_ref_unchecked(), self) + } + } + fn with_expr_mut(&mut self, id: &ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T { + unsafe { + let idx = id.clone().raw(); let self_mut = &mut *(self as *mut Self); - f(&mut self.hirs.get_unchecked_mut(idx).borrow_mut(), self_mut) + f( + &mut self + .irs + .get_unchecked_mut(idx) + .borrow_mut() + .unwrap_hir_mut_unchecked(), + self_mut, + ) } } } impl ResolveContext for Context { - fn lookup(&self, name: &str) -> nixjit_lir::LookupResult { + fn lookup(&self, name: &str) -> LookupResult { + let mut arg_idx = 0; for scope in self.scopes.iter().rev() { - if let Some(val) = scope.get(name) { - return val.clone(); + match scope { + Scope::Let(scope) => { + if let Some(expr) = scope.get(name) { + return LookupResult::Expr(unsafe { expr.clone() }); + } + } + Scope::Arg(ident) => { + if ident.as_deref() == Some(name) { + return LookupResult::Arg(unsafe { ArgIdx::from(arg_idx) }); + } + arg_idx += 1; + } + Scope::With => return LookupResult::Unknown, } } LookupResult::NotFound } - fn new_dep(&mut self, expr: ExprId, dep: ExprId) { - let expr = *self.nodes.get(usize::from(expr)).unwrap(); - let dep = *self.nodes.get(usize::from(dep)).unwrap(); - self.graph.add_edge(expr, dep, ()); + fn new_dep(&mut self, expr: &ExprId, dep: ExprId) { + unsafe { + let expr = expr.clone().raw(); + let dep = dep.raw(); + let expr = *self.nodes.get_unchecked(expr); + let dep = *self.nodes.get_unchecked(dep); + self.graph.add_edge(expr, dep, ()); + } } - fn resolve(&mut self, expr: ExprId) -> Result<()> { - let idx = usize::from(expr); - if self.resolved[idx] { - return Ok(()); + fn resolve(&mut self, expr: &ExprId) -> Result<()> { + unsafe { + let idx = expr.clone().raw(); + let self_mut = &mut *(self as *mut Self); + replace_with_and_return( + &mut *self.irs.get_unchecked(idx).borrow_mut(), + || { + Ir::Hir(Hir::Const(Const { + val: nixjit_value::Const::Null, + })) + }, + |ir| { + let hir = ir.unwrap_hir_unchecked(); + match hir.resolve(self_mut) { + Ok(lir) => (Ok(()), Ir::Lir(lir)), + Err(err) => ( + Err(err), + Ir::Hir(Hir::Const(Const { + val: nixjit_value::Const::Null, + })), + ), + } + }, + )?; } - - let hir = self.hirs[idx].replace(nixjit_hir::Hir::Const(nixjit_ir::Const::from(false))); - let lir = hir.resolve(self)?; - self.lirs[idx].write(lir); - self.resolved[idx] = true; Ok(()) } - fn new_func(&mut self, body: ExprId, param: Param) { - self.funcs.insert(body, param); + fn new_func(&mut self, body: &ExprId, param: Param) { + self.funcs.insert(unsafe { body.clone() }, param); } fn with_let_env<'a, T>( &mut self, - bindings: impl IntoIterator, + bindings: impl Iterator, f: impl FnOnce(&mut Self) -> T, ) -> T { let mut scope = HashMap::new(); for (name, expr) in bindings { - scope.insert(name.clone(), LookupResult::Expr(*expr)); + scope.insert(name.clone(), unsafe { expr.clone() }); } - self.scopes.push(scope); + self.scopes.push(Scope::Let(scope)); let res = f(self); self.scopes.pop(); res } fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T) { - self.scopes.push(HashMap::new()); + self.scopes.push(Scope::With); let res = f(self); self.scopes.pop(); (true, res) } - fn with_param_env<'a, T>( - &mut self, - ident: Option<&'a str>, - f: impl FnOnce(&mut Self) -> T, - ) -> T { - let mut scope = HashMap::new(); - if let Some(ident) = ident { - scope.insert(ident.to_string(), LookupResult::Arg(ArgIdx::from(0))); - } - self.scopes.push(scope); + fn with_param_env(&mut self, ident: Option, f: impl FnOnce(&mut Self) -> T) -> T { + self.scopes.push(Scope::Arg(ident)); + self.args_count += 1; let res = f(self); + self.args_count -= 1; self.scopes.pop(); res } } impl EvalContext for Context { - fn eval(&mut self, expr: ExprId) -> Result> { - let idx = usize::from(expr); - let lir = unsafe { &*(self.lirs[idx].assume_init_ref() as *const Lir) }; + fn eval(&mut self, expr: &ExprId) -> Result { + let idx = unsafe { expr.clone().raw() }; + let lir = unsafe { + &*(self + .irs + .get_unchecked(idx) + .borrow() + .unwrap_lir_ref_unchecked() as *const Lir) + }; + println!("{:#?}", self.irs); lir.eval(self) } - fn pop_frame(&mut self) -> Vec> { - self.stack.pop().unwrap().values + fn pop_frame(&mut self) -> Vec { + self.stack.pop().unwrap() } - fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a nixjit_eval::Value> { + fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a nixjit_eval::Value> { for scope in self.with_scopes.iter().rev() { if let Some(val) = scope.get(ident) { return Some(val); @@ -169,13 +333,17 @@ impl EvalContext for Context { None } - fn lookup_arg<'a>(&'a self, offset: usize) -> Option<&'a Value> { - self.stack.last().and_then(|frame| frame.values.get(offset)) + fn lookup_arg<'a>(&'a self, idx: ArgIdx) -> &'a Value { + unsafe { + let values = self.stack.last().unwrap_unchecked(); + dbg!(values, idx); + &values[values.len() - idx.raw() - 1] + } } fn with_with_env( &mut self, - namespace: std::rc::Rc>>, + namespace: std::rc::Rc>, f: impl FnOnce(&mut Self) -> T, ) -> T { self.with_scopes.push(namespace); @@ -186,53 +354,33 @@ impl EvalContext for Context { fn with_args_env( &mut self, - args: Vec>, + args: Vec, f: impl FnOnce(&mut Self) -> T, - ) -> (Vec>, T) { - self.stack.push(Frame { - left: args.len(), - values: args, - }); + ) -> (Vec, T) { + self.stack.push(args); let res = f(self); - (self.stack.pop().unwrap().values, res) + let frame = self.stack.pop().unwrap(); + (frame, res) } - fn consume_arg(&mut self, func: ExprId) -> Result { - let Some(frame) = self.stack.last_mut() else { - return Ok(false); - }; - if frame.left == 0 { - return Ok(false); + fn call_primop(&mut self, id: nixjit_ir::PrimOpId, args: Vec) -> Result { + unsafe { + (self.primops.get_unchecked(id.raw()))(self, args) } - frame.left -= 1; - let param = self.funcs.get(&func).unwrap(); - if let Some(required) = ¶m.required { - let attrs = frame.values[frame.values.len() - frame.left - 1] - .as_ref() - .try_unwrap_attr_set() - .map_err(|_| Error::EvalError(format!("expected a set but found ...")))?; - if required.iter().any(|attr| attrs.get(attr).is_none()) - || param.allowed.as_ref().map_or(false, |allowed| { - attrs.iter().any(|(attr, _)| allowed.get(attr).is_none()) - }) - { - return Err(Error::EvalError(format!("TODO"))); - } - } - Ok(true) } } impl JITContext for Context { - fn lookup_arg(&self, offset: usize) -> &nixjit_eval::Value { - &self.stack.last().unwrap().values[offset] + fn lookup_arg(&self, offset: usize) -> &nixjit_eval::Value { + let values = self.stack.last().unwrap(); + &values[values.len() - offset - 1] } - fn lookup_stack(&self, offset: usize) -> &nixjit_eval::Value { + fn lookup_stack(&self, offset: usize) -> &nixjit_eval::Value { todo!() } - fn enter_with(&mut self, namespace: std::rc::Rc>>) { + fn enter_with(&mut self, namespace: std::rc::Rc>) { self.with_scopes.push(namespace); } @@ -240,3 +388,5 @@ impl JITContext for Context { self.with_scopes.pop(); } } + +impl BuiltinsContext for Context {} diff --git a/evaluator/nixjit_eval/src/lib.rs b/evaluator/nixjit_eval/src/lib.rs index bf56d2c..354dbab 100644 --- a/evaluator/nixjit_eval/src/lib.rs +++ b/evaluator/nixjit_eval/src/lib.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use hashbrown::HashMap; use nixjit_error::{Error, Result}; -use nixjit_ir::{self as ir, ExprId}; +use nixjit_ir::{self as ir, ArgIdx, ExprId, PrimOpId}; use nixjit_lir as lir; use nixjit_value::{Const, Symbol}; @@ -12,35 +12,31 @@ pub use crate::value::*; mod value; pub trait EvalContext: Sized { - fn eval(&mut self, expr: ExprId) -> Result>; + fn eval(&mut self, expr: &ExprId) -> Result; fn with_with_env( &mut self, - namespace: Rc>>, + namespace: Rc>, f: impl FnOnce(&mut Self) -> T, ) -> T; - fn with_args_env( - &mut self, - args: Vec>, - f: impl FnOnce(&mut Self) -> T, - ) -> (Vec>, T); - fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a Value>; - fn lookup_arg<'a>(&'a self, offset: usize) -> Option<&'a Value>; - fn pop_frame(&mut self) -> Vec>; - fn consume_arg(&mut self, func: ExprId) -> Result; + fn with_args_env(&mut self, args: Vec, f: impl FnOnce(&mut Self) -> T) -> (Vec, T); + fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a Value>; + fn lookup_arg<'a>(&'a self, idx: ArgIdx) -> &'a Value; + fn pop_frame(&mut self) -> Vec; + fn call_primop(&mut self, id: PrimOpId, args: Vec) -> Result; } pub trait Evaluate { - fn eval(&self, ctx: &mut Ctx) -> Result>; + fn eval(&self, ctx: &mut Ctx) -> Result; } impl Evaluate for ExprId { - fn eval(&self, ctx: &mut Ctx) -> Result> { - ctx.eval(*self) + fn eval(&self, ctx: &mut Ctx) -> Result { + ctx.eval(self) } } impl Evaluate for lir::Lir { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { use lir::Lir::*; match self { AttrSet(x) => x.eval(ctx), @@ -58,26 +54,16 @@ impl Evaluate for lir::Lir { Str(x) => x.eval(ctx), Var(x) => x.eval(ctx), Path(x) => x.eval(ctx), - ExprRef(expr) => expr.eval(ctx), - &FuncRef(func) => { - if ctx.consume_arg(func)? { - ctx.eval(func) - } else { - Ok(Value::Func(func)) - } - }, - &ArgRef(arg) => { - let idx: usize = unsafe { core::mem::transmute(arg) }; - ctx.lookup_arg(idx) - .cloned() - .ok_or_else(|| Error::EvalError("argument not found".to_string())) - } + ExprRef(expr) => ctx.eval(expr), + FuncRef(func) => Ok(Value::Func(unsafe { func.clone() })), + &ArgRef(idx) => Ok(ctx.lookup_arg(idx).clone()), + &PrimOp(primop) => Ok(Value::PrimOp(primop)), } } } impl Evaluate for ir::AttrSet { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { let mut attrs = AttrSet::new( self.stcs .iter() @@ -99,7 +85,7 @@ impl Evaluate for ir::AttrSet { } impl Evaluate for ir::List { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { let items = self .items .iter() @@ -111,7 +97,7 @@ impl Evaluate for ir::List { } impl Evaluate for ir::HasAttr { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { use ir::Attr::*; let mut val = self.lhs.eval(ctx)?; val.has_attr(self.rhs.iter().map(|attr| { @@ -130,7 +116,7 @@ impl Evaluate for ir::HasAttr { } impl Evaluate for ir::BinOp { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { use ir::BinOpKind::*; let mut lhs = self.lhs.eval(ctx)?; let mut rhs = self.rhs.eval(ctx)?; @@ -169,9 +155,9 @@ impl Evaluate for ir::BinOp { } Con => lhs.concat(rhs), Upd => lhs.update(rhs), - PipeL => lhs.call(vec![rhs], ctx)?, + PipeL => lhs.call(core::iter::once(Ok(rhs)), ctx)?, PipeR => { - rhs.call(vec![lhs], ctx)?; + rhs.call(core::iter::once(Ok(lhs)), ctx)?; lhs = rhs; } } @@ -180,7 +166,7 @@ impl Evaluate for ir::BinOp { } impl Evaluate for ir::UnOp { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { use ir::UnOpKind::*; let mut rhs = self.rhs.eval(ctx)?; match self.kind { @@ -196,7 +182,7 @@ impl Evaluate for ir::UnOp { } impl Evaluate for ir::Select { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { use ir::Attr::*; let mut val = self.expr.eval(ctx)?; if let Some(default) = &self.default { @@ -232,7 +218,7 @@ impl Evaluate for ir::Select { } impl Evaluate for ir::If { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { // TODO: Error Handling let cond = self.cond.eval(ctx)?; let cond = cond @@ -248,21 +234,22 @@ impl Evaluate for ir::If { } impl Evaluate for ir::Call { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { let mut func = self.func.eval(ctx)?; + // FIXME: ? + let ctx_mut = unsafe { &mut *(ctx as *mut Ctx) }; func.call( self.args .iter() - .map(|arg| arg.eval(ctx)) - .collect::>()?, - ctx, + .map(|arg| arg.eval(ctx)), + ctx_mut, )?; Ok(func.ok().unwrap()) } } impl Evaluate for ir::With { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { let namespace = self.namespace.eval(ctx)?; ctx.with_with_env( namespace @@ -275,7 +262,7 @@ impl Evaluate for ir::With { } impl Evaluate for ir::Assert { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { let cond = self.assertion.eval(ctx)?; let cond = cond .try_unwrap_bool() @@ -289,7 +276,7 @@ impl Evaluate for ir::Assert { } impl Evaluate for ir::ConcatStrings { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { let mut parts = self .parts .iter() @@ -310,14 +297,14 @@ impl Evaluate for ir::ConcatStrings { } impl Evaluate for ir::Str { - fn eval(&self, _: &mut Ctx) -> Result> { + fn eval(&self, _: &mut Ctx) -> Result { let result = Value::String(self.val.clone()).ok(); Ok(result.unwrap()) } } impl Evaluate for ir::Const { - fn eval(&self, _: &mut Ctx) -> Result> { + fn eval(&self, _: &mut Ctx) -> Result { let result = match self.val { Const::Null => Value::Null, Const::Int(x) => Value::Int(x), @@ -330,7 +317,7 @@ impl Evaluate for ir::Const { } impl Evaluate for ir::Var { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { ctx.lookup_with(&self.sym) .ok_or_else(|| { Error::EvalError(format!( @@ -343,7 +330,7 @@ impl Evaluate for ir::Var { } impl Evaluate for ir::Path { - fn eval(&self, ctx: &mut Ctx) -> Result> { + fn eval(&self, ctx: &mut Ctx) -> Result { todo!() } } diff --git a/evaluator/nixjit_eval/src/value/attrset.rs b/evaluator/nixjit_eval/src/value/attrset.rs index 2447f09..e0fd691 100644 --- a/evaluator/nixjit_eval/src/value/attrset.rs +++ b/evaluator/nixjit_eval/src/value/attrset.rs @@ -1,6 +1,6 @@ use core::ops::Deref; -use std::rc::Rc; use std::fmt::Debug; +use std::rc::Rc; use derive_more::Constructor; use hashbrown::{HashMap, HashSet}; @@ -15,11 +15,11 @@ use crate::EvalContext; #[repr(transparent)] #[derive(Constructor, PartialEq)] -pub struct AttrSet { - data: HashMap>, +pub struct AttrSet { + data: HashMap, } -impl Debug for AttrSet { +impl Debug for AttrSet { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use Value::*; write!(f, "{{ ")?; @@ -34,7 +34,7 @@ impl Debug for AttrSet { } } -impl Clone for AttrSet { +impl Clone for AttrSet { fn clone(&self) -> Self { AttrSet { data: self.data.clone(), @@ -42,31 +42,31 @@ impl Clone for AttrSet { } } -impl From>> for AttrSet { - fn from(data: HashMap>) -> Self { +impl From> for AttrSet { + fn from(data: HashMap) -> Self { Self { data } } } -impl Deref for AttrSet { - type Target = HashMap>; +impl Deref for AttrSet { + type Target = HashMap; fn deref(&self) -> &Self::Target { &self.data } } -impl AttrSet { +impl AttrSet { pub fn with_capacity(cap: usize) -> Self { AttrSet { data: HashMap::with_capacity(cap), } } - pub fn push_attr_force(&mut self, sym: String, val: Value) { + pub fn push_attr_force(&mut self, sym: String, val: Value) { self.data.insert(sym, val); } - pub fn push_attr(&mut self, sym: String, val: Value) { + pub fn push_attr(&mut self, sym: String, val: Value) { if self.data.get(&sym).is_some() { todo!() } @@ -76,7 +76,7 @@ impl AttrSet { pub fn select( &self, mut path: impl DoubleEndedIterator>, - ) -> Result> { + ) -> Result { let mut data = &self.data; let last = path.nth_back(0).unwrap(); for item in path { @@ -116,15 +116,15 @@ impl AttrSet { } } - pub fn as_inner(&self) -> &HashMap> { + pub fn as_inner(&self) -> &HashMap { &self.data } - pub fn into_inner(self: Rc) -> Rc>> { + pub fn into_inner(self: Rc) -> Rc> { unsafe { core::mem::transmute(self) } } - pub fn from_inner(data: HashMap>) -> Self { + pub fn from_inner(data: HashMap) -> Self { Self { data } } @@ -137,11 +137,11 @@ impl AttrSet { .all(|((k1, v1), (k2, v2))| k1 == k2 && v1.eq_impl(v2)) } - pub fn to_public(&self, ctx: &Ctx, seen: &mut HashSet>) -> p::Value { + pub fn to_public(&self, seen: &mut HashSet) -> p::Value { p::Value::AttrSet(p::AttrSet::new( self.data .iter() - .map(|(sym, value)| (sym.as_str().into(), value.to_public(ctx, seen))) + .map(|(sym, value)| (sym.as_str().into(), value.to_public(seen))) .collect(), )) } diff --git a/evaluator/nixjit_eval/src/value/func.rs b/evaluator/nixjit_eval/src/value/func.rs index cf6886f..ffc9966 100644 --- a/evaluator/nixjit_eval/src/value/func.rs +++ b/evaluator/nixjit_eval/src/value/func.rs @@ -9,39 +9,59 @@ use super::Value; use crate::EvalContext; #[derive(Debug, Constructor)] -pub struct FuncApp { +pub struct FuncApp { pub body: ExprId, - pub args: Vec>, - pub frame: Vec>, + pub args: Vec, + pub frame: Vec, } -impl Clone for FuncApp { +impl Clone for FuncApp { fn clone(&self) -> Self { Self { - body: self.body, + body: unsafe { self.body.clone() }, args: self.args.clone(), frame: self.frame.clone(), } } } -impl FuncApp { - pub fn call( +impl FuncApp { + pub fn call( self: &mut Rc, - new_args: Vec>, + mut iter: impl Iterator> + ExactSizeIterator, ctx: &mut Ctx, - ) -> Result> { - let FuncApp { body: expr, args, frame } = Rc::make_mut(self); - args.extend(new_args); - let (args, ret) = ctx.with_args_env(core::mem::take(args), |ctx| ctx.eval(*expr)); - let mut ret = ret?; - if let Value::Func(expr) = ret { - let frame = ctx.pop_frame(); - ret = Value::FuncApp(FuncApp::new(expr, args, frame).into()); - } else if let Value::FuncApp(func) = &mut ret { - todo!(); - let func = Rc::make_mut(func); + ) -> Result { + let FuncApp { + body: expr, + args, + frame, + } = Rc::make_mut(self); + let mut val; + let mut args = core::mem::take(args); + args.push(iter.next().unwrap()?); + let (ret_args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(expr)); + args = ret_args; + val = ret?; + loop { + if !matches!(val, Value::Func(_) | Value::FuncApp(_)) { + break; + } + let Some(arg) = iter.next() else { + break; + }; + args.push(arg?); + if let Value::Func(expr) = val { + let (ret_args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(&expr)); + args = ret_args; + val = ret?; + } else if let Value::FuncApp(func) = val { + let mut func = Rc::unwrap_or_clone(func); + func.args.push(args.pop().unwrap()); + let (ret_args, ret) = ctx.with_args_env(func.args, |ctx| ctx.eval(&func.body)); + args = ret_args; + val = ret?; + } } - ret.ok() + val.ok() } } diff --git a/evaluator/nixjit_eval/src/value/list.rs b/evaluator/nixjit_eval/src/value/list.rs index 7a17937..7dee374 100644 --- a/evaluator/nixjit_eval/src/value/list.rs +++ b/evaluator/nixjit_eval/src/value/list.rs @@ -1,5 +1,5 @@ -use std::ops::Deref; use std::fmt::Debug; +use std::ops::Deref; use hashbrown::HashSet; @@ -10,11 +10,11 @@ use super::Value; use crate::EvalContext; #[derive(Default)] -pub struct List { - data: Vec>, +pub struct List { + data: Vec, } -impl Debug for List { +impl Debug for List { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[ ")?; for v in self.data.iter() { @@ -24,7 +24,7 @@ impl Debug for List { } } -impl Clone for List { +impl Clone for List { fn clone(&self) -> Self { Self { data: self.data.clone(), @@ -32,20 +32,20 @@ impl Clone for List { } } -impl>>> From for List { +impl>> From for List { fn from(value: T) -> Self { Self { data: value.into() } } } -impl Deref for List { - type Target = [Value]; +impl Deref for List { + type Target = [Value]; fn deref(&self) -> &Self::Target { &self.data } } -impl List { +impl List { pub fn new() -> Self { List { data: Vec::new() } } @@ -56,7 +56,7 @@ impl List { } } - pub fn push(&mut self, elem: Value) { + pub fn push(&mut self, elem: Value) { self.data.push(elem); } @@ -66,7 +66,7 @@ impl List { } } - pub fn into_inner(self) -> Vec> { + pub fn into_inner(self) -> Vec { self.data } @@ -75,11 +75,11 @@ impl List { && core::iter::zip(self.iter(), other.iter()).all(|(a, b)| a.eq_impl(b)) } - pub fn to_public(&self, engine: &Ctx, seen: &mut HashSet>) -> PubValue { + pub fn to_public(&self, seen: &mut HashSet) -> PubValue { PubValue::List(PubList::new( self.data .iter() - .map(|value| value.clone().to_public(engine, seen)) + .map(|value| value.clone().to_public(seen)) .collect(), )) } diff --git a/evaluator/nixjit_eval/src/value/mod.rs b/evaluator/nixjit_eval/src/value/mod.rs index c937b78..8cf3997 100644 --- a/evaluator/nixjit_eval/src/value/mod.rs +++ b/evaluator/nixjit_eval/src/value/mod.rs @@ -1,14 +1,13 @@ +use std::fmt::Debug; use std::hash::Hash; use std::process::abort; use std::rc::Rc; -use std::fmt::{write, Debug}; use derive_more::TryUnwrap; use derive_more::{IsVariant, Unwrap}; -use func::FuncApp; use hashbrown::HashSet; -use nixjit_ir::ExprId; -use replace_with::{replace_with_and_return, replace_with_or_abort}; +use nixjit_ir::{ExprId, PrimOp}; +use replace_with::replace_with_and_return; use nixjit_error::{Error, Result}; use nixjit_value::Const; @@ -23,28 +22,29 @@ mod primop; mod string; pub use attrset::*; +pub use func::*; pub use list::List; pub use primop::*; #[repr(C, u64)] #[derive(IsVariant, TryUnwrap, Unwrap)] -pub enum Value { +pub enum Value { Int(i64), Float(f64), Bool(bool), String(String), Null, Thunk(usize), - AttrSet(Rc>), - List(Rc>), + AttrSet(Rc), + List(Rc), Catchable(String), - PrimOp(Rc>), - PrimOpApp(Rc>), + PrimOp(PrimOp), + PrimOpApp(Rc), Func(ExprId), - FuncApp(Rc>), + FuncApp(Rc), } -impl Debug for Value { +impl Debug for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use Value::*; match self { @@ -65,28 +65,28 @@ impl Debug for Value { } } -impl Clone for Value { +impl Clone for Value { fn clone(&self) -> Self { use Value::*; match self { AttrSet(attrs) => AttrSet(attrs.clone()), List(list) => List(list.clone()), Catchable(catchable) => Catchable(catchable.clone()), - Int(x) => Int(*x), - Float(x) => Float(*x), - Bool(x) => Bool(*x), + &Int(x) => Int(x), + &Float(x) => Float(x), + &Bool(x) => Bool(x), String(x) => String(x.clone()), Null => Null, - Thunk(expr) => Thunk(*expr), - PrimOp(primop) => PrimOp(primop.clone()), + &Thunk(expr) => Thunk(expr), + &PrimOp(primop) => PrimOp(primop), PrimOpApp(primop) => PrimOpApp(primop.clone()), - Func(expr) => Func(*expr), + Func(expr) => Func(unsafe { expr.clone() }), FuncApp(func) => FuncApp(func.clone()), } } } -impl Hash for Value { +impl Hash for Value { fn hash(&self, state: &mut H) { use Value::*; std::mem::discriminant(self).hash(state); @@ -98,7 +98,7 @@ impl Hash for Value { } } -impl Value { +impl Value { pub const INT: u64 = 0; pub const FLOAT: u64 = 1; pub const BOOL: u64 = 2; @@ -130,7 +130,7 @@ impl Value { } } -impl PartialEq for Value { +impl PartialEq for Value { fn eq(&self, other: &Self) -> bool { use Value::*; match (self, other) { @@ -141,27 +141,27 @@ impl PartialEq for Value { } } -impl Eq for Value {} +impl Eq for Value {} #[derive(IsVariant, TryUnwrap, Unwrap, Clone)] -pub enum ValueAsRef<'v, Ctx: EvalContext> { +pub enum ValueAsRef<'v> { Int(i64), Float(f64), Bool(bool), String(&'v String), Null, Thunk(usize), - AttrSet(&'v AttrSet), - List(&'v List), + AttrSet(&'v AttrSet), + List(&'v List), Catchable(&'v str), - PrimOp(&'v PrimOp), - PartialPrimOp(&'v PrimOpApp), - Func(ExprId), - PartialFunc(&'v FuncApp), + PrimOp(&'v PrimOp), + PartialPrimOp(&'v PrimOpApp), + Func(&'v ExprId), + PartialFunc(&'v FuncApp), } -impl Value { - pub fn as_ref(&self) -> ValueAsRef<'_, Ctx> { +impl Value { + pub fn as_ref(&self) -> ValueAsRef<'_> { use Value::*; use ValueAsRef as R; match self { @@ -176,12 +176,12 @@ impl Value { Catchable(x) => R::Catchable(x), PrimOp(x) => R::PrimOp(x), PrimOpApp(x) => R::PartialPrimOp(x), - Func(x) => R::Func(*x), + Func(x) => R::Func(x), FuncApp(x) => R::PartialFunc(x), } } } -impl Value { +impl Value { pub fn ok(self) -> Result { Ok(self) } @@ -213,32 +213,58 @@ impl Value { } } - pub fn call(&mut self, args: Vec, ctx: &mut Ctx) -> Result<()> { + pub fn call(&mut self, mut iter: impl Iterator> + ExactSizeIterator, ctx: &mut Ctx) -> Result<()> { use Value::*; - for arg in args.iter() { - if matches!(arg, Value::Catchable(_)) { - *self = arg.clone(); - return Ok(()); - } - } *self = match self { - PrimOp(func) => func.call(args, ctx), - PrimOpApp(func) => func.call(args, ctx), - FuncApp(func) => func.call(args, ctx), - &mut Func(expr) => { - let (args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(expr)); - let mut ret = ret?; - if let Value::Func(expr) = ret { - let frame = ctx.pop_frame(); - ret = Value::FuncApp(self::FuncApp::new(expr, args, frame).into()); - } else if let Value::FuncApp(func) = &mut ret { - todo!(); - let func = Rc::make_mut(func); + &mut PrimOp(func) => { + if iter.len() > func.arity { + todo!() + } + if func.arity > iter.len() { + Value::PrimOpApp(Rc::new(self::PrimOpApp::new( + func.name, + func.arity - iter.len(), + func.id, + iter.collect::>()?, + ))) + .ok() + } else { + ctx.call_primop(func.id, iter.collect::>()?) } - ret.ok() } + Func(expr) => { + let mut val; + let mut args = Vec::with_capacity(iter.len()); + args.push(iter.next().unwrap()?); + let (ret_args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(expr)); + args = ret_args; + val = ret?; + loop { + if !matches!(val, Value::Func(_) | Value::FuncApp(_)) { + break; + } + let Some(arg) = iter.next() else { + break; + }; + args.push(arg?); + if let Value::Func(expr) = val { + let (ret_args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(&expr)); + args = ret_args; + val = ret?; + } else if let Value::FuncApp(func) = val { + let mut func = Rc::unwrap_or_clone(func); + func.args.push(args.pop().unwrap()); + let (ret_args, ret) = ctx.with_args_env(func.args, |ctx| ctx.eval(&func.body)); + args = ret_args; + val = ret?; + } + } + val.ok() + } + PrimOpApp(func) => func.call(iter.collect::>()?, ctx), + FuncApp(func) => func.call(iter, ctx), Catchable(_) => return Ok(()), - other => todo!("{}", other.typename()), + _ => Err(Error::EvalError("attempt to call something which is not a function but ...".to_string())) }?; Ok(()) } @@ -310,22 +336,26 @@ impl Value { pub fn add(&mut self, other: Self) -> Result<()> { use Value::*; - replace_with_and_return(self, || abort(), |a| { - let val = match (a, other) { - (Int(a), Int(b)) => Int(a + b), - (Int(a), Float(b)) => Float(a as f64 + b), - (Float(a), Int(b)) => Float(a + b as f64), - (Float(a), Float(b)) => Float(a + b), - (String(mut a), String(b)) => { - a.push_str(&b); - String(a) - } - (a @ Value::Catchable(_), _) => a, - (_, x @ Value::Catchable(_)) => x, - _ => return (Err(Error::EvalError(format!(""))), Value::Null), - }; - (Ok(()), val) - }) + replace_with_and_return( + self, + || abort(), + |a| { + let val = match (a, other) { + (Int(a), Int(b)) => Int(a + b), + (Int(a), Float(b)) => Float(a as f64 + b), + (Float(a), Int(b)) => Float(a + b as f64), + (Float(a), Float(b)) => Float(a + b), + (String(mut a), String(b)) => { + a.push_str(&b); + String(a) + } + (a @ Value::Catchable(_), _) => a, + (_, x @ Value::Catchable(_)) => x, + _ => return (Err(Error::EvalError(format!(""))), Value::Null), + }; + (Ok(()), val) + }, + ) } pub fn mul(&mut self, other: Self) { @@ -490,7 +520,7 @@ impl Value { self } - pub fn to_public(&self, ctx: &Ctx, seen: &mut HashSet>) -> PubValue { + pub fn to_public(&self, seen: &mut HashSet) -> PubValue { use Value::*; if seen.contains(self) { return PubValue::Repeated; @@ -498,11 +528,11 @@ impl Value { match self { AttrSet(attrs) => { seen.insert(self.clone()); - attrs.to_public(ctx, seen) + attrs.to_public(seen) } List(list) => { seen.insert(self.clone()); - list.to_public(ctx, seen) + list.to_public(seen) } Catchable(catchable) => PubValue::Catchable(catchable.clone().into()), Int(x) => PubValue::Const(Const::Int(*x)), diff --git a/evaluator/nixjit_eval/src/value/primop.rs b/evaluator/nixjit_eval/src/value/primop.rs index eed6f83..06faee3 100644 --- a/evaluator/nixjit_eval/src/value/primop.rs +++ b/evaluator/nixjit_eval/src/value/primop.rs @@ -3,67 +3,34 @@ use std::rc::Rc; use derive_more::Constructor; use nixjit_error::Result; +use nixjit_ir::PrimOpId; use super::Value; use crate::EvalContext; #[derive(Debug, Clone, Constructor)] -pub struct PrimOp { +pub struct PrimOpApp { pub name: &'static str, arity: usize, - func: fn(Vec>, &Ctx) -> Result>, + id: PrimOpId, + args: Vec, } -impl PrimOp { - pub fn call(&self, args: Vec>, ctx: &Ctx) -> Result> { - if args.len() > self.arity { - todo!() - } - if self.arity > args.len() { - Value::PrimOpApp(Rc::new(PrimOpApp { - name: self.name, - arity: self.arity - args.len(), - args, - func: self.func, - })) - .ok() - } else { - (self.func)(args, ctx) - } - } -} - -#[derive(Debug)] -pub struct PrimOpApp { - pub name: &'static str, - arity: usize, - args: Vec>, - func: fn(Vec>, &Ctx) -> Result>, -} - -impl Clone for PrimOpApp { - fn clone(&self) -> Self { - Self { - name: self.name, - arity: self.arity, - args: self.args.clone(), - func: self.func, - } - } -} - -impl PrimOpApp { - pub fn call(self: &mut Rc, args: Vec>, ctx: &Ctx) -> Result> { +impl PrimOpApp { + pub fn call( + self: &mut Rc, + args: Vec, + ctx: &mut impl EvalContext, + ) -> Result { if self.arity < args.len() { todo!() } - let func = self.func; let Some(ret) = ({ let self_mut = Rc::make_mut(self); self_mut.arity -= args.len(); self_mut.args.extend(args); if self_mut.arity == 0 { - Some(func(std::mem::take(&mut self_mut.args), ctx)) + Some(ctx.call_primop(self_mut.id, std::mem::take(&mut self_mut.args))) } else { None } diff --git a/evaluator/nixjit_hir/src/downgrade.rs b/evaluator/nixjit_hir/src/downgrade.rs index ab92ef3..a2fcd19 100644 --- a/evaluator/nixjit_hir/src/downgrade.rs +++ b/evaluator/nixjit_hir/src/downgrade.rs @@ -256,40 +256,6 @@ impl Downgrade for ast::Lambda { let param = downgrade_param(self.param().unwrap(), ctx)?; let mut body = self.body().unwrap().downgrade(ctx)?; - // Desugar pattern matching in function arguments into a `let` expression. - // For example, `({ a, b ? 2 }): a + b` is desugared into: - // `arg: let a = arg.a; b = arg.b or 2; in a + b` - if let Param::Formals { formals, alias, .. } = ¶m { - // `Arg` represents the raw argument (the attribute set) passed to the function. - let arg = ctx.new_expr(Hir::Arg(Arg)); - let mut bindings: HashMap<_, _> = formals - .iter() - .map(|&(ref k, default)| { - // For each formal parameter, create a `Select` expression to extract it from the argument set. - ( - k.clone(), - ctx.new_expr( - Select { - expr: arg, - attrpath: vec![Attr::Str(k.clone())], - default, - } - .to_hir(), - ), - ) - }) - .collect(); - // If there's an alias (`... }@alias`), bind the alias name to the raw argument set. - if let Some(alias) = alias { - bindings.insert( - alias.clone(), - ctx.new_expr(Var { sym: alias.clone() }.to_hir()), - ); - } - // Wrap the original function body in the new `let` expression. - let let_ = Let { bindings, body }; - body = ctx.new_expr(let_.to_hir()); - } let ident; let required; let allowed; @@ -304,22 +270,55 @@ impl Downgrade for ast::Lambda { ellipsis, alias, } => { - ident = alias; + ident = alias.clone(); required = Some( formals .iter() - .cloned() .filter(|(_, default)| default.is_none()) - .map(|(k, _)| k) + .map(|(k, _)| k.clone()) .collect(), ); allowed = if ellipsis { None } else { - Some(formals.into_iter().map(|(k, _)| k).collect()) + Some(formals.iter().map(|(k, _)| k.clone()).collect()) }; + + // Desugar pattern matching in function arguments into a `let` expression. + // For example, `({ a, b ? 2 }): a + b` is desugared into: + // `arg: let a = arg.a; b = arg.b or 2; in a + b` + let mut bindings: HashMap<_, _> = formals + .into_iter() + .map(|(k, default)| { + // For each formal parameter, create a `Select` expression to extract it from the argument set. + // `Arg` represents the raw argument (the attribute set) passed to the function. + let arg = ctx.new_expr(Hir::Arg(Arg)); + ( + k.clone(), + ctx.new_expr( + Select { + expr: arg, + attrpath: vec![Attr::Str(k.clone())], + default, + } + .to_hir(), + ), + ) + }) + .collect(); + // If there's an alias (`... }@alias`), bind the alias name to the raw argument set. + if let Some(alias) = alias { + bindings.insert( + alias.clone(), + ctx.new_expr(Var { sym: alias.clone() }.to_hir()), + ); + } + // Wrap the original function body in the new `let` expression. + let let_ = Let { bindings, body }; + body = ctx.new_expr(let_.to_hir()); } } + let param = ir::Param { ident, required, diff --git a/evaluator/nixjit_hir/src/lib.rs b/evaluator/nixjit_hir/src/lib.rs index cf8f6c9..d42174f 100644 --- a/evaluator/nixjit_hir/src/lib.rs +++ b/evaluator/nixjit_hir/src/lib.rs @@ -40,7 +40,7 @@ pub trait DowngradeContext { fn with_expr(&self, id: ExprId, f: impl FnOnce(&Hir, &Self) -> T) -> T; /// Provides temporary mutable access to an expression. - fn with_expr_mut(&mut self, id: ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T; + fn with_expr_mut(&mut self, id: &ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T; } ir! { @@ -123,7 +123,7 @@ impl Attrs for AttrSet { match attr { Attr::Str(ident) => { // If the next attribute is a static string. - if let Some(&id) = self.stcs.get(&ident) { + if let Some(id) = self.stcs.get(&ident) { // If a sub-attrset already exists, recurse into it. ctx.with_expr_mut(id, |expr, ctx| { expr.as_mut() @@ -186,7 +186,7 @@ impl Attrs for AttrSet { } } -#[derive(Clone, Debug)] +#[derive(Debug)] enum Param { /// A simple parameter, e.g., `x: ...`. Ident(String), diff --git a/evaluator/nixjit_hir/src/utils.rs b/evaluator/nixjit_hir/src/utils.rs index c5004f9..4156d5a 100644 --- a/evaluator/nixjit_hir/src/utils.rs +++ b/evaluator/nixjit_hir/src/utils.rs @@ -125,13 +125,13 @@ pub fn downgrade_inherit( )); } }; - let expr = from.map_or_else( + let expr = from.as_ref().map_or_else( // If `from` is None, `inherit foo;` becomes `foo = foo;`. || Var { sym: ident.clone() }.to_hir(), // If `from` is Some, `inherit (from) foo;` becomes `foo = from.foo;`. |expr| { Select { - expr, + expr: unsafe { expr.clone() }, attrpath: vec![Attr::Str(ident.clone())], default: None, } diff --git a/evaluator/nixjit_ir/src/lib.rs b/evaluator/nixjit_ir/src/lib.rs index dfc45c4..529001f 100644 --- a/evaluator/nixjit_ir/src/lib.rs +++ b/evaluator/nixjit_ir/src/lib.rs @@ -3,7 +3,7 @@ //! The IR provides a simplified, language-agnostic representation of Nix expressions, //! serving as a bridge between the high-level representation (HIR) and the low-level //! representation (LIR). It defines the fundamental building blocks like expression IDs, -//! function IDs, and structures for various expression types (e.g., binary operations, +//! argument indexes, and structures for various expression types (e.g., binary operations, //! attribute sets, function calls). //! //! These structures are designed to be generic and reusable across different stages of @@ -18,35 +18,40 @@ use nixjit_value::Const as PubConst; /// A type-safe wrapper for an index into an expression table. #[repr(transparent)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ExprId(usize); -impl From for ExprId { - fn from(id: usize) -> Self { - ExprId(id) +impl ExprId { + #[inline(always)] + pub unsafe fn clone(&self) -> Self { + Self(self.0) + } + + #[inline(always)] + pub unsafe fn raw(self) -> usize { + self.0 + } + + #[inline(always)] + pub unsafe fn from(id: usize) -> Self { + Self(id) } } -impl From for usize { - fn from(id: ExprId) -> Self { - id.0 - } -} - -/// A type-safe wrapper for an index into a function table. +/// A type-safe wrapper for an index into a primop (builtin function) table. #[repr(transparent)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct FuncId(usize); +pub struct PrimOpId(usize); -impl From for FuncId { - fn from(id: usize) -> Self { - FuncId(id) +impl PrimOpId { + #[inline(always)] + pub unsafe fn raw(self) -> usize { + self.0 } -} -impl From for usize { - fn from(id: FuncId) -> Self { - id.0 + #[inline(always)] + pub unsafe fn from(id: usize) -> Self { + Self(id) } } @@ -55,15 +60,15 @@ impl From for usize { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ArgIdx(usize); -impl From for ArgIdx { - fn from(id: usize) -> Self { - ArgIdx(id) +impl ArgIdx { + #[inline(always)] + pub unsafe fn raw(self) -> usize { + self.0 } -} -impl From for usize { - fn from(id: ArgIdx) -> Self { - id.0 + #[inline(always)] + pub unsafe fn from(idx: usize) -> Self { + Self(idx) } } @@ -79,7 +84,7 @@ pub struct AttrSet { } /// Represents a key in an attribute path. -#[derive(Clone, Debug, TryUnwrap)] +#[derive(Debug, TryUnwrap)] pub enum Attr { /// A dynamic attribute key, which is an expression that must evaluate to a string. Dynamic(ExprId), @@ -234,6 +239,14 @@ pub struct Call { pub args: Vec, } +// Represents a primitive operation (builtin function) +#[derive(Debug, Clone, Copy)] +pub struct PrimOp { + pub name: &'static str, + pub id: PrimOpId, + pub arity: usize, +} + /// Represents a `with` expression. #[derive(Debug)] pub struct With { diff --git a/evaluator/nixjit_jit/src/compile.rs b/evaluator/nixjit_jit/src/compile.rs index a6b6a93..a41c678 100644 --- a/evaluator/nixjit_jit/src/compile.rs +++ b/evaluator/nixjit_jit/src/compile.rs @@ -88,10 +88,10 @@ impl JITCompile for BinOp { let float_block = ctx.builder.create_block(); let float_check_block = ctx.builder.create_block(); - let is_int = - ctx.builder - .ins() - .icmp_imm(IntCC::Equal, lhs_tag, Value::::INT as i64); + let is_int = ctx + .builder + .ins() + .icmp_imm(IntCC::Equal, lhs_tag, Value::INT as i64); ctx.builder .ins() .brif(is_int, int_block, [], float_check_block, []); @@ -109,7 +109,7 @@ impl JITCompile for BinOp { let is_float = ctx.builder .ins() - .icmp_imm(IntCC::Equal, lhs_tag, Value::::FLOAT as i64); + .icmp_imm(IntCC::Equal, lhs_tag, Value::FLOAT as i64); ctx.builder .ins() .brif(is_float, float_block, [], default_block, []); @@ -141,10 +141,10 @@ impl JITCompile for BinOp { let float_block = ctx.builder.create_block(); let float_check_block = ctx.builder.create_block(); - let is_int = - ctx.builder - .ins() - .icmp_imm(IntCC::Equal, lhs_tag, Value::::INT as i64); + let is_int = ctx + .builder + .ins() + .icmp_imm(IntCC::Equal, lhs_tag, Value::INT as i64); ctx.builder .ins() .brif(is_int, int_block, [], float_check_block, []); @@ -162,7 +162,7 @@ impl JITCompile for BinOp { let is_float = ctx.builder .ins() - .icmp_imm(IntCC::Equal, lhs_tag, Value::::FLOAT as i64); + .icmp_imm(IntCC::Equal, lhs_tag, Value::FLOAT as i64); ctx.builder .ins() .brif(is_float, float_block, [], default_block, []); @@ -194,10 +194,10 @@ impl JITCompile for BinOp { let float_block = ctx.builder.create_block(); let float_check_block = ctx.builder.create_block(); - let is_int = - ctx.builder - .ins() - .icmp_imm(IntCC::Equal, lhs_tag, Value::::INT as i64); + let is_int = ctx + .builder + .ins() + .icmp_imm(IntCC::Equal, lhs_tag, Value::INT as i64); ctx.builder .ins() .brif(is_int, int_block, [], float_check_block, []); @@ -215,7 +215,7 @@ impl JITCompile for BinOp { let is_float = ctx.builder .ins() - .icmp_imm(IntCC::Equal, lhs_tag, Value::::FLOAT as i64); + .icmp_imm(IntCC::Equal, lhs_tag, Value::FLOAT as i64); ctx.builder .ins() .brif(is_float, float_block, [], default_block, []); @@ -245,10 +245,10 @@ impl JITCompile for BinOp { let bool_block = ctx.builder.create_block(); let non_bool_block = ctx.builder.create_block(); - let is_bool = - ctx.builder - .ins() - .icmp_imm(IntCC::Equal, lhs_tag, Value::::BOOL as i64); + let is_bool = ctx + .builder + .ins() + .icmp_imm(IntCC::Equal, lhs_tag, Value::BOOL as i64); ctx.builder .ins() .brif(is_bool, bool_block, [], non_bool_block, []); @@ -275,10 +275,10 @@ impl JITCompile for BinOp { let bool_block = ctx.builder.create_block(); let non_bool_block = ctx.builder.create_block(); - let is_bool = - ctx.builder - .ins() - .icmp_imm(IntCC::Equal, lhs_tag, Value::::BOOL as i64); + let is_bool = ctx + .builder + .ins() + .icmp_imm(IntCC::Equal, lhs_tag, Value::BOOL as i64); ctx.builder .ins() .brif(is_bool, bool_block, [], non_bool_block, []); @@ -378,10 +378,10 @@ impl JITCompile for If { let judge_block = ctx.builder.create_block(); let slot = ctx.alloca(); - let is_bool = - ctx.builder - .ins() - .icmp_imm(IntCC::Equal, cond_type, Value::::BOOL as i64); + let is_bool = ctx + .builder + .ins() + .icmp_imm(IntCC::Equal, cond_type, Value::BOOL as i64); ctx.builder .ins() .brif(is_bool, judge_block, [], error_block, []); @@ -480,37 +480,25 @@ impl JITCompile for Const { let slot = ctx.alloca(); match self.val { Bool(x) => { - let tag = ctx - .builder - .ins() - .iconst(types::I64, Value::::BOOL as i64); + let tag = ctx.builder.ins().iconst(types::I64, Value::BOOL as i64); let val = ctx.builder.ins().iconst(types::I64, x as i64); ctx.builder.ins().stack_store(tag, slot, 0); ctx.builder.ins().stack_store(val, slot, 8); } Int(x) => { - let tag = ctx - .builder - .ins() - .iconst(types::I64, Value::::INT as i64); + let tag = ctx.builder.ins().iconst(types::I64, Value::INT as i64); let val = ctx.builder.ins().iconst(types::I64, x); ctx.builder.ins().stack_store(tag, slot, 0); ctx.builder.ins().stack_store(val, slot, 8); } Float(x) => { - let tag = ctx - .builder - .ins() - .iconst(types::I64, Value::::FLOAT as i64); + let tag = ctx.builder.ins().iconst(types::I64, Value::FLOAT as i64); let val = ctx.builder.ins().f64const(x); ctx.builder.ins().stack_store(tag, slot, 0); ctx.builder.ins().stack_store(val, slot, 8); } Null => { - let tag = ctx - .builder - .ins() - .iconst(types::I64, Value::::NULL as i64); + let tag = ctx.builder.ins().iconst(types::I64, Value::NULL as i64); ctx.builder.ins().stack_store(tag, slot, 0); } } diff --git a/evaluator/nixjit_jit/src/helpers.rs b/evaluator/nixjit_jit/src/helpers.rs index d628fc1..4512823 100644 --- a/evaluator/nixjit_jit/src/helpers.rs +++ b/evaluator/nixjit_jit/src/helpers.rs @@ -11,21 +11,21 @@ use nixjit_eval::{AttrSet, EvalContext, List, Value}; use super::JITContext; pub extern "C" fn helper_call( - func: &mut Value, - args_ptr: *mut Value, + func: &mut Value, + args_ptr: *mut Value, args_len: usize, ctx: &mut Ctx, ) { // TODO: Error Handling let args = core::ptr::slice_from_raw_parts_mut(args_ptr, args_len); let args = unsafe { Box::from_raw(args) }; - func.call(args.into_iter().collect(), ctx).unwrap(); + func.call(args.into_iter().map(Ok), ctx).unwrap(); } pub extern "C" fn helper_lookup_stack( ctx: &Ctx, offset: usize, - ret: &mut MaybeUninit>, + ret: &mut MaybeUninit, ) { ret.write(ctx.lookup_stack(offset).clone()); } @@ -33,7 +33,7 @@ pub extern "C" fn helper_lookup_stack( pub extern "C" fn helper_lookup_arg( ctx: &Ctx, offset: usize, - ret: &mut MaybeUninit>, + ret: &mut MaybeUninit, ) { ret.write(JITContext::lookup_arg(ctx, offset).clone()); } @@ -42,7 +42,7 @@ pub extern "C" fn helper_lookup( ctx: &Ctx, sym_ptr: *const u8, sym_len: usize, - ret: &mut MaybeUninit>, + ret: &mut MaybeUninit, ) { // TODO: Error Handling unsafe { @@ -57,8 +57,8 @@ pub extern "C" fn helper_lookup( } pub extern "C" fn helper_select( - val: &mut Value, - path_ptr: *mut Value, + val: &mut Value, + path_ptr: *mut Value, path_len: usize, ) { let path = core::ptr::slice_from_raw_parts_mut(path_ptr, path_len); @@ -71,10 +71,10 @@ pub extern "C" fn helper_select( } pub extern "C" fn helper_select_with_default( - val: &mut Value, - path_ptr: *mut Value, + val: &mut Value, + path_ptr: *mut Value, path_len: usize, - default: NonNull>, + default: NonNull, ) { let path = core::ptr::slice_from_raw_parts_mut(path_ptr, path_len); let path = unsafe { Box::from_raw(path) }; @@ -88,14 +88,14 @@ pub extern "C" fn helper_select_with_default( .unwrap(); } -pub extern "C" fn helper_eq(lhs: &mut Value, rhs: &Value) { +pub extern "C" fn helper_eq(lhs: &mut Value, rhs: &Value) { lhs.eq(rhs); } pub unsafe extern "C" fn helper_create_string( ptr: *const u8, len: usize, - ret: &mut MaybeUninit>, + ret: &mut MaybeUninit, ) { unsafe { ret.write(Value::String( @@ -105,9 +105,9 @@ pub unsafe extern "C" fn helper_create_string( } pub unsafe extern "C" fn helper_create_list( - ptr: *mut Value, + ptr: *mut Value, len: usize, - ret: &mut MaybeUninit>, + ret: &mut MaybeUninit, ) { unsafe { ret.write(Value::List( @@ -117,16 +117,16 @@ pub unsafe extern "C" fn helper_create_list( } pub unsafe extern "C" fn helper_create_attrs( - ret: &mut MaybeUninit>>, + ret: &mut MaybeUninit>, ) { ret.write(HashMap::new()); } pub unsafe extern "C" fn helper_push_attr( - attrs: &mut HashMap>, + attrs: &mut HashMap, sym_ptr: *const u8, sym_len: usize, - val: NonNull>, + val: NonNull, ) { unsafe { attrs.insert( @@ -137,8 +137,8 @@ pub unsafe extern "C" fn helper_push_attr( } pub unsafe extern "C" fn helper_finalize_attrs( - attrs: NonNull>>, - ret: &mut MaybeUninit>, + attrs: NonNull>, + ret: &mut MaybeUninit, ) { ret.write(Value::AttrSet( AttrSet::from(unsafe { attrs.read() }).into(), @@ -147,7 +147,7 @@ pub unsafe extern "C" fn helper_finalize_attrs( pub unsafe extern "C" fn helper_enter_with( ctx: &mut Ctx, - namespace: NonNull>, + namespace: NonNull, ) { ctx.enter_with(unsafe { namespace.read() }.unwrap_attr_set().into_inner()); } @@ -157,9 +157,9 @@ pub unsafe extern "C" fn helper_exit_with(ctx: &mut Ctx) { } pub unsafe extern "C" fn helper_alloc_array(len: usize) -> *mut u8 { - unsafe { alloc(Layout::array::>(len).unwrap()) } + unsafe { alloc(Layout::array::(len).unwrap()) } } -pub extern "C" fn helper_dbg(value: &Value) { +pub extern "C" fn helper_dbg(value: &Value) { println!("{value:?}") } diff --git a/evaluator/nixjit_jit/src/lib.rs b/evaluator/nixjit_jit/src/lib.rs index e1ed25e..f80b689 100644 --- a/evaluator/nixjit_jit/src/lib.rs +++ b/evaluator/nixjit_jit/src/lib.rs @@ -19,13 +19,13 @@ pub use compile::JITCompile; use helpers::*; pub trait JITContext: EvalContext + Sized { - fn lookup_stack(&self, offset: usize) -> &Value; - fn lookup_arg(&self, offset: usize) -> &Value; - fn enter_with(&mut self, namespace: Rc>>); + fn lookup_stack(&self, offset: usize) -> &Value; + fn lookup_arg(&self, offset: usize) -> &Value; + fn enter_with(&mut self, namespace: Rc>); fn exit_with(&mut self); } -type F = unsafe extern "C" fn(*const Ctx, *mut Value); +type F = unsafe extern "C" fn(*const Ctx, *mut Value); pub struct JITFunc { func: F, diff --git a/evaluator/nixjit_lir/src/lib.rs b/evaluator/nixjit_lir/src/lib.rs index 534cfa0..2209c91 100644 --- a/evaluator/nixjit_lir/src/lib.rs +++ b/evaluator/nixjit_lir/src/lib.rs @@ -24,12 +24,16 @@ ir! { Str, Var, Path, + PrimOp, ExprRef(ExprId), FuncRef(ExprId), ArgRef(ArgIdx), } -#[derive(Debug, Clone, Copy)] +#[derive(Debug)] +pub struct Builtins; + +#[derive(Debug)] pub enum LookupResult { Expr(ExprId), Arg(ArgIdx), @@ -38,21 +42,17 @@ pub enum LookupResult { } pub trait ResolveContext { - fn new_dep(&mut self, expr: ExprId, dep: ExprId); - fn new_func(&mut self, body: ExprId, param: Param); - fn resolve(&mut self, expr: ExprId) -> Result<()>; + fn new_dep(&mut self, expr: &ExprId, dep: ExprId); + fn new_func(&mut self, body: &ExprId, param: Param); + fn resolve(&mut self, expr: &ExprId) -> Result<()>; fn lookup(&self, name: &str) -> LookupResult; fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T); fn with_let_env<'a, T>( &mut self, - bindings: impl IntoIterator, - f: impl FnOnce(&mut Self) -> T, - ) -> T; - fn with_param_env<'a, T>( - &mut self, - ident: Option<&'a str>, + bindings: impl Iterator, f: impl FnOnce(&mut Self) -> T, ) -> T; + fn with_param_env(&mut self, ident: Option, f: impl FnOnce(&mut Self) -> T) -> T; } pub trait Resolve { @@ -80,7 +80,7 @@ impl Resolve for hir::Hir { Var(x) => x.resolve(ctx), Path(x) => x.resolve(ctx), Let(x) => x.resolve(ctx), - Arg(_) => todo!(), + Arg(_) => unsafe { Ok(Lir::ArgRef(ArgIdx::from(0))) }, } } } @@ -90,10 +90,10 @@ impl Resolve for AttrSet { if self.rec { todo!() } else { - for (_, &v) in self.stcs.iter() { + for (_, v) in self.stcs.iter() { ctx.resolve(v)?; } - for &(k, v) in self.dyns.iter() { + for (k, v) in self.dyns.iter() { ctx.resolve(k)?; ctx.resolve(v)?; } @@ -104,7 +104,7 @@ impl Resolve for AttrSet { impl Resolve for List { fn resolve(self, ctx: &mut Ctx) -> Result { - for &item in self.items.iter() { + for item in self.items.iter() { ctx.resolve(item)?; } Ok(self.to_lir()) @@ -113,10 +113,10 @@ impl Resolve for List { impl Resolve for HasAttr { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.lhs)?; + ctx.resolve(&self.lhs)?; for attr in self.rhs.iter() { - if let &Attr::Dynamic(expr) = attr { - ctx.resolve(expr)?; + if let Attr::Dynamic(expr) = attr { + ctx.resolve(&expr)?; } } Ok(self.to_lir()) @@ -125,28 +125,28 @@ impl Resolve for HasAttr { impl Resolve for BinOp { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.lhs)?; - ctx.resolve(self.rhs)?; + ctx.resolve(&self.lhs)?; + ctx.resolve(&self.rhs)?; Ok(self.to_lir()) } } impl Resolve for UnOp { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.rhs)?; + ctx.resolve(&self.rhs)?; Ok(self.to_lir()) } } impl Resolve for Select { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.expr)?; + ctx.resolve(&self.expr)?; for attr in self.attrpath.iter() { - if let &Attr::Dynamic(expr) = attr { - ctx.resolve(expr)?; + if let Attr::Dynamic(expr) = attr { + ctx.resolve(&expr)?; } } - if let Some(expr) = self.default { + if let Some(ref expr) = self.default { ctx.resolve(expr)?; } Ok(self.to_lir()) @@ -155,25 +155,25 @@ impl Resolve for Select { impl Resolve for If { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.cond)?; - ctx.resolve(self.consq)?; - ctx.resolve(self.alter)?; + ctx.resolve(&self.cond)?; + ctx.resolve(&self.consq)?; + ctx.resolve(&self.alter)?; Ok(self.to_lir()) } } impl Resolve for Func { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.with_param_env(self.param.ident.as_deref(), |ctx| ctx.resolve(self.body))?; - ctx.new_func(self.body, self.param); + ctx.with_param_env(self.param.ident.clone(), |ctx| ctx.resolve(&self.body))?; + ctx.new_func(&self.body, self.param); Ok(Lir::FuncRef(self.body)) } } impl Resolve for Call { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.func)?; - for &arg in self.args.iter() { + ctx.resolve(&self.func)?; + for arg in self.args.iter() { ctx.resolve(arg)?; } Ok(self.to_lir()) @@ -182,8 +182,8 @@ impl Resolve for Call { impl Resolve for With { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.namespace)?; - let (env_used, res) = ctx.with_with_env(|ctx| ctx.resolve(self.expr)); + ctx.resolve(&self.namespace)?; + let (env_used, res) = ctx.with_with_env(|ctx| ctx.resolve(&self.expr)); res?; if env_used { Ok(self.to_lir()) @@ -195,15 +195,15 @@ impl Resolve for With { impl Resolve for Assert { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.assertion)?; - ctx.resolve(self.expr)?; + ctx.resolve(&self.assertion)?; + ctx.resolve(&self.expr)?; Ok(self.to_lir()) } } impl Resolve for ConcatStrings { fn resolve(self, ctx: &mut Ctx) -> Result { - for &part in self.parts.iter() { + for part in self.parts.iter() { ctx.resolve(part)?; } Ok(self.to_lir()) @@ -227,7 +227,7 @@ impl Resolve for Var { impl Resolve for Path { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.expr)?; + ctx.resolve(&self.expr)?; Ok(self.to_lir()) } } @@ -235,10 +235,10 @@ impl Resolve for Path { impl Resolve for hir::Let { fn resolve(self, ctx: &mut Ctx) -> Result { ctx.with_let_env(self.bindings.iter(), |ctx| { - for &id in self.bindings.values() { + for id in self.bindings.values() { ctx.resolve(id)?; } - ctx.resolve(self.body) + ctx.resolve(&self.body) })?; Ok(Lir::ExprRef(self.body)) } diff --git a/evaluator/nixjit_macros/src/builtins.rs b/evaluator/nixjit_macros/src/builtins.rs index 983d4ee..a9079b4 100644 --- a/evaluator/nixjit_macros/src/builtins.rs +++ b/evaluator/nixjit_macros/src/builtins.rs @@ -1,16 +1,17 @@ use convert_case::{Case, Casing}; use proc_macro::TokenStream; use proc_macro2::Span; -use quote::{format_ident, quote, ToTokens}; +use quote::{ToTokens, format_ident, quote}; use syn::{ - parse_macro_input, FnArg, Item, ItemConst, ItemFn, ItemMod, Pat, PatType, Type, - Visibility, + FnArg, Item, ItemFn, ItemMod, Pat, PatType, Type, Visibility, parse_macro_input, }; pub fn builtins_impl(input: TokenStream) -> TokenStream { let item_mod = parse_macro_input!(input as ItemMod); + let mod_name = &item_mod.ident; + let visibility = &item_mod.vis; - let (_brace, items) = match item_mod.content.clone() { + let (_brace, items) = match item_mod.content { Some(content) => content, None => { return syn::Error::new_spanned( @@ -23,63 +24,76 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { }; let mut pub_item_mod: Vec = Vec::new(); - let mut const_inserters = Vec::new(); - let mut global_inserters = Vec::new(); - let mut scoped_inserters = Vec::new(); + let mut consts = Vec::new(); + let mut global = Vec::new(); + let mut scoped = Vec::new(); let mut wrappers = Vec::new(); for item in &items { match item { Item::Const(item_const) => { - let inserter = generate_const_inserter(item_const); - const_inserters.push(inserter); - pub_item_mod.push(quote! { - pub #item_const - }.into()); + let name_str = item_const + .ident + .to_string() + .from_case(Case::UpperSnake) + .to_case(Case::Camel); + let const_name = &item_const.ident; + consts.push(quote! { (#name_str, builtins::#const_name) }); + pub_item_mod.push( + quote! { + pub #item_const + } + .into(), + ); } Item::Fn(item_fn) => { - let (inserter, wrapper) = match generate_fn_wrapper(item_fn) { + let (primop, wrapper) = match generate_primop_wrapper(item_fn) { Ok(result) => result, Err(e) => return e.to_compile_error().into(), }; if matches!(item_fn.vis, Visibility::Public(_)) { - global_inserters.push(inserter); + global.push(primop); pub_item_mod.push(quote! { #item_fn }.into()); } else { - scoped_inserters.push(inserter); - pub_item_mod.push(quote! { - pub #item_fn - }.into()); + scoped.push(primop); + pub_item_mod.push( + quote! { + pub #item_fn + } + .into(), + ); } wrappers.push(wrapper); } - item => pub_item_mod.push(item.to_token_stream()) + item => pub_item_mod.push(item.to_token_stream()), } } + let consts_len = consts.len(); + let global_len = global.len(); + let scoped_len = scoped.len(); let output = quote! { - mod builtins { + #visibility mod #mod_name { #(#pub_item_mod)* #(#wrappers)* + pub const CONSTS_LEN: usize = #consts_len; + pub const GLOBAL_LEN: usize = #global_len; + pub const SCOPED_LEN: usize = #scoped_len; } pub struct Builtins { - pub consts: ::std::vec::Vec<(String, ::nixjit_value::Const)>, - pub global: ::std::vec::Vec<(String, fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value>)>, - pub scoped: ::std::vec::Vec<(String, fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value>)>, + pub consts: [(&'static str, ::nixjit_value::Const); #mod_name::CONSTS_LEN], + pub global: [(&'static str, usize, fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::GLOBAL_LEN], + pub scoped: [(&'static str, usize, fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::SCOPED_LEN], } impl Builtins { pub fn new() -> Self { - let mut consts = ::std::vec::Vec::new(); - let mut global = ::std::vec::Vec::new(); - let mut scoped = ::std::vec::Vec::new(); - - #(#const_inserters)* - #(#global_inserters)* - #(#scoped_inserters)* - - Self { consts, global, scoped } + Self { + consts: [#(#consts,)*], + global: [#(#global,)*], + scoped: [#(#scoped,)*], + } } } }; @@ -87,26 +101,17 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { output.into() } -fn generate_const_inserter( - item_const: &ItemConst, -) -> proc_macro2::TokenStream { - let name_str = item_const.ident.to_string().from_case(Case::UpperSnake).to_case(Case::Camel); - let const_name = &item_const.ident; - - quote! { - consts.push((#name_str.to_string(), builtins::#const_name)); - } -} - -fn generate_fn_wrapper( +fn generate_primop_wrapper( item_fn: &ItemFn, ) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> { let fn_name = &item_fn.sig.ident; - let name_str = fn_name.to_string().from_case(Case::Snake).to_case(Case::Camel); + let name_str = fn_name + .to_string() + .from_case(Case::Snake) + .to_case(Case::Camel); let wrapper_name = format_ident!("wrapper_{}", fn_name); let mod_name = format_ident!("builtins"); - let is_pub = matches!(item_fn.vis, Visibility::Public(_)); let mut user_args = item_fn.sig.inputs.iter().peekable(); let has_ctx = if let Some(FnArg::Typed(first_arg)) = user_args.peek() { @@ -185,15 +190,13 @@ fn generate_fn_wrapper( quote! { Ok(#fn_name(#call_args).into()) } }; - let fn_type = quote! { fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value> }; - let inserter = if is_pub { - quote! { global.push((#name_str.to_string(), #mod_name::#wrapper_name as #fn_type)); } - } else { - quote! { scoped.push((#name_str.to_string(), #mod_name::#wrapper_name as #fn_type)); } - }; + let arity = arg_names.len(); + let fn_type = quote! { fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value> }; + let primop = + quote! { (#name_str, #arity, #mod_name::#wrapper_name as #fn_type) }; let wrapper = quote! { - pub fn #wrapper_name(ctx: &mut Ctx, mut args: Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value> { + pub fn #wrapper_name(ctx: &mut Ctx, mut args: Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value> { if args.len() != #arg_count { return Err(::nixjit_error::Error::EvalError(format!("Function '{}' expects {} arguments, but received {}", #name_str, #arg_count, args.len()))); } @@ -203,5 +206,5 @@ fn generate_fn_wrapper( } }; - Ok((inserter, wrapper)) + Ok((primop, wrapper)) }