use std::cell::{OnceCell, RefCell}; use std::rc::Rc; use derive_more::Unwrap; use hashbrown::{HashMap, HashSet}; use itertools::Itertools; use petgraph::graph::{DiGraph, NodeIndex}; use nixjit_builtins::{ Builtins, BuiltinsContext, builtins::{CONSTS_LEN, GLOBAL_LEN, SCOPED_LEN}, }; use nixjit_error::{Error, Result}; use nixjit_eval::{EvalContext, Evaluate, Value}; use nixjit_hir::{Downgrade, DowngradeContext, Hir}; use nixjit_ir::{ArgIdx, Const, ExprId, Param, PrimOp, PrimOpId}; use nixjit_lir::{Lir, LookupResult, Resolve, ResolveContext}; use nixjit_jit::{JITCompiler, JITContext, JITFunc}; use replace_with::replace_with_and_return; enum Scope { With, Let(HashMap), Arg(Option), } #[derive(Debug, Unwrap)] enum Ir { Hir(Hir), Lir(Lir), } impl Ir { unsafe fn unwrap_hir_ref_unchecked(&self) -> &Hir { if let Self::Hir(hir) = self { hir } else { unsafe { core::hint::unreachable_unchecked() } } } unsafe fn unwrap_hir_mut_unchecked(&mut self) -> &mut Hir { #[cfg(debug_assertions)] if let Self::Hir(hir) = self { hir } else { unsafe { core::hint::unreachable_unchecked() } } #[cfg(not(debug_assertions))] if let Self::Hir(hir) = self { hir } else { unsafe { core::hint::unreachable_unchecked() } } } unsafe fn unwrap_hir_unchecked(self) -> Hir { if cfg!(debug_assertions) { self.unwrap_hir() } else if let Self::Hir(hir) = self { hir } else { unsafe { core::hint::unreachable_unchecked() } } } unsafe fn unwrap_lir_ref_unchecked(&self) -> &Lir { #[cfg(debug_assertions)] if let Self::Lir(lir) = self { lir } else { unsafe { core::hint::unreachable_unchecked() } } #[cfg(not(debug_assertions))] if let Self::Lir(lir) = self { lir } else { panic!() } } } pub struct Context { irs: Vec>, resolved: Vec, scopes: Vec, args_count: usize, primops: Vec) -> Result>, funcs: HashMap, graph: DiGraph, nodes: Vec, stack: Vec>, with_scopes: Vec>>, jit: JITCompiler, compiled: Vec>>, } impl Default for Context { fn default() -> Self { let Builtins { consts, global, scoped, } = Builtins::new(); let global_scope = Scope::Let( consts .iter() .enumerate() .map(|(id, (k, _))| (k.to_string(), unsafe { ExprId::from(id) })) .chain(global.iter().enumerate().map(|(idx, (k, _, _))| { (k.to_string(), unsafe { ExprId::from(idx + CONSTS_LEN) }) })) .chain(core::iter::once(("builtins".to_string(), unsafe { ExprId::from(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN) }))) .collect(), ); let primops = global .iter() .map(|&(_, _, f)| f) .chain(scoped.iter().map(|&(_, _, f)| f)) .collect(); let irs = consts .into_iter() .map(|(_, val)| Ir::Lir(Lir::Const(Const { val }))) .chain( global .into_iter() .enumerate() .map(|(idx, (name, arity, _))| { Ir::Lir(Lir::PrimOp(PrimOp { name, arity, id: unsafe { PrimOpId::from(idx) }, })) }), ) .map(RefCell::new) .collect(); Self { irs, resolved: Vec::new(), scopes: vec![global_scope], args_count: 0, primops, funcs: HashMap::new(), graph: DiGraph::new(), nodes: Vec::new(), stack: Vec::new(), with_scopes: Vec::new(), jit: JITCompiler::new(), compiled: Vec::new(), } } } impl Context { pub fn new() -> Self { Self::default() } pub fn eval(mut self, expr: &str) -> Result { let root = rnix::Root::parse(expr); if !root.errors().is_empty() { return Err(Error::ParseError( root.errors().iter().map(|err| err.to_string()).join(";"), )); } let root = root.tree().expr().unwrap().downgrade(&mut self)?; self.resolve(&root)?; Ok(EvalContext::eval(&mut self, &root)?.to_public(&mut HashSet::new())) } } impl DowngradeContext for Context { fn new_expr(&mut self, expr: Hir) -> ExprId { let id = unsafe { ExprId::from(self.irs.len()) }; self.irs.push(Ir::Hir(expr).into()); self.nodes.push(self.graph.add_node(unsafe { id.clone() })); self.resolved.push(false); self.compiled.push(OnceCell::new()); id } fn with_expr(&self, id: ExprId, f: impl FnOnce(&Hir, &Self) -> T) -> T { unsafe { let idx = id.raw(); f(&self.irs[idx].borrow().unwrap_hir_ref_unchecked(), self) } } fn with_expr_mut(&mut self, id: &ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T { unsafe { let idx = id.clone().raw(); let self_mut = &mut *(self as *mut Self); f( &mut self .irs .get_unchecked_mut(idx) .borrow_mut() .unwrap_hir_mut_unchecked(), self_mut, ) } } } impl ResolveContext for Context { fn lookup(&self, name: &str) -> LookupResult { let mut arg_idx = 0; for scope in self.scopes.iter().rev() { match scope { Scope::Let(scope) => { if let Some(expr) = scope.get(name) { return LookupResult::Expr(unsafe { expr.clone() }); } } Scope::Arg(ident) => { if ident.as_deref() == Some(name) { return LookupResult::Arg(unsafe { ArgIdx::from(arg_idx) }); } arg_idx += 1; } Scope::With => return LookupResult::Unknown, } } LookupResult::NotFound } fn new_dep(&mut self, expr: &ExprId, dep: ExprId) { unsafe { let expr = expr.clone().raw(); let dep = dep.raw(); let expr = *self.nodes.get_unchecked(expr); let dep = *self.nodes.get_unchecked(dep); self.graph.add_edge(expr, dep, ()); } } fn resolve(&mut self, expr: &ExprId) -> Result<()> { unsafe { let idx = expr.clone().raw(); let self_mut = &mut *(self as *mut Self); replace_with_and_return( &mut *self.irs.get_unchecked(idx).borrow_mut(), || { Ir::Hir(Hir::Const(Const { val: nixjit_value::Const::Null, })) }, |ir| { let hir = ir.unwrap_hir_unchecked(); match hir.resolve(self_mut) { Ok(lir) => (Ok(()), Ir::Lir(lir)), Err(err) => ( Err(err), Ir::Hir(Hir::Const(Const { val: nixjit_value::Const::Null, })), ), } }, )?; } Ok(()) } fn new_func(&mut self, body: &ExprId, param: Param) { self.funcs.insert(unsafe { body.clone() }, param); } fn with_let_env<'a, T>( &mut self, bindings: impl Iterator, f: impl FnOnce(&mut Self) -> T, ) -> T { let mut scope = HashMap::new(); for (name, expr) in bindings { scope.insert(name.clone(), unsafe { expr.clone() }); } self.scopes.push(Scope::Let(scope)); let res = f(self); self.scopes.pop(); res } fn with_with_env(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T) { self.scopes.push(Scope::With); let res = f(self); self.scopes.pop(); (true, res) } fn with_param_env(&mut self, ident: Option, f: impl FnOnce(&mut Self) -> T) -> T { self.scopes.push(Scope::Arg(ident)); self.args_count += 1; let res = f(self); self.args_count -= 1; self.scopes.pop(); res } } impl EvalContext for Context { fn eval(&mut self, expr: &ExprId) -> Result { let idx = unsafe { expr.clone().raw() }; let lir = unsafe { &*(self .irs .get_unchecked(idx) .borrow() .unwrap_lir_ref_unchecked() as *const Lir) }; println!("{:#?}", self.irs); lir.eval(self) } fn pop_frame(&mut self) -> Vec { self.stack.pop().unwrap() } fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a nixjit_eval::Value> { for scope in self.with_scopes.iter().rev() { if let Some(val) = scope.get(ident) { return Some(val); } } None } fn lookup_arg<'a>(&'a self, idx: ArgIdx) -> &'a Value { unsafe { let values = self.stack.last().unwrap_unchecked(); dbg!(values, idx); &values[values.len() - idx.raw() - 1] } } fn with_with_env( &mut self, namespace: std::rc::Rc>, f: impl FnOnce(&mut Self) -> T, ) -> T { self.with_scopes.push(namespace); let res = f(self); self.with_scopes.pop(); res } fn with_args_env( &mut self, args: Vec, f: impl FnOnce(&mut Self) -> T, ) -> (Vec, T) { self.stack.push(args); let res = f(self); let frame = self.stack.pop().unwrap(); (frame, res) } fn call_primop(&mut self, id: nixjit_ir::PrimOpId, args: Vec) -> Result { unsafe { (self.primops.get_unchecked(id.raw()))(self, args) } } } impl JITContext for Context { fn lookup_arg(&self, offset: usize) -> &nixjit_eval::Value { let values = self.stack.last().unwrap(); &values[values.len() - offset - 1] } fn lookup_stack(&self, offset: usize) -> &nixjit_eval::Value { todo!() } fn enter_with(&mut self, namespace: std::rc::Rc>) { self.with_scopes.push(namespace); } fn exit_with(&mut self) { self.with_scopes.pop(); } } impl BuiltinsContext for Context {}