use std::cell::RefCell; use derive_more::Unwrap; use hashbrown::HashMap; use nixjit_error::Result; use nixjit_hir::Hir; use nixjit_ir::{Const, ExprId, Param, StackIdx}; use nixjit_lir::{Lir, LookupResult, Resolve, ResolveContext}; use replace_with::replace_with_and_return; use super::{Context, Pin}; #[derive(Clone)] enum Scope<'ctx> { /// A `let` binding scope, mapping variable names to their expression IDs. Let(HashMap), /// A function argument scope. `Some` holds the name of the argument set if present. Arg(Option), Builtins(&'ctx HashMap<&'static str, ExprId>), Repl(&'ctx HashMap), } /// Represents an expression at different stages of compilation. #[derive(Debug, Unwrap)] enum Ir { /// An expression in the High-Level Intermediate Representation (HIR). Hir(Hir), /// An expression in the Low-Level Intermediate Representation (LIR). Lir(Lir), } impl Ir { unsafe fn unwrap_hir_unchecked(self) -> Hir { if let Self::Hir(hir) = self { hir } else { unsafe { core::hint::unreachable_unchecked() } } } } pub struct ResolveCtx<'ctx, 'bump> { ctx: &'ctx mut Context<'bump>, irs: Vec>>, scopes: Vec>, has_with: bool, with_used: bool, closures: Vec<(ExprId, Option, usize)>, current_expr: Option, } impl<'ctx, 'bump> ResolveCtx<'ctx, 'bump> { pub fn new(ctx: &'ctx mut Context<'bump>) -> Self { let ctx_mut = unsafe { &mut *(ctx as *mut Context) }; Self { scopes: vec![ Scope::Builtins(&ctx.global_scope), Scope::Repl(&ctx.repl_scope), ], has_with: false, with_used: false, irs: core::mem::take(&mut ctx.hirs) .into_iter() .map(|hir| Ir::Hir(hir).into()) .map(|ir| Pin::new_in(ir, ctx.bump)) .collect(), ctx: ctx_mut, closures: Vec::new(), current_expr: None, } } fn get_ir(&self, id: ExprId) -> &RefCell { let idx = unsafe { id.raw() } - self.ctx.lirs.len(); if cfg!(debug_assertions) { self.irs.get(idx).unwrap() } else { unsafe { self.irs.get_unchecked(idx) } } } fn get_ir_mut(&mut self, id: ExprId) -> &mut RefCell { let idx = unsafe { id.raw() } - self.ctx.lirs.len(); if cfg!(debug_assertions) { self.irs.get_mut(idx).unwrap() } else { unsafe { self.irs.get_unchecked_mut(idx) } } } fn add_dep(&mut self, from: ExprId, to: ExprId, count: &mut usize) -> StackIdx { if let Some(&idx) = self.ctx.graph.edge_weight(from, to) { idx } else { *count += 1; let idx = unsafe { StackIdx::from_raw(*count - 1) }; assert_ne!(from, to); self.ctx.graph.add_edge(from, to, idx); idx } } fn new_lir(&mut self, lir: Lir) -> ExprId { self.irs.push(Pin::new_in( RefCell::new(Ir::Lir(lir)), self.ctx.bump, )); unsafe { ExprId::from_raw(self.ctx.lirs.len() + self.irs.len() - 1) } } } impl ResolveContext for ResolveCtx<'_, '_> { fn resolve(&mut self, expr: ExprId) -> Result<()> { let prev_expr = self.current_expr.replace(expr); let result = unsafe { let ctx = &mut *(self as *mut Self); let ir = &mut self.get_ir_mut(expr); if !matches!(ir.try_borrow().as_deref(), Ok(Ir::Hir(_))) { return Ok(()); } replace_with_and_return( &mut *ir.borrow_mut(), || { Ir::Hir(Hir::Const(Const { val: nixjit_value::Const::Null, })) }, |ir| match ir.unwrap_hir_unchecked().resolve(ctx) { Ok(lir) => (Ok(()), Ir::Lir(lir)), Err(err) => ( Err(err), Ir::Hir(Hir::Const(Const { val: nixjit_value::Const::Null, })), ), }, ) }; self.current_expr = prev_expr; result } fn resolve_call(&mut self, func: ExprId, arg: ExprId) -> Result<()> { todo!() } fn resolve_root(mut self, expr: ExprId) -> Result<()> { self.closures.push((expr, None, 0)); let ret = self.resolve(expr); if ret.is_ok() { self.ctx.lirs.extend( self.irs .into_iter() .map(Pin::into_inner) .map(RefCell::into_inner) .map(Ir::unwrap_lir) .map(|lir| crate::Pin::new_in(lir, self.ctx.bump)), ); } ret } fn lookup(&mut self, name: &str) -> LookupResult { let mut closure_depth = 0; // Then search from outer to inner scopes for dependencies for scope in self.scopes.iter().rev() { match scope { Scope::Builtins(scope) => { if let Some(&expr) = scope.get(&name) { return LookupResult::Expr(expr); } } Scope::Let(scope) | &Scope::Repl(scope) => { if let Some(&dep) = scope.get(name) { let (expr, _, deps) = unsafe { &mut *(self as *mut Self) } .closures .last_mut() .unwrap(); let idx = self.add_dep(*expr, dep, deps); return LookupResult::Stack(idx); } } Scope::Arg(ident) => { if ident.as_deref() == Some(name) { // This is an outer function's parameter, treat as dependency // We need to find the corresponding parameter expression to create dependency // For now, we need to handle this case by creating a dependency to the parameter let mut iter = unsafe { &mut *(self as *mut Self) } .closures .iter_mut() .rev() .take(closure_depth + 1) .rev(); let Some((func, Some(arg), count)) = iter.next() else { unreachable!() }; let mut cur = self.add_dep(*func, *arg, count); for (func, _, count) in iter { let idx = self.new_lir(Lir::StackRef(cur)); cur = self.add_dep(*func, idx, count); } return LookupResult::Stack(cur); } closure_depth += 1; } } } if self.has_with { self.with_used = true; LookupResult::Unknown } else { LookupResult::NotFound } } fn lookup_arg(&mut self) -> StackIdx { let Some((func, Some(arg), count)) = unsafe { &mut *(self as *mut Self) }.closures.last_mut() else { unreachable!() }; self.add_dep(*func, *arg, count) } fn new_func(&mut self, body: ExprId, param: Param) { self.ctx.funcs.insert(body, param); } fn with_let_env( &mut self, bindings: HashMap, f: impl FnOnce(&mut Self) -> T, ) -> T { self.scopes.push(Scope::Let(bindings)); let res = f(self); self.scopes.pop(); res } fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T) { let has_with = self.has_with; let with_used = self.with_used; self.has_with = true; self.with_used = false; let res = f(self); self.has_with = has_with; (core::mem::replace(&mut self.with_used, with_used), res) } fn with_closure_env( &mut self, func: ExprId, ident: Option, f: impl FnOnce(&mut Self) -> T, ) -> T { let arg = self.new_lir(Lir::Arg(nixjit_ir::Arg)); self.closures.push((func, Some(arg), 0)); self.scopes.push(Scope::Arg(ident)); let res = f(self); self.scopes.pop(); self.closures.pop(); res } }