diff --git a/TODO.md b/TODO.md index 234d694..dfcbefd 100644 --- a/TODO.md +++ b/TODO.md @@ -11,17 +11,145 @@ 阻拦变量解析的是 `with import` 而不是一般的 `with` 解析阶段可以解决除涉及非字符串常量拼接路径导入之外的问题 -# 工作流程 +静态分析图与求值期图分离,动态生成求值期图 + +依赖类型: +1. 强依赖 (x!) + - builtins.toString x => x! + - x + y => x! y! + - { x = 1; ${sym} = 2; } => sym! +2. 弱依赖 a.k.a. thunk (x?) + - builtins.seq x y => x! y? +3. 递归依赖 (x!!) + - builtins.deepSeq x y => x!! y? + +e.g. +- let a? = { inherit a?; }; in a! => a! + - a! => { a = } +- f: let x? = f! x?; in x! => f! x! +- let ret? = (self: { x? = 1; y? = self.x! + 1; }) ret?; in ret.y! => ret! x! y! + +工作流程: +1. string -> AST +2. AST -> HIR (alloc ExprId) +3. HIR -> LIR (resolve var, build graph) +4. LIR -> Value + ```mermaid flowchart TD - Start --> CtxNew[Context::new] - CtxNew --> CtxEval[Context::eval] - CtxEval --> DCtxNew[DowngradeCtx::new] - DCtxNew --> DCtxDowngradeRoot[DowngradeCtx::downgrade_root] - DCtxDowngradeRoot --> RCtxNew[ResolveCtx::new] - RCtxNew --> RCtxResolveRoot[ResolveCtx::resolve_root] - RCtxResolveRoot --> ECtxNew[EvalCtx::new] - ECtxNew --> ECtxEvalRoot[EvalCtx::eval_root] - ECtxEvalRoot --> ECtxEvalDeps[EvalCtx::eval_deps] - ECtxEvalDeps --> ECtxEval[EvalCtx::eval] + Start([Context::eval]) --> Eval[EvalContext::eval] + Eval --> AddNode[add node to graph] + AddNode --> CheckType{check expression type} + + CheckType -->|AttrSet| A_EvalKeys[eval keys] + A_EvalKeys --> A_Construct[construct attrset] + A_Construct --> A_Return[return attrset] + A_Return --> End + A_EvalKeys -.->|eval| Eval + + CheckType -->|List| L_Construct[construct list] + L_Construct --> L_Return[return list] + L_Return --> End + + CheckType -->|HasAttr| HA_ForceAttrSet[force attrset] + HA_ForceAttrSet --> HA_ForceAttrPath[force attrpath] + HA_ForceAttrPath --> HA_Return[return bool] + HA_Return --> End + HA_ForceAttrSet -.->|eval| Eval + HA_ForceAttrPath -.->|eval| Eval + + CheckType -->|BinOp| B_ForceLeft[force left operand] + B_ForceLeft --> B_ForceRight[force right operand] + B_ForceRight --> B_Apply[apply operator] + B_Apply --> B_Return[return value] + B_Return --> End + B_ForceLeft -.->|eval| Eval + B_ForceRight -.->|eval| Eval + + CheckType -->|UnOp| U_ForceOperand[force operand] + U_ForceOperand --> U_Apply[apply operator] + U_Apply --> U_Return[return value] + U_Return --> End + U_ForceOperand -.->|eval| Eval + + CheckType -->|Select| S_ForceAttrSet[force attrset] + S_ForceAttrSet --> S_ForceAttrPath[force attrpath] + S_ForceAttrPath --> S_CallEval[call eval on value] + S_CallEval --> S_Return[return value] + S_Return --> End + S_ForceAttrSet -.->|eval| Eval + S_ForceAttrPath -.->|eval| Eval + S_CallEval -.-> Eval + + CheckType -->|If| I_ForceCond[force condition] + I_ForceCond --> I_Cond{condition} + I_Cond -->|true| I_ForceConsq[force consequence] + I_Cond -->|false| I_ForceAlter[force alternative] + I_ForceConsq --> I_CallEvalConsq[call eval on consequence] + I_ForceAlter --> I_CallEvalAlter[call eval on alternative] + I_CallEvalConsq --> I_ReturnConsq[return value] + I_CallEvalAlter --> I_ReturnAlter[return value] + I_ReturnConsq --> End + I_ReturnAlter --> End + I_ForceCond -.->|eval| Eval + I_ForceConsq -.->|eval| Eval + I_ForceAlter -.->|eval| Eval + I_CallEvalConsq -.-> Eval + I_CallEvalAlter -.-> Eval + + CheckType -->|Call| C_RegArg[register argument] + C_RegArg --> C_ForceBody[force body] + C_ForceBody --> C_CallEval[call eval on body] + C_CallEval --> C_Return[return value] + C_Return --> End + C_ForceBody -.->|eval| Eval + C_CallEval -.-> Eval + + CheckType -->|With| W_EnterWith[enter with] + W_EnterWith --> W_ForceBody[force body] + W_ForceBody --> W_CallEval[call eval on body] + W_CallEval --> W_Return[return value] + W_Return --> End + W_ForceBody -.->|eval| Eval + W_CallEval -.-> Eval + + CheckType -->|Assert| As_ForceCond[force condition] + As_ForceCond --> As_Cond{condition} + As_Cond -->|true| As_ForceBody[force body] + As_Cond -->|false| As_Throw[throw Catchable] + As_ForceBody --> As_CallEval[call eval on body] + As_CallEval --> End + As_ForceCond -.->|eval| Eval + As_ForceBody -.->|eval| Eval + As_CallEval -.-> Eval + + CheckType -->|ConcatStrings| CS_ForceParts[force string parts] + CS_ForceParts --> CS_Construct[construct string] + CS_Construct --> End + CS_ForceParts -.->|eval| Eval + + CheckType -->|Const| Co_Return[return constant] + Co_Return --> End + + CheckType -->|Str| St_Return[return string] + St_Return --> End + + CheckType -->|Var| V_Lookup[lookup var] + V_Lookup --> V_Return[return value] + V_Return --> End + + CheckType -->|Arg| Ar_Lookup[lookup arg] + Ar_Lookup --> End + + CheckType -->|Func| F_Construct[construct function] + F_Construct --> End + + CheckType -->|StrictRef| SR_Eval[eval referenced expr] + SR_Eval --> Eval + + CheckType -->|LazyRef| LR_Resolve[resolve dynamic variable lookups of the referenced expr] + LR_Resolve --> LR_Contruct[construct lazy reference] + LR_Contruct --> End + + End([return result]) ``` diff --git a/evaluator/nixjit_context/src/downgrade.rs b/evaluator/nixjit_context/src/downgrade.rs index 3661546..fbae341 100644 --- a/evaluator/nixjit_context/src/downgrade.rs +++ b/evaluator/nixjit_context/src/downgrade.rs @@ -34,7 +34,7 @@ impl DowngradeCtx<'_, '_> { impl DowngradeContext for DowngradeCtx<'_, '_> { fn new_expr(&mut self, expr: Hir) -> ExprId { self.irs.push(expr.into()); - unsafe { ExprId::from_raw(self.ctx.lirs.len() + self.ctx.hirs.len() + self.irs.len() - 1) } + self.ctx.alloc_id() } fn with_expr_mut(&mut self, id: ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T { diff --git a/evaluator/nixjit_context/src/eval.rs b/evaluator/nixjit_context/src/eval.rs index 6b1ff93..3df93fd 100644 --- a/evaluator/nixjit_context/src/eval.rs +++ b/evaluator/nixjit_context/src/eval.rs @@ -1,3 +1,7 @@ +#[cfg(debug_assertions)] +use std::cell::OnceCell; +#[cfg(not(debug_assertions))] +use std::mem::MaybeUninit; use std::rc::Rc; use hashbrown::HashMap; @@ -8,12 +12,40 @@ use nixjit_eval::{Args, EvalContext, Evaluate, StackFrame, Value}; use nixjit_ir::ExprId; use nixjit_jit::JITContext; use nixjit_lir::Lir; -use petgraph::visit::{Topo, Walker}; +use petgraph::prelude::DiGraph; use super::Context; +#[derive(Default)] +struct ValueCache( + #[cfg(debug_assertions)] + Option, + #[cfg(not(debug_assertions))] + MaybeUninit +); + +impl ValueCache { + fn insert(&mut self, val: Value) { + #[cfg(debug_assertions)] + { + assert!(self.0.is_none()); + let _ = self.0.insert(val); + } + #[cfg(not(debug_assertions))] + self.0.write(val); + } +} + +impl Drop for ValueCache { + fn drop(&mut self) { + #[cfg(not(debug_assertions))] + self.0.assume_init_drop(); + } +} + pub struct EvalCtx<'ctx, 'bump> { ctx: &'ctx mut Context<'bump>, + graph: DiGraph, stack: Vec, with_scopes: Vec>>, } @@ -22,6 +54,7 @@ impl<'ctx, 'bump> EvalCtx<'ctx, 'bump> { pub fn new(ctx: &'ctx mut Context<'bump>) -> Self { Self { ctx, + graph: DiGraph::new(), stack: Vec::new(), with_scopes: Vec::new(), } diff --git a/evaluator/nixjit_context/src/lib.rs b/evaluator/nixjit_context/src/lib.rs index 91b2afb..ddaf4f5 100644 --- a/evaluator/nixjit_context/src/lib.rs +++ b/evaluator/nixjit_context/src/lib.rs @@ -1,4 +1,7 @@ -use std::{marker::PhantomPinned, ops::{Deref, DerefMut}}; +use std::cell::Cell; +use std::marker::PhantomPinned; +use std::ops::{Deref, DerefMut}; +use std::ptr::NonNull; use bumpalo::{Bump, boxed::Box}; use hashbrown::HashMap; @@ -10,12 +13,12 @@ use petgraph::{ use nixjit_builtins::{ Builtins, BuiltinsContext, - builtins::{CONSTS_LEN, GLOBAL_LEN, SCOPED_LEN}, + builtins::{GLOBAL_LEN, SCOPED_LEN}, }; use nixjit_error::{Error, Result}; use nixjit_eval::{Args, EvalContext, Value}; use nixjit_hir::{DowngradeContext, Hir}; -use nixjit_ir::{AttrSet, Const, ExprId, Param, PrimOpId, StackIdx}; +use nixjit_ir::{AttrSet, ExprId, Param, PrimOpId, StackIdx}; use nixjit_lir::{Lir, ResolveContext}; use crate::downgrade::DowngradeCtx; @@ -65,13 +68,14 @@ impl<'bump, T> Pin<'bump, T> { /// This struct orchestrates the entire Nix expression evaluation process, /// from parsing and semantic analysis to interpretation and JIT compilation. pub struct Context<'bump> { + ir_count: usize, hirs: Vec, lirs: Vec>, /// Maps a function's body `ExprId` to its parameter definition. funcs: HashMap, - repl_scope: HashMap, - global_scope: HashMap<&'static str, ExprId>, + repl_scope: NonNull>, + global_scope: NonNull>, /// A dependency graph between expressions. graph: DiGraphMap, @@ -82,25 +86,24 @@ pub struct Context<'bump> { bump: &'bump Bump, } +impl Drop for Context<'_> { + fn drop(&mut self) { + unsafe { + self.repl_scope.drop_in_place(); + self.global_scope.drop_in_place(); + } + } +} + impl<'bump> Context<'bump> { pub fn new(bump: &'bump Bump) -> Self { - let Builtins { - consts, - global, - scoped, - } = Builtins::new(); - let global_scope = consts + let Builtins { global, scoped } = Builtins::new(); + let global_scope = global .iter() .enumerate() - .map(|(id, (k, _))| (*k, unsafe { ExprId::from_raw(id) })) - .chain( - global - .iter() - .enumerate() - .map(|(idx, (k, _, _))| (*k, unsafe { ExprId::from_raw(idx + CONSTS_LEN) })), - ) + .map(|(idx, (k, _, _))| (*k, unsafe { ExprId::from_raw(idx) })) .chain(core::iter::once(("builtins", unsafe { - ExprId::from_raw(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN) + ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN) }))) .collect(); let primops = global @@ -109,47 +112,41 @@ impl<'bump> Context<'bump> { .chain(scoped.iter().map(|&(_, arity, f)| (arity, f))) .collect_array() .unwrap(); - let lirs = consts - .into_iter() - .map(|(_, val)| Lir::Const(Const { val })) - .chain((0..global.len()).map(|idx| Lir::PrimOp(unsafe { PrimOpId::from_raw(idx) }))) + let lirs = (0..global.len()).map(|idx| Lir::PrimOp(unsafe { PrimOpId::from_raw(idx) })) .chain( (0..scoped.len()) .map(|idx| Lir::PrimOp(unsafe { PrimOpId::from_raw(idx + GLOBAL_LEN) })), ) .chain(core::iter::once(Lir::AttrSet(AttrSet { - stcs: consts - .into_iter() - .enumerate() - .map(|(idx, (name, _))| (name.to_string(), unsafe { ExprId::from_raw(idx) })) - .chain(global.into_iter().enumerate().map(|(idx, (name, ..))| { + stcs: global.into_iter().enumerate().map(|(idx, (name, ..))| { (name.to_string(), unsafe { - ExprId::from_raw(idx + CONSTS_LEN) + ExprId::from_raw(idx) }) - })) + }) .chain(scoped.into_iter().enumerate().map(|(idx, (name, ..))| { (name.to_string(), unsafe { - ExprId::from_raw(idx + CONSTS_LEN + GLOBAL_LEN) + ExprId::from_raw(idx + GLOBAL_LEN) }) })) .chain(core::iter::once(("builtins".to_string(), unsafe { - ExprId::from_raw(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN + 1) + ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN + 1) }))) .collect(), ..AttrSet::default() }))) .chain(core::iter::once(Lir::Thunk(unsafe { - ExprId::from_raw(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN) + ExprId::from_raw(GLOBAL_LEN + SCOPED_LEN) }))) .map(|lir| Pin::new_in(lir, bump)) .collect(); Self { + ir_count: 0, hirs: Vec::new(), lirs, funcs: HashMap::new(), - global_scope, - repl_scope: HashMap::new(), + global_scope: NonNull::from(bump.alloc(global_scope)), + repl_scope: NonNull::from(bump.alloc(HashMap::new())), graph: DiGraphMap::new(), primops, @@ -207,9 +204,29 @@ impl<'bump> Context<'bump> { let root_expr = root.tree().expr().unwrap(); let expr_id = self.downgrade_ctx().downgrade_root(root_expr)?; self.resolve_ctx().resolve_root(expr_id)?; - self.repl_scope.insert(ident.to_string(), expr_id); + unsafe { self.repl_scope.as_mut() }.insert(ident.to_string(), expr_id); Ok(()) } } +impl Context<'_> { + fn alloc_id(&mut self) -> ExprId { + self.ir_count += 1; + unsafe { ExprId::from_raw(self.ir_count - 1) } + } + + fn add_dep(&mut self, from: ExprId, to: ExprId, count: &Cell) -> StackIdx { + if let Some(&idx) = self.graph.edge_weight(from, to) { + idx + } else { + let idx = count.get(); + count.set(idx + 1); + let idx = unsafe { StackIdx::from_raw(idx) }; + assert_ne!(from, to); + self.graph.add_edge(from, to, idx); + idx + } + } +} + impl BuiltinsContext for Context<'_> {} diff --git a/evaluator/nixjit_context/src/resolve.rs b/evaluator/nixjit_context/src/resolve.rs index 0fce511..2f9c029 100644 --- a/evaluator/nixjit_context/src/resolve.rs +++ b/evaluator/nixjit_context/src/resolve.rs @@ -1,11 +1,11 @@ -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use derive_more::Unwrap; use hashbrown::HashMap; use nixjit_error::Result; use nixjit_hir::Hir; -use nixjit_ir::{Const, ExprId, Param, StackIdx}; +use nixjit_ir::{Const, ExprId, Param, PrimOpId, StackIdx}; use nixjit_lir::{Lir, LookupResult, Resolve, ResolveContext}; use replace_with::replace_with_and_return; @@ -30,33 +30,22 @@ enum Ir { Lir(Lir), } -impl Ir { - unsafe fn unwrap_hir_unchecked(self) -> Hir { - if let Self::Hir(hir) = self { - hir - } else { - unsafe { core::hint::unreachable_unchecked() } - } - } -} - pub struct ResolveCtx<'ctx, 'bump> { ctx: &'ctx mut Context<'bump>, irs: Vec>>, scopes: Vec>, has_with: bool, with_used: bool, - closures: Vec<(ExprId, Option, usize)>, + closures: Vec<(ExprId, Option, Cell)>, current_expr: Option, } impl<'ctx, 'bump> ResolveCtx<'ctx, 'bump> { pub fn new(ctx: &'ctx mut Context<'bump>) -> Self { - let ctx_mut = unsafe { &mut *(ctx as *mut Context) }; Self { scopes: vec![ - Scope::Builtins(&ctx.global_scope), - Scope::Repl(&ctx.repl_scope), + Scope::Builtins(unsafe { ctx.global_scope.as_ref() }), + Scope::Repl(unsafe { ctx.repl_scope.as_ref() }), ], has_with: false, with_used: false, @@ -65,7 +54,7 @@ impl<'ctx, 'bump> ResolveCtx<'ctx, 'bump> { .map(|hir| Ir::Hir(hir).into()) .map(|ir| Pin::new_in(ir, ctx.bump)) .collect(), - ctx: ctx_mut, + ctx, closures: Vec::new(), current_expr: None, } @@ -80,33 +69,10 @@ impl<'ctx, 'bump> ResolveCtx<'ctx, 'bump> { } } - fn get_ir_mut(&mut self, id: ExprId) -> &mut RefCell { - let idx = unsafe { id.raw() } - self.ctx.lirs.len(); - if cfg!(debug_assertions) { - self.irs.get_mut(idx).unwrap() - } else { - unsafe { self.irs.get_unchecked_mut(idx) } - } - } - - fn add_dep(&mut self, from: ExprId, to: ExprId, count: &mut usize) -> StackIdx { - if let Some(&idx) = self.ctx.graph.edge_weight(from, to) { - idx - } else { - *count += 1; - let idx = unsafe { StackIdx::from_raw(*count - 1) }; - assert_ne!(from, to); - self.ctx.graph.add_edge(from, to, idx); - idx - } - } - fn new_lir(&mut self, lir: Lir) -> ExprId { - self.irs.push(Pin::new_in( - RefCell::new(Ir::Lir(lir)), - self.ctx.bump, - )); - unsafe { ExprId::from_raw(self.ctx.lirs.len() + self.irs.len() - 1) } + self.irs + .push(Pin::new_in(RefCell::new(Ir::Lir(lir)), self.ctx.bump)); + self.ctx.alloc_id() } } @@ -115,7 +81,7 @@ impl ResolveContext for ResolveCtx<'_, '_> { let prev_expr = self.current_expr.replace(expr); let result = unsafe { let ctx = &mut *(self as *mut Self); - let ir = &mut self.get_ir_mut(expr); + let ir = self.get_ir(expr); if !matches!(ir.try_borrow().as_deref(), Ok(Ir::Hir(_))) { return Ok(()); } @@ -126,7 +92,7 @@ impl ResolveContext for ResolveCtx<'_, '_> { val: nixjit_value::Const::Null, })) }, - |ir| match ir.unwrap_hir_unchecked().resolve(ctx) { + |ir| match ir.unwrap_hir().resolve(ctx) { Ok(lir) => (Ok(()), Ir::Lir(lir)), Err(err) => ( Err(err), @@ -141,12 +107,8 @@ impl ResolveContext for ResolveCtx<'_, '_> { result } - fn resolve_call(&mut self, func: ExprId, arg: ExprId) -> Result<()> { - todo!() - } - fn resolve_root(mut self, expr: ExprId) -> Result<()> { - self.closures.push((expr, None, 0)); + self.closures.push((expr, None, Cell::new(0))); let ret = self.resolve(expr); if ret.is_ok() { self.ctx.lirs.extend( @@ -167,17 +129,17 @@ impl ResolveContext for ResolveCtx<'_, '_> { for scope in self.scopes.iter().rev() { match scope { Scope::Builtins(scope) => { - if let Some(&expr) = scope.get(&name) { - return LookupResult::Expr(expr); + if let Some(&primop) = scope.get(&name) { + return LookupResult::PrimOp(primop); } } Scope::Let(scope) | &Scope::Repl(scope) => { if let Some(&dep) = scope.get(name) { - let (expr, _, deps) = unsafe { &mut *(self as *mut Self) } + let (expr, _, deps) = self .closures - .last_mut() + .last() .unwrap(); - let idx = self.add_dep(*expr, dep, deps); + let idx = self.ctx.add_dep(*expr, dep, deps); return LookupResult::Stack(idx); } } @@ -186,19 +148,18 @@ impl ResolveContext for ResolveCtx<'_, '_> { // This is an outer function's parameter, treat as dependency // We need to find the corresponding parameter expression to create dependency // For now, we need to handle this case by creating a dependency to the parameter - let mut iter = unsafe { &mut *(self as *mut Self) } - .closures - .iter_mut() - .rev() - .take(closure_depth + 1) - .rev(); + let mut iter = self.closures.iter().rev().take(closure_depth + 1).rev(); let Some((func, Some(arg), count)) = iter.next() else { unreachable!() }; - let mut cur = self.add_dep(*func, *arg, count); + let mut cur = self.ctx.add_dep(*func, *arg, count); for (func, _, count) in iter { - let idx = self.new_lir(Lir::StackRef(cur)); - cur = self.add_dep(*func, idx, count); + self.irs.push(Pin::new_in( + RefCell::new(Ir::Lir(Lir::StackRef(cur))), + self.ctx.bump, + )); + let idx = self.ctx.alloc_id(); + cur = self.ctx.add_dep(*func, idx, count); } return LookupResult::Stack(cur); } @@ -215,12 +176,10 @@ impl ResolveContext for ResolveCtx<'_, '_> { } fn lookup_arg(&mut self) -> StackIdx { - let Some((func, Some(arg), count)) = - unsafe { &mut *(self as *mut Self) }.closures.last_mut() - else { + let Some((func, Some(arg), count)) = self.closures.last() else { unreachable!() }; - self.add_dep(*func, *arg, count) + self.ctx.add_dep(*func, *arg, count) } fn new_func(&mut self, body: ExprId, param: Param) { @@ -255,7 +214,7 @@ impl ResolveContext for ResolveCtx<'_, '_> { f: impl FnOnce(&mut Self) -> T, ) -> T { let arg = self.new_lir(Lir::Arg(nixjit_ir::Arg)); - self.closures.push((func, Some(arg), 0)); + self.closures.push((func, Some(arg), Cell::new(0))); self.scopes.push(Scope::Arg(ident)); let res = f(self); self.scopes.pop(); diff --git a/evaluator/nixjit_eval/src/lib.rs b/evaluator/nixjit_eval/src/lib.rs index 9373dc2..f1849ef 100644 --- a/evaluator/nixjit_eval/src/lib.rs +++ b/evaluator/nixjit_eval/src/lib.rs @@ -330,13 +330,7 @@ impl Evaluate for ir::Str { impl Evaluate for ir::Const { /// Evaluates a `Const` literal into its corresponding `Value` variant. fn eval(&self, _: &mut Ctx) -> Result { - let result = match self.val { - Const::Null => Value::Null, - Const::Int(x) => Value::Int(x), - Const::Float(x) => Value::Float(x), - Const::Bool(x) => Value::Bool(x), - }; - Ok(result) + Ok(self.val.into()) } } diff --git a/evaluator/nixjit_eval/src/value/mod.rs b/evaluator/nixjit_eval/src/value/mod.rs index 968f1a3..e46e7e3 100644 --- a/evaluator/nixjit_eval/src/value/mod.rs +++ b/evaluator/nixjit_eval/src/value/mod.rs @@ -76,6 +76,17 @@ impl Debug for Value { } } +impl From for Value { + fn from(value: nixjit_value::Const) -> Self { + match value { + Const::Null => Value::Null, + Const::Int(x) => Value::Int(x), + Const::Float(x) => Value::Float(x), + Const::Bool(x) => Value::Bool(x), + } + } +} + impl Value { pub const INT: u64 = 0; pub const FLOAT: u64 = 1; diff --git a/evaluator/nixjit_lir/src/lib.rs b/evaluator/nixjit_lir/src/lib.rs index 809052a..3196b33 100644 --- a/evaluator/nixjit_lir/src/lib.rs +++ b/evaluator/nixjit_lir/src/lib.rs @@ -52,7 +52,7 @@ ir! { pub enum LookupResult { Stack(StackIdx), /// The variable was found and resolved to a specific expression. - Expr(ExprId), + PrimOp(ExprId), /// The variable could not be resolved statically, likely due to a `with` expression. /// The lookup must be performed dynamically at evaluation time. Unknown, @@ -71,8 +71,6 @@ pub trait ResolveContext { /// Triggers the resolution of a given expression. fn resolve(&mut self, expr: ExprId) -> Result<()>; - fn resolve_call(&mut self, func: ExprId, arg: ExprId) -> Result<()>; - fn resolve_root(self, expr: ExprId) -> Result<()>; /// Looks up a variable by name in the current scope. @@ -233,7 +231,6 @@ impl Resolve for Call { fn resolve(self, ctx: &mut Ctx) -> Result { ctx.resolve(self.func)?; ctx.resolve(self.arg)?; - ctx.resolve_call(self.func, self.arg)?; Ok(self.to_lir()) } } @@ -280,7 +277,7 @@ impl Resolve for Var { use LookupResult::*; match ctx.lookup(&self.sym) { Stack(idx) => Ok(Lir::StackRef(idx)), - Expr(expr) => Ok(Lir::ExprRef(expr)), + PrimOp(id) => Ok(Lir::ExprRef(id)), Unknown => Ok(self.to_lir()), NotFound => Err(Error::resolution_error(format!( "undefined variable '{}'", diff --git a/evaluator/nixjit_macros/src/builtins.rs b/evaluator/nixjit_macros/src/builtins.rs index 3c1086b..2ea18c9 100644 --- a/evaluator/nixjit_macros/src/builtins.rs +++ b/evaluator/nixjit_macros/src/builtins.rs @@ -18,7 +18,7 @@ use proc_macro::TokenStream; use proc_macro2::Span; use quote::{ToTokens, format_ident, quote}; use syn::{ - FnArg, Item, ItemFn, ItemMod, Pat, PatIdent, PatType, Type, Visibility, parse_macro_input, + parse_macro_input, FnArg, Item, ItemConst, ItemFn, ItemMod, Pat, PatIdent, PatType, Type, Visibility }; /// The implementation of the `#[builtins]` macro. @@ -40,7 +40,6 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { }; let mut pub_item_mod = Vec::new(); - let mut consts = Vec::new(); let mut global = Vec::new(); let mut scoped = Vec::new(); let mut wrappers = Vec::new(); @@ -49,20 +48,17 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { for item in &items { match item { Item::Const(item_const) => { - // Handle `const` definitions. These are exposed as constants in Nix. - let name_str = item_const - .ident - .to_string() - .from_case(Case::UpperSnake) - .to_case(Case::Camel); - let const_name = &item_const.ident; - consts.push(quote! { (#name_str, builtins::#const_name) }); - pub_item_mod.push( - quote! { - pub #item_const - } - .into(), - ); + let (primop, wrapper) = match generate_const_wrapper(item_const) { + Ok(result) => result, + Err(e) => return e.to_compile_error().into(), + }; + // Public functions are added to the global scope, private ones to a scoped set. + if matches!(item_const.vis, Visibility::Public(_)) { + global.push(primop); + } else { + scoped.push(primop); + } + wrappers.push(wrapper); } Item::Fn(item_fn) => { // Handle function definitions. These become primops. @@ -90,7 +86,6 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { } } - let consts_len = consts.len(); let global_len = global.len(); let scoped_len = scoped.len(); @@ -100,15 +95,12 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { #visibility mod #mod_name { #(#pub_item_mod)* #(#wrappers)* - pub const CONSTS_LEN: usize = #consts_len; pub const GLOBAL_LEN: usize = #global_len; pub const SCOPED_LEN: usize = #scoped_len; } /// A struct containing all the built-in constants and functions. pub struct Builtins { - /// Constant values available in the global scope. - pub consts: [(&'static str, ::nixjit_value::Const); #mod_name::CONSTS_LEN], /// Global functions available in the global scope. pub global: [(&'static str, usize, fn(&mut Ctx, ::nixjit_eval::Args) -> ::nixjit_error::Result<::nixjit_eval::Value>); #mod_name::GLOBAL_LEN], /// Scoped functions, typically available under the `builtins` attribute set. @@ -119,7 +111,6 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { /// Creates a new instance of the `Builtins` struct. pub fn new() -> Self { Self { - consts: [#(#consts,)*], global: [#(#global,)*], scoped: [#(#scoped,)*], } @@ -130,6 +121,33 @@ pub fn builtins_impl(input: TokenStream) -> TokenStream { output.into() } +fn generate_const_wrapper( + item_const: &ItemConst, +) -> syn::Result<(proc_macro2::TokenStream, proc_macro2::TokenStream)> { + let const_name = &item_const.ident; + let const_val = &item_const.expr; + let name_str = const_name + .to_string() + .from_case(Case::UpperSnake) + .to_case(Case::Camel); + let const_name = format_ident!("{name_str}"); + let wrapper_name = format_ident!("wrapper_{}", const_name); + let mod_name = format_ident!("builtins"); + + let fn_type = quote! { fn(&mut Ctx, ::nixjit_eval::Args) -> ::nixjit_error::Result<::nixjit_eval::Value> }; + + // The primop metadata tuple: (name, arity, wrapper_function_pointer) + let primop = quote! { (#name_str, 0, #mod_name::#wrapper_name as #fn_type) }; + + // The generated wrapper function. + let wrapper = quote! { + pub fn #wrapper_name(ctx: &mut Ctx, mut args: ::nixjit_eval::Args) -> ::nixjit_error::Result<::nixjit_eval::Value> { + Ok(#const_val.into()) + } + }; + + Ok((primop, wrapper)) +} /// Generates the primop metadata and the wrapper function for a single user-defined function. fn generate_primop_wrapper( item_fn: &ItemFn, @@ -166,7 +184,7 @@ fn generate_primop_wrapper( }; // Collect the remaining arguments. - let arg_pats: Vec<_> = user_args.rev().collect(); + let arg_pats: Vec<_> = user_args.collect(); let arg_count = arg_pats.len(); let arg_unpacks = arg_pats.iter().enumerate().map(|(i, arg)| { @@ -177,7 +195,7 @@ fn generate_primop_wrapper( }; quote! { - let #arg_name: #arg_ty = args.pop().ok_or_else(|| ::nixjit_error::Error::eval_error("Not enough arguments provided".to_string()))? + let #arg_name: #arg_ty = args.next().ok_or_else(|| ::nixjit_error::Error::eval_error("Not enough arguments provided".to_string()))? .try_into().map_err(|e| ::nixjit_error::Error::eval_error(format!("Argument type conversion failed: {}", e)))?; } }); @@ -192,7 +210,6 @@ fn generate_primop_wrapper( } _ => unreachable!(), }) - .rev() .collect(); // Construct the argument list for the final call. @@ -232,6 +249,8 @@ fn generate_primop_wrapper( if args.len() != #arg_count { return Err(::nixjit_error::Error::eval_error(format!("Function '{}' expects {} arguments, but received {}", #name_str, #arg_count, args.len()))); } + + let mut args = args.into_iter(); #(#arg_unpacks)* #call_expr