feat: no clone in JIT

IMPORTANT: should not drop or create values in JIT anymore
This commit is contained in:
2025-05-21 20:48:56 +08:00
parent 177acfabcf
commit 2a19ddb279
11 changed files with 105 additions and 82 deletions

View File

@@ -5,10 +5,10 @@ use itertools::Itertools;
use nixjit::compile::compile;
use nixjit::error::Error;
use nixjit::error::Result;
use nixjit::ir::downgrade;
use nixjit::jit::JITContext;
use nixjit::vm::run;
use nixjit::error::Result;
fn main() -> Result<()> {
let mut args = std::env::args();
@@ -21,8 +21,8 @@ fn main() -> Result<()> {
let root = rnix::Root::parse(&expr);
if !root.errors().is_empty() {
return Err(Error::ParseError(
root.errors().iter().map(|err| err.to_string()).join(";")
))
root.errors().iter().map(|err| err.to_string()).join(";"),
));
}
let expr = root.tree().expr().unwrap();
let downgraded = downgrade(expr)?;

View File

@@ -3,7 +3,7 @@ use std::rc::Rc;
use crate::ty::common::Const;
use crate::ty::internal::{AttrSet, PrimOp, Value};
use crate::vm::{VmEnv, VM};
use crate::vm::{VM, VmEnv};
pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> VmEnv<'jit, 'vm> {
let mut env_map = HashMap::new();

View File

@@ -9,7 +9,7 @@ use inkwell::values::{BasicValueEnum, FunctionValue};
use crate::jit::JITValueData;
use crate::ty::internal::{Thunk, Value};
use crate::vm::{VmEnv, VM};
use crate::vm::{VM, VmEnv};
use super::{JITValue, ValueTag};
@@ -204,20 +204,17 @@ impl<'ctx> Helpers<'ctx> {
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_debug(value: JITValue) {
dbg!(value.tag);
}
#[unsafe(no_mangle)]
extern "C" fn helper_capture_env(thunk: JITValue, env: *const VmEnv) {
let thunk: &Thunk = unsafe { std::mem::transmute(thunk.data.ptr.as_ref().unwrap()) };
let thunk = unsafe { (thunk.data.ptr as *const Thunk).as_ref().unwrap() };
let env = unsafe { Rc::from_raw(env) };
thunk.capture(env.clone());
std::mem::forget(env);
}
#[unsafe(no_mangle)]
extern "C" fn helper_neg(rhs: JITValue, _env: *const VmEnv) -> JITValue {
use ValueTag::*;
match rhs.tag {
@@ -237,7 +234,6 @@ extern "C" fn helper_neg(rhs: JITValue, _env: *const VmEnv) -> JITValue {
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_not(rhs: JITValue, _env: *const VmEnv) -> JITValue {
use ValueTag::*;
match rhs.tag {
@@ -251,7 +247,6 @@ extern "C" fn helper_not(rhs: JITValue, _env: *const VmEnv) -> JITValue {
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue {
use ValueTag::*;
match (lhs.tag, rhs.tag) {
@@ -269,7 +264,6 @@ extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue {
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_sub(lhs: JITValue, rhs: JITValue) -> JITValue {
use ValueTag::*;
match (lhs.tag, rhs.tag) {
@@ -287,7 +281,6 @@ extern "C" fn helper_sub(lhs: JITValue, rhs: JITValue) -> JITValue {
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_eq(lhs: JITValue, rhs: JITValue) -> JITValue {
use ValueTag::*;
match (lhs.tag, rhs.tag) {
@@ -305,7 +298,6 @@ extern "C" fn helper_eq(lhs: JITValue, rhs: JITValue) -> JITValue {
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_or(lhs: JITValue, rhs: JITValue) -> JITValue {
use ValueTag::*;
match (lhs.tag, rhs.tag) {
@@ -323,7 +315,6 @@ extern "C" fn helper_or(lhs: JITValue, rhs: JITValue) -> JITValue {
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_call<'jit>(
func: JITValue,
args: *mut JITValue,
@@ -346,14 +337,12 @@ extern "C" fn helper_call<'jit>(
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_lookup<'jit, 'vm>(sym: usize, env: *const VmEnv<'jit, 'vm>) -> JITValue {
let env = unsafe { env.as_ref() }.unwrap();
let val = env.lookup(&sym);
val.cloned().unwrap().into()
let val: JITValue = env.lookup(&sym).unwrap().into();
val
}
#[unsafe(no_mangle)]
extern "C" fn helper_force<'jit>(thunk: JITValue, vm: *const VM<'jit>) -> JITValue {
let mut val = Value::from(thunk);
val.force(unsafe { vm.as_ref() }.unwrap()).unwrap();

View File

@@ -12,7 +12,7 @@ use crate::error::*;
use crate::stack::Stack;
use crate::ty::common::Const;
use crate::ty::internal::{Thunk, Value};
use crate::vm::{VmEnv, VM};
use crate::vm::{VM, VmEnv};
mod helpers;
@@ -57,6 +57,12 @@ pub union JITValueData {
impl<'jit: 'vm, 'vm> From<JITValue> for Value<'jit, 'vm> {
fn from(value: JITValue) -> Self {
use ValueTag::*;
match value.tag {
List | AttrSet | String | Function | Thunk | Path => unsafe {
Rc::increment_strong_count(value.data.ptr);
},
_ => (),
}
match value.tag {
Int => Value::Const(Const::Int(unsafe { value.data.int })),
Null => Value::Const(Const::Null),
@@ -67,6 +73,30 @@ impl<'jit: 'vm, 'vm> From<JITValue> for Value<'jit, 'vm> {
}
}
impl From<&Value<'_, '_>> for JITValue {
fn from(value: &Value<'_, '_>) -> Self {
match value {
Value::Const(Const::Int(int)) => JITValue {
tag: ValueTag::Int,
data: JITValueData { int: *int },
},
Value::Func(func) => JITValue {
tag: ValueTag::Function,
data: JITValueData {
ptr: Rc::as_ptr(func) as *const _,
},
},
Value::Thunk(thunk) => JITValue {
tag: ValueTag::Thunk,
data: JITValueData {
ptr: Rc::as_ptr(thunk) as *const _,
},
},
_ => todo!(),
}
}
}
impl From<Value<'_, '_>> for JITValue {
fn from(value: Value<'_, '_>) -> Self {
match value {
@@ -144,14 +174,7 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
let env = func_.get_nth_param(1).unwrap().into_pointer_value();
let entry = self.context.append_basic_block(func_, "entry");
self.builder.position_at_end(entry);
self.build_expr(
&mut iter,
vm,
env,
&mut stack,
func_,
func.opcodes.len(),
)?;
self.build_expr(&mut iter, vm, env, &mut stack, func_, func.opcodes.len())?;
assert_eq!(stack.len(), 1);
let value = stack.pop();

View File

@@ -35,10 +35,7 @@ impl<T, const CAP: usize> Stack<T, CAP> {
pub fn push(&mut self, item: T) -> Result<()> {
self.items
.get_mut(self.top)
.map_or_else(
|| Err(Error::EvalError("stack overflow".to_string())),
Ok,
)?
.map_or_else(|| Err(Error::EvalError("stack overflow".to_string())), Ok)?
.write(item);
self.top += 1;
Ok(())

View File

@@ -6,7 +6,7 @@ use derive_more::Constructor;
use itertools::Itertools;
use crate::error::Result;
use crate::vm::{VmEnv, VM};
use crate::vm::{VM, VmEnv};
use super::super::public as p;
use super::Value;
@@ -42,9 +42,9 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> {
}
pub fn select(&self, sym: usize) -> Option<Value<'jit, 'vm>> {
self.data.get(&sym).cloned().map(|val| match val {
self.data.get(&sym).map(|val| match val {
Value::Builtins(x) => Value::AttrSet(x.upgrade().unwrap()),
val => val,
val => val.clone(),
})
}
@@ -53,8 +53,10 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> {
}
pub fn capture(&mut self, env: &Rc<VmEnv<'jit, 'vm>>) {
self.data.iter().for_each(|(_, v)| if let Value::Thunk(ref thunk) = v.clone() {
thunk.capture(env.clone());
self.data.iter().for_each(|(_, v)| {
if let Value::Thunk(ref thunk) = v.clone() {
thunk.capture(Rc::clone(env));
}
})
}
@@ -77,11 +79,7 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> {
}
pub fn force_deep(&mut self, vm: &'vm VM<'jit>) -> Result<()> {
let mut map: Vec<_> = self
.data
.iter()
.map(|(k, v)| (*k, v.clone()))
.collect();
let mut map: Vec<_> = self.data.iter().map(|(k, v)| (*k, v.clone())).collect();
for (_, v) in map.iter_mut() {
v.force_deep(vm)?;
}

View File

@@ -11,7 +11,7 @@ use crate::error::Result;
use crate::ir;
use crate::jit::JITFunc;
use crate::ty::internal::{Thunk, Value};
use crate::vm::{VmEnv, VM};
use crate::vm::{VM, VmEnv};
#[derive(Debug, Clone)]
pub enum Param {

View File

@@ -11,7 +11,7 @@ use super::public as p;
use crate::bytecode::OpCodes;
use crate::error::*;
use crate::vm::{VmEnv, VM};
use crate::vm::{VM, VmEnv};
mod attrset;
mod func;
@@ -28,7 +28,6 @@ pub use primop::*;
pub enum Value<'jit: 'vm, 'vm> {
Const(Const),
Thunk(Rc<Thunk<'jit, 'vm>>),
ThunkRef(&'vm Thunk<'jit, 'vm>),
AttrSet(Rc<AttrSet<'jit, 'vm>>),
List(Rc<List<'jit, 'vm>>),
Catchable(Catchable),
@@ -45,7 +44,6 @@ impl Hash for Value<'_, '_> {
match self {
Const(x) => x.hash(state),
Thunk(x) => (x.as_ref() as *const self::Thunk).hash(state),
ThunkRef(x) => (*x as *const self::Thunk).hash(state),
AttrSet(x) => (x.as_ref() as *const self::AttrSet).hash(state),
List(x) => (x.as_ref() as *const self::List).hash(state),
Catchable(x) => x.hash(state),
@@ -119,7 +117,6 @@ impl<'v, 'vm: 'v, 'jit: 'vm> Value<'jit, 'vm> {
match self {
Const(x) => R::Const(x),
Thunk(x) => R::Thunk(x),
ThunkRef(x) => R::Thunk(x),
AttrSet(x) => R::AttrSet(x),
List(x) => R::List(x),
Catchable(x) => R::Catchable(x),
@@ -136,7 +133,6 @@ impl<'v, 'vm: 'v, 'jit: 'vm> Value<'jit, 'vm> {
match self {
Const(x) => M::Const(x),
Thunk(x) => M::Thunk(x),
ThunkRef(x) => M::Thunk(x),
AttrSet(x) => M::AttrSet(Rc::make_mut(x)),
List(x) => M::List(Rc::make_mut(x)),
Catchable(x) => M::Catchable(x),
@@ -163,7 +159,6 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
Const(self::Const::String(_)) => "string",
Const(self::Const::Null) => "null",
Thunk(_) => "thunk",
ThunkRef(_) => "thunk",
AttrSet(_) => "set",
List(_) => "list",
Catchable(_) => unreachable!(),
@@ -420,10 +415,7 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
pub fn force(&mut self, vm: &'vm VM<'jit>) -> Result<&mut Self> {
if let Value::Thunk(thunk) = self {
let value = thunk.force(vm)?;
*self = value
} else if let Value::ThunkRef(thunk) = self {
let value = thunk.force(vm)?;
let value = thunk.force(vm)?.clone();
*self = value
}
Ok(self)
@@ -432,12 +424,7 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
pub fn force_deep(&mut self, vm: &'vm VM<'jit>) -> Result<&mut Self> {
match self {
Value::Thunk(thunk) => {
let mut value = thunk.force(vm)?;
let _ = value.force_deep(vm)?;
*self = value;
}
Value::ThunkRef(thunk) => {
let mut value = thunk.force(vm)?;
let mut value = thunk.force(vm)?.clone();
let _ = value.force_deep(vm)?;
*self = value;
}
@@ -461,7 +448,6 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
Catchable(catchable) => Value::Catchable(catchable.clone()),
Const(cnst) => Value::Const(cnst.clone()),
Thunk(_) => Value::Thunk,
ThunkRef(_) => Value::Thunk,
PrimOp(primop) => Value::PrimOp(primop.name),
PartialPrimOp(primop) => Value::PartialPrimOp(primop.name),
Func(_) => Value::Func,
@@ -477,11 +463,17 @@ pub struct Thunk<'jit, 'vm> {
#[derive(Debug, IsVariant, Unwrap, Clone)]
pub enum _Thunk<'jit, 'vm> {
Code(&'vm OpCodes, OnceCell<Rc<VmEnv<'jit, 'vm>>>),
Code(&'vm OpCodes, OnceCell<EnvRef<'jit, 'vm>>),
SuspendedFrom(*const Thunk<'jit, 'vm>),
Value(Value<'jit, 'vm>),
}
#[derive(Debug, IsVariant, Unwrap, Clone)]
pub enum EnvRef<'jit, 'vm> {
Strong(Rc<VmEnv<'jit, 'vm>>),
Weak(Weak<VmEnv<'jit, 'vm>>),
}
impl<'jit, 'vm> Thunk<'jit, 'vm> {
pub fn new(opcodes: &'vm OpCodes) -> Self {
Thunk {
@@ -491,32 +483,48 @@ impl<'jit, 'vm> Thunk<'jit, 'vm> {
pub fn capture(&self, env: Rc<VmEnv<'jit, 'vm>>) {
if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() {
envcell.get_or_init(|| env);
envcell.get_or_init(|| EnvRef::Strong(env));
}
}
pub fn force(&self, vm: &'vm VM<'jit>) -> Result<Value<'jit, 'vm>> {
pub fn capture_weak(&self, env: Weak<VmEnv<'jit, 'vm>>) {
if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() {
envcell.get_or_init(|| EnvRef::Weak(env));
}
}
pub fn force(&self, vm: &'vm VM<'jit>) -> Result<&Value<'jit, 'vm>> {
use _Thunk::*;
match &*self.thunk.borrow() {
_Thunk::Value(value) => return Ok(value.clone()),
_Thunk::SuspendedFrom(from) => {
Value(_) => {
return Ok(match unsafe { &*(&*self.thunk.borrow() as *const _) } {
Value(value) => value,
_ => unreachable!(),
});
}
SuspendedFrom(from) => {
return Err(Error::EvalError(format!(
"thunk {:p} already suspended from {from:p} (infinite recursion encountered)",
self as *const Thunk
)));
}
_Thunk::Code(..) => (),
Code(..) => (),
}
let (opcodes, env) = std::mem::replace(
&mut *self.thunk.borrow_mut(),
_Thunk::SuspendedFrom(self as *const Thunk),
)
.unwrap_code();
let value = vm.eval(opcodes.iter().copied(), env.get().unwrap().clone())?;
let _ = std::mem::replace(
&mut *self.thunk.borrow_mut(),
_Thunk::Value(value.clone()),
);
Ok(value)
let env = match env.get().unwrap() {
EnvRef::Strong(env) => env.clone(),
EnvRef::Weak(env) => env.upgrade().unwrap(),
};
let value = vm.eval(opcodes.iter().copied(), env)?;
let _ = std::mem::replace(&mut *self.thunk.borrow_mut(), _Thunk::Value(value));
Ok(match unsafe { &*(&*self.thunk.borrow() as *const _) } {
Value(value) => value,
_ => unreachable!(),
})
}
pub fn value(&'vm self) -> Option<Value<'jit, 'vm>> {

View File

@@ -1,5 +1,5 @@
use std::{hash::Hash, rc::Rc};
use std::fmt::Debug;
use std::{hash::Hash, rc::Rc};
use hashbrown::HashMap;
@@ -7,21 +7,24 @@ use crate::ty::internal::{AttrSet, Value};
pub struct Env<K: Hash + Eq, V> {
map: Node<K, V>,
last: Option<Rc<Env<K, V>>>
last: Option<Rc<Env<K, V>>>,
}
impl<K: Clone + Hash + Eq, V: Clone> Clone for Env<K, V> {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
last: self.last.clone()
last: self.last.clone(),
}
}
}
impl<K: Debug + Hash + Eq, V: Debug> Debug for Env<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Env").field("map", &self.map).field("last", &self.last).finish()
f.debug_struct("Env")
.field("map", &self.map)
.field("last", &self.last)
.finish()
}
}
@@ -86,7 +89,8 @@ impl<K: Hash + Eq, V> Env<K, V> {
pub fn enter_with(self: Rc<Self>, map: Rc<HashMap<K, V>>) -> Rc<Self> {
let map = Node::Let(map);
let last = Some(self);Env { last, map }.into()
let last = Some(self);
Env { last, map }.into()
}
pub fn leave(self: Rc<Self>) -> Rc<Self> {

View File

@@ -243,9 +243,13 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
stack.tos_mut()?.force(self)?.has_attr(sym);
}
OpCode::LookUp { sym } => {
stack.push(env.lookup(&sym).ok_or_else(|| {
Error::EvalError(format!("{} not found", self.get_sym(sym)))
})?.clone())?;
stack.push(
env.lookup(&sym)
.ok_or_else(|| {
Error::EvalError(format!("{} not found", self.get_sym(sym)))
})?
.clone(),
)?;
}
OpCode::EnterEnv => match stack.pop() {
Value::AttrSet(attrs) => *env = env.clone().enter_with(attrs.into_inner()),

View File

@@ -228,8 +228,8 @@ fn test_fib() {
fn bench_fib(b: &mut Bencher) {
b.iter(|| {
test_expr(
"let fib = n: if n == 1 || n == 2 then 1 else (fib (n - 1)) + (fib (n - 2)); in fib 20",
int!(6765),
"let fib = n: if n == 1 || n == 2 then 1 else (fib (n - 1)) + (fib (n - 2)); in fib 30",
int!(832040),
);
black_box(())
})