Files
nixjit/evaluator/nixjit_context/src/lib.rs
2025-08-15 23:14:21 +08:00

393 lines
11 KiB
Rust

use std::cell::{OnceCell, RefCell};
use std::rc::Rc;
use derive_more::Unwrap;
use hashbrown::{HashMap, HashSet};
use itertools::Itertools;
use petgraph::graph::{DiGraph, NodeIndex};
use nixjit_builtins::{
Builtins, BuiltinsContext,
builtins::{CONSTS_LEN, GLOBAL_LEN, SCOPED_LEN},
};
use nixjit_error::{Error, Result};
use nixjit_eval::{EvalContext, Evaluate, Value};
use nixjit_hir::{Downgrade, DowngradeContext, Hir};
use nixjit_ir::{ArgIdx, Const, ExprId, Param, PrimOp, PrimOpId};
use nixjit_lir::{Lir, LookupResult, Resolve, ResolveContext};
use nixjit_jit::{JITCompiler, JITContext, JITFunc};
use replace_with::replace_with_and_return;
enum Scope {
With,
Let(HashMap<String, ExprId>),
Arg(Option<String>),
}
#[derive(Debug, Unwrap)]
enum Ir {
Hir(Hir),
Lir(Lir),
}
impl Ir {
unsafe fn unwrap_hir_ref_unchecked(&self) -> &Hir {
if let Self::Hir(hir) = self {
hir
} else {
unsafe { core::hint::unreachable_unchecked() }
}
}
unsafe fn unwrap_hir_mut_unchecked(&mut self) -> &mut Hir {
#[cfg(debug_assertions)]
if let Self::Hir(hir) = self {
hir
} else {
unsafe { core::hint::unreachable_unchecked() }
}
#[cfg(not(debug_assertions))]
if let Self::Hir(hir) = self {
hir
} else {
unsafe { core::hint::unreachable_unchecked() }
}
}
unsafe fn unwrap_hir_unchecked(self) -> Hir {
if cfg!(debug_assertions) {
self.unwrap_hir()
} else if let Self::Hir(hir) = self {
hir
} else {
unsafe { core::hint::unreachable_unchecked() }
}
}
unsafe fn unwrap_lir_ref_unchecked(&self) -> &Lir {
#[cfg(debug_assertions)]
if let Self::Lir(lir) = self {
lir
} else {
unsafe { core::hint::unreachable_unchecked() }
}
#[cfg(not(debug_assertions))]
if let Self::Lir(lir) = self {
lir
} else {
panic!()
}
}
}
pub struct Context {
irs: Vec<RefCell<Ir>>,
resolved: Vec<bool>,
scopes: Vec<Scope>,
args_count: usize,
primops: Vec<fn(&mut Context, Vec<Value>) -> Result<Value>>,
funcs: HashMap<ExprId, Param>,
graph: DiGraph<ExprId, ()>,
nodes: Vec<NodeIndex>,
stack: Vec<Vec<Value>>,
with_scopes: Vec<Rc<HashMap<String, Value>>>,
jit: JITCompiler<Self>,
compiled: Vec<OnceCell<JITFunc<Self>>>,
}
impl Default for Context {
fn default() -> Self {
let Builtins {
consts,
global,
scoped,
} = Builtins::new();
let global_scope = Scope::Let(
consts
.iter()
.enumerate()
.map(|(id, (k, _))| (k.to_string(), unsafe { ExprId::from(id) }))
.chain(global.iter().enumerate().map(|(idx, (k, _, _))| {
(k.to_string(), unsafe {
ExprId::from(idx + CONSTS_LEN)
})
}))
.chain(core::iter::once(("builtins".to_string(), unsafe {
ExprId::from(CONSTS_LEN + GLOBAL_LEN + SCOPED_LEN)
})))
.collect(),
);
let primops = global
.iter()
.map(|&(_, _, f)| f)
.chain(scoped.iter().map(|&(_, _, f)| f))
.collect();
let irs = consts
.into_iter()
.map(|(_, val)| Ir::Lir(Lir::Const(Const { val })))
.chain(
global
.into_iter()
.enumerate()
.map(|(idx, (name, arity, _))| {
Ir::Lir(Lir::PrimOp(PrimOp {
name,
arity,
id: unsafe { PrimOpId::from(idx) },
}))
}),
)
.map(RefCell::new)
.collect();
Self {
irs,
resolved: Vec::new(),
scopes: vec![global_scope],
args_count: 0,
primops,
funcs: HashMap::new(),
graph: DiGraph::new(),
nodes: Vec::new(),
stack: Vec::new(),
with_scopes: Vec::new(),
jit: JITCompiler::new(),
compiled: Vec::new(),
}
}
}
impl Context {
pub fn new() -> Self {
Self::default()
}
pub fn eval(mut self, expr: &str) -> Result<nixjit_value::Value> {
let root = rnix::Root::parse(expr);
if !root.errors().is_empty() {
return Err(Error::ParseError(
root.errors().iter().map(|err| err.to_string()).join(";"),
));
}
let root = root.tree().expr().unwrap().downgrade(&mut self)?;
self.resolve(&root)?;
Ok(EvalContext::eval(&mut self, &root)?.to_public(&mut HashSet::new()))
}
}
impl DowngradeContext for Context {
fn new_expr(&mut self, expr: Hir) -> ExprId {
let id = unsafe { ExprId::from(self.irs.len()) };
self.irs.push(Ir::Hir(expr).into());
self.nodes.push(self.graph.add_node(unsafe { id.clone() }));
self.resolved.push(false);
self.compiled.push(OnceCell::new());
id
}
fn with_expr<T>(&self, id: ExprId, f: impl FnOnce(&Hir, &Self) -> T) -> T {
unsafe {
let idx = id.raw();
f(&self.irs[idx].borrow().unwrap_hir_ref_unchecked(), self)
}
}
fn with_expr_mut<T>(&mut self, id: &ExprId, f: impl FnOnce(&mut Hir, &mut Self) -> T) -> T {
unsafe {
let idx = id.clone().raw();
let self_mut = &mut *(self as *mut Self);
f(
&mut self
.irs
.get_unchecked_mut(idx)
.borrow_mut()
.unwrap_hir_mut_unchecked(),
self_mut,
)
}
}
}
impl ResolveContext for Context {
fn lookup(&self, name: &str) -> LookupResult {
let mut arg_idx = 0;
for scope in self.scopes.iter().rev() {
match scope {
Scope::Let(scope) => {
if let Some(expr) = scope.get(name) {
return LookupResult::Expr(unsafe { expr.clone() });
}
}
Scope::Arg(ident) => {
if ident.as_deref() == Some(name) {
return LookupResult::Arg(unsafe { ArgIdx::from(arg_idx) });
}
arg_idx += 1;
}
Scope::With => return LookupResult::Unknown,
}
}
LookupResult::NotFound
}
fn new_dep(&mut self, expr: &ExprId, dep: ExprId) {
unsafe {
let expr = expr.clone().raw();
let dep = dep.raw();
let expr = *self.nodes.get_unchecked(expr);
let dep = *self.nodes.get_unchecked(dep);
self.graph.add_edge(expr, dep, ());
}
}
fn resolve(&mut self, expr: &ExprId) -> Result<()> {
unsafe {
let idx = expr.clone().raw();
let self_mut = &mut *(self as *mut Self);
replace_with_and_return(
&mut *self.irs.get_unchecked(idx).borrow_mut(),
|| {
Ir::Hir(Hir::Const(Const {
val: nixjit_value::Const::Null,
}))
},
|ir| {
let hir = ir.unwrap_hir_unchecked();
match hir.resolve(self_mut) {
Ok(lir) => (Ok(()), Ir::Lir(lir)),
Err(err) => (
Err(err),
Ir::Hir(Hir::Const(Const {
val: nixjit_value::Const::Null,
})),
),
}
},
)?;
}
Ok(())
}
fn new_func(&mut self, body: &ExprId, param: Param) {
self.funcs.insert(unsafe { body.clone() }, param);
}
fn with_let_env<'a, T>(
&mut self,
bindings: impl Iterator<Item = (&'a String, &'a ExprId)>,
f: impl FnOnce(&mut Self) -> T,
) -> T {
let mut scope = HashMap::new();
for (name, expr) in bindings {
scope.insert(name.clone(), unsafe { expr.clone() });
}
self.scopes.push(Scope::Let(scope));
let res = f(self);
self.scopes.pop();
res
}
fn with_with_env<T>(&mut self, f: impl FnOnce(&mut Self) -> T) -> (bool, T) {
self.scopes.push(Scope::With);
let res = f(self);
self.scopes.pop();
(true, res)
}
fn with_param_env<T>(&mut self, ident: Option<String>, f: impl FnOnce(&mut Self) -> T) -> T {
self.scopes.push(Scope::Arg(ident));
self.args_count += 1;
let res = f(self);
self.args_count -= 1;
self.scopes.pop();
res
}
}
impl EvalContext for Context {
fn eval(&mut self, expr: &ExprId) -> Result<nixjit_eval::Value> {
let idx = unsafe { expr.clone().raw() };
let lir = unsafe {
&*(self
.irs
.get_unchecked(idx)
.borrow()
.unwrap_lir_ref_unchecked() as *const Lir)
};
println!("{:#?}", self.irs);
lir.eval(self)
}
fn pop_frame(&mut self) -> Vec<nixjit_eval::Value> {
self.stack.pop().unwrap()
}
fn lookup_with<'a>(&'a self, ident: &str) -> Option<&'a nixjit_eval::Value> {
for scope in self.with_scopes.iter().rev() {
if let Some(val) = scope.get(ident) {
return Some(val);
}
}
None
}
fn lookup_arg<'a>(&'a self, idx: ArgIdx) -> &'a Value {
unsafe {
let values = self.stack.last().unwrap_unchecked();
dbg!(values, idx);
&values[values.len() - idx.raw() - 1]
}
}
fn with_with_env<T>(
&mut self,
namespace: std::rc::Rc<HashMap<String, nixjit_eval::Value>>,
f: impl FnOnce(&mut Self) -> T,
) -> T {
self.with_scopes.push(namespace);
let res = f(self);
self.with_scopes.pop();
res
}
fn with_args_env<T>(
&mut self,
args: Vec<nixjit_eval::Value>,
f: impl FnOnce(&mut Self) -> T,
) -> (Vec<Value>, T) {
self.stack.push(args);
let res = f(self);
let frame = self.stack.pop().unwrap();
(frame, res)
}
fn call_primop(&mut self, id: nixjit_ir::PrimOpId, args: Vec<Value>) -> Result<Value> {
unsafe {
(self.primops.get_unchecked(id.raw()))(self, args)
}
}
}
impl JITContext for Context {
fn lookup_arg(&self, offset: usize) -> &nixjit_eval::Value {
let values = self.stack.last().unwrap();
&values[values.len() - offset - 1]
}
fn lookup_stack(&self, offset: usize) -> &nixjit_eval::Value {
todo!()
}
fn enter_with(&mut self, namespace: std::rc::Rc<HashMap<String, nixjit_eval::Value>>) {
self.with_scopes.push(namespace);
}
fn exit_with(&mut self) {
self.with_scopes.pop();
}
}
impl BuiltinsContext for Context {}