diff --git a/Cargo.lock b/Cargo.lock index 7d348d6..d79d365 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,6 +56,15 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "cc" +version = "1.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc" +dependencies = [ + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.1" @@ -425,6 +434,16 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "libmimalloc-sys" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "667f4fec20f29dfc6bc7357c582d91796c169ad7e2fce709468aefeb2c099870" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -461,6 +480,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mimalloc" +version = "0.1.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1ee66a4b64c74f4ef288bcbb9192ad9c3feaad75193129ac8509af543894fd8" +dependencies = [ + "libmimalloc-sys", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -488,6 +516,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bumpalo", + "mimalloc", "nixjit_context", "nixjit_value", "regex", @@ -815,6 +844,12 @@ dependencies = [ "syn", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "smallvec" version = "1.15.1" diff --git a/TODO.md b/TODO.md index 8152c12..ede9745 100644 --- a/TODO.md +++ b/TODO.md @@ -30,3 +30,9 @@ e.g. 3. HIR -> LIR (resolve var, build graph) 4. LIR -> Value +为每个值分配 ID 的难点在于对动态表达式的引用。 +动态表达式有: + - 依赖于函数参数的表达式 + - 依赖于 with 的表达式 + - 依赖于动态表达式的表达式 +而这些表达式在每一次分配 ValueId 时指向的 ValueId 都不同,因此需要追踪这些变量。 diff --git a/evaluator/nixjit/Cargo.toml b/evaluator/nixjit/Cargo.toml index 3d546c3..42aadc4 100644 --- a/evaluator/nixjit/Cargo.toml +++ b/evaluator/nixjit/Cargo.toml @@ -4,6 +4,8 @@ version = "0.1.0" edition = "2024" [dependencies] +mimalloc = "0.1" + anyhow = "1.0" bumpalo = "3.19" regex = "1.11" diff --git a/evaluator/nixjit/src/lib.rs b/evaluator/nixjit/src/lib.rs index a0fce7a..68889f1 100644 --- a/evaluator/nixjit/src/lib.rs +++ b/evaluator/nixjit/src/lib.rs @@ -4,8 +4,13 @@ //! and evaluating Nix expressions. It integrates all the other `nixjit_*` //! components to provide a complete Nix evaluation environment. +use mimalloc::MiMalloc; + pub use nixjit_context as context; pub use nixjit_value as value; +#[global_allocator] +static GLOBAL: MiMalloc = MiMalloc; + #[cfg(test)] mod test; diff --git a/evaluator/nixjit_builtins/src/lib.rs b/evaluator/nixjit_builtins/src/lib.rs index 2539233..6e6f30b 100644 --- a/evaluator/nixjit_builtins/src/lib.rs +++ b/evaluator/nixjit_builtins/src/lib.rs @@ -1,7 +1,11 @@ +use nixjit_error::Result; +use nixjit_eval::{Args, Value}; use nixjit_macros::builtins; pub trait BuiltinsContext {} +pub type BuiltinFn = fn(&mut Ctx, Args) -> Result; + #[builtins] pub mod builtins { use std::rc::Rc; diff --git a/evaluator/nixjit_context/Cargo.toml b/evaluator/nixjit_context/Cargo.toml index acf1d41..db82b3c 100644 --- a/evaluator/nixjit_context/Cargo.toml +++ b/evaluator/nixjit_context/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -bumpalo = { version = "3.19", features = ["boxed"] } +bumpalo = { version = "3.19", features = ["boxed", "collections"] } derive_more = { version = "2.0", features = ["full"] } hashbrown = "0.15" itertools = "0.14" diff --git a/evaluator/nixjit_context/src/downgrade.rs b/evaluator/nixjit_context/src/downgrade.rs index 8ee06d8..3d9e62e 100644 --- a/evaluator/nixjit_context/src/downgrade.rs +++ b/evaluator/nixjit_context/src/downgrade.rs @@ -43,6 +43,12 @@ impl DowngradeContext for DowngradeCtx<'_, '_> { } fn with_expr_mut(&mut self, id: ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T { + // SAFETY: This is a common pattern to temporarily bypass the borrow checker. + // We are creating a mutable reference to `self` from a raw pointer. This is safe + // because `self_mut` is only used within the closure `f`, and we are careful + // not to create aliasing mutable references. The `RefCell`'s runtime borrow + // checking further ensures that we don't have multiple mutable borrows of the + // same `Hir` expression simultaneously. unsafe { let self_mut = &mut *(self as *mut Self); f(&mut self.get_ir(id).borrow_mut(), self_mut) @@ -57,6 +63,10 @@ impl DowngradeContext for DowngradeCtx<'_, '_> { for (idx, ir) in self.ctx.hirs.iter().enumerate() { println!( "{:?} {:#?}", + // SAFETY: The index `idx` is obtained from iterating over `self.ctx.hirs`, + // so it is guaranteed to be a valid index. The length of `lirs` is added + // as an offset to ensure the `ExprId` correctly corresponds to its position + // in the combined IR storage. unsafe { ExprId::from_raw(idx + self.ctx.lirs.len()) }, &ir ); diff --git a/evaluator/nixjit_context/src/eval.rs b/evaluator/nixjit_context/src/eval.rs index a44097a..edafbf9 100644 --- a/evaluator/nixjit_context/src/eval.rs +++ b/evaluator/nixjit_context/src/eval.rs @@ -1,131 +1,95 @@ -#[cfg(debug_assertions)] -use std::cell::OnceCell; -#[cfg(not(debug_assertions))] -use std::mem::MaybeUninit; use std::rc::Rc; use hashbrown::HashMap; -use itertools::Itertools; +use petgraph::prelude::DiGraph; -use nixjit_error::Result; -use nixjit_eval::{Args, EvalContext, Evaluate, StackFrame, Value}; -use nixjit_ir::ExprId; +use nixjit_error::{Error, Result}; +use nixjit_eval::{Args, EvalContext, Evaluate, PrimOpApp, Value, ValueId}; +use nixjit_ir::{ExprId, PrimOpId}; use nixjit_jit::JITContext; use nixjit_lir::Lir; -use petgraph::prelude::DiGraph; use super::Context; -struct ValueCache( - #[cfg(debug_assertions)] - Option, - #[cfg(not(debug_assertions))] - MaybeUninit -); - -impl Default for ValueCache { - fn default() -> Self { - #[cfg(debug_assertions)] - { - Self(None) - } - #[cfg(not(debug_assertions))] - { - Self(MaybeUninit::uninit()) - } - } +enum ValueCache { + Expr(ExprId), + BlackHole, + Value(Value), } impl ValueCache { - fn insert(&mut self, val: Value) { - #[cfg(debug_assertions)] - { - assert!(self.0.is_none()); - let _ = self.0.insert(val); - } - #[cfg(not(debug_assertions))] - self.0.write(val); - } -} - -impl Drop for ValueCache { - fn drop(&mut self) { - #[cfg(not(debug_assertions))] - unsafe { - self.0.assume_init_drop(); + fn get_or_eval(&mut self, eval: impl FnOnce(ExprId) -> Result) -> Result<&Value> { + match self { + &mut Self::Expr(id) => { + *self = Self::BlackHole; + match eval(id) { + Ok(value) => { + *self = Self::Value(value); + let Self::Value(value) = self else { + unreachable!() + }; + Ok(value) + } + Err(err) => Err(err), + } + } + Self::Value(value) => Ok(value), + Self::BlackHole => Err(Error::eval_error(format!("infinite recursion encountered"))), } } } pub struct EvalCtx<'ctx, 'bump> { ctx: &'ctx mut Context<'bump>, - graph: DiGraph, - stack: Vec, + graph: DiGraph, ()>, + caches: Vec, with_scopes: Vec>>, } impl<'ctx, 'bump> EvalCtx<'ctx, 'bump> { pub fn new(ctx: &'ctx mut Context<'bump>) -> Self { Self { - ctx, - graph: DiGraph::new(), - stack: Vec::new(), + graph: DiGraph::with_capacity(ctx.graph.node_count(), ctx.graph.edge_count()), + caches: Vec::new(), with_scopes: Vec::new(), + ctx, } } - - fn eval_deps(&mut self, expr: ExprId, arg: Option) -> Result<()> { - let deps = self - .ctx - .graph - .edges(expr) - .sorted_by_key(|(.., idx)| **idx) - .map(|(_, dep, idx)| (dep, *idx)) - .collect_vec(); - let mut frame = (0..deps.len()) - .map(|_| Value::Blackhole) - .collect::(); - dbg!(&deps, &self.stack); - for (dep, idx) in deps { - unsafe { - if matches!(&**self.ctx.lirs.get_unchecked(dep.raw()), Lir::Arg(_)) { - *frame.get_unchecked_mut(idx.raw()) = arg.as_ref().unwrap().clone(); - continue; - } - } - let dep = self.eval(dep)?; - unsafe { - *frame.get_unchecked_mut(idx.raw()) = dep; - } - } - *self.stack.last_mut().unwrap() = frame; - dbg!(&self.stack); - Ok(()) - } } impl EvalContext for EvalCtx<'_, '_> { - fn eval_root(mut self, expr: ExprId) -> Result { - self.stack.push(StackFrame::new()); - self.eval_deps(expr, None)?; - self.eval(expr) - } - fn eval(&mut self, expr: ExprId) -> Result { + // SAFETY: The `expr` `ExprId` is guaranteed to be a valid index into the `lirs` + // vector by the `downgrade` and `resolve` stages, which are responsible for + // creating and managing these IDs. The `get_unchecked` is safe under this invariant. + // The subsequent raw pointer operations are to safely extend the lifetime of the + // `Lir` reference. This is sound because the `lirs` vector is allocated within a + // `Bump` arena, ensuring that the `Lir` objects have a stable memory location + // and will not be deallocated or moved for the lifetime of the context. let idx = unsafe { expr.raw() }; let lir = unsafe { &*(&**self.ctx.lirs.get_unchecked(idx) as *const Lir) }; lir.eval(self) } - fn call(&mut self, func: ExprId, arg: Option, frame: StackFrame) -> Result { - self.stack.push(frame); - if let Err(err) = self.eval_deps(func, arg) { - self.stack.pop(); - return Err(err); - } - let ret = self.eval(func); - self.stack.pop(); - ret + fn resolve(&mut self, id: ExprId) -> Result { + let mut deps = Vec::new(); + + self.caches.push(ValueCache::Expr(id)); + let id = self.graph.add_node(deps); + + // SAFETY: The `id.index()` is guaranteed to be a valid raw ID for a `ValueId` + // because it is generated by the `petgraph::DiGraph`, which manages its own + // internal indices. This ensures that the raw value is unique and corresponds + // to a valid node in the graph. + Ok(unsafe { ValueId::from_raw(id.index()) }) + } + + fn call(&mut self, func: ValueId, arg: Value) -> Result { + todo!() + } + + fn force(&mut self, id: ValueId) -> Result { + todo!() } fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a Value> { @@ -137,35 +101,18 @@ impl EvalContext for EvalCtx<'_, '_> { None } - fn lookup_stack(&self, idx: nixjit_ir::StackIdx) -> &Value { - if cfg!(debug_assertions) { - self.stack - .last() - .unwrap() - .get(unsafe { idx.raw() }) - .unwrap() + fn call_primop(&mut self, id: PrimOpId, args: Args) -> Result { + // SAFETY: The `PrimOpId` is created and managed by the `Context` and is + // guaranteed to be a valid index into the `primops` array. The `get_unchecked` + // is safe under this invariant, avoiding a bounds check for performance. + let &(arity, primop) = unsafe { self.ctx.primops.get_unchecked(id.raw()) }; + if args.len() == arity { + primop(self.ctx, args) } else { - unsafe { - self.stack - .last() - .unwrap_unchecked() - .get_unchecked(idx.raw()) - } + Ok(Value::PrimOpApp(PrimOpApp::new(id, args).into())) } } - fn capture_stack(&self) -> &StackFrame { - self.stack.last().unwrap() - } - - fn call_primop(&mut self, id: nixjit_ir::PrimOpId, args: Args) -> Result { - unsafe { (self.ctx.primops.get_unchecked(id.raw()).1)(self.ctx, args) } - } - - fn get_primop_arity(&self, id: nixjit_ir::PrimOpId) -> usize { - unsafe { self.ctx.primops.get_unchecked(id.raw()).0 } - } - fn with_with_env( &mut self, namespace: Rc>, diff --git a/evaluator/nixjit_context/src/lib.rs b/evaluator/nixjit_context/src/lib.rs index ab18317..11e1ed3 100644 --- a/evaluator/nixjit_context/src/lib.rs +++ b/evaluator/nixjit_context/src/lib.rs @@ -1,4 +1,3 @@ -use std::cell::Cell; use std::ptr::NonNull; use bumpalo::{Bump, boxed::Box}; @@ -7,13 +6,12 @@ use itertools::Itertools; use petgraph::graphmap::DiGraphMap; use nixjit_builtins::{ - Builtins, BuiltinsContext, - builtins::{GLOBAL_LEN, SCOPED_LEN}, + builtins::{GLOBAL_LEN, SCOPED_LEN}, BuiltinFn, Builtins, BuiltinsContext }; use nixjit_error::{Error, Result}; use nixjit_eval::{Args, EvalContext, Value}; use nixjit_hir::{DowngradeContext, Hir}; -use nixjit_ir::{AttrSet, ExprId, Param, PrimOpId, StackIdx}; +use nixjit_ir::{AttrSet, ExprId, Param, PrimOpId}; use nixjit_lir::Lir; use crate::downgrade::DowngradeCtx; @@ -39,16 +37,23 @@ pub struct Context<'bump> { global_scope: NonNull>, /// A dependency graph between expressions. - graph: DiGraphMap, + graph: DiGraphMap, /// A table of primitive operation implementations. - primops: [(usize, fn(&mut Self, Args) -> Result); GLOBAL_LEN + SCOPED_LEN], + primops: [(usize, BuiltinFn); GLOBAL_LEN + SCOPED_LEN], bump: &'bump Bump, } impl Drop for Context<'_> { fn drop(&mut self) { + // SAFETY: `repl_scope` and `global_scope` are `NonNull` pointers to `HashMap`s + // allocated within the `bump` arena. Because `NonNull` does not convey ownership, + // Rust's drop checker will not automatically drop the pointed-to `HashMap`s when + // the `Context` is dropped. We must manually call `drop_in_place` to ensure + // their destructors are run. This is safe because these pointers are guaranteed + // to be valid and non-null for the lifetime of the `Context`, as they are + // initialized in `new()` and never deallocated or changed. unsafe { self.repl_scope.drop_in_place(); self.global_scope.drop_in_place(); @@ -62,10 +67,18 @@ impl<'bump> Context<'bump> { let global_scope = global .iter() .enumerate() - .map(|(idx, (k, _, _))| (*k, unsafe { ExprId::from_raw(idx) })) - .chain(core::iter::once(("builtins", unsafe { - ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN) - }))) + .map(|(idx, (k, _, _))| { + // SAFETY: The index `idx` comes from `enumerate()` on the `global` array, + // so it is guaranteed to be a valid, unique index for a primop LIR. + (*k, unsafe { ExprId::from_raw(idx) }) + }) + .chain(core::iter::once(( + "builtins", + // SAFETY: This ID corresponds to the `builtins` attrset LIR, which is + // constructed and placed after all the global and scoped primop LIRs. + // The index is calculated to be exactly at that position. + unsafe { ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN) }, + ))) .collect(); let primops = global .iter() @@ -74,30 +87,48 @@ impl<'bump> Context<'bump> { .collect_array() .unwrap(); let lirs = (0..global.len()) - .map(|idx| Lir::PrimOp(unsafe { PrimOpId::from_raw(idx) })) - .chain( - (0..scoped.len()) - .map(|idx| Lir::PrimOp(unsafe { PrimOpId::from_raw(idx + GLOBAL_LEN) })), - ) + .map(|idx| { + // SAFETY: The index `idx` is guaranteed to be within the bounds of the + // `global` primops array, making it a valid raw ID for a `PrimOpId`. + Lir::PrimOp(unsafe { PrimOpId::from_raw(idx) }) + }) + .chain((0..scoped.len()).map(|idx| { + // SAFETY: The index `idx` is within the bounds of the `scoped` primops + // array. Adding `GLOBAL_LEN` correctly offsets it to its position in + // the combined `primops` table. + Lir::PrimOp(unsafe { PrimOpId::from_raw(idx + GLOBAL_LEN) }) + })) .chain(core::iter::once(Lir::AttrSet(AttrSet { stcs: global .into_iter() .enumerate() - .map(|(idx, (name, ..))| (name.to_string(), unsafe { ExprId::from_raw(idx) })) + .map(|(idx, (name, ..))| { + // SAFETY: `idx` from `enumerate` is a valid index for the LIR + // corresponding to this global primop. + (name.to_string(), unsafe { ExprId::from_raw(idx) }) + }) .chain(scoped.into_iter().enumerate().map(|(idx, (name, ..))| { + // SAFETY: `idx + GLOBAL_LEN` is a valid index for the LIR + // corresponding to this scoped primop. (name.to_string(), unsafe { ExprId::from_raw(idx + GLOBAL_LEN) }) })) - .chain(core::iter::once(("builtins".to_string(), unsafe { - ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN + 1) - }))) + .chain(core::iter::once(( + "builtins".to_string(), + // SAFETY: This ID points to the `Thunk` that wraps this very + // `AttrSet`. The index is calculated to be one position after + // the `AttrSet` itself. + unsafe { ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN + 1) }, + ))) .collect(), ..AttrSet::default() }))) - .chain(core::iter::once(Lir::Thunk(unsafe { - ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN) - }))) + .chain(core::iter::once(Lir::Thunk( + // SAFETY: This ID points to the `builtins` `AttrSet` defined just above. + // Its index is calculated to be at that exact position. + unsafe { ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN) }, + ))) .map(|lir| Box::new_in(lir, bump)) .collect_vec(); Self { @@ -144,8 +175,9 @@ impl<'bump> Context<'bump> { let root = self .downgrade_ctx() .downgrade_root(root.tree().expr().unwrap())?; - self.resolve_ctx(root).resolve_root()?; - Ok(self.eval_ctx().eval_root(root)?.to_public()) + let ctx = self.resolve_ctx(root); + ctx.resolve_root()?; + Ok(self.eval_ctx().eval(root)?.to_public()) } pub fn add_binding(&mut self, ident: &str, expr: &str) -> Result<()> { @@ -157,6 +189,10 @@ impl<'bump> Context<'bump> { .unwrap(); let expr_id = self.downgrade_ctx().downgrade_root(root_expr)?; self.resolve_ctx(expr_id).resolve_root()?; + // SAFETY: `repl_scope` is a `NonNull` pointer that is guaranteed to be valid + // for the lifetime of `Context`. It is initialized in `new()` and the memory + // it points to is managed by the `bump` arena. Therefore, it is safe to + // dereference it to a mutable reference here. unsafe { self.repl_scope.as_mut() }.insert(ident.to_string(), expr_id); Ok(()) } @@ -165,20 +201,15 @@ impl<'bump> Context<'bump> { impl Context<'_> { fn alloc_id(&mut self) -> ExprId { self.ir_count += 1; + // SAFETY: This function is the sole source of new `ExprId`s during the + // downgrade and resolve phases. By monotonically incrementing `ir_count`, + // we guarantee that each ID is unique and corresponds to a valid, soon-to-be- + // allocated slot in the IR vectors. unsafe { ExprId::from_raw(self.ir_count - 1) } } - fn add_dep(&mut self, from: ExprId, to: ExprId, count: &Cell) -> StackIdx { - if let Some(&idx) = self.graph.edge_weight(from, to) { - idx - } else { - let idx = count.get(); - count.set(idx + 1); - let idx = unsafe { StackIdx::from_raw(idx) }; - assert_ne!(from, to); - self.graph.add_edge(from, to, idx); - idx - } + fn add_dep(&mut self, from: ExprId, to: ExprId) { + self.graph.add_edge(from, to, ()); } } diff --git a/evaluator/nixjit_context/src/resolve.rs b/evaluator/nixjit_context/src/resolve.rs index b4b1ed5..64e6b18 100644 --- a/evaluator/nixjit_context/src/resolve.rs +++ b/evaluator/nixjit_context/src/resolve.rs @@ -1,25 +1,26 @@ -use std::cell::{Cell, RefCell}; +use std::{cell::RefCell, ptr::NonNull}; -use bumpalo::boxed::Box; -use derive_more::Unwrap; +use bumpalo::{boxed::Box, collections::Vec}; +use derive_more::{Constructor, Unwrap}; use hashbrown::HashMap; +use replace_with::replace_with_and_return; use nixjit_error::Result; use nixjit_hir::Hir; use nixjit_ir::{Const, ExprId, Param, StackIdx}; use nixjit_lir::{Lir, LookupResult, Resolve, ResolveContext}; -use replace_with::replace_with_and_return; use super::Context; #[derive(Clone)] -enum Scope<'ctx> { +enum Scope { /// A `let` binding scope, mapping variable names to their expression IDs. Let(HashMap), /// A function argument scope. `Some` holds the name of the argument set if present. Arg(Option), - Builtins(&'ctx HashMap<&'static str, ExprId>), - Repl(&'ctx HashMap), + // Not using &'ctx HashMap<_, _> because bumpalo's Vec<'bump, T> is invariant over T. + Builtins(NonNull>), + Repl(NonNull>), } /// Represents an expression at different stages of compilation. @@ -31,29 +32,18 @@ enum Ir { Lir(Lir), } +#[derive(Constructor)] struct Closure { id: ExprId, arg: ExprId, - deps: Cell -} - -impl Closure { - fn new(id: ExprId, arg: ExprId) -> Self { - Self { - id, - arg, - deps: Cell::new(0) - } - } } pub struct ResolveCtx<'ctx, 'bump> { ctx: &'ctx mut Context<'bump>, - irs: Vec>>, + irs: Vec<'bump, RefCell>, root: ExprId, - root_deps: Cell, - closures: Vec, - scopes: Vec>, + closures: Vec<'bump, Closure>, + scopes: Vec<'bump, Scope>, has_with: bool, with_used: bool, } @@ -61,37 +51,38 @@ pub struct ResolveCtx<'ctx, 'bump> { impl<'ctx, 'bump> ResolveCtx<'ctx, 'bump> { pub fn new(ctx: &'ctx mut Context<'bump>, root: ExprId) -> Self { Self { - scopes: vec![ - Scope::Builtins(unsafe { ctx.global_scope.as_ref() }), - Scope::Repl(unsafe { ctx.repl_scope.as_ref() }), - ], + scopes: { + let mut vec = Vec::new_in(ctx.bump); + vec.push(Scope::Builtins(ctx.global_scope)); + vec.push(Scope::Repl(ctx.repl_scope)); + vec + }, has_with: false, with_used: false, - irs: core::mem::take(&mut ctx.hirs) - .into_iter() - .map(|hir| Ir::Hir(hir).into()) - .map(|ir| Box::new_in(ir, ctx.bump)) - .collect(), + irs: Vec::from_iter_in( + core::mem::take(&mut ctx.hirs) + .into_iter() + .map(Ir::Hir) + .map(RefCell::new), + ctx.bump, + ), + closures: Vec::new_in(ctx.bump), ctx, root, - root_deps: Cell::new(0), - closures: Vec::new(), } } pub fn resolve_root(mut self) -> Result<()> { let ret = self.resolve(self.root); - if ret.is_ok() { + ret.map(|_| { self.ctx.lirs.extend( self.irs .into_iter() - .map(Box::into_inner) .map(RefCell::into_inner) .map(Ir::unwrap_lir) .map(|lir| Box::new_in(lir, self.ctx.bump)), ); - } - ret + }) } fn get_ir(&self, id: ExprId) -> &RefCell { @@ -102,21 +93,15 @@ impl<'ctx, 'bump> ResolveCtx<'ctx, 'bump> { unsafe { self.irs.get_unchecked(idx) } } } - - fn new_lir(&mut self, lir: Lir) -> ExprId { - self.irs - .push(Box::new_in(RefCell::new(Ir::Lir(lir)), self.ctx.bump)); - self.ctx.alloc_id() - } } impl ResolveContext for ResolveCtx<'_, '_> { - fn resolve(&mut self, expr: ExprId) -> Result<()> { - let result = unsafe { + fn resolve(&mut self, expr: ExprId) -> Result { + unsafe { let ctx = &mut *(self as *mut Self); let ir = self.get_ir(expr); if !matches!(ir.try_borrow().as_deref(), Ok(Ir::Hir(_))) { - return Ok(()); + return Ok(expr); } replace_with_and_return( &mut *ir.borrow_mut(), @@ -126,7 +111,8 @@ impl ResolveContext for ResolveCtx<'_, '_> { })) }, |ir| match ir.unwrap_hir().resolve(ctx) { - Ok(lir) => (Ok(()), Ir::Lir(lir)), + Ok(lir @ Lir::ExprRef(expr)) => (Ok(expr), Ir::Lir(lir)), + Ok(lir) => (Ok(expr), Ir::Lir(lir)), Err(err) => ( Err(err), Ir::Hir(Hir::Const(Const { @@ -135,44 +121,44 @@ impl ResolveContext for ResolveCtx<'_, '_> { ), }, ) - }; - result + } } fn lookup(&mut self, name: &str) -> LookupResult { let mut closure_depth = 0; - // Then search from outer to inner scopes for dependencies for scope in self.scopes.iter().rev() { match scope { Scope::Builtins(scope) => { - if let Some(&primop) = scope.get(&name) { - return LookupResult::PrimOp(primop); + if let Some(&primop) = unsafe { scope.as_ref() }.get(&name) { + return LookupResult::Expr(primop); } } - Scope::Let(scope) | &Scope::Repl(scope) => { + Scope::Let(scope) => { if let Some(&dep) = scope.get(name) { - let (expr, deps) = self.closures.last().map_or_else(|| (self.root, &self.root_deps), |closure| (closure.id, &closure.deps)); - let idx = self.ctx.add_dep(expr, dep, deps); - return LookupResult::Stack(idx); + let expr = self + .closures + .last() + .map_or_else(|| self.root, |closure| closure.id); + self.ctx.add_dep(expr, dep); + return LookupResult::Expr(dep); + } + } + &Scope::Repl(scope) => { + if let Some(&dep) = unsafe { scope.as_ref() }.get(name) { + let expr = self + .closures + .last() + .map_or_else(|| self.root, |closure| closure.id); + self.ctx.add_dep(expr, dep); + return LookupResult::Expr(dep); } } Scope::Arg(ident) => { if ident.as_deref() == Some(name) { - // This is an outer function's parameter, treat as dependency - // We need to find the corresponding parameter expression to create dependency - // For now, we need to handle this case by creating a dependency to the parameter - let mut iter = self.closures.iter().rev().take(closure_depth + 1).rev(); - let Closure { id: func, arg, deps: count } = iter.next().unwrap(); - let mut cur = self.ctx.add_dep(*func, *arg, count); - for Closure { id: func, deps: count, .. } in iter { - self.irs.push(Box::new_in( - RefCell::new(Ir::Lir(Lir::StackRef(cur))), - self.ctx.bump, - )); - let idx = self.ctx.alloc_id(); - cur = self.ctx.add_dep(*func, idx, count); - } - return LookupResult::Stack(cur); + let &Closure { id: func, arg } = + self.closures.iter().nth_back(closure_depth).unwrap(); + self.ctx.add_dep(func, arg); + return LookupResult::Expr(arg); } closure_depth += 1; } @@ -186,9 +172,8 @@ impl ResolveContext for ResolveCtx<'_, '_> { } } - fn lookup_arg(&mut self) -> StackIdx { - let Closure { id: func, arg, deps } = self.closures.last().unwrap(); - self.ctx.add_dep(*func, *arg, deps) + fn lookup_arg(&mut self) -> ExprId { + self.closures.last().unwrap().arg } fn new_func(&mut self, body: ExprId, param: Param) { @@ -206,23 +191,23 @@ impl ResolveContext for ResolveCtx<'_, '_> { res } - fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T) { + fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> Result<()>) -> Result { let has_with = self.has_with; let with_used = self.with_used; self.has_with = true; self.with_used = false; let res = f(self); self.has_with = has_with; - (core::mem::replace(&mut self.with_used, with_used), res) + res.map(|_| core::mem::replace(&mut self.with_used, with_used)) } fn with_closure_env( &mut self, func: ExprId, + arg: ExprId, ident: Option, f: impl FnOnce(&mut Self) -> T, ) -> T { - let arg = self.new_lir(Lir::Arg(nixjit_ir::Arg)); self.closures.push(Closure::new(func, arg)); self.scopes.push(Scope::Arg(ident)); let res = f(self); diff --git a/evaluator/nixjit_error/src/lib.rs b/evaluator/nixjit_error/src/lib.rs index 2cd1c36..f0b226d 100644 --- a/evaluator/nixjit_error/src/lib.rs +++ b/evaluator/nixjit_error/src/lib.rs @@ -101,7 +101,7 @@ impl std::fmt::Display for Error { write!(f, "\n --> {}:{}", start_line, start_col)?; write!(f, "\n |\n")?; - write!(f, "{:4} | {}\n", start_line, line_str)?; + writeln!(f, "{:4} | {}", start_line, line_str)?; write!( f, " | {}{}", diff --git a/evaluator/nixjit_eval/src/lib.rs b/evaluator/nixjit_eval/src/lib.rs index f1849ef..dbc2a8b 100644 --- a/evaluator/nixjit_eval/src/lib.rs +++ b/evaluator/nixjit_eval/src/lib.rs @@ -12,9 +12,9 @@ use std::rc::Rc; use hashbrown::HashMap; use nixjit_error::{Error, Result}; -use nixjit_ir::{self as ir, ExprId, PrimOpId, StackIdx}; +use nixjit_ir::{self as ir, ExprId, PrimOpId, SymId}; use nixjit_lir as lir; -use nixjit_value::{Const, format_symbol}; +use nixjit_value::format_symbol; pub use crate::value::*; @@ -22,12 +22,13 @@ mod value; /// A trait defining the context in which LIR expressions are evaluated. pub trait EvalContext { - fn eval_root(self, expr: ExprId) -> Result; + fn eval(&mut self, id: ExprId) -> Result; - /// Evaluates an expression by its ID. - fn eval(&mut self, expr: ExprId) -> Result; + fn resolve(&mut self, id: ExprId) -> Result; - fn call(&mut self, func: ExprId, arg: Option, frame: StackFrame) -> Result; + fn force(&mut self, id: ValueId) -> Result; + + fn call(&mut self, func: ValueId, arg: Value) -> Result; /// Enters a `with` scope for the duration of a closure's execution. fn with_with_env( @@ -36,18 +37,14 @@ pub trait EvalContext { f: impl FnOnce(&mut Self) -> T, ) -> T; - /// Looks up a stack slot on the current stack frame. - fn lookup_stack(&self, idx: StackIdx) -> &Value; - - fn capture_stack(&self) -> &StackFrame; - /// Looks up an identifier in the current `with` scope chain. - fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a Value>; + fn lookup_with<'a>(&'a self, ident: SymId) -> Option<&'a Value>; /// Calls a primitive operation (builtin) by its ID. fn call_primop(&mut self, id: PrimOpId, args: Args) -> Result; - fn get_primop_arity(&self, id: PrimOpId) -> usize; + fn new_sym(&mut self, sym: String) -> SymId; + fn get_sym(&self, id: SymId) -> &str; } /// A trait for types that can be evaluated within an `EvalContext`. @@ -83,18 +80,12 @@ impl Evaluate for lir::Lir { Str(x) => x.eval(ctx), Var(x) => x.eval(ctx), Path(x) => x.eval(ctx), - &StackRef(idx) => { - let mut val = ctx.lookup_stack(idx).clone(); - val.force(ctx)?; - Ok(val) - } &ExprRef(expr) => ctx.eval(expr), - &FuncRef(body) => Ok(Value::Closure( - Closure::new(body, ctx.capture_stack().clone()).into(), - )), + &FuncRef(body) => ctx.resolve(body).map(Value::Closure), &Arg(_) => unreachable!(), &PrimOp(primop) => Ok(Value::PrimOp(primop)), - &Thunk(id) => Ok(Value::Thunk(id)), + &Thunk(id) => ctx.resolve(id).map(Value::Thunk), + &StackRef(idx) => todo!(), } } } @@ -105,15 +96,17 @@ impl Evaluate for ir::AttrSet { let mut attrs = AttrSet::new( self.stcs .iter() - .map(|(k, v)| { + .map(|(&k, &v)| { let eval_result = v.eval(ctx); - Ok((k.clone(), eval_result?)) + Ok((k, eval_result?)) }) .collect::>()?, ); for (k, v) in self.dyns.iter() { let v = v.eval(ctx)?; - attrs.push_attr(k.eval(ctx)?.force_string_no_ctx()?, v)?; + let sym = k.eval(ctx)?.force_string_no_ctx()?; + let sym = ctx.new_sym(sym); + attrs.push_attr(sym, v, ctx)?; } let result = Value::AttrSet(attrs.into()); Ok(result) @@ -138,9 +131,14 @@ impl Evaluate for ir::HasAttr { fn eval(&self, ctx: &mut Ctx) -> Result { use ir::Attr::*; let mut val = self.lhs.eval(ctx)?; - val.has_attr(self.rhs.iter().map(|attr| match attr { - Str(ident) => Ok(Value::String(ident.clone())), - Dynamic(expr) => expr.eval(ctx), + val.has_attr(self.rhs.iter().map(|attr| { + match attr { + &Str(ident) => Ok(ident), + Dynamic(expr) => expr + .eval(ctx)? + .force_string_no_ctx() + .map(|sym| ctx.new_sym(sym)), + } }))?; Ok(val) } @@ -165,9 +163,9 @@ impl Evaluate for ir::BinOp { } Mul => lhs.mul(rhs)?, Div => lhs.div(rhs)?, - Eq => Value::eq(&mut lhs, rhs), + Eq => lhs.eq(rhs), Neq => { - Value::eq(&mut lhs, rhs); + lhs.eq(rhs); let _ = lhs.not(); } Lt => lhs.lt(rhs)?, @@ -182,12 +180,12 @@ impl Evaluate for ir::BinOp { } Geq => { lhs.lt(rhs)?; - let _ = lhs.not()?; + let _ = lhs.not(); } And => lhs.and(rhs)?, Or => lhs.or(rhs)?, Impl => { - let _ = lhs.not(); + lhs.not()?; lhs.or(rhs)?; } Con => lhs.concat(rhs)?, @@ -226,12 +224,11 @@ impl Evaluate for ir::Select { use ir::Attr::*; let mut val = self.expr.eval(ctx)?; for attr in self.attrpath.iter() { - let name_val; let name = match attr { - Str(name) => name, + &Str(name) => name, Dynamic(expr) => { - name_val = expr.eval(ctx)?; - &*name_val.force_string_no_ctx()? + let sym = expr.eval(ctx)?.force_string_no_ctx()?; + ctx.new_sym(sym) } }; if let Some(default) = self.default { @@ -338,11 +335,9 @@ impl Evaluate for ir::Var { /// Evaluates a `Var` by looking it up in the `with` scope chain. /// This is for variables that could not be resolved statically. fn eval(&self, ctx: &mut Ctx) -> Result { - ctx.lookup_with(&self.sym) - .ok_or_else(|| { - Error::eval_error(format!("undefined variable '{}'", format_symbol(&self.sym))) - }) - .map(|val| val.clone()) + ctx.lookup_with(self.sym).cloned().ok_or_else(|| { + Error::eval_error(format!("undefined variable '{}'", format_symbol(ctx.get_sym(self.sym)))) + }) } } diff --git a/evaluator/nixjit_eval/src/value/attrset.rs b/evaluator/nixjit_eval/src/value/attrset.rs index 9deef64..c6d862e 100644 --- a/evaluator/nixjit_eval/src/value/attrset.rs +++ b/evaluator/nixjit_eval/src/value/attrset.rs @@ -10,7 +10,7 @@ use hashbrown::hash_map::Entry; use itertools::Itertools; use nixjit_error::{Error, Result}; -use nixjit_ir::ExprId; +use nixjit_ir::{ExprId, SymId}; use nixjit_value::{self as p, format_symbol}; use crate::EvalContext; @@ -24,7 +24,7 @@ use super::Value; #[repr(transparent)] #[derive(Clone, Constructor)] pub struct AttrSet { - data: HashMap, + data: HashMap, } impl Debug for AttrSet { @@ -33,23 +33,23 @@ impl Debug for AttrSet { write!(f, "{{ ")?; for (k, v) in self.data.iter() { match v { - List(_) => write!(f, "{} = [ ... ]; ", format_symbol(k))?, - AttrSet(_) => write!(f, "{} = {{ ... }}; ", format_symbol(k))?, - v => write!(f, "{} = {v:?}; ", format_symbol(k))?, + List(_) => write!(f, "{:?} = [ ... ]; ", k)?, + AttrSet(_) => write!(f, "{:?} = {{ ... }}; ", k)?, + v => write!(f, "{:?} = {v:?}; ", k)?, } } write!(f, "}}") } } -impl From> for AttrSet { - fn from(data: HashMap) -> Self { +impl From> for AttrSet { + fn from(data: HashMap) -> Self { Self { data } } } impl Deref for AttrSet { - type Target = HashMap; + type Target = HashMap; fn deref(&self) -> &Self::Target { &self.data } @@ -64,16 +64,16 @@ impl AttrSet { } /// Inserts an attribute, overwriting any existing attribute with the same name. - pub fn push_attr_force(&mut self, sym: String, val: Value) { + pub fn push_attr_force(&mut self, sym: SymId, val: Value) { self.data.insert(sym, val); } /// Inserts an attribute, returns an error if the attribute is already defined. - pub fn push_attr(&mut self, sym: String, val: Value) -> Result<()> { + pub fn push_attr(&mut self, sym: SymId, val: Value, ctx: &mut impl EvalContext) -> Result<()> { match self.data.entry(sym) { Entry::Occupied(occupied) => Err(Error::eval_error(format!( "attribute '{}' already defined", - format_symbol(occupied.key()) + format_symbol(ctx.get_sym(*occupied.key())) ))), Entry::Vacant(vacant) => { vacant.insert(val); @@ -82,29 +82,29 @@ impl AttrSet { } } - pub fn select(&self, name: &str, ctx: &mut impl EvalContext) -> Result { + pub fn select(&self, name: SymId, ctx: &mut impl EvalContext) -> Result { self.data - .get(name) + .get(&name) .cloned() .map(|attr| match attr { - Value::Thunk(id) => ctx.eval(id), + Value::Thunk(id) => ctx.force(id), val => Ok(val), }) .ok_or_else(|| { - Error::eval_error(format!("attribute '{}' not found", format_symbol(name))) + Error::eval_error(format!("attribute '{}' not found", format_symbol(ctx.get_sym(name)))) })? } pub fn select_or( &self, - name: &str, + name: SymId, default: ExprId, ctx: &mut impl EvalContext, ) -> Result { self.data - .get(name) + .get(&name) .map(|attr| match attr { - &Value::Thunk(id) => ctx.eval(id), + &Value::Thunk(id) => ctx.force(id), val => Ok(val.clone()), }) .unwrap_or_else(|| ctx.eval(default)) @@ -113,19 +113,19 @@ impl AttrSet { /// Checks if an attribute path exists within the set. pub fn has_attr( &self, - mut path: impl DoubleEndedIterator>, + mut path: impl DoubleEndedIterator>, ) -> Result { let mut data = &self.data; let last = path.nth_back(0).unwrap(); for item in path { - let Some(Value::AttrSet(attrs)) = data.get(&item.unwrap().force_string_no_ctx()?) + let Some(Value::AttrSet(attrs)) = data.get(&item?) else { return Ok(Value::Bool(false)); }; data = attrs.as_inner(); } Ok(Value::Bool( - data.get(&last.unwrap().force_string_no_ctx()?).is_some(), + data.get(&last?).is_some(), )) } @@ -138,24 +138,18 @@ impl AttrSet { } /// Returns a reference to the inner `HashMap`. - pub fn as_inner(&self) -> &HashMap { + pub fn as_inner(&self) -> &HashMap { &self.data } /// Converts an `Rc` to an `Rc>` without allocation. - /// - /// # Safety - /// - /// This is safe because `AttrSet` is `#[repr(transparent)]`. pub fn into_inner(self: Rc) -> Rc> { + // SAFETY: This is safe because `AttrSet` is `#[repr(transparent)]` over + // `HashMap`, so `Rc` has the same layout as + // `Rc>`. unsafe { core::mem::transmute(self) } } - /// Creates an `AttrSet` from a `HashMap`. - pub fn from_inner(data: HashMap) -> Self { - Self { data } - } - /// Performs a deep equality comparison between two `AttrSet`s. /// /// It recursively compares the contents of both sets, ensuring that both keys @@ -171,11 +165,11 @@ impl AttrSet { } /// Converts the `AttrSet` to its public-facing representation. - pub fn to_public(self) -> p::Value { + pub fn to_public(self, ctx: &mut impl EvalContext) -> p::Value { p::Value::AttrSet(p::AttrSet::new( self.data .into_iter() - .map(|(sym, value)| (sym.into(), value.to_public())) + .map(|(sym, value)| (ctx.get_sym(sym).into(), value.to_public(ctx))) .collect(), )) } diff --git a/evaluator/nixjit_eval/src/value/closure.rs b/evaluator/nixjit_eval/src/value/closure.rs deleted file mode 100644 index ffa2415..0000000 --- a/evaluator/nixjit_eval/src/value/closure.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Defines the runtime representation of a partially applied function. -use std::rc::Rc; - -use derive_more::Constructor; - -use nixjit_error::Result; -use nixjit_ir::ExprId; - -use super::Value; -use crate::EvalContext; - -pub type StackFrame = smallvec::SmallVec<[Value; 5]>; - -#[derive(Debug, Clone, Constructor)] -pub struct Closure { - pub body: ExprId, - pub frame: StackFrame, -} - -impl Closure { - pub fn call( - self: Rc, - arg: Option, - ctx: &mut Ctx, - ) -> Result { - let Self { body: func, frame } = Rc::unwrap_or_clone(self); - ctx.call(func, arg, frame) - } -} diff --git a/evaluator/nixjit_eval/src/value/list.rs b/evaluator/nixjit_eval/src/value/list.rs index f69aa16..febfcb9 100644 --- a/evaluator/nixjit_eval/src/value/list.rs +++ b/evaluator/nixjit_eval/src/value/list.rs @@ -70,7 +70,7 @@ impl List { self.data .get(idx) .map(|elem| match elem { - &Value::Thunk(id) => ctx.eval(id), + &Value::Thunk(id) => ctx.force(id), val => Ok(val.clone()), }) .ok_or_else(|| { @@ -93,11 +93,11 @@ impl List { } /// Converts the `List` to its public-facing representation. - pub fn to_public(&self) -> PubValue { + pub fn to_public(&self, ctx: &mut impl EvalContext) -> PubValue { PubValue::List(PubList::new( self.data .iter() - .map(|value| value.clone().to_public()) + .map(|value| value.clone().to_public(ctx)) .collect(), )) } diff --git a/evaluator/nixjit_eval/src/value/mod.rs b/evaluator/nixjit_eval/src/value/mod.rs index e46e7e3..fdc178f 100644 --- a/evaluator/nixjit_eval/src/value/mod.rs +++ b/evaluator/nixjit_eval/src/value/mod.rs @@ -13,6 +13,7 @@ use nixjit_ir::ExprId; use nixjit_ir::PrimOpId; use nixjit_error::{Error, Result}; +use nixjit_ir::SymId; use nixjit_value::Const; use nixjit_value::Value as PubValue; use replace_with::replace_with_and_return; @@ -21,13 +22,11 @@ use smallvec::smallvec; use crate::EvalContext; mod attrset; -mod closure; mod list; mod primop; mod string; pub use attrset::AttrSet; -pub use closure::*; pub use list::List; pub use primop::*; @@ -45,14 +44,12 @@ pub enum Value { Bool(bool) = Self::BOOL, String(String) = Self::STRING, Null = Self::NULL, - Thunk(ExprId) = Self::THUNK, - ClosureThunk(Rc) = Self::CLOSURE_THUNK, + Thunk(ValueId) = Self::THUNK, AttrSet(Rc) = Self::ATTRSET, List(Rc) = Self::LIST, PrimOp(PrimOpId) = Self::PRIMOP, PrimOpApp(Rc) = Self::PRIMOP_APP, - Closure(Rc) = Self::CLOSURE, - Blackhole, + Closure(ValueId) = Self::CLOSURE, } impl Debug for Value { @@ -67,11 +64,9 @@ impl Debug for Value { AttrSet(x) => write!(f, "{x:?}"), List(x) => write!(f, "{x:?}"), Thunk(thunk) => write!(f, ""), - ClosureThunk(_) => write!(f, ""), - Closure(func) => write!(f, "", func.body), + Closure(func) => write!(f, "", func), PrimOp(_) => write!(f, ""), PrimOpApp(_) => write!(f, ""), - Blackhole => write!(f, ""), } } } @@ -129,13 +124,11 @@ impl Value { String(_) => "string", Null => "null", Thunk(_) => "thunk", - ClosureThunk(_) => "thunk", AttrSet(_) => "set", List(_) => "list", PrimOp(_) => "lambda", PrimOpApp(_) => "lambda", Closure(..) => "lambda", - Blackhole => unreachable!(), } } @@ -148,8 +141,7 @@ impl Value { self, || Value::Null, |val| match val { - Value::Thunk(id) => map(ctx.eval(id)), - Value::ClosureThunk(thunk) => map(thunk.call(None, ctx)), + Value::Thunk(id) => map(ctx.force(id)), val => (Ok(()), val), }, ) @@ -170,23 +162,9 @@ impl Value { self, || Null, |func| match func { - PrimOp(id) => { - let arity = ctx.get_primop_arity(id); - if arity == 1 { - map(ctx.call_primop(id, smallvec![arg])) - } else { - ( - Ok(()), - Value::PrimOpApp(Rc::new(self::PrimOpApp::new( - arity - 1, - id, - smallvec![arg], - ))), - ) - } - } - PrimOpApp(func) => map(func.call(arg, ctx)), - Closure(func) => map(func.call(Some(arg), ctx)), + PrimOp(id) => map(ctx.call_primop(id, smallvec![arg])), + PrimOpApp(primop) => map(primop.call(arg, ctx)), + Closure(func) => map(ctx.call(func, arg)), _ => ( Err(Error::eval_error( "attempt to call something which is not a function but ...".to_string(), @@ -240,10 +218,7 @@ impl Value { } pub fn eq(&mut self, other: Self) { - use Value::Bool; - *self = match (&*self, other) { - (s, other) => Bool(s.eq_impl(&other)), - }; + *self = Value::Bool(self.eq_impl(&other)); } pub fn lt(&mut self, other: Self) -> Result<()> { @@ -378,7 +353,7 @@ impl Value { } } - pub fn select(&mut self, name: &str, ctx: &mut impl EvalContext) -> Result<()> { + pub fn select(&mut self, name: SymId, ctx: &mut impl EvalContext) -> Result<()> { use Value::*; let val = match self { AttrSet(attrs) => attrs.select(name, ctx), @@ -393,7 +368,7 @@ impl Value { pub fn select_or( &mut self, - name: &str, + name: SymId, default: ExprId, ctx: &mut Ctx, ) -> Result<()> { @@ -411,7 +386,7 @@ impl Value { Ok(()) } - pub fn has_attr(&mut self, path: impl DoubleEndedIterator>) -> Result<()> { + pub fn has_attr(&mut self, path: impl DoubleEndedIterator>) -> Result<()> { use Value::*; if let AttrSet(attrs) = self { let val = attrs.has_attr(path)?; @@ -436,22 +411,46 @@ impl Value { /// Converts the internal `Value` to its public-facing, serializable /// representation from the `nixjit_value` crate. - pub fn to_public(self) -> PubValue { + pub fn to_public(self, ctx: &mut impl EvalContext) -> PubValue { use Value::*; match self { - AttrSet(attrs) => Rc::unwrap_or_clone(attrs).to_public(), - List(list) => Rc::unwrap_or_clone(list.clone()).to_public(), + AttrSet(attrs) => Rc::unwrap_or_clone(attrs).to_public(ctx), + List(list) => Rc::unwrap_or_clone(list.clone()).to_public(ctx), Int(x) => PubValue::Const(Const::Int(x)), Float(x) => PubValue::Const(Const::Float(x)), Bool(x) => PubValue::Const(Const::Bool(x)), String(x) => PubValue::String(x), Null => PubValue::Const(Const::Null), Thunk(_) => PubValue::Thunk, - ClosureThunk(_) => PubValue::Thunk, PrimOp(_) => PubValue::PrimOp, PrimOpApp(_) => PubValue::PrimOpApp, Closure(..) => PubValue::Func, - Blackhole => unreachable!(), } } } + +#[repr(transparent)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)] +pub struct ValueId(usize); + +impl ValueId { + /// Returns the raw `usize` index. + /// + /// # Safety + /// + /// The caller is responsible for using this index correctly and not causing out-of-bounds access. + #[inline(always)] + pub unsafe fn raw(self) -> usize { + self.0 + } + + /// Creates an `ExprId` from a raw `usize` index. + /// + /// # Safety + /// + /// The caller must ensure that the provided index is valid for the expression table. + #[inline(always)] + pub unsafe fn from_raw(id: usize) -> Self { + Self(id) + } +} diff --git a/evaluator/nixjit_eval/src/value/primop.rs b/evaluator/nixjit_eval/src/value/primop.rs index 2daaa6a..256421e 100644 --- a/evaluator/nixjit_eval/src/value/primop.rs +++ b/evaluator/nixjit_eval/src/value/primop.rs @@ -18,8 +18,6 @@ pub type Args = smallvec::SmallVec<[Value; 2]>; /// all, of its required arguments. #[derive(Debug, Clone, Constructor)] pub struct PrimOpApp { - /// The number of remaining arguments the primop expects. - arity: usize, /// The unique ID of the primop. id: PrimOpId, /// The arguments that have already been applied. @@ -27,20 +25,9 @@ pub struct PrimOpApp { } impl PrimOpApp { - /// Applies more arguments to a partially applied primop. - /// - /// If enough arguments are provided to satisfy the primop's arity, it is - /// executed. Otherwise, it returns a new `PrimOpApp` with the combined - /// arguments. pub fn call(self: Rc, arg: Value, ctx: &mut impl EvalContext) -> Result { - let mut primop = Rc::unwrap_or_clone(self); - if primop.arity == 1 { - primop.args.push(arg); - ctx.call_primop(primop.id, primop.args) - } else { - primop.args.push(arg); - primop.arity -= 1; - Ok(Value::PrimOpApp(primop.into())) - } + let PrimOpApp { id, mut args } = Rc::unwrap_or_clone(self); + args.push(arg); + ctx.call_primop(id, args) } } diff --git a/evaluator/nixjit_hir/src/downgrade.rs b/evaluator/nixjit_hir/src/downgrade.rs index eafa89a..a670ab9 100644 --- a/evaluator/nixjit_hir/src/downgrade.rs +++ b/evaluator/nixjit_hir/src/downgrade.rs @@ -150,6 +150,7 @@ impl Downgrade for ast::Literal { impl Downgrade for ast::Ident { fn downgrade(self, ctx: &mut Ctx) -> Result { let sym = self.ident_token().unwrap().to_string(); + let sym = ctx.new_sym(sym); Ok(ctx.new_expr(Var { sym }.to_hir())) } } @@ -237,8 +238,9 @@ impl Downgrade for ast::LegacyLet { let bindings = attrs.stcs.clone(); let body = ctx.new_expr(attrs.to_hir()); let expr = ctx.new_expr(Let { bindings, body }.to_hir()); + let sym = ctx.new_sym("body".into()); // The result of a `legacy let` is the `body` attribute of the resulting set. - let attrpath = vec![Attr::Str("body".into())]; + let attrpath = vec![Attr::Str(sym)]; Ok(ctx.new_expr( Select { expr, @@ -274,6 +276,7 @@ impl Downgrade for ast::Lambda { fn downgrade(self, ctx: &mut Ctx) -> Result { let param = downgrade_param(self.param().unwrap(), ctx)?; let mut body = self.body().unwrap().downgrade(ctx)?; + let arg = ctx.new_expr(Hir::Arg(())); let ident; let required; @@ -281,7 +284,7 @@ impl Downgrade for ast::Lambda { match param { Param::Ident(id) => { // Simple case: `x: body` - ident = Some(id); + ident = Some(ctx.new_sym(id)); required = None; allowed = None; } @@ -291,35 +294,36 @@ impl Downgrade for ast::Lambda { alias, } => { // Complex case: `{ a, b ? 2, ... }@args: body` - ident = alias.clone(); + let alias = alias.map(|sym| ctx.new_sym(sym)); + ident = alias; required = Some( formals .iter() .filter(|(_, default)| default.is_none()) - .map(|(k, _)| k.clone()) + .map(|(k, _)| ctx.new_sym(k.clone())) .collect(), ); allowed = if ellipsis { None // `...` means any attribute is allowed. } else { - Some(formals.iter().map(|(k, _)| k.clone()).collect()) + Some(formals.iter().map(|(k, _)| ctx.new_sym(k.clone())).collect()) }; // Desugar pattern matching in function arguments into a `let` expression. // For example, `({ a, b ? 2 }): a + b` is desugared into: // `arg: let a = arg.a; b = arg.b or 2; in a + b` - let arg = ctx.new_expr(Hir::Arg(Arg)); let mut bindings: HashMap<_, _> = formals .into_iter() .map(|(k, default)| { // For each formal parameter, create a `Select` expression to extract it from the argument set. // `Arg` represents the raw argument (the attribute set) passed to the function. + let k = ctx.new_sym(k); ( - k.clone(), + k, ctx.new_expr( Select { expr: arg, - attrpath: vec![Attr::Str(k.clone())], + attrpath: vec![Attr::Str(k)], default, } .to_hir(), @@ -329,7 +333,7 @@ impl Downgrade for ast::Lambda { .collect(); // If there's an alias (`... }@alias`), bind the alias name to the raw argument set. if let Some(alias) = alias { - bindings.insert(alias.clone(), arg); + bindings.insert(alias, arg); } // Wrap the original function body in the new `let` expression. let let_ = Let { bindings, body }; @@ -343,7 +347,7 @@ impl Downgrade for ast::Lambda { allowed, }; // The function's body and parameters are now stored directly in the `Func` node. - Ok(ctx.new_expr(Func { body, param }.to_hir())) + Ok(ctx.new_expr(Func { body, param, arg }.to_hir())) } } diff --git a/evaluator/nixjit_hir/src/lib.rs b/evaluator/nixjit_hir/src/lib.rs index 5488181..ac910c7 100644 --- a/evaluator/nixjit_hir/src/lib.rs +++ b/evaluator/nixjit_hir/src/lib.rs @@ -17,8 +17,7 @@ use hashbrown::HashMap; use nixjit_error::{Error, Result}; use nixjit_ir::{ - Assert, Attr, AttrSet, BinOp, Call, ConcatStrings, Const, ExprId, Func, HasAttr, If, List, - Param as IrParam, Path, Select, Str, UnOp, Var, With, + Assert, Attr, AttrSet, BinOp, Call, ConcatStrings, Const, ExprId, Func, HasAttr, If, List, Param as IrParam, Path, Select, Str, SymId, UnOp, Var, With }; use nixjit_macros::ir; use nixjit_value::format_symbol; @@ -37,6 +36,10 @@ pub trait DowngradeContext { /// Allocates a new HIR expression in the context and returns its ID. fn new_expr(&mut self, expr: Hir) -> ExprId; + fn new_sym(&mut self, sym: String) -> SymId; + + fn get_sym(&self, id: SymId) -> &str; + /// Provides temporary mutable access to an expression. fn with_expr_mut(&mut self, id: ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T; @@ -81,17 +84,12 @@ ir! { // Represents a path expression. Path, // Represents a `let ... in ...` binding. - Let { pub bindings: HashMap, pub body: ExprId }, + Let { pub bindings: HashMap, pub body: ExprId }, // Represents a function argument lookup within the body of a function. - Arg, + Arg(()), Thunk(ExprId) } -/// A placeholder struct for the `Arg` HIR variant. It signifies that at this point -/// in the expression tree, we should be looking up a function argument. -#[derive(Debug)] -pub struct Arg; - /// A trait defining operations on attribute sets within the HIR. trait Attrs { /// Inserts a value into the attribute set at a given path. @@ -137,7 +135,7 @@ impl Attrs for AttrSet { // This path segment exists but is not an attrset, which is an error. Error::downgrade_error(format!( "attribute '{}' already defined but is not an attribute set", - format_symbol(ident) + format_symbol(ctx.get_sym(ident)) )) }) .and_then(|attrs| attrs._insert(path, name, value, ctx)) @@ -164,10 +162,10 @@ impl Attrs for AttrSet { // This is the final attribute in the path, so insert the value here. match name { Attr::Str(ident) => { - if self.stcs.insert(ident.clone(), value).is_some() { + if self.stcs.insert(ident, value).is_some() { return Err(Error::downgrade_error(format!( "attribute '{}' already defined", - format_symbol(ident) + format_symbol(ctx.get_sym(ident)) ))); } } diff --git a/evaluator/nixjit_hir/src/utils.rs b/evaluator/nixjit_hir/src/utils.rs index a1e99a5..e8ccf70 100644 --- a/evaluator/nixjit_hir/src/utils.rs +++ b/evaluator/nixjit_hir/src/utils.rs @@ -10,7 +10,7 @@ use nixjit_value::format_symbol; use rnix::ast; use nixjit_error::{Error, Result}; -use nixjit_ir::{Attr, AttrSet, ConcatStrings, ExprId, Select, Str, Var}; +use nixjit_ir::{Attr, AttrSet, ConcatStrings, ExprId, Select, Str, SymId, Var}; use crate::Hir; @@ -121,7 +121,7 @@ pub fn downgrade_attrs( pub fn downgrade_static_attrs( attrs: impl ast::HasEntry, ctx: &mut impl DowngradeContext, -) -> Result> { +) -> Result> { let entries = attrs.entries(); let mut attrs = AttrSet { stcs: HashMap::new(), @@ -145,7 +145,7 @@ pub fn downgrade_static_attrs( /// `inherit a b;` is translated into `a = a; b = b;` (i.e., bringing variables into scope). pub fn downgrade_inherit( inherit: ast::Inherit, - stcs: &mut HashMap, + stcs: &mut HashMap, ctx: &mut impl DowngradeContext, ) -> Result<()> { // Downgrade the `from` expression if it exists. @@ -181,7 +181,7 @@ pub fn downgrade_inherit( Entry::Occupied(occupied) => { return Err(Error::eval_error(format!( "attribute '{}' already defined", - format_symbol(occupied.key()) + format_symbol(ctx.get_sym(*occupied.key())) ))); } Entry::Vacant(vacant) => vacant.insert(ctx.new_expr(expr)), @@ -196,15 +196,15 @@ pub fn downgrade_attr(attr: ast::Attr, ctx: &mut impl DowngradeContext) -> Resul use ast::Attr::*; use ast::InterpolPart::*; match attr { - Ident(ident) => Ok(Attr::Str(ident.to_string())), + Ident(ident) => Ok(Attr::Str(ctx.new_sym(ident.to_string()))), Str(string) => { let parts = string.normalized_parts(); if parts.is_empty() { - Ok(Attr::Str("".into())) + Ok(Attr::Str(ctx.new_sym("".into()))) } else if parts.len() == 1 { // If the string has only one part, it's either a literal or a single interpolation. match parts.into_iter().next().unwrap() { - Literal(ident) => Ok(Attr::Str(ident)), + Literal(ident) => Ok(Attr::Str(ctx.new_sym(ident))), Interpolation(interpol) => { Ok(Attr::Dynamic(interpol.expr().unwrap().downgrade(ctx)?)) } diff --git a/evaluator/nixjit_ir/src/lib.rs b/evaluator/nixjit_ir/src/lib.rs index 1d9e068..9019412 100644 --- a/evaluator/nixjit_ir/src/lib.rs +++ b/evaluator/nixjit_ir/src/lib.rs @@ -27,6 +27,7 @@ impl ExprId { /// Returns the raw `usize` index. /// /// # Safety + /// /// The caller is responsible for using this index correctly and not causing out-of-bounds access. #[inline(always)] pub unsafe fn raw(self) -> usize { @@ -36,6 +37,7 @@ impl ExprId { /// Creates an `ExprId` from a raw `usize` index. /// /// # Safety + /// /// The caller must ensure that the provided index is valid for the expression table. #[inline(always)] pub unsafe fn from_raw(id: usize) -> Self { @@ -43,6 +45,33 @@ impl ExprId { } } +/// A type-safe wrapper for an index into an symbol table. +#[repr(transparent)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct SymId(usize); + +impl SymId { + /// Returns the raw `usize` index. + /// + /// # Safety + /// + /// The caller is responsible for using this index correctly and not causing out-of-bounds access. + #[inline(always)] + pub unsafe fn raw(self) -> usize { + self.0 + } + + /// Creates an `SymId` from a raw `usize` index. + /// + /// # Safety + /// + /// The caller must ensure that the provided index is valid for the symbol table. + #[inline(always)] + pub unsafe fn from_raw(id: usize) -> Self { + Self(id) + } +} + /// A type-safe wrapper for an index into a primop (builtin function) table. #[repr(transparent)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -52,6 +81,7 @@ impl PrimOpId { /// Returns the raw `usize` index. /// /// # Safety + /// /// The caller is responsible for using this index correctly. #[inline(always)] pub unsafe fn raw(self) -> usize { @@ -61,6 +91,7 @@ impl PrimOpId { /// Creates a `PrimOpId` from a raw `usize` index. /// /// # Safety + /// /// The caller must ensure that the provided index is valid. #[inline(always)] pub unsafe fn from_raw(id: usize) -> Self { @@ -77,6 +108,7 @@ impl StackIdx { /// Returns the raw `usize` index. /// /// # Safety + /// /// The caller is responsible for using this index correctly. #[inline(always)] pub unsafe fn raw(self) -> usize { @@ -86,6 +118,7 @@ impl StackIdx { /// Creates an `StackIdx` from a raw `usize` index. /// /// # Safety + /// /// The caller must ensure that the provided index is valid. #[inline(always)] pub unsafe fn from_raw(idx: usize) -> Self { @@ -100,7 +133,7 @@ pub struct Arg; #[derive(Debug, Default)] pub struct AttrSet { /// Statically known attributes (key is a string). - pub stcs: HashMap, + pub stcs: HashMap, /// Dynamically computed attributes, where both the key and value are expressions. pub dyns: Vec<(ExprId, ExprId)>, } @@ -113,7 +146,7 @@ pub enum Attr { Dynamic(ExprId), /// A static attribute key. /// Example: `attrs.key` - Str(String), + Str(SymId), } /// Represents a Nix list. @@ -246,6 +279,8 @@ pub struct Func { pub body: ExprId, /// The parameter specification for the function. pub param: Param, + + pub arg: ExprId, } /// Describes the parameters of a function. @@ -253,12 +288,12 @@ pub struct Func { pub struct Param { /// The name of the argument if it's a simple identifier (e.g., `x: ...`). /// Also used for the alias in a pattern (e.g., `args @ { ... }`). - pub ident: Option, + pub ident: Option, /// The set of required parameter names for a pattern-matching function. - pub required: Option>, + pub required: Option>, /// The set of all allowed parameter names for a non-ellipsis pattern-matching function. /// If `None`, any attribute is allowed (ellipsis `...` is present). - pub allowed: Option>, + pub allowed: Option>, } /// Represents a function call. @@ -315,7 +350,7 @@ pub struct Str { /// Represents a variable lookup by its name. #[derive(Debug)] pub struct Var { - pub sym: String, + pub sym: SymId, } /// Represents a path literal. diff --git a/evaluator/nixjit_jit/src/compile.rs b/evaluator/nixjit_jit/src/compile.rs index 202b424..c689da0 100644 --- a/evaluator/nixjit_jit/src/compile.rs +++ b/evaluator/nixjit_jit/src/compile.rs @@ -45,7 +45,7 @@ impl JITCompile for AttrSet { /// This creates a new attribute set and compiles all static attributes into it. fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { let attrs = ctx.create_attrs(); - for (k, v) in self.stcs.iter() { + for (&k, v) in self.stcs.iter() { let v = v.compile(ctx, rt_ctx); ctx.push_attr(attrs, k, v); } diff --git a/evaluator/nixjit_jit/src/helpers.rs b/evaluator/nixjit_jit/src/helpers.rs index 994a9a0..565df15 100644 --- a/evaluator/nixjit_jit/src/helpers.rs +++ b/evaluator/nixjit_jit/src/helpers.rs @@ -13,7 +13,7 @@ use hashbrown::HashMap; use nixjit_eval::{AttrSet, EvalContext, List, Value}; use nixjit_ir::ExprId; -use nixjit_ir::StackIdx; +use nixjit_ir::SymId; use super::JITContext; @@ -23,20 +23,12 @@ pub extern "C" fn helper_call( arg: NonNull, ctx: &mut Ctx, ) { + // SAFETY: The `arg` pointer is guaranteed to be valid and non-null by the JIT compiler, + // which allocates it on the stack. The JIT code ensures that the pointer points to a + // valid `Value` and that its lifetime is managed correctly within the compiled function. func.call(unsafe { arg.read() }, ctx).unwrap(); } -/// Helper function to look up a value in the evaluation stack. -/// -/// This function is called from JIT-compiled code to access values in the evaluation stack. -pub extern "C" fn helper_lookup_stack( - ctx: &Ctx, - idx: StackIdx, - ret: &mut MaybeUninit, -) { - ret.write(ctx.lookup_stack(idx).clone()); -} - /// Helper function to look up a function argument. /// /// This function is called from JIT-compiled code to access function arguments. @@ -56,6 +48,9 @@ pub extern "C" fn helper_lookup( ret: &mut MaybeUninit, ) { // TODO: Error Handling + // SAFETY: The `sym_ptr` and `sym_len` are provided by the JIT compiler and are + // guaranteed to form a valid UTF-8 string slice. The string data is embedded + // in the compiled code and has a static lifetime, ensuring the pointer is always valid. unsafe { ret.write( ctx.lookup_with(str::from_utf8_unchecked(slice::from_raw_parts( @@ -78,6 +73,9 @@ pub extern "C" fn helper_select( ctx: &mut Ctx, ) { let path = core::ptr::slice_from_raw_parts_mut(path_ptr, path_len); + // SAFETY: The `path_ptr` is allocated by the JIT compiler using `helper_alloc_array` + // and is guaranteed to be valid for the given length. The `Box::from_raw` call + // correctly takes ownership of the allocated slice, ensuring it is properly deallocated. let path = unsafe { Box::from_raw(path) }; for attr in path { val.select(&attr.force_string_no_ctx().unwrap(), ctx) @@ -97,6 +95,9 @@ pub extern "C" fn helper_select_with_default( ctx: &mut Ctx, ) { let path = core::ptr::slice_from_raw_parts_mut(path_ptr, path_len); + // SAFETY: The `path_ptr` is allocated by the JIT compiler using `helper_alloc_array` + // and is guaranteed to be valid for the given length. The `Box::from_raw` call + // correctly takes ownership of the allocated slice, ensuring it is properly deallocated. let path = unsafe { Box::from_raw(path) }; for attr in path { val.select_or(&attr.force_string_no_ctx().unwrap(), default, ctx) @@ -107,7 +108,10 @@ pub extern "C" fn helper_select_with_default( /// Helper function to check equality between two values. /// /// This function is called from JIT-compiled code to perform equality comparisons. -pub extern "C" fn helper_eq(lhs: &mut Value, rhs: NonNull) { +pub extern "C" fn helper_eq(lhs: &mut Value, rhs: NonNull) { + // SAFETY: The `rhs` pointer is guaranteed to be valid and non-null by the JIT compiler, + // which allocates it on the stack. The JIT code ensures that the pointer points to a + // valid `Value` and that its lifetime is managed correctly within the compiled function. lhs.eq(unsafe { rhs.read() }); } @@ -115,11 +119,14 @@ pub extern "C" fn helper_eq(lhs: &mut Value, rhs: NonNull( +pub unsafe extern "C" fn helper_create_string( ptr: *const u8, len: usize, ret: &mut MaybeUninit, ) { + // SAFETY: The `ptr` and `len` are provided by the JIT compiler and are guaranteed + // to form a valid UTF-8 string slice. The string data is embedded in the compiled + // code and has a static lifetime, ensuring the pointer is always valid. unsafe { ret.write(Value::String( str::from_utf8_unchecked(slice::from_raw_parts(ptr, len)).to_owned(), @@ -131,11 +138,14 @@ pub unsafe extern "C" fn helper_create_string( /// /// This function is called from JIT-compiled code to create list values /// from arrays of values. -pub unsafe extern "C" fn helper_create_list( +pub unsafe extern "C" fn helper_create_list( ptr: *mut Value, len: usize, ret: &mut MaybeUninit, ) { + // SAFETY: The `ptr` is allocated by the JIT compiler using `helper_alloc_array` and + // is guaranteed to be valid for `len` elements. The `Vec::from_raw_parts` call + // correctly takes ownership of the allocated memory, ensuring it is properly managed. unsafe { ret.write(Value::List( List::from(Vec::from_raw_parts(ptr, len, len)).into(), @@ -146,7 +156,7 @@ pub unsafe extern "C" fn helper_create_list( /// Helper function to create an attribute set. /// /// This function is called from JIT-compiled code to create a new, empty attribute set. -pub unsafe extern "C" fn helper_create_attrs( +pub unsafe extern "C" fn helper_create_attrs( ret: &mut MaybeUninit>, ) { ret.write(HashMap::new()); @@ -156,15 +166,17 @@ pub unsafe extern "C" fn helper_create_attrs( /// /// This function is called from JIT-compiled code to insert a key-value pair /// into an attribute set. -pub unsafe extern "C" fn helper_push_attr( - attrs: &mut HashMap, - sym_ptr: *const u8, - sym_len: usize, +pub unsafe extern "C" fn helper_push_attr( + attrs: &mut HashMap, + sym: SymId, val: NonNull, ) { + // SAFETY: The `sym_ptr` and `sym_len` are provided by the JIT compiler and are + // guaranteed to form a valid UTF-8 string slice. The `val` pointer is also + // guaranteed to be valid and non-null by the JIT compiler. unsafe { attrs.insert( - str::from_utf8_unchecked(slice::from_raw_parts(sym_ptr, sym_len)).to_owned(), + sym, val.read(), ); } @@ -174,10 +186,13 @@ pub unsafe extern "C" fn helper_push_attr( /// /// This function is called from JIT-compiled code to convert a HashMap into /// a proper attribute set value. -pub unsafe extern "C" fn helper_finalize_attrs( +pub unsafe extern "C" fn helper_finalize_attrs( attrs: NonNull>, ret: &mut MaybeUninit, ) { + // SAFETY: The `attrs` pointer is guaranteed to be valid and non-null by the JIT + // compiler, which allocates it on the stack. The `read` operation correctly + // takes ownership of the HashMap. ret.write(Value::AttrSet( AttrSet::from(unsafe { attrs.read() }).into(), )); @@ -191,6 +206,8 @@ pub unsafe extern "C" fn helper_enter_with( ctx: &mut Ctx, namespace: NonNull, ) { + // SAFETY: The `namespace` pointer is guaranteed to be valid and non-null by the JIT + // compiler. The `read` operation correctly takes ownership of the `Value`. ctx.enter_with(unsafe { namespace.read() }.unwrap_attr_set().into_inner()); } @@ -205,13 +222,16 @@ pub unsafe extern "C" fn helper_exit_with(ctx: &mut Ctx) { /// /// This function is called from JIT-compiled code to allocate memory for /// arrays of values, such as function arguments or list elements. -pub unsafe extern "C" fn helper_alloc_array(len: usize) -> *mut u8 { +pub unsafe extern "C" fn helper_alloc_array(len: usize) -> *mut u8 { + // SAFETY: The `Layout` is guaranteed to be valid for non-zero `len`. The caller + // is responsible for deallocating the memory, which is typically done by + // `Vec::from_raw_parts` or `Box::from_raw` in other helpers. unsafe { alloc(Layout::array::(len).unwrap()) } } /// Helper function for debugging. /// /// This function is called from JIT-compiled code to print a value for debugging purposes. -pub extern "C" fn helper_dbg(value: &Value) { +pub extern "C" fn helper_dbg(value: &Value) { println!("{value:?}") } diff --git a/evaluator/nixjit_jit/src/lib.rs b/evaluator/nixjit_jit/src/lib.rs index 7d452c8..8e50cb0 100644 --- a/evaluator/nixjit_jit/src/lib.rs +++ b/evaluator/nixjit_jit/src/lib.rs @@ -21,6 +21,7 @@ use cranelift_module::{FuncId, Linkage, Module}; use hashbrown::{HashMap, HashSet}; use nixjit_eval::{EvalContext, Value}; +use nixjit_ir::SymId; use nixjit_lir::Lir; mod compile; @@ -174,25 +175,20 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { slot } - fn push_attr(&mut self, attrs: StackSlot, sym: &str, val: StackSlot) { + fn push_attr(&mut self, attrs: StackSlot, sym: SymId, val: StackSlot) { self.free_slot(attrs); self.free_slot(val); let attrs = self.builder.ins().stack_addr(types::I64, attrs, 0); let val = self.builder.ins().stack_addr(types::I64, val, 0); - let sym = self.strings.get_or_insert_with(sym, |_| sym.to_owned()); - let ptr = self + let sym = self .builder .ins() - .iconst(self.compiler.ptr_type, sym.as_ptr() as i64); - let len = self - .builder - .ins() - .iconst(self.compiler.ptr_type, sym.len() as i64); + .iconst(self.compiler.ptr_type, unsafe { sym.raw() } as i64); let push_attr = self .compiler .module .declare_func_in_func(self.compiler.push_attr, self.builder.func); - self.builder.ins().call(push_attr, &[attrs, ptr, len, val]); + self.builder.ins().call(push_attr, &[attrs, sym, val]); } fn finalize_attrs(&mut self, attrs: StackSlot) -> StackSlot { @@ -281,24 +277,6 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { slot } - fn lookup_stack(&mut self, ctx: ir::Value, idx: usize) -> StackSlot { - let slot = self.alloca(); - let lookup_stack = self - .compiler - .module - .declare_func_in_func(self.compiler.lookup_stack, self.builder.func); - let idx = self - .builder - .ins() - .iconst(self.compiler.ptr_type, idx as i64); - let ptr = self - .builder - .ins() - .stack_addr(self.compiler.ptr_type, slot, 0); - self.builder.ins().call(lookup_stack, &[ctx, idx, ptr]); - slot - } - fn lookup_arg(&mut self, ctx: ir::Value, idx: usize) -> StackSlot { let slot = self.alloca(); let lookup_arg = self @@ -406,7 +384,6 @@ pub struct JITCompiler { func_sig: Signature, call: FuncId, - lookup_stack: FuncId, lookup_arg: FuncId, lookup: FuncId, select: FuncId, @@ -445,7 +422,6 @@ impl JITCompiler { let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol("helper_call", helper_call:: as _); - builder.symbol("helper_lookup_stack", helper_lookup_stack:: as _); builder.symbol("helper_lookup_arg", helper_lookup_arg:: as _); builder.symbol("helper_lookup", helper_lookup:: as _); builder.symbol("helper_select", helper_select:: as _); @@ -453,17 +429,17 @@ impl JITCompiler { "helper_select_with_default", helper_select_with_default:: as _, ); - builder.symbol("helper_eq", helper_eq:: as _); + builder.symbol("helper_eq", helper_eq as _); - builder.symbol("helper_alloc_array", helper_alloc_array:: as _); - builder.symbol("helper_create_string", helper_create_string:: as _); - builder.symbol("helper_create_list", helper_create_list:: as _); - builder.symbol("helper_create_attrs", helper_create_attrs:: as _); - builder.symbol("helper_push_attr", helper_push_attr:: as _); - builder.symbol("helper_finalize_attrs", helper_finalize_attrs:: as _); + builder.symbol("helper_alloc_array", helper_alloc_array as _); + builder.symbol("helper_create_string", helper_create_string as _); + builder.symbol("helper_create_list", helper_create_list as _); + builder.symbol("helper_create_attrs", helper_create_attrs as _); + builder.symbol("helper_push_attr", helper_push_attr as _); + builder.symbol("helper_finalize_attrs", helper_finalize_attrs as _); builder.symbol("helper_enter_with", helper_enter_with:: as _); builder.symbol("helper_exit_with", helper_exit_with:: as _); - builder.symbol("helper_dbg", helper_dbg:: as _); + builder.symbol("helper_dbg", helper_dbg as _); let mut module = JITModule::new(builder); let ctx = module.make_context(); @@ -495,18 +471,6 @@ impl JITCompiler { .declare_function("helper_call", Linkage::Import, &call_sig) .unwrap(); - let mut lookup_stack_sig = module.make_signature(); - lookup_stack_sig.params.extend( - [AbiParam { - value_type: ptr_type, - purpose: ArgumentPurpose::Normal, - extension: ArgumentExtension::None, - }; 3], - ); - let lookup_stack = module - .declare_function("helper_lookup_stack", Linkage::Import, &lookup_stack_sig) - .unwrap(); - let mut lookup_arg_sig = module.make_signature(); lookup_arg_sig.params.extend( [AbiParam { @@ -626,7 +590,7 @@ impl JITCompiler { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, - }; 4], + }; 3], ); let push_attr = module .declare_function("helper_push_attr", Linkage::Import, &push_attr_sig) @@ -694,7 +658,6 @@ impl JITCompiler { func_sig, call, - lookup_stack, lookup_arg, lookup, select, @@ -760,6 +723,11 @@ impl JITCompiler { self.ctx.clear(); let _ = self.builder_ctx.insert(builder_ctx); + // SAFETY: The `get_finalized_function` method returns a raw pointer to the + // compiled machine code. We transmute it to the correct function pointer type `F`. + // This is safe because the function was compiled with the signature defined in `self.func_sig`, + // which matches the signature of `F`. The lifetime of the compiled code is managed + // by the `JITModule`, ensuring the pointer remains valid. unsafe { JITFunc { func: std::mem::transmute::<*const u8, F>( diff --git a/evaluator/nixjit_lir/src/lib.rs b/evaluator/nixjit_lir/src/lib.rs index e3a3394..8be849c 100644 --- a/evaluator/nixjit_lir/src/lib.rs +++ b/evaluator/nixjit_lir/src/lib.rs @@ -39,10 +39,10 @@ ir! { Str, Var, Path, - Arg, + Arg(()), PrimOp(PrimOpId), - StackRef(StackIdx), ExprRef(ExprId), + StackRef(StackIdx), FuncRef(ExprId), Thunk(ExprId), } @@ -52,7 +52,7 @@ ir! { pub enum LookupResult { Stack(StackIdx), /// The variable was found and resolved to a specific expression. - PrimOp(ExprId), + Expr(ExprId), /// The variable could not be resolved statically, likely due to a `with` expression. /// The lookup must be performed dynamically at evaluation time. Unknown, @@ -69,20 +69,22 @@ pub trait ResolveContext { fn new_func(&mut self, body: ExprId, param: Param); /// Triggers the resolution of a given expression. - fn resolve(&mut self, expr: ExprId) -> Result<()>; + fn resolve(&mut self, expr: ExprId) -> Result; /// Looks up a variable by name in the current scope. - fn lookup(&mut self, name: &str) -> LookupResult; + fn lookup(&mut self, name: SymId) -> LookupResult; - fn lookup_arg(&mut self) -> StackIdx; + fn get_sym(&self, id: SymId) -> &str; + + fn lookup_arg(&mut self) -> ExprId; /// Enters a `with` scope for the duration of a closure. - fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T); + fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> Result<()>) -> Result; /// Enters a `let` scope with a given set of bindings for the duration of a closure. fn with_let_env( &mut self, - bindings: HashMap, + bindings: HashMap, f: impl FnOnce(&mut Self) -> T, ) -> T; @@ -90,7 +92,8 @@ pub trait ResolveContext { fn with_closure_env( &mut self, func: ExprId, - ident: Option, + arg: ExprId, + ident: Option, f: impl FnOnce(&mut Self) -> T, ) -> T; } @@ -123,26 +126,23 @@ impl Resolve for hir::Hir { Var(x) => x.resolve(ctx), Path(x) => x.resolve(ctx), Let(x) => x.resolve(ctx), - Thunk(x) => { - ctx.resolve(x)?; - Ok(Lir::Thunk(x)) - } - Arg(_) => Ok(Lir::StackRef(ctx.lookup_arg())), + Thunk(x) => ctx.resolve(x).map(Lir::Thunk), + Arg(_) => Ok(Lir::ExprRef(ctx.lookup_arg())), } } } /// Resolves an `AttrSet` by resolving all key and value expressions. impl Resolve for AttrSet { - fn resolve(self, ctx: &mut Ctx) -> Result { - for (_, &v) in self.stcs.iter() { - ctx.resolve(v)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + for (_, v) in self.stcs.iter_mut() { + *v = ctx.resolve(*v)?; } - for &(k, _) in self.dyns.iter() { - ctx.resolve(k)?; + for (k, _) in self.dyns.iter_mut() { + *k = ctx.resolve(*k)?; } - for &(_, v) in self.dyns.iter() { - ctx.resolve(v)?; + for (_, v) in self.dyns.iter_mut() { + *v = ctx.resolve(*v)?; } Ok(self.to_lir()) } @@ -150,9 +150,9 @@ impl Resolve for AttrSet { /// Resolves a `List` by resolving each of its items. impl Resolve for List { - fn resolve(self, ctx: &mut Ctx) -> Result { - for &item in self.items.iter() { - ctx.resolve(item)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + for item in self.items.iter_mut() { + *item = ctx.resolve(*item)?; } Ok(self.to_lir()) } @@ -160,11 +160,11 @@ impl Resolve for List { /// Resolves a `HasAttr` expression by resolving the LHS and any dynamic attributes in the path. impl Resolve for HasAttr { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.lhs)?; - for attr in self.rhs.iter() { - if let &Attr::Dynamic(expr) = attr { - ctx.resolve(expr)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.lhs = ctx.resolve(self.lhs)?; + for attr in self.rhs.iter_mut() { + if let &mut Attr::Dynamic(expr) = attr { + *attr = ctx.resolve(expr).map(Attr::Dynamic)? } } Ok(self.to_lir()) @@ -173,17 +173,17 @@ impl Resolve for HasAttr { /// Resolves a `BinOp` by resolving its left and right hand sides. impl Resolve for BinOp { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.lhs)?; - ctx.resolve(self.rhs)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.lhs = ctx.resolve(self.lhs)?; + self.rhs = ctx.resolve(self.rhs)?; Ok(self.to_lir()) } } /// Resolves a `UnOp` by resolving its right hand side. impl Resolve for UnOp { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.rhs)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.rhs = ctx.resolve(self.rhs)?; Ok(self.to_lir()) } } @@ -191,15 +191,15 @@ impl Resolve for UnOp { /// Resolves a `Select` by resolving the expression being selected from, any dynamic /// attributes in the path, and the default value if it exists. impl Resolve for Select { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.expr)?; - for attr in self.attrpath.iter() { - if let &Attr::Dynamic(expr) = attr { - ctx.resolve(expr)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.expr = ctx.resolve(self.expr)?; + for attr in self.attrpath.iter_mut() { + if let &mut Attr::Dynamic(expr) = attr { + *attr = ctx.resolve(expr).map(Attr::Dynamic)? } } - if let Some(expr) = self.default { - ctx.resolve(expr)?; + if let Some(expr) = &mut self.default { + *expr = ctx.resolve(*expr)?; } Ok(self.to_lir()) } @@ -207,10 +207,10 @@ impl Resolve for Select { /// Resolves an `If` expression by resolving the condition, consequence, and alternative. impl Resolve for If { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.cond)?; - ctx.resolve(self.consq)?; - ctx.resolve(self.alter)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.cond = ctx.resolve(self.cond)?; + self.consq = ctx.resolve(self.consq)?; + self.alter = ctx.resolve(self.alter)?; Ok(self.to_lir()) } } @@ -218,9 +218,10 @@ impl Resolve for If { /// Resolves a `Func` by resolving its body within a new parameter scope. /// It then registers the function with the context. impl Resolve for Func { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.with_closure_env(self.body, self.param.ident.clone(), |ctx| { - ctx.resolve(self.body) + fn resolve(mut self, ctx: &mut Ctx) -> Result { + ctx.with_closure_env(self.body, self.arg, self.param.ident, |ctx| { + self.body = ctx.resolve(self.body)?; + Ok(()) })?; ctx.new_func(self.body, self.param); Ok(Lir::FuncRef(self.body)) @@ -228,9 +229,9 @@ impl Resolve for Func { } impl Resolve for Call { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.func)?; - ctx.resolve(self.arg)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.func = ctx.resolve(self.func)?; + self.func = ctx.resolve(self.arg)?; Ok(self.to_lir()) } } @@ -238,10 +239,12 @@ impl Resolve for Call { /// Resolves a `With` expression by resolving the namespace and the body. /// The body is resolved within a special "with" scope. impl Resolve for With { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.namespace)?; - let (env_used, res) = ctx.with_with_env(|ctx| ctx.resolve(self.expr)); - res?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.namespace = ctx.resolve(self.namespace)?; + let env_used = ctx.with_with_env(|ctx| { + self.expr = ctx.resolve(self.expr)?; + Ok(()) + })?; // Optimization: if the `with` environment was not actually used by any variable // lookup in the body, we can elide the `With` node entirely. if env_used { @@ -254,18 +257,18 @@ impl Resolve for With { /// Resolves an `Assert` by resolving the assertion condition and the body. impl Resolve for Assert { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.assertion)?; - ctx.resolve(self.expr)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.assertion = ctx.resolve(self.assertion)?; + self.expr = ctx.resolve(self.expr)?; Ok(self.to_lir()) } } /// Resolves a `ConcatStrings` by resolving each part. impl Resolve for ConcatStrings { - fn resolve(self, ctx: &mut Ctx) -> Result { - for &part in self.parts.iter() { - ctx.resolve(part)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + for part in self.parts.iter_mut() { + *part = ctx.resolve(*part)?; } Ok(self.to_lir()) } @@ -275,13 +278,13 @@ impl Resolve for ConcatStrings { impl Resolve for Var { fn resolve(self, ctx: &mut Ctx) -> Result { use LookupResult::*; - match ctx.lookup(&self.sym) { + match ctx.lookup(self.sym) { + Expr(id) => Ok(Lir::ExprRef(id)), Stack(idx) => Ok(Lir::StackRef(idx)), - PrimOp(id) => Ok(Lir::ExprRef(id)), Unknown => Ok(self.to_lir()), NotFound => Err(Error::resolution_error(format!( "undefined variable '{}'", - format_symbol(&self.sym) + format_symbol(ctx.get_sym(self.sym)) ))), } } @@ -289,8 +292,8 @@ impl Resolve for Var { /// Resolves a `Path` by resolving the underlying expression that defines the path's content. impl Resolve for Path { - fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.resolve(self.expr)?; + fn resolve(mut self, ctx: &mut Ctx) -> Result { + self.expr = ctx.resolve(self.expr)?; Ok(self.to_lir()) } } @@ -298,12 +301,13 @@ impl Resolve for Path { /// Resolves a `Let` expression by creating a new scope for the bindings, resolving /// the bindings and the body, and then returning a reference to the body. impl Resolve for hir::Let { - fn resolve(self, ctx: &mut Ctx) -> Result { + fn resolve(mut self, ctx: &mut Ctx) -> Result { ctx.with_let_env(self.bindings.clone(), |ctx| { - for &id in self.bindings.values() { - ctx.resolve(id)?; + for id in self.bindings.values_mut() { + *id = ctx.resolve(*id)?; } - ctx.resolve(self.body) + self.body = ctx.resolve(self.body)?; + Ok(()) })?; // The `let` expression itself evaluates to its body. Ok(Lir::ExprRef(self.body)) diff --git a/evaluator/nixjit_macros/src/builtins.rs b/evaluator/nixjit_macros/src/builtins.rs index 2ea18c9..edb262c 100644 --- a/evaluator/nixjit_macros/src/builtins.rs +++ b/evaluator/nixjit_macros/src/builtins.rs @@ -18,7 +18,8 @@ use proc_macro::TokenStream; use proc_macro2::Span; use quote::{ToTokens, format_ident, quote}; use syn::{ - parse_macro_input, FnArg, Item, ItemConst, ItemFn, ItemMod, Pat, PatIdent, PatType, Type, Visibility + FnArg, Item, ItemConst, ItemFn, ItemMod, Pat, PatIdent, PatType, Type, Visibility, + parse_macro_input, }; /// The implementation of the `#[builtins]` macro. @@ -69,14 +70,13 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { // Public functions are added to the global scope, private ones to a scoped set. if matches!(item_fn.vis, Visibility::Public(_)) { global.push(primop); - pub_item_mod.push(quote! { #item_fn }.into()); + pub_item_mod.push(quote! { #item_fn }); } else { scoped.push(primop); pub_item_mod.push( quote! { pub #item_fn } - .into(), ); } wrappers.push(wrapper); @@ -125,7 +125,7 @@ fn generate_const_wrapper( item_const: &ItemConst, ) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> { let const_name = &item_const.ident; - let const_val = &item_const.expr; + let const_val = &item_const.expr; let name_str = const_name .to_string() .from_case(Case::UpperSnake)