diff --git a/Cargo.lock b/Cargo.lock index 5107cd6..7d348d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,6 +41,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" + [[package]] name = "bumpalo" version = "3.19.0" @@ -56,6 +62,21 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +[[package]] +name = "cfg_aliases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" + +[[package]] +name = "clipboard-win" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4" +dependencies = [ + "error-code", +] + [[package]] name = "convert_case" version = "0.7.1" @@ -279,18 +300,51 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "endian-type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" + [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "error-code" +version = "3.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" + [[package]] name = "fallible-iterator" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" +[[package]] +name = "fd-lock" +version = "4.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" +dependencies = [ + "cfg-if", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "fixedbitset" version = "0.5.7" @@ -331,6 +385,15 @@ dependencies = [ "foldhash", ] +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "indexmap" version = "2.10.0" @@ -362,6 +425,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "log" version = "0.4.27" @@ -392,12 +461,37 @@ dependencies = [ "autocfg", ] +[[package]] +name = "nibble_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] + +[[package]] +name = "nix" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nixjit" version = "0.1.0" dependencies = [ + "anyhow", + "bumpalo", "nixjit_context", "nixjit_value", + "regex", + "rustyline", ] [[package]] @@ -414,6 +508,7 @@ dependencies = [ name = "nixjit_context" version = "0.1.0" dependencies = [ + "bumpalo", "cranelift", "cranelift-jit", "cranelift-module", @@ -438,6 +533,7 @@ dependencies = [ name = "nixjit_error" version = "0.1.0" dependencies = [ + "rnix", "thiserror", ] @@ -453,6 +549,7 @@ dependencies = [ "nixjit_lir", "nixjit_value", "replace_with", + "smallvec", ] [[package]] @@ -461,7 +558,6 @@ version = "0.1.0" dependencies = [ "derive_more", "hashbrown 0.15.4", - "itertools", "nixjit_error", "nixjit_ir", "nixjit_macros", @@ -475,7 +571,6 @@ version = "0.1.0" dependencies = [ "derive_more", "hashbrown 0.15.4", - "nixjit_error", "nixjit_value", "rnix", ] @@ -560,6 +655,16 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radix_trie" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +dependencies = [ + "endian-type", + "nibble_vec", +] + [[package]] name = "regalloc2" version = "0.12.2" @@ -609,7 +714,7 @@ version = "3.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6b6ebd13bc009aef9cd476c1310d49ac354d36e240cf1bd753290f3dc7199a7" dependencies = [ - "bitflags", + "bitflags 1.3.2", "libc", "mach2", "windows-sys 0.52.0", @@ -655,6 +760,41 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustix" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" +dependencies = [ + "bitflags 2.9.1", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustyline" +version = "14.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7803e8936da37efd9b6d4478277f4b2b9bb5cdb37a113e8d63222e58da647e63" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "clipboard-win", + "fd-lock", + "home", + "libc", + "log", + "memchr", + "nix", + "radix_trie", + "unicode-segmentation", + "unicode-width", + "utf8parse", + "windows-sys 0.52.0", +] + [[package]] name = "serde" version = "1.0.219" @@ -742,12 +882,24 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + [[package]] name = "unicode-xid" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "wasmtime-internal-jit-icache-coherence" version = "35.0.0" diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..9cc44a3 --- /dev/null +++ b/TODO.md @@ -0,0 +1,5 @@ +- [x] downgrade stage +- [ ] resolve stage + - [ ] dynamic attr resolution + - [ ] import resolution + - [ ] path resolution diff --git a/evaluator/nixjit/Cargo.toml b/evaluator/nixjit/Cargo.toml index 7b49985..3d546c3 100644 --- a/evaluator/nixjit/Cargo.toml +++ b/evaluator/nixjit/Cargo.toml @@ -4,5 +4,10 @@ version = "0.1.0" edition = "2024" [dependencies] +anyhow = "1.0" +bumpalo = "3.19" +regex = "1.11" +rustyline = "14.0" + nixjit_context = { path = "../nixjit_context" } nixjit_value = { path = "../nixjit_value" } diff --git a/evaluator/nixjit/src/lib.rs b/evaluator/nixjit/src/lib.rs index df87da3..a0fce7a 100644 --- a/evaluator/nixjit/src/lib.rs +++ b/evaluator/nixjit/src/lib.rs @@ -4,5 +4,8 @@ //! and evaluating Nix expressions. It integrates all the other `nixjit_*` //! components to provide a complete Nix evaluation environment. +pub use nixjit_context as context; +pub use nixjit_value as value; + #[cfg(test)] mod test; diff --git a/evaluator/nixjit/src/main.rs b/evaluator/nixjit/src/main.rs new file mode 100644 index 0000000..ce225fb --- /dev/null +++ b/evaluator/nixjit/src/main.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use bumpalo::Bump; +use regex::Regex; +use rustyline::DefaultEditor; +use rustyline::error::ReadlineError; + +use nixjit::context::Context; + +fn main() -> Result<()> { + let mut rl = DefaultEditor::new()?; + let bump = Bump::new(); + let mut context = Context::new(&bump); + let re = Regex::new(r"^\s*([a-zA-Z_][a-zA-Z0-9_'-]*)\s*=(.*)$").unwrap(); + loop { + let readline = rl.readline("nixjit-repl> "); + match readline { + Ok(line) => { + if line.trim().is_empty() { + continue; + } + let _ = rl.add_history_entry(line.as_str()); + if let Some(caps) = re.captures(&line) { + let ident = caps.get(1).unwrap().as_str(); + let expr = caps.get(2).unwrap().as_str().trim(); + if expr.is_empty() { + eprintln!("Error: missing expression after '='"); + continue; + } + if let Err(err) = context.add_binding(ident, expr) { + eprintln!("Error: {}", err); + } + } else { + match context.eval(&line) { + Ok(value) => println!("{}", value), + Err(err) => eprintln!("Error: {}", err), + } + } + } + Err(ReadlineError::Interrupted) => { + println!(); + } + Err(ReadlineError::Eof) => { + println!("CTRL-D"); + break; + } + Err(err) => { + eprintln!("Error: {:?}", err); + break; + } + } + } + Ok(()) +} diff --git a/evaluator/nixjit/src/test.rs b/evaluator/nixjit/src/test.rs index ff8df6e..35751ae 100644 --- a/evaluator/nixjit/src/test.rs +++ b/evaluator/nixjit/src/test.rs @@ -2,13 +2,14 @@ use std::collections::BTreeMap; +use bumpalo::Bump; use nixjit_context::Context; use nixjit_value::{AttrSet, Const, List, Symbol, Value}; #[inline] fn test_expr(expr: &str, expected: Value) { println!("{expr}"); - assert_eq!(Context::new().eval(expr).unwrap(), expected); + assert_eq!(Context::new(&Bump::new()).eval(expr).unwrap(), expected); } macro_rules! map { @@ -197,6 +198,10 @@ fn test_func() { "(inputs@{ x, y, ... }: x + inputs.y) { x = 1; y = 2; z = 3; }", int!(3), ); + test_expr( + "((f: let x = f x; in x) (self: { x = 1; y = self.x + 1; })).y", + int!(2), + ); test_expr( "let fix = f: let x = f x; in x; in (fix (self: { x = 1; y = self.x + 1; })).y", int!(2), diff --git a/evaluator/nixjit_builtins/src/lib.rs b/evaluator/nixjit_builtins/src/lib.rs index e382728..2539233 100644 --- a/evaluator/nixjit_builtins/src/lib.rs +++ b/evaluator/nixjit_builtins/src/lib.rs @@ -4,8 +4,10 @@ pub trait BuiltinsContext {} #[builtins] pub mod builtins { + use std::rc::Rc; + use nixjit_error::{Error, Result}; - use nixjit_eval::Value; + use nixjit_eval::{List, Value}; use nixjit_value::Const; use super::BuiltinsContext; @@ -21,7 +23,24 @@ pub mod builtins { (Int(a), Float(b)) => Float(a as f64 + b), (Float(a), Int(b)) => Float(a + b as f64), (Float(a), Float(b)) => Float(a + b), - _ => return Err(Error::EvalError(format!(""))), + (Int(_), b) => { + return Err(Error::eval_error(format!( + "expected an integer but found {}", + b.typename() + ))); + } + (Float(_), b) => { + return Err(Error::eval_error(format!( + "expected an float but found {}", + b.typename() + ))); + } + (a, _) => { + return Err(Error::eval_error(format!( + "expected an integer but found {}", + a.typename() + ))); + } }) } @@ -29,16 +48,10 @@ pub mod builtins { todo!() } - fn elem_at(list: Value, idx: Value) -> Result { - let list = list - .try_unwrap_list() - .map_err(|_| Error::EvalError("expected a list but found ...".to_string()))?; - let idx = idx - .try_unwrap_int() - .map_err(|_| Error::EvalError("expected a int but found ...".to_string()))?; + fn elem_at(list: Rc, idx: i64) -> Result { list.get(idx as usize) .ok_or_else(|| { - Error::EvalError(format!( + Error::eval_error(format!( "'builtins.elemAt' called with index {idx} on a list of size {}", list.len() )) @@ -46,7 +59,7 @@ pub mod builtins { .cloned() } - fn elem(elem: Value, list: Value) -> Result { + fn elem(elem: Value, list: Rc) -> Result { todo!() } } diff --git a/evaluator/nixjit_context/Cargo.toml b/evaluator/nixjit_context/Cargo.toml index 033dcea..acf1d41 100644 --- a/evaluator/nixjit_context/Cargo.toml +++ b/evaluator/nixjit_context/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] +bumpalo = { version = "3.19", features = ["boxed"] } derive_more = { version = "2.0", features = ["full"] } hashbrown = "0.15" itertools = "0.14" diff --git a/evaluator/nixjit_context/src/downgrade.rs b/evaluator/nixjit_context/src/downgrade.rs new file mode 100644 index 0000000..3661546 --- /dev/null +++ b/evaluator/nixjit_context/src/downgrade.rs @@ -0,0 +1,61 @@ +use std::cell::RefCell; + +use nixjit_error::Result; +use nixjit_hir::{Downgrade, DowngradeContext, Hir}; +use nixjit_ir::ExprId; + +use super::Context; + +pub struct DowngradeCtx<'ctx, 'bump> { + ctx: &'ctx mut Context<'bump>, + irs: Vec>, +} + +impl<'ctx, 'bump> DowngradeCtx<'ctx, 'bump> { + pub fn new(ctx: &'ctx mut Context<'bump>) -> Self { + Self { + ctx, + irs: Vec::new(), + } + } +} + +impl DowngradeCtx<'_, '_> { + fn get_ir(&self, id: ExprId) -> &RefCell { + let idx = unsafe { id.raw() } - self.ctx.lirs.len() - self.ctx.hirs.len(); + if cfg!(debug_assertions) { + self.irs.get(idx).unwrap() + } else { + unsafe { self.irs.get_unchecked(idx) } + } + } +} + +impl DowngradeContext for DowngradeCtx<'_, '_> { + fn new_expr(&mut self, expr: Hir) -> ExprId { + self.irs.push(expr.into()); + unsafe { ExprId::from_raw(self.ctx.lirs.len() + self.ctx.hirs.len() + self.irs.len() - 1) } + } + + fn with_expr_mut(&mut self, id: ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T { + unsafe { + let self_mut = &mut *(self as *mut Self); + f(&mut self.get_ir(id).borrow_mut(), self_mut) + } + } + + fn downgrade_root(mut self, root: rnix::ast::Expr) -> Result { + let id = root.downgrade(&mut self)?; + self.ctx + .hirs + .extend(self.irs.into_iter().map(RefCell::into_inner)); + for (idx, ir) in self.ctx.hirs.iter().enumerate() { + println!( + "{:?} {:#?}", + unsafe { ExprId::from_raw(idx + self.ctx.lirs.len()) }, + &ir + ); + } + Ok(id) + } +} diff --git a/evaluator/nixjit_context/src/eval.rs b/evaluator/nixjit_context/src/eval.rs new file mode 100644 index 0000000..ed5e8ef --- /dev/null +++ b/evaluator/nixjit_context/src/eval.rs @@ -0,0 +1,144 @@ +use std::rc::Rc; + +use hashbrown::HashMap; +use itertools::Itertools; + +use nixjit_error::Result; +use nixjit_eval::{Args, EvalContext, Evaluate, StackFrame, Value}; +use nixjit_ir::ExprId; +use nixjit_jit::JITContext; +use nixjit_lir::Lir; + +use super::Context; + +pub struct EvalCtx<'ctx, 'bump> { + ctx: &'ctx mut Context<'bump>, + stack: Vec, + with_scopes: Vec>>, +} + +impl<'ctx, 'bump> EvalCtx<'ctx, 'bump> { + pub fn new(ctx: &'ctx mut Context<'bump>) -> Self { + Self { + ctx, + stack: Vec::new(), + with_scopes: Vec::new(), + } + } + + fn eval_deps(&mut self, expr: ExprId, arg: Option) -> Result<()> { + let deps = self + .ctx + .graph + .edges(expr) + .sorted_by_key(|(.., idx)| **idx) + .map(|(_, dep, idx)| (dep, *idx)) + .collect_vec(); + let mut frame = (0..deps.len()) + .map(|_| Value::Blackhole) + .collect::(); + dbg!(&deps, &self.stack); + for (dep, idx) in deps { + unsafe { + if matches!( + &**self.ctx.lirs.get_unchecked(dep.raw()), + Lir::Arg(_) + ) { + *frame.get_unchecked_mut(idx.raw()) = arg.as_ref().unwrap().clone(); + continue; + } + } + let dep = self.eval(dep)?; + unsafe { + *frame.get_unchecked_mut(idx.raw()) = dep; + } + } + *self.stack.last_mut().unwrap() = frame; + dbg!(&self.stack); + Ok(()) + } +} + +impl EvalContext for EvalCtx<'_, '_> { + fn eval_root(mut self, expr: ExprId) -> Result { + self.stack.push(StackFrame::new()); + self.eval_deps(expr, None)?; + self.eval(expr) + } + + fn eval(&mut self, expr: ExprId) -> Result { + let idx = unsafe { expr.raw() }; + let lir = unsafe { &*(&**self.ctx.lirs.get_unchecked(idx) as *const Lir) }; + lir.eval(self) + } + + fn call(&mut self, func: ExprId, arg: Option, frame: StackFrame) -> Result { + self.stack.push(frame); + if let Err(err) = self.eval_deps(func, arg) { + self.stack.pop(); + return Err(err) + } + let ret = self.eval(func); + self.stack.pop(); + ret + } + + fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a Value> { + for scope in self.with_scopes.iter().rev() { + if let Some(val) = scope.get(ident) { + return Some(val); + } + } + None + } + + fn lookup_stack(&self, idx: nixjit_ir::StackIdx) -> &Value { + if cfg!(debug_assertions) { + self.stack + .last() + .unwrap() + .get(unsafe { idx.raw() }) + .unwrap() + } else { + unsafe { + self.stack + .last() + .unwrap_unchecked() + .get_unchecked(idx.raw()) + } + } + } + + fn capture_stack(&self) -> &StackFrame { + self.stack.last().unwrap() + } + + fn call_primop(&mut self, id: nixjit_ir::PrimOpId, args: Args) -> Result { + unsafe { (self.ctx.primops.get_unchecked(id.raw()).1)(self.ctx, args) } + } + + fn get_primop_arity(&self, id: nixjit_ir::PrimOpId) -> usize { + unsafe { self.ctx.primops.get_unchecked(id.raw()).0 } + } + + fn with_with_env( + &mut self, + namespace: Rc>, + f: impl FnOnce(&mut Self) -> T, + ) -> T { + self.with_scopes.push(namespace); + let res = f(self); + self.with_scopes.pop(); + res + } +} + +impl JITContext for EvalCtx<'_, '_> { + fn enter_with(&mut self, namespace: Rc>) { + self.with_scopes.push(namespace); + } + + fn exit_with(&mut self) { + self.with_scopes.pop(); + } +} diff --git a/evaluator/nixjit_context/src/lib.rs b/evaluator/nixjit_context/src/lib.rs index 5674c52..a9bb79e 100644 --- a/evaluator/nixjit_context/src/lib.rs +++ b/evaluator/nixjit_context/src/lib.rs @@ -1,89 +1,49 @@ -//! The central evaluation context for the nixjit interpreter. -//! -//! This module defines the `Context` struct, which holds all the state -//! necessary for the evaluation of a Nix expression. It manages the -//! Intermediate Representations (IRs), scopes, evaluation stack, and -//! the Just-In-Time (JIT) compiler. -//! -//! The `Context` implements various traits (`DowngradeContext`, `ResolveContext`, etc.) -//! to provide the necessary services for each stage of the compilation and -//! evaluation pipeline. -use std::cell::{OnceCell, RefCell}; -use std::rc::Rc; +use std::{marker::PhantomPinned, ops::Deref}; -use derive_more::Unwrap; -use hashbrown::{HashMap, HashSet}; +use bumpalo::{Bump, boxed::Box}; +use hashbrown::HashMap; use itertools::Itertools; -use petgraph::graph::{DiGraph, NodeIndex}; +use petgraph::{ + dot::{Config, Dot}, + graphmap::DiGraphMap, +}; use nixjit_builtins::{ Builtins, BuiltinsContext, builtins::{CONSTS_LEN, GLOBAL_LEN, SCOPED_LEN}, }; use nixjit_error::{Error, Result}; -use nixjit_eval::{EvalContext, Evaluate, Value}; -use nixjit_hir::{Downgrade, DowngradeContext, Hir}; -use nixjit_ir::{ArgIdx, Const, ExprId, Param, PrimOp, PrimOpId}; -use nixjit_lir::{Lir, LookupResult, Resolve, ResolveContext}; +use nixjit_eval::{Args, EvalContext, Value}; +use nixjit_hir::{DowngradeContext, Hir}; +use nixjit_ir::{AttrSet, Const, ExprId, Param, PrimOpId, StackIdx}; +use nixjit_lir::{Lir, ResolveContext}; -use nixjit_jit::{JITCompiler, JITContext, JITFunc}; -use replace_with::replace_with_and_return; +use crate::downgrade::DowngradeCtx; +use crate::eval::EvalCtx; +use crate::resolve::ResolveCtx; -/// Represents a lexical scope during name resolution. -enum Scope { - /// A `with` expression scope. - With, - /// A `let` binding scope, mapping variable names to their expression IDs. - Let(HashMap), - /// A function argument scope. `Some` holds the name of the argument set if present. - Arg(Option), +mod downgrade; +mod eval; +mod resolve; + +#[derive(Debug)] +struct Pin<'bump, T> { + ptr: Box<'bump, T>, + _marker: PhantomPinned, } -/// Represents an expression at different stages of compilation. -#[derive(Debug, Unwrap)] -enum Ir { - /// An expression in the High-Level Intermediate Representation (HIR). - Hir(Hir), - /// An expression in the Low-Level Intermediate Representation (LIR). - Lir(Lir), +impl Deref for Pin<'_, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + self.ptr.as_ref() + } } -impl Ir { - unsafe fn unwrap_hir_ref_unchecked(&self) -> &Hir { - if let Self::Hir(hir) = self { - hir - } else { - unsafe { core::hint::unreachable_unchecked() } - } - } - - unsafe fn unwrap_hir_mut_unchecked(&mut self) -> &mut Hir { - #[cfg(debug_assertions)] - if let Self::Hir(hir) = self { - hir - } else { - unsafe { core::hint::unreachable_unchecked() } - } - #[cfg(not(debug_assertions))] - if let Self::Hir(hir) = self { - hir - } else { - unsafe { core::hint::unreachable_unchecked() } - } - } - - unsafe fn unwrap_lir_ref_unchecked(&self) -> &Lir { - #[cfg(debug_assertions)] - if let Self::Lir(lir) = self { - lir - } else { - unsafe { core::hint::unreachable_unchecked() } - } - #[cfg(not(debug_assertions))] - if let Self::Lir(lir) = self { - lir - } else { - panic!() +impl<'bump, T> Pin<'bump, T> { + fn new_in(x: T, bump: &'bump Bump) -> Self { + Self { + ptr: Box::new_in(x, bump), + _marker: PhantomPinned, } } } @@ -92,103 +52,111 @@ impl Ir { /// /// This struct orchestrates the entire Nix expression evaluation process, /// from parsing and semantic analysis to interpretation and JIT compilation. -pub struct Context { - /// Arena for all expressions, which can be either HIR or LIR. - /// `RefCell` is used for interior mutability to allow on-demand resolution. - irs: Vec>, - /// Tracks whether an `ExprId` has been resolved from HIR to LIR. - resolved: Vec, - /// The stack of lexical scopes used for name resolution. - scopes: Vec, - /// The number of arguments in the current function call scope. - args_count: usize, - /// A table of primitive operation implementations. - primops: Vec) -> Result>, +pub struct Context<'bump> { + hirs: Vec, + lirs: Vec>, /// Maps a function's body `ExprId` to its parameter definition. funcs: HashMap, + + repl_scope: HashMap, + global_scope: HashMap<&'static str, ExprId>, + /// A dependency graph between expressions. - graph: DiGraph, - /// Maps an `ExprId` to its corresponding `NodeIndex` in the dependency graph. - nodes: Vec, + graph: DiGraphMap, - /// The call stack for function evaluation, where each frame holds arguments. - stack: Vec>, - /// A stack of namespaces for `with` expressions during evaluation. - with_scopes: Vec>>, + /// A table of primitive operation implementations. + primops: [(usize, fn(&mut Self, Args) -> Result); GLOBAL_LEN + SCOPED_LEN], - /// The Just-In-Time (JIT) compiler. - jit: JITCompiler, - /// A cache for JIT-compiled functions, indexed by `ExprId`. - compiled: Vec>>, + bump: &'bump Bump, } -impl Default for Context { - fn default() -> Self { +impl<'bump> Context<'bump> { + pub fn new(bump: &'bump Bump) -> Self { let Builtins { consts, global, scoped, } = Builtins::new(); - let global_scope = Scope::Let( - consts - .iter() - .enumerate() - .map(|(id, (k, _))| (k.to_string(), unsafe { ExprId::from(id) })) - .chain(global.iter().enumerate().map(|(idx, (k, _, _))| { - (k.to_string(), unsafe { ExprId::from(idx + CONSTS_LEN) }) - })) - .chain(core::iter::once(("builtins".to_string(), unsafe { - ExprId::from(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN) - }))) - .collect(), - ); - let primops = global + let global_scope = consts .iter() - .map(|&(_, _, f)| f) - .chain(scoped.iter().map(|&(_, _, f)| f)) - .collect(); - let irs = consts - .into_iter() - .map(|(_, val)| Ir::Lir(Lir::Const(Const { val }))) + .enumerate() + .map(|(id, (k, _))| (*k, unsafe { ExprId::from_raw(id) })) .chain( global + .iter() + .enumerate() + .map(|(idx, (k, _, _))| (*k, unsafe { ExprId::from_raw(idx + CONSTS_LEN) })), + ) + .chain(core::iter::once(("builtins", unsafe { + ExprId::from_raw(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN) + }))) + .collect(); + let primops = global + .iter() + .map(|&(_, arity, f)| (arity, f)) + .chain(scoped.iter().map(|&(_, arity, f)| (arity, f))) + .collect_array() + .unwrap(); + let lirs = consts + .into_iter() + .map(|(_, val)| Lir::Const(Const { val })) + .chain((0..global.len()).map(|idx| Lir::PrimOp(unsafe { PrimOpId::from_raw(idx) }))) + .chain( + (0..scoped.len()) + .map(|idx| Lir::PrimOp(unsafe { PrimOpId::from_raw(idx + GLOBAL_LEN) })), + ) + .chain(core::iter::once(Lir::AttrSet(AttrSet { + stcs: consts .into_iter() .enumerate() - .map(|(idx, (name, arity, _))| { - Ir::Lir(Lir::PrimOp(PrimOp { - name, - arity, - id: unsafe { PrimOpId::from(idx) }, - })) - }), - ) - .map(RefCell::new) + .map(|(idx, (name, _))| (name.to_string(), unsafe { ExprId::from_raw(idx) })) + .chain(global.into_iter().enumerate().map(|(idx, (name, ..))| { + (name.to_string(), unsafe { + ExprId::from_raw(idx + CONSTS_LEN) + }) + })) + .chain(scoped.into_iter().enumerate().map(|(idx, (name, ..))| { + (name.to_string(), unsafe { + ExprId::from_raw(idx + CONSTS_LEN + GLOBAL_LEN) + }) + })) + .chain(core::iter::once(("builtins".to_string(), unsafe { + ExprId::from_raw(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN + 1) + }))) + .collect(), + ..AttrSet::default() + }))) + .chain(core::iter::once(Lir::Thunk(unsafe { + ExprId::from_raw(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN) + }))) + .map(|lir| Pin::new_in(lir, bump)) .collect(); Self { - irs, - resolved: Vec::new(), - scopes: vec![global_scope], - args_count: 0, - primops, + hirs: Vec::new(), + lirs, funcs: HashMap::new(), - graph: DiGraph::new(), - nodes: Vec::new(), - stack: Vec::new(), - with_scopes: Vec::new(), + global_scope, + repl_scope: HashMap::new(), + graph: DiGraphMap::new(), - jit: JITCompiler::new(), - compiled: Vec::new(), + primops, + + bump, } } -} -impl Context { - /// Creates a new, default `Context`. - pub fn new() -> Self { - Self::default() + pub fn downgrade_ctx<'a>(&'a mut self) -> DowngradeCtx<'a, 'bump> { + DowngradeCtx::new(self) } + pub fn resolve_ctx<'a>(&'a mut self) -> ResolveCtx<'a, 'bump> { + ResolveCtx::new(self) + } + + pub fn eval_ctx<'a>(&'a mut self) -> EvalCtx<'a, 'bump> { + EvalCtx::new(self) + } /// The main entry point for evaluating a Nix expression string. /// /// This function performs the following steps: @@ -196,227 +164,40 @@ impl Context { /// 2. Downgrades the AST to the High-Level IR (HIR). /// 3. Resolves the HIR to the Low-Level IR (LIR). /// 4. Evaluates the LIR to produce a final `Value`. - pub fn eval(mut self, expr: &str) -> Result { + pub fn eval(&mut self, expr: &str) -> Result { let root = rnix::Root::parse(expr); if !root.errors().is_empty() { - return Err(Error::ParseError( - root.errors().iter().map(|err| err.to_string()).join(";"), + return Err(Error::parse_error( + root.errors().iter().map(|err| err.to_string()).join("; "), )); } - let root = root.tree().expr().unwrap().downgrade(&mut self)?; - self.resolve(root)?; - Ok(EvalContext::eval(&mut self, root)?.to_public()) - } -} - -impl DowngradeContext for Context { - fn new_expr(&mut self, expr: Hir) -> ExprId { - let id = unsafe { ExprId::from(self.irs.len()) }; - self.irs.push(Ir::Hir(expr).into()); - self.nodes.push(self.graph.add_node(id)); - self.resolved.push(false); - self.compiled.push(OnceCell::new()); - id - } - fn with_expr(&self, id: ExprId, f: impl FnOnce(&Hir, &Self) -> T) -> T { - unsafe { - let idx = id.raw(); - f(&self.irs[idx].borrow().unwrap_hir_ref_unchecked(), self) - } - } - fn with_expr_mut(&mut self, id: ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T { - unsafe { - let idx = id.raw(); - let self_mut = &mut *(self as *mut Self); - f( - &mut self - .irs - .get_unchecked_mut(idx) - .borrow_mut() - .unwrap_hir_mut_unchecked(), - self_mut, - ) - } - } -} - -impl ResolveContext for Context { - fn lookup(&self, name: &str) -> LookupResult { - let mut arg_idx = 0; - let mut has_with = false; - for scope in self.scopes.iter().rev() { - match scope { - Scope::Let(scope) => { - if let Some(&expr) = scope.get(name) { - return LookupResult::Expr(expr); - } - } - Scope::Arg(ident) => { - if ident.as_deref() == Some(name) { - return LookupResult::Arg(unsafe { ArgIdx::from(arg_idx) }); - } - arg_idx += 1; - } - Scope::With => has_with = true, - } - } - if has_with { - LookupResult::Unknown - } else { - LookupResult::NotFound + let root = self + .downgrade_ctx() + .downgrade_root(root.tree().expr().unwrap())?; + self.resolve_ctx().resolve_root(root)?; + println!( + "{:?}", + Dot::with_config(&self.graph, &[Config::EdgeNoLabel]) + ); + for (idx, ir) in self.lirs.iter().enumerate() { + println!("{:?} {:#?}", unsafe { ExprId::from_raw(idx) }, &ir); } + Ok(self.eval_ctx().eval_root(root)?.to_public()) } - fn new_dep(&mut self, expr: ExprId, dep: ExprId) { - unsafe { - let expr = expr.raw(); - let dep = dep.raw(); - let expr = *self.nodes.get_unchecked(expr); - let dep = *self.nodes.get_unchecked(dep); - self.graph.add_edge(expr, dep, ()); - } - } - - fn resolve(&mut self, expr: ExprId) -> Result<()> { - unsafe { - let idx = expr.raw(); - let self_mut = &mut *(self as *mut Self); - replace_with_and_return( - &mut *self.irs.get_unchecked(idx).borrow_mut(), - || { - Ir::Hir(Hir::Const(Const { - val: nixjit_value::Const::Null, - })) - }, - |ir| { - let Ir::Hir(hir) = ir else { - return (Ok(()), ir); - }; - match hir.resolve(self_mut) { - Ok(lir) => (Ok(()), Ir::Lir(lir)), - Err(err) => ( - Err(err), - Ir::Hir(Hir::Const(Const { - val: nixjit_value::Const::Null, - })), - ), - } - }, - )?; + pub fn add_binding(&mut self, ident: &str, expr: &str) -> Result<()> { + let root = rnix::Root::parse(expr); + if !root.errors().is_empty() { + return Err(Error::parse_error( + root.errors().iter().map(|err| err.to_string()).join("; "), + )); } + let root_expr = root.tree().expr().unwrap(); + let expr_id = self.downgrade_ctx().downgrade_root(root_expr)?; + self.resolve_ctx().resolve_root(expr_id)?; + self.repl_scope.insert(ident.to_string(), expr_id); Ok(()) } - - fn new_func(&mut self, body: ExprId, param: Param) { - self.funcs.insert(body, param); - } - - fn with_let_env<'a, T>( - &mut self, - bindings: impl Iterator, - f: impl FnOnce(&mut Self) -> T, - ) -> T { - let mut scope = HashMap::new(); - for (name, expr) in bindings { - scope.insert(name.clone(), *expr); - } - self.scopes.push(Scope::Let(scope)); - let res = f(self); - self.scopes.pop(); - res - } - - fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T) { - self.scopes.push(Scope::With); - let res = f(self); - self.scopes.pop(); - (true, res) - } - - fn with_param_env(&mut self, ident: Option, f: impl FnOnce(&mut Self) -> T) -> T { - self.scopes.push(Scope::Arg(ident)); - self.args_count += 1; - let res = f(self); - self.args_count -= 1; - self.scopes.pop(); - res - } } -impl EvalContext for Context { - fn eval(&mut self, expr: ExprId) -> Result { - let idx = unsafe { expr.raw() }; - let lir = unsafe { - &*(self - .irs - .get_unchecked(idx) - .borrow() - .unwrap_lir_ref_unchecked() as *const Lir) - }; - println!("{:#?}", self.irs); - lir.eval(self) - } - - fn pop_frame(&mut self) -> Vec { - self.stack.pop().unwrap() - } - - fn lookup_stack<'a>(&'a self, offoset: usize) -> &'a Value { - todo!() - } - - fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a nixjit_eval::Value> { - for scope in self.with_scopes.iter().rev() { - if let Some(val) = scope.get(ident) { - return Some(val); - } - } - None - } - - fn lookup_arg<'a>(&'a self, idx: ArgIdx) -> &'a Value { - unsafe { - let values = self.stack.last().unwrap_unchecked(); - dbg!(values, idx); - &values[values.len() - idx.raw() - 1] - } - } - - fn with_with_env( - &mut self, - namespace: std::rc::Rc>, - f: impl FnOnce(&mut Self) -> T, - ) -> T { - self.with_scopes.push(namespace); - let res = f(self); - self.with_scopes.pop(); - res - } - - fn with_args_env( - &mut self, - args: Vec, - f: impl FnOnce(&mut Self) -> T, - ) -> (Vec, T) { - self.stack.push(args); - let res = f(self); - let frame = self.stack.pop().unwrap(); - (frame, res) - } - - fn call_primop(&mut self, id: nixjit_ir::PrimOpId, args: Vec) -> Result { - unsafe { (self.primops.get_unchecked(id.raw()))(self, args) } - } -} - -impl JITContext for Context { - fn enter_with(&mut self, namespace: std::rc::Rc>) { - self.with_scopes.push(namespace); - } - - fn exit_with(&mut self) { - self.with_scopes.pop(); - } -} - -impl BuiltinsContext for Context {} +impl BuiltinsContext for Context<'_> {} diff --git a/evaluator/nixjit_context/src/resolve.rs b/evaluator/nixjit_context/src/resolve.rs new file mode 100644 index 0000000..3e107a2 --- /dev/null +++ b/evaluator/nixjit_context/src/resolve.rs @@ -0,0 +1,263 @@ +use std::cell::RefCell; +use std::pin::Pin; + +use bumpalo::boxed::Box; +use derive_more::Unwrap; +use hashbrown::HashMap; + +use nixjit_error::Result; +use nixjit_hir::Hir; +use nixjit_ir::{Const, ExprId, Param, StackIdx}; +use nixjit_lir::{Lir, LookupResult, Resolve, ResolveContext}; +use replace_with::replace_with_and_return; + +use super::Context; + +#[derive(Clone)] +enum Scope<'ctx> { + /// A `let` binding scope, mapping variable names to their expression IDs. + Let(HashMap), + /// A function argument scope. `Some` holds the name of the argument set if present. + Arg(Option), + Builtins(&'ctx HashMap<&'static str, ExprId>), + Repl(&'ctx HashMap) +} + +/// Represents an expression at different stages of compilation. +#[derive(Debug, Unwrap)] +enum Ir { + /// An expression in the High-Level Intermediate Representation (HIR). + Hir(Hir), + /// An expression in the Low-Level Intermediate Representation (LIR). + Lir(Lir), +} + +impl Ir { + unsafe fn unwrap_hir_unchecked(self) -> Hir { + if let Self::Hir(hir) = self { + hir + } else { + unsafe { core::hint::unreachable_unchecked() } + } + } +} + +pub struct ResolveCtx<'ctx, 'bump> { + ctx: &'ctx mut Context<'bump>, + irs: Vec>>>, + scopes: Vec>, + has_with: bool, + with_used: bool, + closures: Vec<(ExprId, Option, usize)>, + current_expr: Option, +} + +impl<'ctx, 'bump> ResolveCtx<'ctx, 'bump> { + pub fn new(ctx: &'ctx mut Context<'bump>) -> Self { + let ctx_mut = unsafe { &mut *(ctx as *mut Context) }; + Self { + scopes: vec![ + Scope::Builtins(&ctx.global_scope), + Scope::Repl(&ctx.repl_scope) + ], + has_with: false, + with_used: false, + irs: core::mem::take(&mut ctx.hirs) + .into_iter() + .map(|hir| Ir::Hir(hir).into()) + .map(|ir| Box::new_in(ir, ctx.bump)) + .map(Pin::new) + .collect(), + ctx: ctx_mut, + closures: Vec::new(), + current_expr: None, + } + } + + fn get_ir(&self, id: ExprId) -> &RefCell { + let idx = unsafe { id.raw() } - self.ctx.lirs.len(); + if cfg!(debug_assertions) { + self.irs.get(idx).unwrap() + } else { + unsafe { self.irs.get_unchecked(idx) } + } + } + + fn get_ir_mut(&mut self, id: ExprId) -> &mut RefCell { + let idx = unsafe { id.raw() } - self.ctx.lirs.len(); + if cfg!(debug_assertions) { + self.irs.get_mut(idx).unwrap() + } else { + unsafe { self.irs.get_unchecked_mut(idx) } + } + } + + fn add_dep(&mut self, from: ExprId, to: ExprId, count: &mut usize) -> StackIdx { + if let Some(&idx) = self.ctx.graph.edge_weight(from, to) { + idx + } else { + *count += 1; + let idx = unsafe { StackIdx::from_raw(*count - 1) }; + assert_ne!(from, to); + self.ctx.graph.add_edge(from, to, idx); + idx + } + } + + fn new_lir(&mut self, lir: Lir) -> ExprId { + self.irs.push(Pin::new(Box::new_in( + RefCell::new(Ir::Lir(lir)), + self.ctx.bump, + ))); + unsafe { ExprId::from_raw(self.ctx.lirs.len() + self.irs.len() - 1) } + } +} + +impl ResolveContext for ResolveCtx<'_, '_> { + fn resolve(&mut self, expr: ExprId) -> Result<()> { + let prev_expr = self.current_expr.replace(expr); + let result = unsafe { + let ctx = &mut *(self as *mut Self); + let ir = &mut self.get_ir_mut(expr); + if !matches!(ir.try_borrow().as_deref(), Ok(Ir::Hir(_))) { + return Ok(()); + } + replace_with_and_return( + &mut *ir.borrow_mut(), + || { + Ir::Hir(Hir::Const(Const { + val: nixjit_value::Const::Null, + })) + }, + |ir| match ir.unwrap_hir_unchecked().resolve(ctx) { + Ok(lir) => (Ok(()), Ir::Lir(lir)), + Err(err) => ( + Err(err), + Ir::Hir(Hir::Const(Const { + val: nixjit_value::Const::Null, + })), + ), + }, + ) + }; + self.current_expr = prev_expr; + result + } + + fn resolve_root(mut self, expr: ExprId) -> Result<()> { + self.closures.push((expr, None, 0)); + let ret = self.resolve(expr); + if ret.is_ok() { + self.ctx.lirs.extend( + self.irs + .into_iter() + .map(|pin| unsafe { core::mem::transmute::, Box<_>>(pin) }) + .map(Box::into_inner) + .map(RefCell::into_inner) + .map(Ir::unwrap_lir) + .map(|lir| crate::Pin::new_in(lir, self.ctx.bump)) + ); + } + ret + } + + fn lookup(&mut self, name: &str) -> LookupResult { + let mut closure_depth = 0; + // Then search from outer to inner scopes for dependencies + for scope in self.scopes.iter().rev() { + match scope { + Scope::Builtins(scope) => { + if let Some(&expr) = scope.get(&name) { + return LookupResult::Expr(expr); + } + } + Scope::Let(scope) | &Scope::Repl(scope) => { + if let Some(&dep) = scope.get(name) { + let (expr, _, deps) = unsafe { &mut *(self as *mut Self) } + .closures + .last_mut() + .unwrap(); + let idx = self.add_dep(*expr, dep, deps); + return LookupResult::Stack(idx); + } + } + Scope::Arg(ident) => { + if ident.as_deref() == Some(name) { + // This is an outer function's parameter, treat as dependency + // We need to find the corresponding parameter expression to create dependency + // For now, we need to handle this case by creating a dependency to the parameter + let mut iter = unsafe { &mut *(self as *mut Self) } + .closures + .iter_mut() + .rev() + .take(closure_depth + 1) + .rev(); + let Some((func, Some(arg), count)) = iter.next() else { + unreachable!() + }; + let mut cur = self.add_dep(*func, *arg, count); + for (func, _, count) in iter { + let idx = self.new_lir(Lir::StackRef(cur)); + cur = self.add_dep(*func, idx, count); + } + return LookupResult::Stack(cur); + } + closure_depth += 1; + } + } + } + if self.has_with { + self.with_used = true; + LookupResult::Unknown + } else { + LookupResult::NotFound + } + } + + fn lookup_arg(&mut self) -> StackIdx { + let Some((func, Some(arg), count)) = unsafe { &mut *(self as *mut Self) }.closures.last_mut() else { + unreachable!() + }; + self.add_dep(*func, *arg, count) + } + + fn new_func(&mut self, body: ExprId, param: Param) { + self.ctx.funcs.insert(body, param); + } + + fn with_let_env( + &mut self, + bindings: HashMap, + f: impl FnOnce(&mut Self) -> T, + ) -> T { + self.scopes.push(Scope::Let(bindings)); + let res = f(self); + self.scopes.pop(); + res + } + + fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T) { + let has_with = self.has_with; + let with_used = self.with_used; + self.has_with = true; + self.with_used = false; + let res = f(self); + self.has_with = has_with; + (core::mem::replace(&mut self.with_used, with_used), res) + } + + fn with_param_env( + &mut self, + func: ExprId, + ident: Option, + f: impl FnOnce(&mut Self) -> T, + ) -> T { + let arg = self.new_lir(Lir::Arg(nixjit_ir::Arg)); + self.closures.push((func, Some(arg), 0)); + self.scopes.push(Scope::Arg(ident)); + let res = f(self); + self.scopes.pop(); + self.closures.pop(); + res + } +} diff --git a/evaluator/nixjit_error/Cargo.toml b/evaluator/nixjit_error/Cargo.toml index 09a586e..2e13152 100644 --- a/evaluator/nixjit_error/Cargo.toml +++ b/evaluator/nixjit_error/Cargo.toml @@ -5,3 +5,4 @@ edition = "2024" [dependencies] thiserror = "2.0" +rnix = "0.12" diff --git a/evaluator/nixjit_error/src/lib.rs b/evaluator/nixjit_error/src/lib.rs index 8f18e24..2cd1c36 100644 --- a/evaluator/nixjit_error/src/lib.rs +++ b/evaluator/nixjit_error/src/lib.rs @@ -3,6 +3,7 @@ //! handling here, we ensure a consistent approach to reporting failures across //! different stages of processing, from parsing to final evaluation. +use std::rc::Rc; use thiserror::Error; /// A specialized `Result` type used for all fallible operations within the @@ -12,7 +13,7 @@ pub type Result = core::result::Result; /// The primary error enum, encompassing all potential failures that can occur /// during the lifecycle of a Nix expression's evaluation. #[derive(Error, Debug)] -pub enum Error { +pub enum ErrorKind { /// An error occurred during the initial parsing phase. This typically /// indicates a syntax error in the Nix source code, as detected by the /// `rnix` parser. @@ -47,3 +48,112 @@ pub enum Error { #[error("an unknown or unexpected error occurred")] Unknown, } + +#[derive(Debug)] +pub struct Error { + pub kind: ErrorKind, + pub span: Option, + pub source: Option>, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Basic display + write!(f, "{}", self.kind)?; + + // If we have source and span, print context + if let (Some(source), Some(span)) = (&self.source, self.span) { + let start_byte = usize::from(span.start()); + let end_byte = usize::from(span.end()); + + if start_byte > source.len() || end_byte > source.len() { + return Ok(()); // Span is out of bounds + } + + let mut start_line = 1; + let mut start_col = 1usize; + let mut line_start_byte = 0; + for (i, c) in source.char_indices() { + if i >= start_byte { + break; + } + if c == '\n' { + start_line += 1; + start_col = 1; + line_start_byte = i + 1; + } else { + start_col += 1; + } + } + + let line_end_byte = source[line_start_byte..] + .find('\n') + .map(|i| line_start_byte + i) + .unwrap_or(source.len()); + + let line_str = &source[line_start_byte..line_end_byte]; + + let underline_len = if end_byte > start_byte { + end_byte - start_byte + } else { + 1 + }; + + write!(f, "\n --> {}:{}", start_line, start_col)?; + write!(f, "\n |\n")?; + write!(f, "{:4} | {}\n", start_line, line_str)?; + write!( + f, + " | {}{}", + " ".repeat(start_col.saturating_sub(1)), + "^".repeat(underline_len) + )?; + } + Ok(()) + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&self.kind) + } +} + +impl Error { + pub fn new(kind: ErrorKind) -> Self { + Self { + kind, + span: None, + source: None, + } + } + + pub fn with_span(mut self, span: rnix::TextRange) -> Self { + self.span = Some(span); + self + } + + pub fn with_source(mut self, source: Rc) -> Self { + self.source = Some(source); + self + } + + pub fn parse_error(msg: String) -> Self { + Self::new(ErrorKind::ParseError(msg)) + } + pub fn downgrade_error(msg: String) -> Self { + Self::new(ErrorKind::DowngradeError(msg)) + } + pub fn resolution_error(msg: String) -> Self { + Self::new(ErrorKind::ResolutionError(msg)) + } + pub fn eval_error(msg: String) -> Self { + Self::new(ErrorKind::EvalError(msg)) + } + pub fn catchable(msg: String) -> Self { + Self::new(ErrorKind::Catchable(msg)) + } + pub fn unknown() -> Self { + Self::new(ErrorKind::Unknown) + } +} diff --git a/evaluator/nixjit_eval/Cargo.toml b/evaluator/nixjit_eval/Cargo.toml index a3e7e74..3c30365 100644 --- a/evaluator/nixjit_eval/Cargo.toml +++ b/evaluator/nixjit_eval/Cargo.toml @@ -8,6 +8,7 @@ derive_more = { version = "2.0", features = ["full"] } hashbrown = "0.15" itertools = "0.14" replace_with = "0.1" +smallvec = { version = "1.15", features = ["union"] } nixjit_error = { path = "../nixjit_error" } nixjit_ir = { path = "../nixjit_ir" } diff --git a/evaluator/nixjit_eval/src/lib.rs b/evaluator/nixjit_eval/src/lib.rs index c138944..cc58082 100644 --- a/evaluator/nixjit_eval/src/lib.rs +++ b/evaluator/nixjit_eval/src/lib.rs @@ -12,7 +12,7 @@ use std::rc::Rc; use hashbrown::HashMap; use nixjit_error::{Error, Result}; -use nixjit_ir::{self as ir, ArgIdx, ExprId, PrimOpId}; +use nixjit_ir::{self as ir, ExprId, PrimOpId, StackIdx}; use nixjit_lir as lir; use nixjit_value::{Const, format_symbol}; @@ -21,10 +21,14 @@ pub use crate::value::*; mod value; /// A trait defining the context in which LIR expressions are evaluated. -pub trait EvalContext: Sized { +pub trait EvalContext { + fn eval_root(self, expr: ExprId) -> Result; + + /// Evaluates an expression by its ID. fn eval(&mut self, expr: ExprId) -> Result; + fn call(&mut self, func: ExprId, arg: Option, frame: StackFrame) -> Result; /// Enters a `with` scope for the duration of a closure's execution. fn with_with_env( &mut self, @@ -32,27 +36,18 @@ pub trait EvalContext: Sized { f: impl FnOnce(&mut Self) -> T, ) -> T; - /// Pushes a new set of arguments onto the stack for a function call. - fn with_args_env( - &mut self, - args: Vec, - f: impl FnOnce(&mut Self) -> T, - ) -> (Vec, T); - /// Looks up a stack slot on the current stack frame. - fn lookup_stack<'a>(&'a self, idx: usize) -> &'a Value; + fn lookup_stack(&self, idx: StackIdx) -> &Value; + + fn capture_stack(&self) -> &StackFrame; /// Looks up an identifier in the current `with` scope chain. fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a Value>; - /// Looks up a function argument by its index on the current stack frame. - fn lookup_arg<'a>(&'a self, idx: ArgIdx) -> &'a Value; - - /// Pops the current stack frame, returning the arguments. - fn pop_frame(&mut self) -> Vec; - /// Calls a primitive operation (builtin) by its ID. - fn call_primop(&mut self, id: PrimOpId, args: Vec) -> Result; + fn call_primop(&mut self, id: PrimOpId, args: Args) -> Result; + + fn get_primop_arity(&self, id: PrimOpId) -> usize; } /// A trait for types that can be evaluated within an `EvalContext`. @@ -88,10 +83,12 @@ impl Evaluate for lir::Lir { Str(x) => x.eval(ctx), Var(x) => x.eval(ctx), Path(x) => x.eval(ctx), + &StackRef(idx) => Ok(ctx.lookup_stack(idx).clone()), &ExprRef(expr) => ctx.eval(expr), - &FuncRef(func) => Ok(Value::Func(func)), - &ArgRef(idx) => Ok(ctx.lookup_arg(idx).clone()), + &FuncRef(body) => Ok(Value::Closure(Closure::new(body, ctx.capture_stack().clone()).into())), + &Arg(_) => unreachable!(), &PrimOp(primop) => Ok(Value::PrimOp(primop)), + &Thunk(id) => Ok(Value::Thunk(id)), } } } @@ -109,10 +106,8 @@ impl Evaluate for ir::AttrSet { .collect::>()?, ); for (k, v) in self.dyns.iter() { - let mut k = k.eval(ctx)?; - k.coerce_to_string()?; - let v_eval_result = v.eval(ctx)?; - attrs.push_attr(k.unwrap_string(), v_eval_result)?; + let v = v.eval(ctx)?; + attrs.push_attr(k.eval(ctx)?.force_string_no_ctx()?, v)?; } let result = Value::AttrSet(attrs.into()); Ok(result) @@ -137,11 +132,9 @@ impl Evaluate for ir::HasAttr { fn eval(&self, ctx: &mut Ctx) -> Result { use ir::Attr::*; let mut val = self.lhs.eval(ctx)?; - val.has_attr(self.rhs.iter().map(|attr| { - match attr { - Str(ident) => Ok(Value::String(ident.clone())), - Dynamic(expr) => expr.eval(ctx) - } + val.has_attr(self.rhs.iter().map(|attr| match attr { + Str(ident) => Ok(Value::String(ident.clone())), + Dynamic(expr) => expr.eval(ctx), }))?; Ok(val) } @@ -155,7 +148,7 @@ impl Evaluate for ir::BinOp { if matches!((&self.kind, &lhs), (And, Value::Bool(false))) { return Ok(Value::Bool(false)); } else if matches!((&self.kind, &lhs), (Or, Value::Bool(true))) { - return Ok(Value::Bool(true)) + return Ok(Value::Bool(true)); } let mut rhs = self.rhs.eval(ctx)?; match self.kind { @@ -193,9 +186,9 @@ impl Evaluate for ir::BinOp { } Con => lhs.concat(rhs)?, Upd => lhs.update(rhs)?, - PipeL => lhs.call(core::iter::once(Ok(rhs)), ctx)?, + PipeL => lhs.call(rhs, ctx)?, PipeR => { - rhs.call(core::iter::once(Ok(lhs)), ctx)?; + rhs.call(lhs, ctx)?; lhs = rhs; } } @@ -226,32 +219,20 @@ impl Evaluate for ir::Select { fn eval(&self, ctx: &mut Ctx) -> Result { use ir::Attr::*; let mut val = self.expr.eval(ctx)?; - if let Some(default) = &self.default { - let default = default.eval(ctx)?; - val.select_with_default( - self.attrpath.iter().map(|attr| { - Ok(match attr { - Str(ident) => ident.clone(), - Dynamic(expr) => { - let mut val = expr.eval(ctx)?; - val.coerce_to_string()?; - val.unwrap_string() - } - }) - }), - default, - )?; - } else { - val.select(self.attrpath.iter().map(|attr| { - Ok(match attr { - Str(ident) => ident.clone(), - Dynamic(expr) => { - let mut val = expr.eval(ctx)?; - val.coerce_to_string()?; - val.unwrap_string() - } - }) - }))?; + for attr in self.attrpath.iter() { + let name_val; + let name = match attr { + Str(name) => name, + Dynamic(expr) => { + name_val = expr.eval(ctx)?; + &*name_val.force_string_no_ctx()? + } + }; + if let Some(default) = self.default { + val.select_or(name, default, ctx) + } else { + val.select(name, ctx) + }? } Ok(val) } @@ -260,9 +241,9 @@ impl Evaluate for ir::Select { impl Evaluate for ir::If { /// Evaluates an `If` by evaluating the condition and then either the consequence or the alternative. fn eval(&self, ctx: &mut Ctx) -> Result { - let cond = self.cond.eval(ctx)?; - let cond = cond.as_ref().try_unwrap_bool().map_err(|_| { - Error::EvalError(format!( + let cond = &self.cond.eval(ctx)?; + let &cond = cond.try_into().map_err(|_| { + Error::eval_error(format!( "if-condition must be a boolean, but got {}", cond.typename() )) @@ -277,11 +258,9 @@ impl Evaluate for ir::If { } impl Evaluate for ir::Call { - /// Evaluates a `Call` by evaluating the function and its arguments, then performing the call. fn eval(&self, ctx: &mut Ctx) -> Result { let mut func = self.func.eval(ctx)?; - let ctx_mut = unsafe { &mut *(ctx as *mut Ctx) }; - func.call(self.args.iter().map(|arg| arg.eval(ctx)), ctx_mut)?; + func.call(self.arg.eval(ctx)?, ctx)?; Ok(func) } } @@ -296,7 +275,7 @@ impl Evaluate for ir::With { namespace .try_unwrap_attr_set() .map_err(|_| { - Error::EvalError(format!("'with' expects a set, but got {}", typename)) + Error::eval_error(format!("'with' expects a set, but got {}", typename)) })? .into_inner(), |ctx| self.expr.eval(ctx), @@ -308,9 +287,9 @@ impl Evaluate for ir::Assert { /// Evaluates an `Assert` by evaluating the condition. If true, it evaluates and /// returns the body; otherwise, it returns an error. fn eval(&self, ctx: &mut Ctx) -> Result { - let cond = self.assertion.eval(ctx)?; - let cond = cond.as_ref().try_unwrap_bool().map_err(|_| { - Error::EvalError(format!( + let cond = &self.assertion.eval(ctx)?; + let &cond = cond.try_into().map_err(|_| { + Error::eval_error(format!( "assertion condition must be a boolean, but got {}", cond.typename() )) @@ -318,7 +297,7 @@ impl Evaluate for ir::Assert { if cond { self.expr.eval(ctx) } else { - Err(Error::Catchable("assertion failed".into())) + Err(Error::catchable("assertion failed".into())) } } } @@ -329,7 +308,7 @@ impl Evaluate for ir::ConcatStrings { fn eval(&self, ctx: &mut Ctx) -> Result { let mut buf = String::new(); for part in self.parts.iter() { - buf.push_str(part.eval(ctx)?.coerce_to_string()?.as_ref().unwrap_string()); + buf.push_str(&part.eval(ctx)?.force_string_no_ctx()?); } Ok(Value::String(buf)) } @@ -361,7 +340,7 @@ impl Evaluate for ir::Var { fn eval(&self, ctx: &mut Ctx) -> Result { ctx.lookup_with(&self.sym) .ok_or_else(|| { - Error::EvalError(format!("undefined variable '{}'", format_symbol(&self.sym))) + Error::eval_error(format!("undefined variable '{}'", format_symbol(&self.sym))) }) .map(|val| val.clone()) } diff --git a/evaluator/nixjit_eval/src/value/attrset.rs b/evaluator/nixjit_eval/src/value/attrset.rs index 980fd0c..9deef64 100644 --- a/evaluator/nixjit_eval/src/value/attrset.rs +++ b/evaluator/nixjit_eval/src/value/attrset.rs @@ -5,14 +5,16 @@ use std::fmt::Debug; use std::rc::Rc; use derive_more::Constructor; -use hashbrown::hash_map::Entry; use hashbrown::HashMap; +use hashbrown::hash_map::Entry; use itertools::Itertools; use nixjit_error::{Error, Result}; -use nixjit_value::Symbol; +use nixjit_ir::ExprId; use nixjit_value::{self as p, format_symbol}; +use crate::EvalContext; + use super::Value; /// A wrapper around a `HashMap` representing a Nix attribute set. @@ -20,7 +22,7 @@ use super::Value; /// It uses `#[repr(transparent)]` to ensure it has the same memory layout /// as `HashMap`. #[repr(transparent)] -#[derive(Clone, Constructor, PartialEq)] +#[derive(Clone, Constructor)] pub struct AttrSet { data: HashMap, } @@ -31,9 +33,9 @@ impl Debug for AttrSet { write!(f, "{{ ")?; for (k, v) in self.data.iter() { match v { - List(_) => write!(f, "{k:?} = [ ... ]; ")?, - AttrSet(_) => write!(f, "{k:?} = {{ ... }}; ")?, - v => write!(f, "{k:?} = {v:?}; ")?, + List(_) => write!(f, "{} = [ ... ]; ", format_symbol(k))?, + AttrSet(_) => write!(f, "{} = {{ ... }}; ", format_symbol(k))?, + v => write!(f, "{} = {v:?}; ", format_symbol(k))?, } } write!(f, "}}") @@ -69,7 +71,7 @@ impl AttrSet { /// Inserts an attribute, returns an error if the attribute is already defined. pub fn push_attr(&mut self, sym: String, val: Value) -> Result<()> { match self.data.entry(sym) { - Entry::Occupied(occupied) => Err(Error::EvalError(format!( + Entry::Occupied(occupied) => Err(Error::eval_error(format!( "attribute '{}' already defined", format_symbol(occupied.key()) ))), @@ -80,30 +82,32 @@ impl AttrSet { } } - /// Performs a deep selection of an attribute from a nested set. - /// - /// It traverses the attribute path and returns the final value, or an error - /// if any intermediate attribute does not exist or is not a set. - pub fn select( + pub fn select(&self, name: &str, ctx: &mut impl EvalContext) -> Result { + self.data + .get(name) + .cloned() + .map(|attr| match attr { + Value::Thunk(id) => ctx.eval(id), + val => Ok(val), + }) + .ok_or_else(|| { + Error::eval_error(format!("attribute '{}' not found", format_symbol(name))) + })? + } + + pub fn select_or( &self, - mut path: impl DoubleEndedIterator>, + name: &str, + default: ExprId, + ctx: &mut impl EvalContext, ) -> Result { - let mut data = &self.data; - let last = path.nth_back(0).unwrap(); - for item in path { - let item = item?; - let Some(Value::AttrSet(attrs)) = data.get(&item) else { - return Err(Error::EvalError(format!( - "attribute '{}' not found", - format_symbol(item) - ))); - }; - data = attrs.as_inner(); - } - let last = last?; - data.get(&last).cloned().ok_or_else(|| { - Error::EvalError(format!("attribute '{}' not found", Symbol::from(last))) - }) + self.data + .get(name) + .map(|attr| match attr { + &Value::Thunk(id) => ctx.eval(id), + val => Ok(val.clone()), + }) + .unwrap_or_else(|| ctx.eval(default)) } /// Checks if an attribute path exists within the set. @@ -114,16 +118,14 @@ impl AttrSet { let mut data = &self.data; let last = path.nth_back(0).unwrap(); for item in path { - let Some(Value::AttrSet(attrs)) = - data.get(item.unwrap().coerce_to_string()?.as_ref().unwrap_string()) + let Some(Value::AttrSet(attrs)) = data.get(&item.unwrap().force_string_no_ctx()?) else { return Ok(Value::Bool(false)); }; data = attrs.as_inner(); } Ok(Value::Bool( - data.get(last.unwrap().coerce_to_string()?.as_ref().unwrap_string()) - .is_some(), + data.get(&last.unwrap().force_string_no_ctx()?).is_some(), )) } diff --git a/evaluator/nixjit_eval/src/value/closure.rs b/evaluator/nixjit_eval/src/value/closure.rs new file mode 100644 index 0000000..bafbdd7 --- /dev/null +++ b/evaluator/nixjit_eval/src/value/closure.rs @@ -0,0 +1,25 @@ +//! Defines the runtime representation of a partially applied function. +use std::rc::Rc; + +use derive_more::Constructor; + +use nixjit_error::Result; +use nixjit_ir::ExprId; + +use super::Value; +use crate::EvalContext; + +pub type StackFrame = smallvec::SmallVec<[Value; 5]>; + +#[derive(Debug, Clone, Constructor)] +pub struct Closure { + pub body: ExprId, + pub frame: StackFrame, +} + +impl Closure { + pub fn call(self: Rc, arg: Option, ctx: &mut Ctx) -> Result { + let Self { body: func, frame } = Rc::unwrap_or_clone(self); + ctx.call(func, arg, frame) + } +} diff --git a/evaluator/nixjit_eval/src/value/func.rs b/evaluator/nixjit_eval/src/value/func.rs deleted file mode 100644 index 3b2e051..0000000 --- a/evaluator/nixjit_eval/src/value/func.rs +++ /dev/null @@ -1,69 +0,0 @@ -//! Defines the runtime representation of a partially applied function. -use std::rc::Rc; - -use derive_more::Constructor; - -use nixjit_error::Result; -use nixjit_ir::ExprId; - -use super::Value; -use crate::EvalContext; - -/// Represents a partially applied user-defined function. -/// -/// This struct captures the state of a function that has received some, but not -/// all, of its expected arguments. -#[derive(Debug, Clone, Constructor)] -pub struct FuncApp { - /// The expression ID of the function body to be executed. - pub body: ExprId, - /// The arguments that have already been applied to the function. - pub args: Vec, - /// The lexical scope (stack frame) captured at the time of the initial call. - pub frame: Vec, -} - -impl FuncApp { - /// Applies more arguments to a partially applied function. - /// - /// It takes an iterator of new arguments, appends them to the existing ones, - /// and re-evaluates the function body within its captured environment. - pub fn call( - self: &mut Rc, - mut iter: impl Iterator> + ExactSizeIterator, - ctx: &mut Ctx, - ) -> Result { - let FuncApp { - body: expr, - args, - frame, - } = Rc::make_mut(self); - let mut val; - let mut args = core::mem::take(args); - args.push(iter.next().unwrap()?); - let (ret_args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(*expr)); - args = ret_args; - val = ret?; - loop { - if !matches!(val, Value::Func(_) | Value::FuncApp(_)) { - break; - } - let Some(arg) = iter.next() else { - break; - }; - args.push(arg?); - if let Value::Func(expr) = val { - let (ret_args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(expr)); - args = ret_args; - val = ret?; - } else if let Value::FuncApp(func) = val { - let mut func = Rc::unwrap_or_clone(func); - func.args.push(args.pop().unwrap()); - let (ret_args, ret) = ctx.with_args_env(func.args, |ctx| ctx.eval(func.body)); - args = ret_args; - val = ret?; - } - } - Ok(val) - } -} diff --git a/evaluator/nixjit_eval/src/value/list.rs b/evaluator/nixjit_eval/src/value/list.rs index 88e7316..f69aa16 100644 --- a/evaluator/nixjit_eval/src/value/list.rs +++ b/evaluator/nixjit_eval/src/value/list.rs @@ -3,11 +3,12 @@ use std::fmt::Debug; use std::ops::Deref; -use hashbrown::HashSet; - +use nixjit_error::{Error, Result}; use nixjit_value::List as PubList; use nixjit_value::Value as PubValue; +use crate::EvalContext; + use super::Value; /// A wrapper around a `Vec` representing a Nix list. @@ -65,6 +66,21 @@ impl List { } } + pub fn elem_at(&self, idx: usize, ctx: &mut impl EvalContext) -> Result { + self.data + .get(idx) + .map(|elem| match elem { + &Value::Thunk(id) => ctx.eval(id), + val => Ok(val.clone()), + }) + .ok_or_else(|| { + Error::eval_error(format!( + "'builtins.elemAt' called with index {idx} on a list of size {}", + self.len() + )) + })? + } + /// Consumes the `List` and returns the inner `Vec`. pub fn into_inner(self) -> Vec { self.data diff --git a/evaluator/nixjit_eval/src/value/mod.rs b/evaluator/nixjit_eval/src/value/mod.rs index e382aa3..968f1a3 100644 --- a/evaluator/nixjit_eval/src/value/mod.rs +++ b/evaluator/nixjit_eval/src/value/mod.rs @@ -4,33 +4,30 @@ //! interpreter's runtime. It represents all possible data types that can exist //! during the evaluation of a Nix expression. This is an internal, mutable //! representation, distinct from the public-facing `nixjit_value::Value`. -//! -//! The module also provides `ValueAsRef` for non-owning references and -//! implementations for various operations like arithmetic, comparison, and -//! function calls. use std::fmt::Debug; -use std::hash::Hash; use std::rc::Rc; -use derive_more::TryUnwrap; -use derive_more::{IsVariant, Unwrap}; -use nixjit_ir::{ExprId, PrimOp}; +use derive_more::{IsVariant, TryInto, TryUnwrap, Unwrap}; +use nixjit_ir::ExprId; +use nixjit_ir::PrimOpId; use nixjit_error::{Error, Result}; use nixjit_value::Const; use nixjit_value::Value as PubValue; +use replace_with::replace_with_and_return; +use smallvec::smallvec; use crate::EvalContext; mod attrset; -mod func; +mod closure; mod list; mod primop; mod string; -pub use attrset::*; -pub use func::*; +pub use attrset::AttrSet; +pub use closure::*; pub use list::List; pub use primop::*; @@ -40,20 +37,22 @@ pub use primop::*; /// JIT-compiled code. It uses `#[repr(C, u64)]` to ensure a predictable layout, /// with the discriminant serving as a type tag. #[repr(C, u64)] -#[derive(IsVariant, Clone, TryUnwrap, Unwrap)] +#[derive(IsVariant, Clone, Unwrap, TryUnwrap, TryInto)] +#[try_into(owned, ref, ref_mut)] pub enum Value { - Int(i64), - Float(f64), - Bool(bool), - String(String), - Null, - Thunk(ExprId), - AttrSet(Rc), - List(Rc), - PrimOp(PrimOp), - PrimOpApp(Rc), - Func(ExprId), - FuncApp(Rc), + Int(i64) = Self::INT, + Float(f64) = Self::FLOAT, + Bool(bool) = Self::BOOL, + String(String) = Self::STRING, + Null = Self::NULL, + Thunk(ExprId) = Self::THUNK, + ClosureThunk(Rc) = Self::CLOSURE_THUNK, + AttrSet(Rc) = Self::ATTRSET, + List(Rc) = Self::LIST, + PrimOp(PrimOpId) = Self::PRIMOP, + PrimOpApp(Rc) = Self::PRIMOP_APP, + Closure(Rc) = Self::CLOSURE, + Blackhole, } impl Debug for Value { @@ -64,26 +63,15 @@ impl Debug for Value { Float(x) => write!(f, "{x}"), Bool(x) => write!(f, "{x}"), Null => write!(f, "null"), - String(x) => write!(f, "{x}"), + String(x) => write!(f, "{x:?}"), AttrSet(x) => write!(f, "{x:?}"), List(x) => write!(f, "{x:?}"), Thunk(thunk) => write!(f, ""), - Func(func) => write!(f, ""), - FuncApp(func) => write!(f, "", func.body), - PrimOp(primop) => write!(f, "", primop.name), - PrimOpApp(primop) => write!(f, "", primop.name), - } - } -} - -impl Hash for Value { - fn hash(&self, state: &mut H) { - use Value::*; - std::mem::discriminant(self).hash(state); - match self { - AttrSet(x) => Rc::as_ptr(x).hash(state), - List(x) => x.as_ptr().hash(state), - _ => 0.hash(state), + ClosureThunk(_) => write!(f, ""), + Closure(func) => write!(f, "", func.body), + PrimOp(_) => write!(f, ""), + PrimOpApp(_) => write!(f, ""), + Blackhole => write!(f, ""), } } } @@ -95,13 +83,12 @@ impl Value { pub const STRING: u64 = 3; pub const NULL: u64 = 4; pub const THUNK: u64 = 5; - pub const ATTRSET: u64 = 6; - pub const LIST: u64 = 7; - pub const CATCHABLE: u64 = 8; + pub const CLOSURE_THUNK: u64 = 6; + pub const ATTRSET: u64 = 7; + pub const LIST: u64 = 8; pub const PRIMOP: u64 = 9; - pub const PARTIAL_PRIMOP: u64 = 10; - pub const FUNC: u64 = 11; - pub const PARTIAL_FUNC: u64 = 12; + pub const PRIMOP_APP: u64 = 10; + pub const CLOSURE: u64 = 11; fn eq_impl(&self, other: &Self) -> bool { use Value::*; @@ -120,59 +107,6 @@ impl Value { } } -impl PartialEq for Value { - fn eq(&self, other: &Self) -> bool { - use Value::*; - match (self, other) { - (AttrSet(a), AttrSet(b)) => Rc::as_ptr(a).eq(&Rc::as_ptr(b)), - (List(a), List(b)) => a.as_ptr().eq(&b.as_ptr()), - _ => false, - } - } -} - -impl Eq for Value {} - -/// A non-owning reference to a `Value`. -/// -/// This is used to avoid unnecessary cloning when inspecting values. -#[derive(IsVariant, TryUnwrap, Unwrap, Clone)] -pub enum ValueAsRef<'v> { - Int(i64), - Float(f64), - Bool(bool), - String(&'v String), - Null, - Thunk(&'v ExprId), - AttrSet(&'v AttrSet), - List(&'v List), - PrimOp(&'v PrimOp), - PartialPrimOp(&'v PrimOpApp), - Func(&'v ExprId), - PartialFunc(&'v FuncApp), -} - -impl Value { - /// Returns a `ValueAsRef`, providing a non-owning view of the value. - pub fn as_ref(&self) -> ValueAsRef<'_> { - use Value::*; - use ValueAsRef as R; - match self { - Int(x) => R::Int(*x), - Float(x) => R::Float(*x), - Bool(x) => R::Bool(*x), - String(x) => R::String(x), - Null => R::Null, - Thunk(x) => R::Thunk(x), - AttrSet(x) => R::AttrSet(x), - List(x) => R::List(x), - PrimOp(x) => R::PrimOp(x), - PrimOpApp(x) => R::PartialPrimOp(x), - Func(x) => R::Func(x), - FuncApp(x) => R::PartialFunc(x), - } - } -} impl Value { /// Returns the name of the value's type. pub fn typename(&self) -> &'static str { @@ -184,82 +118,72 @@ impl Value { String(_) => "string", Null => "null", Thunk(_) => "thunk", + ClosureThunk(_) => "thunk", AttrSet(_) => "set", List(_) => "list", PrimOp(_) => "lambda", PrimOpApp(_) => "lambda", - Func(_) => "lambda", - FuncApp(..) => "lambda", + Closure(..) => "lambda", + Blackhole => unreachable!(), } } + pub fn force(&mut self, ctx: &mut impl EvalContext) -> Result<()> { + let map = |result| match result { + Ok(ok) => (Ok(()), ok), + Err(err) => (Err(err), Value::Null), + }; + replace_with_and_return( + self, + || Value::Null, + |val| match val { + Value::Thunk(id) => map(ctx.eval(id)), + Value::ClosureThunk(thunk) => map(thunk.call(None, ctx)), + val => (Ok(()), val), + }, + ) + } + /// Performs a function call on the `Value`. /// /// This method handles calling functions, primops, and their partially /// applied variants. It manages argument application and delegates to the /// `EvalContext` for the actual execution. - pub fn call( - &mut self, - mut iter: impl Iterator> + ExactSizeIterator, - ctx: &mut Ctx, - ) -> Result<()> { + pub fn call(&mut self, arg: Value, ctx: &mut Ctx) -> Result<()> { use Value::*; - *self = match self { - &mut PrimOp(primop) => { - if iter.len() > primop.arity { - let mut args = iter.collect::>>()?; - let leftover = args.split_off(primop.arity); - let mut ret = ctx.call_primop(primop.id, args)?; - ret.call(leftover.into_iter().map(Ok), ctx)?; - Ok(ret) - } else if primop.arity > iter.len() { - Ok(Value::PrimOpApp(Rc::new(self::PrimOpApp::new( - primop.name, - primop.arity - iter.len(), - primop.id, - iter.collect::>()?, - )))) - } else { - ctx.call_primop(primop.id, iter.collect::>()?) - } - } - &mut Func(expr) => { - let mut val; - let mut args = Vec::with_capacity(iter.len()); - args.push(iter.next().unwrap()?); - let (ret_args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(expr)); - args = ret_args; - val = ret?; - loop { - if !matches!(val, Value::Func(_) | Value::FuncApp(_)) { - break; - } - let Some(arg) = iter.next() else { - break; - }; - args.push(arg?); - if let Value::Func(expr) = val { - let (ret_args, ret) = ctx.with_args_env(args, |ctx| ctx.eval(expr)); - args = ret_args; - val = ret?; - } else if let Value::FuncApp(func) = val { - let mut func = Rc::unwrap_or_clone(func); - func.args.push(args.pop().unwrap()); - let (ret_args, ret) = - ctx.with_args_env(func.args, |ctx| ctx.eval(func.body)); - args = ret_args; - val = ret?; + let map = |result| match result { + Ok(ok) => (Ok(()), ok), + Err(err) => (Err(err), Null), + }; + replace_with_and_return( + self, + || Null, + |func| match func { + PrimOp(id) => { + let arity = ctx.get_primop_arity(id); + if arity == 1 { + map(ctx.call_primop(id, smallvec![arg])) + } else { + ( + Ok(()), + Value::PrimOpApp(Rc::new(self::PrimOpApp::new( + arity - 1, + id, + smallvec![arg], + ))), + ) } } - Ok(val) - } - PrimOpApp(func) => func.call(iter.collect::>()?, ctx), - FuncApp(func) => func.call(iter, ctx), - _ => Err(Error::EvalError( - "attempt to call something which is not a function but ...".to_string(), - )), - }?; - Ok(()) + PrimOpApp(func) => map(func.call(arg, ctx)), + Closure(func) => map(func.call(Some(arg), ctx)), + _ => ( + Err(Error::eval_error( + "attempt to call something which is not a function but ...".to_string(), + )), + Null, + ), + }, + ) } pub fn not(&mut self) -> Result<()> { @@ -269,7 +193,10 @@ impl Value { *self = Bool(!bool); Ok(()) } - _ => Err(Error::EvalError(format!("expected a boolean but found {}", self.typename()))), + _ => Err(Error::eval_error(format!( + "expected a boolean but found {}", + self.typename() + ))), } } @@ -280,7 +207,10 @@ impl Value { *self = Bool(a && b); Ok(()) } - _ => Err(Error::EvalError(format!("expected a boolean but found {}", self.typename()))), + _ => Err(Error::eval_error(format!( + "expected a boolean but found {}", + self.typename() + ))), } } @@ -291,7 +221,10 @@ impl Value { *self = Bool(a || b); Ok(()) } - _ => Err(Error::EvalError(format!("expected a boolean but found {}", self.typename()))), + _ => Err(Error::eval_error(format!( + "expected a boolean but found {}", + self.typename() + ))), } } @@ -310,7 +243,13 @@ impl Value { (Float(a), Int(b)) => *a < b as f64, (Float(a), Float(b)) => *a < b, (String(a), String(b)) => a.as_str() < b.as_str(), - (a, b) => return Err(Error::EvalError(format!("cannot compare {} with {}", a.typename(), b.typename()))), + (a, b) => { + return Err(Error::eval_error(format!( + "cannot compare {} with {}", + a.typename(), + b.typename() + ))); + } }); Ok(()) } @@ -320,7 +259,12 @@ impl Value { *self = match &*self { Int(int) => Int(-int), Float(float) => Float(-float), - _ => return Err(Error::EvalError(format!("expected an integer but found {}", self.typename()))) + _ => { + return Err(Error::eval_error(format!( + "expected an integer but found {}", + self.typename() + ))); + } }; Ok(()) } @@ -336,7 +280,13 @@ impl Value { (&mut Int(a), Float(b)) => Float(*a as f64 + b), (&mut Float(a), Int(b)) => Float(*a + b as f64), (&mut Float(a), Float(b)) => Float(*a + b), - (a, b) => return Err(Error::EvalError(format!("cannot add {} to {}", a.typename(), b.typename()))) + (a, b) => { + return Err(Error::eval_error(format!( + "cannot add {} to {}", + a.typename(), + b.typename() + ))); + } }; Ok(()) } @@ -348,7 +298,13 @@ impl Value { (Int(a), Float(b)) => Float(*a as f64 * b), (Float(a), Int(b)) => Float(a * b as f64), (Float(a), Float(b)) => Float(a * b), - (a, b) => return Err(Error::EvalError(format!("cannot multiply {} with {}", a.typename(), b.typename()))) + (a, b) => { + return Err(Error::eval_error(format!( + "cannot multiply {} with {}", + a.typename(), + b.typename() + ))); + } }; Ok(()) } @@ -356,15 +312,21 @@ impl Value { pub fn div(&mut self, other: Self) -> Result<()> { use Value::*; *self = match (&*self, other) { - (_, Int(0)) => return Err(Error::EvalError("division by zero".to_string())), + (_, Int(0)) => return Err(Error::eval_error("division by zero".to_string())), (_, Float(0.)) => { - return Err(Error::EvalError("division by zero".to_string())); + return Err(Error::eval_error("division by zero".to_string())); } (Int(a), Int(b)) => Int(a / b), (Int(a), Float(b)) => Float(*a as f64 / b), (Float(a), Int(b)) => Float(a / b as f64), (Float(a), Float(b)) => Float(a / b), - (a, b) => return Err(Error::EvalError(format!("cannot divide {} with {}", a.typename(), b.typename()))) + (a, b) => { + return Err(Error::eval_error(format!( + "cannot divide {} with {}", + a.typename(), + b.typename() + ))); + } }; Ok(()) } @@ -376,8 +338,15 @@ impl Value { Rc::make_mut(a).concat(&b); Ok(()) } - (List(_), b) => Err(Error::EvalError(format!("expected a list but found {}", b.typename()))), - (a, _) => Err(Error::EvalError(format!("expected a list but found {}", a.typename()))), } + (List(_), b) => Err(Error::eval_error(format!( + "expected a list but found {}", + b.typename() + ))), + (a, _) => Err(Error::eval_error(format!( + "expected a list but found {}", + a.typename() + ))), + } } pub fn update(mut self: &mut Self, other: Self) -> Result<()> { @@ -387,20 +356,23 @@ impl Value { Rc::make_mut(a).update(&b); Ok(()) } - (AttrSet(_), other) => Err(Error::EvalError(format!("expected a set but found {}", other.typename()))), - _ => Err(Error::EvalError(format!("expected a set but found {}", self.typename()))), + (AttrSet(_), other) => Err(Error::eval_error(format!( + "expected a set but found {}", + other.typename() + ))), + _ => Err(Error::eval_error(format!( + "expected a set but found {}", + self.typename() + ))), } } - pub fn select( - &mut self, - path: impl DoubleEndedIterator>, - ) -> Result<()> { + pub fn select(&mut self, name: &str, ctx: &mut impl EvalContext) -> Result<()> { use Value::*; let val = match self { - AttrSet(attrs) => attrs.select(path), - _ => Err(Error::EvalError(format!( - "can not select from {:?}", + AttrSet(attrs) => attrs.select(name, ctx), + _ => Err(Error::eval_error(format!( + "expected a set but found {}", self.typename() ))), }?; @@ -408,17 +380,18 @@ impl Value { Ok(()) } - pub fn select_with_default( + pub fn select_or( &mut self, - path: impl DoubleEndedIterator>, - default: Self, + name: &str, + default: ExprId, + ctx: &mut Ctx, ) -> Result<()> { use Value::*; let val = match self { - AttrSet(attrs) => attrs.select(path).unwrap_or(default), + AttrSet(attrs) => attrs.select_or(name, default, ctx)?, _ => { - return Err(Error::EvalError(format!( - "can not select from {:?}", + return Err(Error::eval_error(format!( + "expected a set but found {}", self.typename() ))); } @@ -427,10 +400,7 @@ impl Value { Ok(()) } - pub fn has_attr( - &mut self, - path: impl DoubleEndedIterator>, - ) -> Result<()> { + pub fn has_attr(&mut self, path: impl DoubleEndedIterator>) -> Result<()> { use Value::*; if let AttrSet(attrs) = self { let val = attrs.has_attr(path)?; @@ -441,40 +411,36 @@ impl Value { Ok(()) } - pub fn coerce_to_string(&mut self) -> Result<&mut Self> { + pub fn force_string_no_ctx(self) -> Result { use Value::*; - if let String(_) = self { - Ok(self) + if let String(string) = self { + Ok(string) } else { - Err(Error::EvalError(format!("cannot coerce {} to string", self.typename()))) + Err(Error::eval_error(format!( + "cannot coerce {} to string", + self.typename() + ))) } } /// Converts the internal `Value` to its public-facing, serializable /// representation from the `nixjit_value` crate. - /// - /// The `seen` set is used to detect and handle cycles in data structures - /// like attribute sets and lists, replacing subsequent encounters with - /// `PubValue::Repeated`. pub fn to_public(self) -> PubValue { use Value::*; match self { - AttrSet(attrs) => { - Rc::unwrap_or_clone(attrs).to_public() - } - List(list) => { - Rc::unwrap_or_clone(list.clone()).to_public() - } + AttrSet(attrs) => Rc::unwrap_or_clone(attrs).to_public(), + List(list) => Rc::unwrap_or_clone(list.clone()).to_public(), Int(x) => PubValue::Const(Const::Int(x)), Float(x) => PubValue::Const(Const::Float(x)), Bool(x) => PubValue::Const(Const::Bool(x)), String(x) => PubValue::String(x), Null => PubValue::Const(Const::Null), Thunk(_) => PubValue::Thunk, - PrimOp(primop) => PubValue::PrimOp(primop.name), - PrimOpApp(primop) => PubValue::PrimOpApp(primop.name), - Func(_) => PubValue::Func, - FuncApp(..) => PubValue::Func, + ClosureThunk(_) => PubValue::Thunk, + PrimOp(_) => PubValue::PrimOp, + PrimOpApp(_) => PubValue::PrimOpApp, + Closure(..) => PubValue::Func, + Blackhole => unreachable!(), } } } diff --git a/evaluator/nixjit_eval/src/value/primop.rs b/evaluator/nixjit_eval/src/value/primop.rs index d51edcd..2daaa6a 100644 --- a/evaluator/nixjit_eval/src/value/primop.rs +++ b/evaluator/nixjit_eval/src/value/primop.rs @@ -10,20 +10,20 @@ use nixjit_ir::PrimOpId; use super::Value; use crate::EvalContext; +pub type Args = smallvec::SmallVec<[Value; 2]>; + /// Represents a partially applied primitive operation (builtin function). /// /// This struct holds the state of a primop that has received some, but not /// all, of its required arguments. #[derive(Debug, Clone, Constructor)] pub struct PrimOpApp { - /// The name of the primitive operation. - pub name: &'static str, /// The number of remaining arguments the primop expects. arity: usize, /// The unique ID of the primop. id: PrimOpId, /// The arguments that have already been applied. - args: Vec, + args: Args, } impl PrimOpApp { @@ -32,27 +32,15 @@ impl PrimOpApp { /// If enough arguments are provided to satisfy the primop's arity, it is /// executed. Otherwise, it returns a new `PrimOpApp` with the combined /// arguments. - pub fn call( - self: &mut Rc, - mut args: Vec, - ctx: &mut impl EvalContext, - ) -> Result { - if self.arity < args.len() { - let leftover = args.split_off(self.arity); - for arg in self.args.iter().rev().cloned() { - args.insert(0, arg); - } - let mut ret = ctx.call_primop(self.id, args)?; - ret.call(leftover.into_iter().map(Ok), ctx)?; - return Ok(ret); - } - let self_mut = Rc::make_mut(self); - self_mut.arity -= args.len(); - self_mut.args.extend(args); - if self_mut.arity == 0 { - ctx.call_primop(self_mut.id, std::mem::take(&mut self_mut.args)) + pub fn call(self: Rc, arg: Value, ctx: &mut impl EvalContext) -> Result { + let mut primop = Rc::unwrap_or_clone(self); + if primop.arity == 1 { + primop.args.push(arg); + ctx.call_primop(primop.id, primop.args) } else { - Ok(Value::PrimOpApp(self.clone())) + primop.args.push(arg); + primop.arity -= 1; + Ok(Value::PrimOpApp(primop.into())) } } } diff --git a/evaluator/nixjit_hir/Cargo.toml b/evaluator/nixjit_hir/Cargo.toml index 980b8c7..524f10a 100644 --- a/evaluator/nixjit_hir/Cargo.toml +++ b/evaluator/nixjit_hir/Cargo.toml @@ -7,7 +7,6 @@ edition = "2024" [dependencies] derive_more = { version = "2.0", features = ["full"] } hashbrown = "0.15" -itertools = "0.14" rnix = "0.12" nixjit_error = { path = "../nixjit_error" } diff --git a/evaluator/nixjit_hir/src/downgrade.rs b/evaluator/nixjit_hir/src/downgrade.rs index 797761d..eafa89a 100644 --- a/evaluator/nixjit_hir/src/downgrade.rs +++ b/evaluator/nixjit_hir/src/downgrade.rs @@ -10,7 +10,6 @@ use rnix::ast::{self, Expr}; use nixjit_error::{Error, Result}; -use nixjit_ir as ir; use super::*; @@ -35,7 +34,7 @@ impl Downgrade for Expr { // Dispatch to the specific implementation for each expression type. Apply(apply) => apply.downgrade(ctx), Assert(assert) => assert.downgrade(ctx), - Error(error) => Err(self::Error::DowngradeError(error.to_string())), + Error(error) => Err(self::Error::downgrade_error(error.to_string())), IfElse(ifelse) => ifelse.downgrade(ctx), Select(select) => select.downgrade(ctx), Str(str) => str.downgrade(ctx), @@ -159,19 +158,24 @@ impl Downgrade for ast::Ident { impl Downgrade for ast::AttrSet { fn downgrade(self, ctx: &mut Ctx) -> Result { let rec = self.rec_token().is_some(); - let mut attrs = downgrade_attrs(self, ctx)?; - attrs.rec = rec; - Ok(ctx.new_expr(attrs.to_hir())) + let attrs = downgrade_attrs(self, ctx)?; + let bindings = attrs.stcs.clone(); + let body = ctx.new_expr(attrs.to_hir()); + if rec { + Ok(ctx.new_expr(Let { bindings, body }.to_hir())) + } else { + Ok(body) + } } } /// Downgrades a list. impl Downgrade for ast::List { fn downgrade(self, ctx: &mut Ctx) -> Result { - let mut items = Vec::with_capacity(self.items().size_hint().0); - for item in self.items() { - items.push(item.downgrade(ctx)?) - } + let items = self + .items() + .map(|item| maybe_thunk(item, ctx)) + .collect::>()?; Ok(ctx.new_expr(List { items }.to_hir())) } } @@ -229,9 +233,10 @@ impl Downgrade for ast::Select { /// The body of the `let` is accessed via `let.body`. impl Downgrade for ast::LegacyLet { fn downgrade(self, ctx: &mut Ctx) -> Result { - let mut attrs = downgrade_attrs(self, ctx)?; - attrs.rec = true; - let expr = ctx.new_expr(attrs.to_hir()); + let attrs = downgrade_attrs(self, ctx)?; + let bindings = attrs.stcs.clone(); + let body = ctx.new_expr(attrs.to_hir()); + let expr = ctx.new_expr(Let { bindings, body }.to_hir()); // The result of a `legacy let` is the `body` attribute of the resulting set. let attrpath = vec![Attr::Str("body".into())]; Ok(ctx.new_expr( @@ -303,12 +308,12 @@ impl Downgrade for ast::Lambda { // Desugar pattern matching in function arguments into a `let` expression. // For example, `({ a, b ? 2 }): a + b` is desugared into: // `arg: let a = arg.a; b = arg.b or 2; in a + b` + let arg = ctx.new_expr(Hir::Arg(Arg)); let mut bindings: HashMap<_, _> = formals .into_iter() .map(|(k, default)| { // For each formal parameter, create a `Select` expression to extract it from the argument set. // `Arg` represents the raw argument (the attribute set) passed to the function. - let arg = ctx.new_expr(Hir::Arg(Arg)); ( k.clone(), ctx.new_expr( @@ -324,10 +329,7 @@ impl Downgrade for ast::Lambda { .collect(); // If there's an alias (`... }@alias`), bind the alias name to the raw argument set. if let Some(alias) = alias { - bindings.insert( - alias.clone(), - ctx.new_expr(Var { sym: alias.clone() }.to_hir()), - ); + bindings.insert(alias.clone(), arg); } // Wrap the original function body in the new `let` expression. let let_ = Let { bindings, body }; @@ -346,21 +348,12 @@ impl Downgrade for ast::Lambda { } /// Downgrades a function application. -/// The `rnix` AST represents chained function calls as nested `Apply` nodes, -/// e.g., `f a b` is parsed as `(f a) b`. This implementation unnests these -/// calls into a single `Call` HIR node with a list of arguments. +/// In Nix, function application is left-associative, so `f a b` should be parsed as `((f a) b)`. +/// Each Apply node represents a single function call with one argument. impl Downgrade for ast::Apply { fn downgrade(self, ctx: &mut Ctx) -> Result { - let mut args = vec![self.argument().unwrap().downgrade(ctx)?]; - let mut func = self.lambda().unwrap(); - // Traverse the chain of nested `Apply` nodes to collect all arguments. - while let ast::Expr::Apply(call) = func { - func = call.lambda().unwrap(); - args.push(call.argument().unwrap().downgrade(ctx)?); - } - let func = func.downgrade(ctx)?; - // The arguments were collected in reverse order, so fix that. - args.reverse(); - Ok(ctx.new_expr(Call { func, args }.to_hir())) + let func = self.lambda().unwrap().downgrade(ctx)?; + let arg = maybe_thunk(self.argument().unwrap(), ctx)?; + Ok(ctx.new_expr(Call { func, arg }.to_hir())) } } diff --git a/evaluator/nixjit_hir/src/lib.rs b/evaluator/nixjit_hir/src/lib.rs index 586053a..5488181 100644 --- a/evaluator/nixjit_hir/src/lib.rs +++ b/evaluator/nixjit_hir/src/lib.rs @@ -37,11 +37,10 @@ pub trait DowngradeContext { /// Allocates a new HIR expression in the context and returns its ID. fn new_expr(&mut self, expr: Hir) -> ExprId; - /// Provides temporary access to an immutable expression for inspection or use. - fn with_expr(&self, id: ExprId, f: impl FnOnce(&Hir, &Self) -> T) -> T; - /// Provides temporary mutable access to an expression. fn with_expr_mut(&mut self, id: ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T; + + fn downgrade_root(self, expr: rnix::ast::Expr) -> Result; } // The `ir!` macro generates the `Hir` enum and related structs and traits. @@ -85,6 +84,7 @@ ir! { Let { pub bindings: HashMap, pub body: ExprId }, // Represents a function argument lookup within the body of a function. Arg, + Thunk(ExprId) } /// A placeholder struct for the `Arg` HIR variant. It signifies that at this point @@ -135,7 +135,7 @@ impl Attrs for AttrSet { .try_unwrap_attr_set() .map_err(|_| { // This path segment exists but is not an attrset, which is an error. - Error::DowngradeError(format!( + Error::downgrade_error(format!( "attribute '{}' already defined but is not an attribute set", format_symbol(ident) )) @@ -165,7 +165,7 @@ impl Attrs for AttrSet { match name { Attr::Str(ident) => { if self.stcs.insert(ident.clone(), value).is_some() { - return Err(Error::DowngradeError(format!( + return Err(Error::downgrade_error(format!( "attribute '{}' already defined", format_symbol(ident) ))); diff --git a/evaluator/nixjit_hir/src/utils.rs b/evaluator/nixjit_hir/src/utils.rs index 19fe234..a1e99a5 100644 --- a/evaluator/nixjit_hir/src/utils.rs +++ b/evaluator/nixjit_hir/src/utils.rs @@ -12,9 +12,49 @@ use rnix::ast; use nixjit_error::{Error, Result}; use nixjit_ir::{Attr, AttrSet, ConcatStrings, ExprId, Select, Str, Var}; +use crate::Hir; + use super::downgrade::Downgrade; use super::{Attrs, DowngradeContext, Param, ToHir}; +pub fn maybe_thunk(mut expr: ast::Expr, ctx: &mut impl DowngradeContext) -> Result { + use ast::Expr::*; + let expr = loop { + expr = match expr { + Paren(paren) => paren.expr().unwrap(), + Root(root) => root.expr().unwrap(), + expr => break expr, + } + }; + match expr { + Error(error) => return Err(self::Error::downgrade_error(error.to_string())), + Ident(ident) => return ident.downgrade(ctx), + Literal(lit) => return lit.downgrade(ctx), + Str(str) => return str.downgrade(ctx), + Path(path) => return path.downgrade(ctx), + + _ => (), + } + let id = match expr { + Apply(apply) => apply.downgrade(ctx), + Assert(assert) => assert.downgrade(ctx), + IfElse(ifelse) => ifelse.downgrade(ctx), + Select(select) => select.downgrade(ctx), + Lambda(lambda) => lambda.downgrade(ctx), + LegacyLet(let_) => let_.downgrade(ctx), + LetIn(letin) => letin.downgrade(ctx), + List(list) => list.downgrade(ctx), + BinOp(op) => op.downgrade(ctx), + AttrSet(attrs) => attrs.downgrade(ctx), + UnaryOp(op) => op.downgrade(ctx), + With(with) => with.downgrade(ctx), + HasAttr(has) => has.downgrade(ctx), + + _ => unreachable!(), + }?; + Ok(ctx.new_expr(Hir::Thunk(id))) +} + /// Downgrades a function parameter from the AST. pub fn downgrade_param(param: ast::Param, ctx: &mut impl DowngradeContext) -> Result { match param { @@ -63,7 +103,6 @@ pub fn downgrade_attrs( let mut attrs = AttrSet { stcs: HashMap::new(), dyns: Vec::new(), - rec: false, }; for entry in entries { @@ -87,7 +126,6 @@ pub fn downgrade_static_attrs( let mut attrs = AttrSet { stcs: HashMap::new(), dyns: Vec::new(), - rec: false, }; for entry in entries { @@ -121,7 +159,7 @@ pub fn downgrade_inherit( Attr::Str(ident) => ident, _ => { // `inherit` does not allow dynamic attributes. - return Err(Error::DowngradeError( + return Err(Error::downgrade_error( "dynamic attributes not allowed in inherit".to_string(), )); } @@ -140,11 +178,13 @@ pub fn downgrade_inherit( }, ); match stcs.entry(ident) { - Entry::Occupied(occupied) => return Err(Error::EvalError(format!( - "attribute '{}' already defined", - format_symbol(occupied.key()) - ))), - Entry::Vacant(vacant) => vacant.insert(ctx.new_expr(expr)) + Entry::Occupied(occupied) => { + return Err(Error::eval_error(format!( + "attribute '{}' already defined", + format_symbol(occupied.key()) + ))); + } + Entry::Vacant(vacant) => vacant.insert(ctx.new_expr(expr)), }; } Ok(()) @@ -205,7 +245,7 @@ pub fn downgrade_attrpathvalue( ctx: &mut impl DowngradeContext, ) -> Result<()> { let path = downgrade_attrpath(value.attrpath().unwrap(), ctx)?; - let value = value.value().unwrap().downgrade(ctx)?; + let value = maybe_thunk(value.value().unwrap(), ctx)?; attrs.insert(path, value, ctx) } @@ -218,7 +258,7 @@ pub fn downgrade_static_attrpathvalue( ) -> Result<()> { let path = downgrade_attrpath(value.attrpath().unwrap(), ctx)?; if path.iter().any(|attr| matches!(attr, Attr::Dynamic(_))) { - return Err(Error::DowngradeError( + return Err(Error::downgrade_error( "dynamic attributes not allowed in let bindings".to_string(), )); } diff --git a/evaluator/nixjit_ir/Cargo.toml b/evaluator/nixjit_ir/Cargo.toml index 0d3d66e..97f6151 100644 --- a/evaluator/nixjit_ir/Cargo.toml +++ b/evaluator/nixjit_ir/Cargo.toml @@ -9,5 +9,4 @@ derive_more = { version = "2.0", features = ["full"] } hashbrown = "0.15" rnix = "0.12" -nixjit_error = { path = "../nixjit_error" } nixjit_value = { path = "../nixjit_value" } diff --git a/evaluator/nixjit_ir/src/lib.rs b/evaluator/nixjit_ir/src/lib.rs index c785d61..1d9e068 100644 --- a/evaluator/nixjit_ir/src/lib.rs +++ b/evaluator/nixjit_ir/src/lib.rs @@ -3,7 +3,7 @@ //! The IR provides a simplified, language-agnostic representation of Nix expressions, //! serving as a bridge between the high-level representation (HIR) and the low-level //! representation (LIR). It defines the fundamental building blocks like expression IDs, -//! argument indexes, and structures for various expression types (e.g., binary operations, +//! argument indices, and structures for various expression types (e.g., binary operations, //! attribute sets, function calls). //! //! These structures are designed to be generic and reusable across different stages of @@ -18,9 +18,9 @@ use nixjit_value::Const as PubConst; /// A type-safe wrapper for an index into an expression table. /// -/// Using a newtype wrapper like this prevents accidentally mixing up different kinds of indices. +/// Using a newtype wrapper to prevent accidentally mixing up different kinds of indices. #[repr(transparent)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ExprId(usize); impl ExprId { @@ -38,7 +38,7 @@ impl ExprId { /// # Safety /// The caller must ensure that the provided index is valid for the expression table. #[inline(always)] - pub unsafe fn from(id: usize) -> Self { + pub unsafe fn from_raw(id: usize) -> Self { Self(id) } } @@ -63,17 +63,17 @@ impl PrimOpId { /// # Safety /// The caller must ensure that the provided index is valid. #[inline(always)] - pub unsafe fn from(id: usize) -> Self { + pub unsafe fn from_raw(id: usize) -> Self { Self(id) } } -/// A type-safe wrapper for an index into a function's argument list. +/// A type-safe wrapper for an index into a function's dependency stack. #[repr(transparent)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ArgIdx(usize); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct StackIdx(usize); -impl ArgIdx { +impl StackIdx { /// Returns the raw `usize` index. /// /// # Safety @@ -83,16 +83,19 @@ impl ArgIdx { self.0 } - /// Creates an `ArgIdx` from a raw `usize` index. + /// Creates an `StackIdx` from a raw `usize` index. /// /// # Safety /// The caller must ensure that the provided index is valid. #[inline(always)] - pub unsafe fn from(idx: usize) -> Self { + pub unsafe fn from_raw(idx: usize) -> Self { Self(idx) } } +#[derive(Clone, Copy, Debug)] +pub struct Arg; + /// Represents a Nix attribute set. #[derive(Debug, Default)] pub struct AttrSet { @@ -100,8 +103,6 @@ pub struct AttrSet { pub stcs: HashMap, /// Dynamically computed attributes, where both the key and value are expressions. pub dyns: Vec<(ExprId, ExprId)>, - /// `true` if this is a recursive attribute set (`rec { ... }`). - pub rec: bool, } /// Represents a key in an attribute path. @@ -265,16 +266,7 @@ pub struct Param { pub struct Call { /// The expression that evaluates to the function to be called. pub func: ExprId, - /// The list of arguments to pass to the function. - pub args: Vec, -} - -// Represents a primitive operation (builtin function) -#[derive(Debug, Clone, Copy)] -pub struct PrimOp { - pub name: &'static str, - pub id: PrimOpId, - pub arity: usize, + pub arg: ExprId, } /// Represents a `with` expression. diff --git a/evaluator/nixjit_jit/src/compile.rs b/evaluator/nixjit_jit/src/compile.rs index b21d6e2..202b424 100644 --- a/evaluator/nixjit_jit/src/compile.rs +++ b/evaluator/nixjit_jit/src/compile.rs @@ -19,22 +19,22 @@ pub trait JITCompile { /// /// # Arguments /// * `ctx` - The compilation context - /// * `engine` - The evaluation context value + /// * `rt_ctx` - The evaluation context value /// * `env` - The environment value /// /// # Returns /// A stack slot containing the compiled result - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot; + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot; } impl JITCompile for ExprId { - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { todo!() } } impl JITCompile for Lir { - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { todo!() } } @@ -43,10 +43,10 @@ impl JITCompile for AttrSet { /// Compiles an attribute set to Cranelift IR. /// /// This creates a new attribute set and compiles all static attributes into it. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { let attrs = ctx.create_attrs(); for (k, v) in self.stcs.iter() { - let v = v.compile(ctx, engine, env); + let v = v.compile(ctx, rt_ctx); ctx.push_attr(attrs, k, v); } ctx.finalize_attrs(attrs) @@ -57,10 +57,10 @@ impl JITCompile for List { /// Compiles a list to Cranelift IR. /// /// This creates a new list by compiling all items and storing them in an array. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { let array = ctx.alloc_array(self.items.len()); for (i, item) in self.items.iter().enumerate() { - let item = item.compile(ctx, engine, env); + let item = item.compile(ctx, rt_ctx); let tag = ctx.builder.ins().stack_load(types::I64, item, 0); let val0 = ctx.builder.ins().stack_load(types::I64, item, 8); let val1 = ctx.builder.ins().stack_load(types::I64, item, 16); @@ -83,7 +83,7 @@ impl JITCompile for List { } impl JITCompile for HasAttr { - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { todo!() } } @@ -94,10 +94,10 @@ impl JITCompile for BinOp { /// This implementation handles various binary operations like addition, subtraction, /// division, logical AND/OR, and equality checks. It generates code that checks /// the types of operands and performs the appropriate operation. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { use BinOpKind::*; - let lhs = self.lhs.compile(ctx, engine, env); - let rhs = self.rhs.compile(ctx, engine, env); + let lhs = self.lhs.compile(ctx, rt_ctx); + let rhs = self.rhs.compile(ctx, rt_ctx); let lhs_tag = ctx.get_tag(lhs); let rhs_tag = ctx.get_tag(rhs); let eq = ctx.builder.ins().icmp(IntCC::Equal, lhs_tag, rhs_tag); @@ -349,7 +349,7 @@ impl JITCompile for BinOp { } impl JITCompile for UnOp { - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { todo!() } } @@ -358,11 +358,11 @@ impl JITCompile for Attr { /// Compiles an attribute key to Cranelift IR. /// /// An attribute can be either a static string or a dynamic expression. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { use Attr::*; match self { Str(string) => ctx.create_string(string), - Dynamic(ir) => ir.compile(ctx, engine, env), + Dynamic(ir) => ir.compile(ctx, rt_ctx), } } } @@ -372,11 +372,11 @@ impl JITCompile for Select { /// /// This compiles the expression to select from, builds the attribute path, /// and calls the select helper function. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { - let val = self.expr.compile(ctx, engine, env); + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { + let val = self.expr.compile(ctx, rt_ctx); let attrpath = ctx.alloc_array(self.attrpath.len()); for (i, attr) in self.attrpath.iter().enumerate() { - let arg = attr.compile(ctx, engine, env); + let arg = attr.compile(ctx, rt_ctx); let tag = ctx.builder.ins().stack_load(types::I64, arg, 0); let val0 = ctx.builder.ins().stack_load(types::I64, arg, 8); let val1 = ctx.builder.ins().stack_load(types::I64, arg, 16); @@ -394,7 +394,7 @@ impl JITCompile for Select { .ins() .store(MemFlags::new(), val2, attrpath, i as i32 * 32 + 24); } - ctx.select(val, attrpath, self.attrpath.len(), engine, env); + ctx.select(val, attrpath, self.attrpath.len(), rt_ctx); val } } @@ -404,8 +404,8 @@ impl JITCompile for If { /// /// This generates code that evaluates the condition, checks that it's a boolean, /// and then jumps to the appropriate branch (true or false). - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { - let cond = self.cond.compile(ctx, engine, env); + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { + let cond = self.cond.compile(ctx, rt_ctx); let cond_type = ctx.builder.ins().stack_load(types::I64, cond, 0); let cond_value = ctx.builder.ins().stack_load(types::I64, cond, 8); @@ -430,7 +430,7 @@ impl JITCompile for If { .brif(cond_value, true_block, [], false_block, []); ctx.builder.switch_to_block(true_block); - let ret = self.consq.compile(ctx, engine, env); + let ret = self.consq.compile(ctx, rt_ctx); let tag = ctx.builder.ins().stack_load(types::I64, ret, 0); let val0 = ctx.builder.ins().stack_load(types::I64, ret, 8); let val1 = ctx.builder.ins().stack_load(types::I64, ret, 16); @@ -442,7 +442,7 @@ impl JITCompile for If { ctx.builder.ins().jump(exit_block, []); ctx.builder.switch_to_block(false_block); - let ret = self.alter.compile(ctx, engine, env); + let ret = self.alter.compile(ctx, rt_ctx); let tag = ctx.builder.ins().stack_load(types::I64, ret, 0); let val0 = ctx.builder.ins().stack_load(types::I64, ret, 8); let val1 = ctx.builder.ins().stack_load(types::I64, ret, 16); @@ -463,32 +463,10 @@ impl JITCompile for If { impl JITCompile for Call { /// Compiles a function call to Cranelift IR. - /// - /// This compiles the function expression and all arguments, builds an argument array, - /// and calls the call helper function. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { - let func = self.func.compile(ctx, engine, env); - let args = ctx.alloc_array(self.args.len()); - for (i, arg) in self.args.iter().enumerate() { - let arg = arg.compile(ctx, engine, env); - let tag = ctx.builder.ins().stack_load(types::I64, arg, 0); - let val0 = ctx.builder.ins().stack_load(types::I64, arg, 8); - let val1 = ctx.builder.ins().stack_load(types::I64, arg, 16); - let val2 = ctx.builder.ins().stack_load(types::I64, arg, 24); - ctx.builder - .ins() - .store(MemFlags::new(), tag, args, i as i32 * 32); - ctx.builder - .ins() - .store(MemFlags::new(), val0, args, i as i32 * 32 + 8); - ctx.builder - .ins() - .store(MemFlags::new(), val1, args, i as i32 * 32 + 16); - ctx.builder - .ins() - .store(MemFlags::new(), val2, args, i as i32 * 32 + 24); - } - ctx.call(func, args, self.args.len(), engine, env); + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { + let func = self.func.compile(ctx, rt_ctx); + let arg = self.arg.compile(ctx, rt_ctx); + ctx.call(func, arg, rt_ctx); func } } @@ -498,24 +476,24 @@ impl JITCompile for With { /// /// This enters a new `with` scope with the compiled namespace, compiles the body expression, /// and then exits the `with` scope. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { - let namespace = self.namespace.compile(ctx, engine, env); - ctx.enter_with(env, namespace); - let ret = self.expr.compile(ctx, engine, env); - ctx.exit_with(env); + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { + let namespace = self.namespace.compile(ctx, rt_ctx); + ctx.enter_with(rt_ctx, namespace); + let ret = self.expr.compile(ctx, rt_ctx); + ctx.exit_with(rt_ctx); ctx.free_slot(namespace); ret } } impl JITCompile for Assert { - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { todo!() } } impl JITCompile for ConcatStrings { - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { todo!() } } @@ -525,7 +503,7 @@ impl JITCompile for Const { /// /// This handles boolean, integer, float, and null constants by storing /// their values and type tags in a stack slot. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { use nixjit_value::Const::*; let slot = ctx.alloca(); match self.val { @@ -560,7 +538,7 @@ impl JITCompile for Str { /// Compiles a string literal to Cranelift IR. /// /// This creates a string value from the string literal. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { ctx.create_string(&self.val) } } @@ -569,13 +547,13 @@ impl JITCompile for Var { /// Compiles a variable lookup to Cranelift IR. /// /// This looks up a variable by its symbol in the current environment. - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { - ctx.lookup(env, &self.sym) + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { + ctx.lookup(rt_ctx, &self.sym) } } impl JITCompile for Path { - fn compile(&self, ctx: &mut Context, engine: ir::Value, env: ir::Value) -> StackSlot { + fn compile(&self, ctx: &mut Context, rt_ctx: ir::Value) -> StackSlot { todo!() } } diff --git a/evaluator/nixjit_jit/src/helpers.rs b/evaluator/nixjit_jit/src/helpers.rs index 9129f85..994a9a0 100644 --- a/evaluator/nixjit_jit/src/helpers.rs +++ b/evaluator/nixjit_jit/src/helpers.rs @@ -12,24 +12,18 @@ use std::ptr::NonNull; use hashbrown::HashMap; use nixjit_eval::{AttrSet, EvalContext, List, Value}; -use nixjit_ir::ArgIdx; +use nixjit_ir::ExprId; +use nixjit_ir::StackIdx; use super::JITContext; /// Helper function to call a function with arguments. -/// -/// This function is called from JIT-compiled code to perform function calls. -/// It takes a function value and an array of arguments, and executes the call. pub extern "C" fn helper_call( func: &mut Value, - args_ptr: *mut Value, - args_len: usize, + arg: NonNull, ctx: &mut Ctx, ) { - // TODO: Error Handling - let args = core::ptr::slice_from_raw_parts_mut(args_ptr, args_len); - let args = unsafe { Box::from_raw(args) }; - func.call(args.into_iter().map(Ok), ctx).unwrap(); + func.call(unsafe { arg.read() }, ctx).unwrap(); } /// Helper function to look up a value in the evaluation stack. @@ -37,21 +31,18 @@ pub extern "C" fn helper_call( /// This function is called from JIT-compiled code to access values in the evaluation stack. pub extern "C" fn helper_lookup_stack( ctx: &Ctx, - offset: usize, + idx: StackIdx, ret: &mut MaybeUninit, ) { - ret.write(ctx.lookup_stack(offset).clone()); + ret.write(ctx.lookup_stack(idx).clone()); } /// Helper function to look up a function argument. /// /// This function is called from JIT-compiled code to access function arguments. -pub extern "C" fn helper_lookup_arg( - ctx: &Ctx, - idx: ArgIdx, - ret: &mut MaybeUninit, -) { - ret.write(ctx.lookup_arg(idx).clone()); +pub extern "C" fn helper_lookup_arg(ctx: &mut Ctx, ret: &mut MaybeUninit) { + todo!() + // ret.write(ctx.lookup_arg().unwrap().clone()); } /// Helper function to look up a variable by name. @@ -84,14 +75,14 @@ pub extern "C" fn helper_select( val: &mut Value, path_ptr: *mut Value, path_len: usize, + ctx: &mut Ctx, ) { let path = core::ptr::slice_from_raw_parts_mut(path_ptr, path_len); let path = unsafe { Box::from_raw(path) }; - val.select(path.into_iter().map(|mut val| { - val.coerce_to_string().unwrap(); - Ok(val.unwrap_string()) - })) - .unwrap(); + for attr in path { + val.select(&attr.force_string_no_ctx().unwrap(), ctx) + .unwrap(); + } } /// Helper function to perform attribute selection with a default value. @@ -102,25 +93,22 @@ pub extern "C" fn helper_select_with_default( val: &mut Value, path_ptr: *mut Value, path_len: usize, - default: NonNull, + default: ExprId, + ctx: &mut Ctx, ) { let path = core::ptr::slice_from_raw_parts_mut(path_ptr, path_len); let path = unsafe { Box::from_raw(path) }; - val.select_with_default( - path.into_iter().map(|mut val| { - val.coerce_to_string().unwrap(); - Ok(val.unwrap_string()) - }), - unsafe { default.read() }, - ) - .unwrap(); + for attr in path { + val.select_or(&attr.force_string_no_ctx().unwrap(), default, ctx) + .unwrap(); + } } /// Helper function to check equality between two values. /// /// This function is called from JIT-compiled code to perform equality comparisons. -pub extern "C" fn helper_eq(lhs: &mut Value, rhs: &Value) { - lhs.eq(unsafe { core::ptr::read(rhs) }); +pub extern "C" fn helper_eq(lhs: &mut Value, rhs: NonNull) { + lhs.eq(unsafe { rhs.read() }); } /// Helper function to create a string value. diff --git a/evaluator/nixjit_jit/src/lib.rs b/evaluator/nixjit_jit/src/lib.rs index cd59316..7d452c8 100644 --- a/evaluator/nixjit_jit/src/lib.rs +++ b/evaluator/nixjit_jit/src/lib.rs @@ -51,6 +51,7 @@ type F = unsafe extern "C" fn(*const Ctx, *mut Value); /// This struct holds a function pointer to the compiled code and /// a set of strings that were used during compilation, which need /// to be kept alive for the function to work correctly. +#[derive(Debug)] pub struct JITFunc { func: F, strings: HashSet, @@ -209,7 +210,7 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { slot } - fn enter_with(&mut self, env: ir::Value, namespace: StackSlot) { + fn enter_with(&mut self, rt_ctx: ir::Value, namespace: StackSlot) { let ptr = self .builder .ins() @@ -218,15 +219,15 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { .compiler .module .declare_func_in_func(self.compiler.enter_with, self.builder.func); - self.builder.ins().call(enter_with, &[env, ptr]); + self.builder.ins().call(enter_with, &[rt_ctx, ptr]); } - fn exit_with(&mut self, env: ir::Value) { + fn exit_with(&mut self, rt_ctx: ir::Value) { let exit_with = self .compiler .module .declare_func_in_func(self.compiler.exit_with, self.builder.func); - self.builder.ins().call(exit_with, &[env]); + self.builder.ins().call(exit_with, &[rt_ctx]); } fn dbg(&mut self, slot: StackSlot) { @@ -241,18 +242,7 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { self.builder.ins().call(dbg, &[ptr]); } - fn call( - &mut self, - func: StackSlot, - args_ptr: ir::Value, - args_len: usize, - engine: ir::Value, - env: ir::Value, - ) { - let args_len = self - .builder - .ins() - .iconst(self.compiler.ptr_type, args_len as i64); + fn call(&mut self, func: StackSlot, arg: StackSlot, call_ctx: ir::Value) { let call = self .compiler .module @@ -261,12 +251,14 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { .builder .ins() .stack_addr(self.compiler.ptr_type, func, 0); - self.builder + let arg = self + .builder .ins() - .call(call, &[func, args_ptr, args_len, engine, env]); + .stack_addr(self.compiler.ptr_type, arg, 0); + self.builder.ins().call(call, &[func, arg, call_ctx]); } - fn lookup(&mut self, env: ir::Value, sym: &str) -> StackSlot { + fn lookup(&mut self, rt_ctx: ir::Value, sym: &str) -> StackSlot { let sym = self.strings.get_or_insert_with(sym, |_| sym.to_owned()); let ptr = self .builder @@ -285,11 +277,11 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); - self.builder.ins().call(lookup, &[env, ptr, len, ret]); + self.builder.ins().call(lookup, &[rt_ctx, ptr, len, ret]); slot } - fn lookup_stack(&mut self, env: ir::Value, idx: usize) -> StackSlot { + fn lookup_stack(&mut self, ctx: ir::Value, idx: usize) -> StackSlot { let slot = self.alloca(); let lookup_stack = self .compiler @@ -303,11 +295,11 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); - self.builder.ins().call(lookup_stack, &[env, idx, ptr]); + self.builder.ins().call(lookup_stack, &[ctx, idx, ptr]); slot } - fn lookup_arg(&mut self, env: ir::Value, idx: usize) -> StackSlot { + fn lookup_arg(&mut self, ctx: ir::Value, idx: usize) -> StackSlot { let slot = self.alloca(); let lookup_arg = self .compiler @@ -321,18 +313,11 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); - self.builder.ins().call(lookup_arg, &[env, idx, ptr]); + self.builder.ins().call(lookup_arg, &[ctx, idx, ptr]); slot } - fn select( - &mut self, - slot: StackSlot, - path_ptr: ir::Value, - path_len: usize, - engine: ir::Value, - env: ir::Value, - ) { + fn select(&mut self, slot: StackSlot, path_ptr: ir::Value, path_len: usize, ctx: ir::Value) { let select = self .compiler .module @@ -347,7 +332,7 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { .stack_addr(self.compiler.ptr_type, slot, 0); self.builder .ins() - .call(select, &[ptr, path_ptr, path_len, engine, env]); + .call(select, &[ptr, path_ptr, path_len, ctx]); } fn select_with_default( @@ -357,7 +342,7 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { path_len: usize, default: StackSlot, engine: ir::Value, - env: ir::Value, + rt_ctx: ir::Value, ) { let select_with_default = self .compiler @@ -377,7 +362,7 @@ impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { .stack_addr(self.compiler.ptr_type, default, 0); self.builder.ins().call( select_with_default, - &[ptr, path_ptr, path_len, default_ptr, engine, env], + &[ptr, path_ptr, path_len, default_ptr, engine, rt_ctx], ); } @@ -489,25 +474,22 @@ impl JITCompiler { let ptr_type = module.target_config().pointer_type(); let value_type = types::I128; - // fn(*const Context, *const Env, *mut Value) let mut func_sig = module.make_signature(); func_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, - }; 3], + }; 2], ); - // fn(func: &mut Value, args_ptr: *mut Value, args_len: usize, engine: &mut Context, env: - // &mut Env) let mut call_sig = module.make_signature(); call_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, - }; 5], + }; 3], ); let call = module .declare_function("helper_call", Linkage::Import, &call_sig) @@ -525,7 +507,6 @@ impl JITCompiler { .declare_function("helper_lookup_stack", Linkage::Import, &lookup_stack_sig) .unwrap(); - // fn(env: &Env, level: usize, ret: &mut MaybeUninit) let mut lookup_arg_sig = module.make_signature(); lookup_arg_sig.params.extend( [AbiParam { @@ -538,7 +519,6 @@ impl JITCompiler { .declare_function("helper_lookup_arg", Linkage::Import, &lookup_arg_sig) .unwrap(); - // fn(env: &Env, sym_ptr: *const u8, sym_len: usize, ret: &mut MaybeUninit) let mut lookup_sig = module.make_signature(); lookup_sig.params.extend( [AbiParam { @@ -551,20 +531,18 @@ impl JITCompiler { .declare_function("helper_lookup", Linkage::Import, &lookup_sig) .unwrap(); - // fn(val: &mut Value, path_ptr: *mut Value, path_len: usize, engine: &mut Context, env: &mut Env) let mut select_sig = module.make_signature(); select_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, - }; 5], + }; 4], ); let select = module .declare_function("helper_select", Linkage::Import, &select_sig) .unwrap(); - // fn(val: &mut Value, path_ptr: *mut Value, path_len: usize, default: NonNull, engine: &mut Context, env: &mut Env) let mut select_with_default_sig = module.make_signature(); select_with_default_sig.params.extend( [AbiParam { @@ -755,10 +733,9 @@ impl JITCompiler { ctx.builder.switch_to_block(entry); let params = ctx.builder.block_params(entry); - let engine = params[0]; - let env = params[1]; - let ret = params[2]; - let res = ir.compile(&mut ctx, engine, env); + let rt_ctx = params[0]; + let ret = params[1]; + let res = ir.compile(&mut ctx, rt_ctx); let tag = ctx.builder.ins().stack_load(types::I64, res, 0); let val0 = ctx.builder.ins().stack_load(types::I64, res, 8); diff --git a/evaluator/nixjit_lir/src/lib.rs b/evaluator/nixjit_lir/src/lib.rs index d818b12..0fe8279 100644 --- a/evaluator/nixjit_lir/src/lib.rs +++ b/evaluator/nixjit_lir/src/lib.rs @@ -13,6 +13,7 @@ use derive_more::{IsVariant, TryUnwrap, Unwrap}; +use hashbrown::HashMap; use nixjit_error::{Error, Result}; use nixjit_hir as hir; use nixjit_ir::*; @@ -38,19 +39,20 @@ ir! { Str, Var, Path, - PrimOp, + Arg, + PrimOp(PrimOpId), + StackRef(StackIdx), ExprRef(ExprId), FuncRef(ExprId), - ArgRef(ArgIdx), + Thunk(ExprId), } /// Represents the result of a variable lookup within the `ResolveContext`. #[derive(Debug)] pub enum LookupResult { + Stack(StackIdx), /// The variable was found and resolved to a specific expression. Expr(ExprId), - /// The variable was found and resolved to a function argument. - Arg(ArgIdx), /// The variable could not be resolved statically, likely due to a `with` expression. /// The lookup must be performed dynamically at evaluation time. Unknown, @@ -63,30 +65,36 @@ pub enum LookupResult { /// This trait abstracts the environment in which expressions are resolved, managing /// scopes, dependencies, and the resolution of expressions themselves. pub trait ResolveContext { - /// Records a dependency of one expression on another. - fn new_dep(&mut self, expr: ExprId, dep: ExprId); - /// Creates a new function, associating a parameter specification with a body expression. fn new_func(&mut self, body: ExprId, param: Param); /// Triggers the resolution of a given expression. fn resolve(&mut self, expr: ExprId) -> Result<()>; + fn resolve_root(self, expr: ExprId) -> Result<()>; + /// Looks up a variable by name in the current scope. - fn lookup(&self, name: &str) -> LookupResult; + fn lookup(&mut self, name: &str) -> LookupResult; + + fn lookup_arg(&mut self) -> StackIdx; /// Enters a `with` scope for the duration of a closure. fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T); /// Enters a `let` scope with a given set of bindings for the duration of a closure. - fn with_let_env<'a, T>( + fn with_let_env( &mut self, - bindings: impl Iterator, + bindings: HashMap, f: impl FnOnce(&mut Self) -> T, ) -> T; /// Enters a function parameter scope for the duration of a closure. - fn with_param_env(&mut self, ident: Option, f: impl FnOnce(&mut Self) -> T) -> T; + fn with_param_env( + &mut self, + func: ExprId, + ident: Option, + f: impl FnOnce(&mut Self) -> T, + ) -> T; } /// A trait for converting (resolving) an HIR node into an LIR expression. @@ -117,10 +125,11 @@ impl Resolve for hir::Hir { Var(x) => x.resolve(ctx), Path(x) => x.resolve(ctx), Let(x) => x.resolve(ctx), - // The `Arg` in HIR is a placeholder. During resolution, it's replaced by - // a reference to the *current* function's argument. We assume index 0 - // here, as the context manages the actual argument index. - Arg(_) => unsafe { Ok(Lir::ArgRef(ArgIdx::from(0))) }, + Thunk(x) => { + ctx.resolve(x)?; + Ok(Lir::Thunk(x)) + } + Arg(_) => Ok(Lir::StackRef(ctx.lookup_arg())) } } } @@ -128,28 +137,14 @@ impl Resolve for hir::Hir { /// Resolves an `AttrSet` by resolving all key and value expressions. impl Resolve for AttrSet { fn resolve(self, ctx: &mut Ctx) -> Result { - if self.rec { - ctx.with_let_env(self.stcs.iter(), |ctx| { - for &id in self.stcs.values() { - ctx.resolve(id)?; - } - for &(k, v) in self.dyns.iter() { - ctx.resolve(k)?; - ctx.resolve(v)?; - } - Ok(()) - })?; - Ok(self.to_lir()) - } else { - for (_, &v) in self.stcs.iter() { - ctx.resolve(v)?; - } - for &(k, v) in self.dyns.iter() { - ctx.resolve(k)?; - ctx.resolve(v)?; - } - Ok(self.to_lir()) + for (_, &v) in self.stcs.iter() { + ctx.resolve(v)?; } + for &(k, v) in self.dyns.iter() { + ctx.resolve(k)?; + ctx.resolve(v)?; + } + Ok(self.to_lir()) } } @@ -224,7 +219,7 @@ impl Resolve for If { /// It then registers the function with the context. impl Resolve for Func { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.with_param_env(self.param.ident.clone(), |ctx| ctx.resolve(self.body))?; + ctx.with_param_env(self.body, self.param.ident.clone(), |ctx| ctx.resolve(self.body))?; ctx.new_func(self.body, self.param); Ok(Lir::FuncRef(self.body)) } @@ -234,9 +229,7 @@ impl Resolve for Func { impl Resolve for Call { fn resolve(self, ctx: &mut Ctx) -> Result { ctx.resolve(self.func)?; - for &arg in self.args.iter() { - ctx.resolve(arg)?; - } + ctx.resolve(self.arg)?; Ok(self.to_lir()) } } @@ -282,10 +275,10 @@ impl Resolve for Var { fn resolve(self, ctx: &mut Ctx) -> Result { use LookupResult::*; match ctx.lookup(&self.sym) { + Stack(idx) => Ok(Lir::StackRef(idx)), Expr(expr) => Ok(Lir::ExprRef(expr)), - Arg(arg) => Ok(Lir::ArgRef(arg)), Unknown => Ok(self.to_lir()), - NotFound => Err(Error::ResolutionError(format!( + NotFound => Err(Error::resolution_error(format!( "undefined variable '{}'", format_symbol(&self.sym) ))), @@ -305,7 +298,7 @@ impl Resolve for Path { /// the bindings and the body, and then returning a reference to the body. impl Resolve for hir::Let { fn resolve(self, ctx: &mut Ctx) -> Result { - ctx.with_let_env(self.bindings.iter(), |ctx| { + ctx.with_let_env(self.bindings.clone(), |ctx| { for &id in self.bindings.values() { ctx.resolve(id)?; } diff --git a/evaluator/nixjit_macros/src/builtins.rs b/evaluator/nixjit_macros/src/builtins.rs index 0d694af..3c1086b 100644 --- a/evaluator/nixjit_macros/src/builtins.rs +++ b/evaluator/nixjit_macros/src/builtins.rs @@ -39,7 +39,7 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { } }; - let mut pub_item_mod: Vec = Vec::new(); + let mut pub_item_mod = Vec::new(); let mut consts = Vec::new(); let mut global = Vec::new(); let mut scoped = Vec::new(); @@ -110,9 +110,9 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { /// Constant values available in the global scope. pub consts: [(&'static str, ::nixjit_value::Const); #mod_name::CONSTS_LEN], /// Global functions available in the global scope. - pub global: [(&'static str, usize, fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::GLOBAL_LEN], + pub global: [(&'static str, usize, fn(&mut Ctx, ::nixjit_eval::Args) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::GLOBAL_LEN], /// Scoped functions, typically available under the `builtins` attribute set. - pub scoped: [(&'static str, usize, fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::SCOPED_LEN], + pub scoped: [(&'static str, usize, fn(&mut Ctx, ::nixjit_eval::Args) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::SCOPED_LEN], } impl Builtins { @@ -169,7 +169,6 @@ fn generate_primop_wrapper( let arg_pats: Vec<_> = user_args.rev().collect(); let arg_count = arg_pats.len(); - // Generate code to unpack and convert arguments from the `Vec`. let arg_unpacks = arg_pats.iter().enumerate().map(|(i, arg)| { let arg_name = format_ident!("_arg{}", i, span = Span::call_site()); let arg_ty = match &arg { @@ -178,8 +177,8 @@ fn generate_primop_wrapper( }; quote! { - let #arg_name: #arg_ty = args.pop().ok_or_else(|| ::nixjit_error::Error::EvalError("Not enough arguments provided".to_string()))? - .try_into().map_err(|e| ::nixjit_error::Error::EvalError(format!("Argument type conversion failed: {}", e)))?; + let #arg_name: #arg_ty = args.pop().ok_or_else(|| ::nixjit_error::Error::eval_error("Not enough arguments provided".to_string()))? + .try_into().map_err(|e| ::nixjit_error::Error::eval_error(format!("Argument type conversion failed: {}", e)))?; } }); @@ -222,16 +221,16 @@ fn generate_primop_wrapper( }; let arity = arg_names.len(); - let fn_type = quote! { fn(&mut Ctx, Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value> }; + let fn_type = quote! { fn(&mut Ctx, ::nixjit_eval::Args) -> ::nixjit_error::Result<::nixjit_eval::Value> }; // The primop metadata tuple: (name, arity, wrapper_function_pointer) let primop = quote! { (#name_str, #arity, #mod_name::#wrapper_name as #fn_type) }; // The generated wrapper function. let wrapper = quote! { - pub fn #wrapper_name(ctx: &mut Ctx, mut args: Vec<::nixjit_eval::Value>) -> ::nixjit_error::Result<::nixjit_eval::Value> { + pub fn #wrapper_name(ctx: &mut Ctx, mut args: ::nixjit_eval::Args) -> ::nixjit_error::Result<::nixjit_eval::Value> { if args.len() != #arg_count { - return Err(::nixjit_error::Error::EvalError(format!("Function '{}' expects {} arguments, but received {}", #name_str, #arg_count, args.len()))); + return Err(::nixjit_error::Error::eval_error(format!("Function '{}' expects {} arguments, but received {}", #name_str, #arg_count, args.len()))); } #(#arg_unpacks)* diff --git a/evaluator/nixjit_value/src/lib.rs b/evaluator/nixjit_value/src/lib.rs index 913f721..7957093 100644 --- a/evaluator/nixjit_value/src/lib.rs +++ b/evaluator/nixjit_value/src/lib.rs @@ -191,9 +191,9 @@ pub enum Value { /// A function (lambda). Func, /// A primitive (built-in) operation. - PrimOp(&'static str), + PrimOp, /// A partially applied primitive operation. - PrimOpApp(&'static str), + PrimOpApp, /// A marker for a value that has been seen before during serialization, to break cycles. /// This is used to prevent infinite recursion when printing or serializing cyclic data structures. Repeated, @@ -209,8 +209,8 @@ impl Display for Value { List(x) => write!(f, "{x}"), Thunk => write!(f, ""), Func => write!(f, ""), - PrimOp(x) => write!(f, ""), - PrimOpApp(x) => write!(f, ""), + PrimOp => write!(f, ""), + PrimOpApp => write!(f, ""), Repeated => write!(f, ""), } } diff --git a/flake.lock b/flake.lock index 48a213a..6948733 100644 --- a/flake.lock +++ b/flake.lock @@ -8,11 +8,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1754376329, - "narHash": "sha256-Uz90O6qpmXQoNV57bf78yNd+nTxOoV5sjF1MibSdqWg=", + "lastModified": 1754894611, + "narHash": "sha256-TEyTVDhzFyfvPahhi1iAmkopt6fMiTlmn6f278lTdDs=", "owner": "nix-community", "repo": "fenix", - "rev": "ee7cae7d4cd68f7f2e78493a1f62212640db223c", + "rev": "a01861ebeb4d9c504845e7fb81509b82333ca0aa", "type": "github" }, "original": { @@ -23,11 +23,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1754214453, - "narHash": "sha256-Q/I2xJn/j1wpkGhWkQnm20nShYnG7TI99foDBpXm1SY=", + "lastModified": 1754725699, + "narHash": "sha256-iAcj9T/Y+3DBy2J0N+yF9XQQQ8IEb5swLFzs23CdP88=", "owner": "nixos", "repo": "nixpkgs", - "rev": "5b09dc45f24cf32316283e62aec81ffee3c3e376", + "rev": "85dbfc7aaf52ecb755f87e577ddbe6dbbdbc1054", "type": "github" }, "original": { @@ -46,11 +46,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1754320527, - "narHash": "sha256-5/EHPlvDFb1MPVcnUwpAcc9sHejhDhAj4uloUU4rthk=", + "lastModified": 1754834452, + "narHash": "sha256-otzv/l7c1rL+eH1cuJnUZVp4DR2dMdEIfhtLxTelIBY=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "978393bae86212f867e0c43872989e1658f7690f", + "rev": "4e147e787987fdb1baf081bd5c60bedfb0aabe16", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index f083931..e43ffbe 100644 --- a/flake.nix +++ b/flake.nix @@ -23,7 +23,7 @@ "rust-analyzer" "miri" ]) - gdb + lldb valgrind gemini-cli claude-code