From 2a19ddb279a2fb4565b3737cacdcf85e09a7fc5a Mon Sep 17 00:00:00 2001 From: imxyy_soope_ Date: Wed, 21 May 2025 20:48:56 +0800 Subject: [PATCH] feat: no clone in JIT IMPORTANT: should not drop or create values in JIT anymore --- src/bin/eval.rs | 6 ++-- src/builtins/mod.rs | 2 +- src/jit/helpers.rs | 19 +++-------- src/jit/mod.rs | 41 +++++++++++++++++------ src/stack.rs | 5 +-- src/ty/internal/attrset.rs | 18 +++++------ src/ty/internal/func.rs | 2 +- src/ty/internal/mod.rs | 66 +++++++++++++++++++++----------------- src/vm/env.rs | 14 +++++--- src/vm/mod.rs | 10 ++++-- src/vm/test.rs | 4 +-- 11 files changed, 105 insertions(+), 82 deletions(-) diff --git a/src/bin/eval.rs b/src/bin/eval.rs index 99462ef..bbc8db5 100644 --- a/src/bin/eval.rs +++ b/src/bin/eval.rs @@ -5,10 +5,10 @@ use itertools::Itertools; use nixjit::compile::compile; use nixjit::error::Error; +use nixjit::error::Result; use nixjit::ir::downgrade; use nixjit::jit::JITContext; use nixjit::vm::run; -use nixjit::error::Result; fn main() -> Result<()> { let mut args = std::env::args(); @@ -21,8 +21,8 @@ fn main() -> Result<()> { let root = rnix::Root::parse(&expr); if !root.errors().is_empty() { return Err(Error::ParseError( - root.errors().iter().map(|err| err.to_string()).join(";") - )) + root.errors().iter().map(|err| err.to_string()).join(";"), + )); } let expr = root.tree().expr().unwrap(); let downgraded = downgrade(expr)?; diff --git a/src/builtins/mod.rs b/src/builtins/mod.rs index 877f833..8b3c910 100644 --- a/src/builtins/mod.rs +++ b/src/builtins/mod.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use crate::ty::common::Const; use crate::ty::internal::{AttrSet, PrimOp, Value}; -use crate::vm::{VmEnv, VM}; +use crate::vm::{VM, VmEnv}; pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> VmEnv<'jit, 'vm> { let mut env_map = HashMap::new(); diff --git a/src/jit/helpers.rs b/src/jit/helpers.rs index 2dc10df..7526116 100644 --- a/src/jit/helpers.rs +++ b/src/jit/helpers.rs @@ -9,7 +9,7 @@ use inkwell::values::{BasicValueEnum, FunctionValue}; use crate::jit::JITValueData; use crate::ty::internal::{Thunk, Value}; -use crate::vm::{VmEnv, VM}; +use crate::vm::{VM, VmEnv}; use super::{JITValue, ValueTag}; @@ -204,20 +204,17 @@ impl<'ctx> Helpers<'ctx> { } } -#[unsafe(no_mangle)] extern "C" fn helper_debug(value: JITValue) { dbg!(value.tag); } -#[unsafe(no_mangle)] extern "C" fn helper_capture_env(thunk: JITValue, env: *const VmEnv) { - let thunk: &Thunk = unsafe { std::mem::transmute(thunk.data.ptr.as_ref().unwrap()) }; + let thunk = unsafe { (thunk.data.ptr as *const Thunk).as_ref().unwrap() }; let env = unsafe { Rc::from_raw(env) }; thunk.capture(env.clone()); std::mem::forget(env); } -#[unsafe(no_mangle)] extern "C" fn helper_neg(rhs: JITValue, _env: *const VmEnv) -> JITValue { use ValueTag::*; match rhs.tag { @@ -237,7 +234,6 @@ extern "C" fn helper_neg(rhs: JITValue, _env: *const VmEnv) -> JITValue { } } -#[unsafe(no_mangle)] extern "C" fn helper_not(rhs: JITValue, _env: *const VmEnv) -> JITValue { use ValueTag::*; match rhs.tag { @@ -251,7 +247,6 @@ extern "C" fn helper_not(rhs: JITValue, _env: *const VmEnv) -> JITValue { } } -#[unsafe(no_mangle)] extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue { use ValueTag::*; match (lhs.tag, rhs.tag) { @@ -269,7 +264,6 @@ extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue { } } -#[unsafe(no_mangle)] extern "C" fn helper_sub(lhs: JITValue, rhs: JITValue) -> JITValue { use ValueTag::*; match (lhs.tag, rhs.tag) { @@ -287,7 +281,6 @@ extern "C" fn helper_sub(lhs: JITValue, rhs: JITValue) -> JITValue { } } -#[unsafe(no_mangle)] extern "C" fn helper_eq(lhs: JITValue, rhs: JITValue) -> JITValue { use ValueTag::*; match (lhs.tag, rhs.tag) { @@ -305,7 +298,6 @@ extern "C" fn helper_eq(lhs: JITValue, rhs: JITValue) -> JITValue { } } -#[unsafe(no_mangle)] extern "C" fn helper_or(lhs: JITValue, rhs: JITValue) -> JITValue { use ValueTag::*; match (lhs.tag, rhs.tag) { @@ -323,7 +315,6 @@ extern "C" fn helper_or(lhs: JITValue, rhs: JITValue) -> JITValue { } } -#[unsafe(no_mangle)] extern "C" fn helper_call<'jit>( func: JITValue, args: *mut JITValue, @@ -346,14 +337,12 @@ extern "C" fn helper_call<'jit>( } } -#[unsafe(no_mangle)] extern "C" fn helper_lookup<'jit, 'vm>(sym: usize, env: *const VmEnv<'jit, 'vm>) -> JITValue { let env = unsafe { env.as_ref() }.unwrap(); - let val = env.lookup(&sym); - val.cloned().unwrap().into() + let val: JITValue = env.lookup(&sym).unwrap().into(); + val } -#[unsafe(no_mangle)] extern "C" fn helper_force<'jit>(thunk: JITValue, vm: *const VM<'jit>) -> JITValue { let mut val = Value::from(thunk); val.force(unsafe { vm.as_ref() }.unwrap()).unwrap(); diff --git a/src/jit/mod.rs b/src/jit/mod.rs index 99cf402..5730856 100644 --- a/src/jit/mod.rs +++ b/src/jit/mod.rs @@ -12,7 +12,7 @@ use crate::error::*; use crate::stack::Stack; use crate::ty::common::Const; use crate::ty::internal::{Thunk, Value}; -use crate::vm::{VmEnv, VM}; +use crate::vm::{VM, VmEnv}; mod helpers; @@ -57,6 +57,12 @@ pub union JITValueData { impl<'jit: 'vm, 'vm> From for Value<'jit, 'vm> { fn from(value: JITValue) -> Self { use ValueTag::*; + match value.tag { + List | AttrSet | String | Function | Thunk | Path => unsafe { + Rc::increment_strong_count(value.data.ptr); + }, + _ => (), + } match value.tag { Int => Value::Const(Const::Int(unsafe { value.data.int })), Null => Value::Const(Const::Null), @@ -67,6 +73,30 @@ impl<'jit: 'vm, 'vm> From for Value<'jit, 'vm> { } } +impl From<&Value<'_, '_>> for JITValue { + fn from(value: &Value<'_, '_>) -> Self { + match value { + Value::Const(Const::Int(int)) => JITValue { + tag: ValueTag::Int, + data: JITValueData { int: *int }, + }, + Value::Func(func) => JITValue { + tag: ValueTag::Function, + data: JITValueData { + ptr: Rc::as_ptr(func) as *const _, + }, + }, + Value::Thunk(thunk) => JITValue { + tag: ValueTag::Thunk, + data: JITValueData { + ptr: Rc::as_ptr(thunk) as *const _, + }, + }, + _ => todo!(), + } + } +} + impl From> for JITValue { fn from(value: Value<'_, '_>) -> Self { match value { @@ -144,14 +174,7 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { let env = func_.get_nth_param(1).unwrap().into_pointer_value(); let entry = self.context.append_basic_block(func_, "entry"); self.builder.position_at_end(entry); - self.build_expr( - &mut iter, - vm, - env, - &mut stack, - func_, - func.opcodes.len(), - )?; + self.build_expr(&mut iter, vm, env, &mut stack, func_, func.opcodes.len())?; assert_eq!(stack.len(), 1); let value = stack.pop(); diff --git a/src/stack.rs b/src/stack.rs index 6a7080f..2081006 100644 --- a/src/stack.rs +++ b/src/stack.rs @@ -35,10 +35,7 @@ impl Stack { pub fn push(&mut self, item: T) -> Result<()> { self.items .get_mut(self.top) - .map_or_else( - || Err(Error::EvalError("stack overflow".to_string())), - Ok, - )? + .map_or_else(|| Err(Error::EvalError("stack overflow".to_string())), Ok)? .write(item); self.top += 1; Ok(()) diff --git a/src/ty/internal/attrset.rs b/src/ty/internal/attrset.rs index b2a653c..687f172 100644 --- a/src/ty/internal/attrset.rs +++ b/src/ty/internal/attrset.rs @@ -6,7 +6,7 @@ use derive_more::Constructor; use itertools::Itertools; use crate::error::Result; -use crate::vm::{VmEnv, VM}; +use crate::vm::{VM, VmEnv}; use super::super::public as p; use super::Value; @@ -42,9 +42,9 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> { } pub fn select(&self, sym: usize) -> Option> { - self.data.get(&sym).cloned().map(|val| match val { + self.data.get(&sym).map(|val| match val { Value::Builtins(x) => Value::AttrSet(x.upgrade().unwrap()), - val => val, + val => val.clone(), }) } @@ -53,8 +53,10 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> { } pub fn capture(&mut self, env: &Rc>) { - self.data.iter().for_each(|(_, v)| if let Value::Thunk(ref thunk) = v.clone() { - thunk.capture(env.clone()); + self.data.iter().for_each(|(_, v)| { + if let Value::Thunk(ref thunk) = v.clone() { + thunk.capture(Rc::clone(env)); + } }) } @@ -77,11 +79,7 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> { } pub fn force_deep(&mut self, vm: &'vm VM<'jit>) -> Result<()> { - let mut map: Vec<_> = self - .data - .iter() - .map(|(k, v)| (*k, v.clone())) - .collect(); + let mut map: Vec<_> = self.data.iter().map(|(k, v)| (*k, v.clone())).collect(); for (_, v) in map.iter_mut() { v.force_deep(vm)?; } diff --git a/src/ty/internal/func.rs b/src/ty/internal/func.rs index 1f88bd1..e9f909d 100644 --- a/src/ty/internal/func.rs +++ b/src/ty/internal/func.rs @@ -11,7 +11,7 @@ use crate::error::Result; use crate::ir; use crate::jit::JITFunc; use crate::ty::internal::{Thunk, Value}; -use crate::vm::{VmEnv, VM}; +use crate::vm::{VM, VmEnv}; #[derive(Debug, Clone)] pub enum Param { diff --git a/src/ty/internal/mod.rs b/src/ty/internal/mod.rs index 2b6ab5e..6415834 100644 --- a/src/ty/internal/mod.rs +++ b/src/ty/internal/mod.rs @@ -11,7 +11,7 @@ use super::public as p; use crate::bytecode::OpCodes; use crate::error::*; -use crate::vm::{VmEnv, VM}; +use crate::vm::{VM, VmEnv}; mod attrset; mod func; @@ -28,7 +28,6 @@ pub use primop::*; pub enum Value<'jit: 'vm, 'vm> { Const(Const), Thunk(Rc>), - ThunkRef(&'vm Thunk<'jit, 'vm>), AttrSet(Rc>), List(Rc>), Catchable(Catchable), @@ -45,7 +44,6 @@ impl Hash for Value<'_, '_> { match self { Const(x) => x.hash(state), Thunk(x) => (x.as_ref() as *const self::Thunk).hash(state), - ThunkRef(x) => (*x as *const self::Thunk).hash(state), AttrSet(x) => (x.as_ref() as *const self::AttrSet).hash(state), List(x) => (x.as_ref() as *const self::List).hash(state), Catchable(x) => x.hash(state), @@ -119,7 +117,6 @@ impl<'v, 'vm: 'v, 'jit: 'vm> Value<'jit, 'vm> { match self { Const(x) => R::Const(x), Thunk(x) => R::Thunk(x), - ThunkRef(x) => R::Thunk(x), AttrSet(x) => R::AttrSet(x), List(x) => R::List(x), Catchable(x) => R::Catchable(x), @@ -136,7 +133,6 @@ impl<'v, 'vm: 'v, 'jit: 'vm> Value<'jit, 'vm> { match self { Const(x) => M::Const(x), Thunk(x) => M::Thunk(x), - ThunkRef(x) => M::Thunk(x), AttrSet(x) => M::AttrSet(Rc::make_mut(x)), List(x) => M::List(Rc::make_mut(x)), Catchable(x) => M::Catchable(x), @@ -163,7 +159,6 @@ impl<'jit, 'vm> Value<'jit, 'vm> { Const(self::Const::String(_)) => "string", Const(self::Const::Null) => "null", Thunk(_) => "thunk", - ThunkRef(_) => "thunk", AttrSet(_) => "set", List(_) => "list", Catchable(_) => unreachable!(), @@ -420,10 +415,7 @@ impl<'jit, 'vm> Value<'jit, 'vm> { pub fn force(&mut self, vm: &'vm VM<'jit>) -> Result<&mut Self> { if let Value::Thunk(thunk) = self { - let value = thunk.force(vm)?; - *self = value - } else if let Value::ThunkRef(thunk) = self { - let value = thunk.force(vm)?; + let value = thunk.force(vm)?.clone(); *self = value } Ok(self) @@ -432,12 +424,7 @@ impl<'jit, 'vm> Value<'jit, 'vm> { pub fn force_deep(&mut self, vm: &'vm VM<'jit>) -> Result<&mut Self> { match self { Value::Thunk(thunk) => { - let mut value = thunk.force(vm)?; - let _ = value.force_deep(vm)?; - *self = value; - } - Value::ThunkRef(thunk) => { - let mut value = thunk.force(vm)?; + let mut value = thunk.force(vm)?.clone(); let _ = value.force_deep(vm)?; *self = value; } @@ -461,7 +448,6 @@ impl<'jit, 'vm> Value<'jit, 'vm> { Catchable(catchable) => Value::Catchable(catchable.clone()), Const(cnst) => Value::Const(cnst.clone()), Thunk(_) => Value::Thunk, - ThunkRef(_) => Value::Thunk, PrimOp(primop) => Value::PrimOp(primop.name), PartialPrimOp(primop) => Value::PartialPrimOp(primop.name), Func(_) => Value::Func, @@ -477,11 +463,17 @@ pub struct Thunk<'jit, 'vm> { #[derive(Debug, IsVariant, Unwrap, Clone)] pub enum _Thunk<'jit, 'vm> { - Code(&'vm OpCodes, OnceCell>>), + Code(&'vm OpCodes, OnceCell>), SuspendedFrom(*const Thunk<'jit, 'vm>), Value(Value<'jit, 'vm>), } +#[derive(Debug, IsVariant, Unwrap, Clone)] +pub enum EnvRef<'jit, 'vm> { + Strong(Rc>), + Weak(Weak>), +} + impl<'jit, 'vm> Thunk<'jit, 'vm> { pub fn new(opcodes: &'vm OpCodes) -> Self { Thunk { @@ -491,32 +483,48 @@ impl<'jit, 'vm> Thunk<'jit, 'vm> { pub fn capture(&self, env: Rc>) { if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() { - envcell.get_or_init(|| env); + envcell.get_or_init(|| EnvRef::Strong(env)); } } - pub fn force(&self, vm: &'vm VM<'jit>) -> Result> { + pub fn capture_weak(&self, env: Weak>) { + if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() { + envcell.get_or_init(|| EnvRef::Weak(env)); + } + } + + pub fn force(&self, vm: &'vm VM<'jit>) -> Result<&Value<'jit, 'vm>> { + use _Thunk::*; match &*self.thunk.borrow() { - _Thunk::Value(value) => return Ok(value.clone()), - _Thunk::SuspendedFrom(from) => { + Value(_) => { + return Ok(match unsafe { &*(&*self.thunk.borrow() as *const _) } { + Value(value) => value, + _ => unreachable!(), + }); + } + SuspendedFrom(from) => { return Err(Error::EvalError(format!( "thunk {:p} already suspended from {from:p} (infinite recursion encountered)", self as *const Thunk ))); } - _Thunk::Code(..) => (), + Code(..) => (), } let (opcodes, env) = std::mem::replace( &mut *self.thunk.borrow_mut(), _Thunk::SuspendedFrom(self as *const Thunk), ) .unwrap_code(); - let value = vm.eval(opcodes.iter().copied(), env.get().unwrap().clone())?; - let _ = std::mem::replace( - &mut *self.thunk.borrow_mut(), - _Thunk::Value(value.clone()), - ); - Ok(value) + let env = match env.get().unwrap() { + EnvRef::Strong(env) => env.clone(), + EnvRef::Weak(env) => env.upgrade().unwrap(), + }; + let value = vm.eval(opcodes.iter().copied(), env)?; + let _ = std::mem::replace(&mut *self.thunk.borrow_mut(), _Thunk::Value(value)); + Ok(match unsafe { &*(&*self.thunk.borrow() as *const _) } { + Value(value) => value, + _ => unreachable!(), + }) } pub fn value(&'vm self) -> Option> { diff --git a/src/vm/env.rs b/src/vm/env.rs index b49b765..a15009b 100644 --- a/src/vm/env.rs +++ b/src/vm/env.rs @@ -1,5 +1,5 @@ -use std::{hash::Hash, rc::Rc}; use std::fmt::Debug; +use std::{hash::Hash, rc::Rc}; use hashbrown::HashMap; @@ -7,21 +7,24 @@ use crate::ty::internal::{AttrSet, Value}; pub struct Env { map: Node, - last: Option>> + last: Option>>, } impl Clone for Env { fn clone(&self) -> Self { Self { map: self.map.clone(), - last: self.last.clone() + last: self.last.clone(), } } } impl Debug for Env { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Env").field("map", &self.map).field("last", &self.last).finish() + f.debug_struct("Env") + .field("map", &self.map) + .field("last", &self.last) + .finish() } } @@ -86,7 +89,8 @@ impl Env { pub fn enter_with(self: Rc, map: Rc>) -> Rc { let map = Node::Let(map); - let last = Some(self);Env { last, map }.into() + let last = Some(self); + Env { last, map }.into() } pub fn leave(self: Rc) -> Rc { diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 0b88556..68354d2 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -243,9 +243,13 @@ impl<'vm, 'jit: 'vm> VM<'jit> { stack.tos_mut()?.force(self)?.has_attr(sym); } OpCode::LookUp { sym } => { - stack.push(env.lookup(&sym).ok_or_else(|| { - Error::EvalError(format!("{} not found", self.get_sym(sym))) - })?.clone())?; + stack.push( + env.lookup(&sym) + .ok_or_else(|| { + Error::EvalError(format!("{} not found", self.get_sym(sym))) + })? + .clone(), + )?; } OpCode::EnterEnv => match stack.pop() { Value::AttrSet(attrs) => *env = env.clone().enter_with(attrs.into_inner()), diff --git a/src/vm/test.rs b/src/vm/test.rs index 5a8e741..66edca7 100644 --- a/src/vm/test.rs +++ b/src/vm/test.rs @@ -228,8 +228,8 @@ fn test_fib() { fn bench_fib(b: &mut Bencher) { b.iter(|| { test_expr( - "let fib = n: if n == 1 || n == 2 then 1 else (fib (n - 1)) + (fib (n - 2)); in fib 20", - int!(6765), + "let fib = n: if n == 1 || n == 2 then 1 else (fib (n - 1)) + (fib (n - 2)); in fib 30", + int!(832040), ); black_box(()) })