feat(jit): fib!

This commit is contained in:
2025-05-19 19:29:25 +08:00
parent 6d26716412
commit 9e172bf013
10 changed files with 363 additions and 166 deletions

View File

@@ -3,12 +3,12 @@ use std::rc::Rc;
use crate::ty::common::Const; use crate::ty::common::Const;
use crate::ty::internal::{AttrSet, PrimOp, Value}; use crate::ty::internal::{AttrSet, PrimOp, Value};
use crate::vm::{Env, VM}; use crate::vm::{LetEnv, VM};
pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> Env<'jit, 'vm> { pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> LetEnv<'jit, 'vm> {
let mut env = Env::empty(); let mut env_map = HashMap::new();
env.insert(vm.new_sym("true"), Value::Const(Const::Bool(true))); env_map.insert(vm.new_sym("true"), Value::Const(Const::Bool(true)));
env.insert(vm.new_sym("false"), Value::Const(Const::Bool(false))); env_map.insert(vm.new_sym("false"), Value::Const(Const::Bool(false)));
let primops = [ let primops = [
PrimOp::new("add", 2, |_, args| { PrimOp::new("add", 2, |_, args| {
@@ -46,18 +46,19 @@ pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> Env<'jit, 'vm> {
let mut map = HashMap::new(); let mut map = HashMap::new();
for primop in primops { for primop in primops {
let primop = Rc::new(primop); let primop = Rc::new(primop);
env.insert( env_map.insert(
vm.new_sym(format!("__{}", primop.name)), vm.new_sym(format!("__{}", primop.name)),
Value::PrimOp(primop.clone()), Value::PrimOp(primop.clone()),
); );
map.insert(vm.new_sym(primop.name), Value::PrimOp(primop)); map.insert(vm.new_sym(primop.name), Value::PrimOp(primop));
} }
let sym = vm.new_sym("builtins");
let attrs = Rc::new_cyclic(|weak| { let attrs = Rc::new_cyclic(|weak| {
map.insert(vm.new_sym("builtins"), Value::Builtins(weak.clone())); map.insert(sym, Value::Builtins(weak.clone()));
AttrSet::from_inner(map) AttrSet::from_inner(map)
}); });
let builtins = Value::AttrSet(attrs); let builtins = Value::AttrSet(attrs);
env.insert(vm.new_sym("builtins"), builtins); env_map.insert(sym, builtins);
env LetEnv::new(env_map.into())
} }

View File

@@ -33,8 +33,6 @@ pub enum OpCode {
/// jump forward /// jump forward
Jmp { step: usize }, Jmp { step: usize },
/// [ .. cond ] consume 1 element, if `cond`` is true, then jump forward
JmpIfTrue { step: usize },
/// [ .. cond ] consume 1 element, if `cond` is false, then jump forward /// [ .. cond ] consume 1 element, if `cond` is false, then jump forward
JmpIfFalse { step: usize }, JmpIfFalse { step: usize },

View File

@@ -1,5 +1,3 @@
use std::rc::Rc;
use inkwell::AddressSpace; use inkwell::AddressSpace;
use inkwell::context::Context; use inkwell::context::Context;
use inkwell::execution_engine::ExecutionEngine; use inkwell::execution_engine::ExecutionEngine;
@@ -8,8 +6,8 @@ use inkwell::types::{FloatType, FunctionType, IntType, PointerType, StructType};
use inkwell::values::{BasicValueEnum, FunctionValue}; use inkwell::values::{BasicValueEnum, FunctionValue};
use crate::jit::JITValueData; use crate::jit::JITValueData;
use crate::ty::internal::Thunk; use crate::ty::internal::{Thunk, Value};
use crate::vm::{Env, VM}; use crate::vm::{LetEnv, VM};
use super::{JITFunc, JITValue, ValueTag}; use super::{JITFunc, JITValue, ValueTag};
@@ -25,10 +23,14 @@ pub struct Helpers<'ctx> {
pub debug: FunctionValue<'ctx>, pub debug: FunctionValue<'ctx>,
pub capture_env: FunctionValue<'ctx>, pub capture_env: FunctionValue<'ctx>,
pub neg: FunctionValue<'ctx>, pub neg: FunctionValue<'ctx>,
pub not: FunctionValue<'ctx>,
pub add: FunctionValue<'ctx>, pub add: FunctionValue<'ctx>,
pub sub: FunctionValue<'ctx>, pub sub: FunctionValue<'ctx>,
pub eq: FunctionValue<'ctx>,
pub or: FunctionValue<'ctx>,
pub call: FunctionValue<'ctx>, pub call: FunctionValue<'ctx>,
pub lookup: FunctionValue<'ctx>, pub lookup: FunctionValue<'ctx>,
pub force: FunctionValue<'ctx>,
} }
impl<'ctx> Helpers<'ctx> { impl<'ctx> Helpers<'ctx> {
@@ -61,6 +63,11 @@ impl<'ctx> Helpers<'ctx> {
value_type.fn_type(&[value_type.into(), ptr_type.into()], false), value_type.fn_type(&[value_type.into(), ptr_type.into()], false),
None, None,
); );
let not = module.add_function(
"not",
value_type.fn_type(&[value_type.into(), ptr_type.into()], false),
None,
);
let add = module.add_function( let add = module.add_function(
"add", "add",
value_type.fn_type(&[value_type.into(), value_type.into()], false), value_type.fn_type(&[value_type.into(), value_type.into()], false),
@@ -71,15 +78,23 @@ impl<'ctx> Helpers<'ctx> {
value_type.fn_type(&[value_type.into(), value_type.into()], false), value_type.fn_type(&[value_type.into(), value_type.into()], false),
None, None,
); );
// Assuming a single argument for now based on the test case let eq = module.add_function(
"eq",
value_type.fn_type(&[value_type.into(), value_type.into()], false),
None,
);
let or = module.add_function(
"or",
value_type.fn_type(&[value_type.into(), value_type.into()], false),
None,
);
let call = module.add_function( let call = module.add_function(
"call", "call",
value_type.fn_type( value_type.fn_type(
&[ &[
ptr_type.into(),
ptr_type.into(),
ptr_type.into(),
value_type.into(), value_type.into(),
value_type.into(),
ptr_type.into(),
], ],
false, false,
), ),
@@ -90,14 +105,23 @@ impl<'ctx> Helpers<'ctx> {
value_type.fn_type(&[ptr_int_type.into(), ptr_type.into()], false), value_type.fn_type(&[ptr_int_type.into(), ptr_type.into()], false),
None, None,
); );
let force = module.add_function(
"force",
value_type.fn_type(&[value_type.into(), ptr_type.into()], false),
None,
);
execution_engine.add_global_mapping(&debug, helper_debug as _); execution_engine.add_global_mapping(&debug, helper_debug as _);
execution_engine.add_global_mapping(&capture_env, helper_capture_env as _); execution_engine.add_global_mapping(&capture_env, helper_capture_env as _);
execution_engine.add_global_mapping(&neg, helper_neg as _); execution_engine.add_global_mapping(&neg, helper_neg as _);
execution_engine.add_global_mapping(&not, helper_not as _);
execution_engine.add_global_mapping(&add, helper_add as _); execution_engine.add_global_mapping(&add, helper_add as _);
execution_engine.add_global_mapping(&sub, helper_sub as _); execution_engine.add_global_mapping(&sub, helper_sub as _);
execution_engine.add_global_mapping(&eq, helper_eq as _);
execution_engine.add_global_mapping(&or, helper_or as _);
execution_engine.add_global_mapping(&call, helper_call as _); execution_engine.add_global_mapping(&call, helper_call as _);
execution_engine.add_global_mapping(&lookup, helper_lookup as _); execution_engine.add_global_mapping(&lookup, helper_lookup as _);
execution_engine.add_global_mapping(&force, helper_force as _);
Helpers { Helpers {
int_type, int_type,
@@ -111,10 +135,14 @@ impl<'ctx> Helpers<'ctx> {
debug, debug,
capture_env, capture_env,
neg, neg,
not,
add, add,
sub, sub,
eq,
or,
call, call,
lookup, lookup,
force
} }
} }
@@ -179,13 +207,13 @@ extern "C" fn helper_debug(value: JITValue) {
} }
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn helper_capture_env(thunk: JITValue, env: *const Env) { extern "C" fn helper_capture_env(thunk: JITValue, env: *const LetEnv) {
let thunk: &Thunk = unsafe { std::mem::transmute(thunk.data.ptr.as_ref().unwrap()) }; let thunk: &Thunk = unsafe { std::mem::transmute(thunk.data.ptr.as_ref().unwrap()) };
thunk.capture(unsafe { env.as_ref().unwrap() }.clone()); thunk.capture(unsafe { env.as_ref().unwrap() }.clone());
} }
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn helper_neg(rhs: JITValue, _env: *const Env) -> JITValue { extern "C" fn helper_neg(rhs: JITValue, _env: *const LetEnv) -> JITValue {
use ValueTag::*; use ValueTag::*;
match rhs.tag { match rhs.tag {
Int => JITValue { Int => JITValue {
@@ -198,6 +226,20 @@ extern "C" fn helper_neg(rhs: JITValue, _env: *const Env) -> JITValue {
} }
} }
#[unsafe(no_mangle)]
extern "C" fn helper_not(rhs: JITValue, _env: *const LetEnv) -> JITValue {
use ValueTag::*;
match rhs.tag {
Bool => JITValue {
tag: Bool,
data: JITValueData {
bool: !unsafe { rhs.data.bool },
},
},
_ => todo!(),
}
}
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue { extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue {
use ValueTag::*; use ValueTag::*;
@@ -235,20 +277,68 @@ extern "C" fn helper_sub(lhs: JITValue, rhs: JITValue) -> JITValue {
} }
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn helper_call<'jit, 'vm>( extern "C" fn helper_eq(lhs: JITValue, rhs: JITValue) -> JITValue {
vm: *const VM<'jit>, use ValueTag::*;
env: *const Env<'jit, 'vm>, match (lhs.tag, rhs.tag) {
func_ptr: *const (), (Int, Int) => JITValue {
arg: JITValue, tag: Bool,
) -> JITValue { data: JITValueData {
let func: JITFunc = unsafe { std::mem::transmute(func_ptr) }; bool: unsafe { lhs.data.int == rhs.data.int }
todo!(); },
unsafe { func(vm, env) } },
_ => todo!(
"Equation not implemented for {:?} and {:?}",
lhs.tag,
rhs.tag
),
}
} }
#[unsafe(no_mangle)] #[unsafe(no_mangle)]
extern "C" fn helper_lookup<'jit, 'vm>(sym: usize, env: *const Env<'jit, 'vm>) -> JITValue { extern "C" fn helper_or(lhs: JITValue, rhs: JITValue) -> JITValue {
use ValueTag::*;
match (lhs.tag, rhs.tag) {
(Bool, Bool) => JITValue {
tag: Bool,
data: JITValueData {
bool: unsafe { lhs.data.bool || rhs.data.bool },
},
},
_ => todo!(
"Substruction not implemented for {:?} and {:?}",
lhs.tag,
rhs.tag
),
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_call<'jit, 'vm>(
func: JITValue,
arg: JITValue,
vm: *const VM<'jit>,
) -> JITValue {
use ValueTag::*;
match func.tag {
Function => {
let func: Value = func.into();
func.call(unsafe { vm.as_ref() }.unwrap(), vec![arg.into()]).unwrap().into()
}
_ => todo!(),
}
}
#[unsafe(no_mangle)]
extern "C" fn helper_lookup<'jit, 'vm>(sym: usize, env: *const LetEnv<'jit, 'vm>) -> JITValue {
let env = unsafe { env.as_ref() }.unwrap(); let env = unsafe { env.as_ref() }.unwrap();
let val = env.lookup(sym); let val = env.lookup(sym);
dbg!(val.as_ref().unwrap().typename());
val.unwrap().into() val.unwrap().into()
} }
#[unsafe(no_mangle)]
extern "C" fn helper_force<'jit, 'vm>(thunk: JITValue, vm: *const VM<'jit>) -> JITValue {
let mut val = Value::from(thunk);
val.force(unsafe { vm.as_ref() }.unwrap()).unwrap();
val.into()
}

View File

@@ -1,18 +1,19 @@
use std::rc::Rc; use std::rc::Rc;
use inkwell::OptimizationLevel; use inkwell::OptimizationLevel;
use inkwell::basic_block::BasicBlock;
use inkwell::builder::Builder; use inkwell::builder::Builder;
use inkwell::context::Context; use inkwell::context::Context;
use inkwell::execution_engine::{ExecutionEngine, JitFunction}; use inkwell::execution_engine::{ExecutionEngine, JitFunction};
use inkwell::module::Module; use inkwell::module::Module;
use inkwell::values::{BasicValueEnum, PointerValue}; use inkwell::values::{BasicValueEnum, FunctionValue, PointerValue};
use crate::bytecode::{Func, OpCode, UnOp}; use crate::bytecode::{Func, OpCode, UnOp};
use crate::error::*; use crate::error::*;
use crate::stack::Stack; use crate::stack::Stack;
use crate::ty::common::Const; use crate::ty::common::Const;
use crate::ty::internal::{Thunk, Value}; use crate::ty::internal::{Thunk, Value};
use crate::vm::{Env, VM}; use crate::vm::{LetEnv, VM};
mod helpers; mod helpers;
@@ -48,17 +49,19 @@ pub struct JITValue {
pub union JITValueData { pub union JITValueData {
int: i64, int: i64,
float: f64, float: f64,
boolean: bool, bool: bool,
ptr: *const (), ptr: *const (),
} }
impl<'jit: 'vm, 'vm> Into<Value<'jit, 'vm>> for JITValue { impl<'jit: 'vm, 'vm> From<JITValue> for Value<'jit, 'vm> {
fn into(self) -> Value<'jit, 'vm> { fn from(value: JITValue) -> Self {
use ValueTag::*; use ValueTag::*;
match self.tag { match value.tag {
Int => Value::Const(Const::Int(unsafe { self.data.int })), Int => Value::Const(Const::Int(unsafe { value.data.int })),
Null => Value::Const(Const::Null), Null => Value::Const(Const::Null),
_ => todo!("not implemented for {:?}", self.tag), Function => Value::Func(unsafe { Rc::from_raw(value.data.ptr as *const _) }),
Thunk => Value::Thunk(unsafe { Rc::from_raw(value.data.ptr as *const _) }),
_ => todo!("not implemented for {:?}", value.tag),
} }
} }
} }
@@ -70,13 +73,25 @@ impl From<Value<'_, '_>> for JITValue {
tag: ValueTag::Int, tag: ValueTag::Int,
data: JITValueData { int }, data: JITValueData { int },
}, },
Value::Func(func) => JITValue {
tag: ValueTag::Function,
data: JITValueData {
ptr: Rc::into_raw(func) as *const _
}
},
Value::Thunk(thunk) => JITValue {
tag: ValueTag::Thunk,
data: JITValueData {
ptr: Rc::into_raw(thunk) as *const _
}
},
_ => todo!(), _ => todo!(),
} }
} }
} }
pub type JITFunc<'jit, 'vm> = pub type JITFunc<'jit, 'vm> =
unsafe extern "C" fn(*const VM<'jit>, *const Env<'jit, 'vm>) -> JITValue; unsafe extern "C" fn(*const VM<'jit>, *const LetEnv<'jit, 'vm>) -> JITValue;
pub struct JITContext<'ctx> { pub struct JITContext<'ctx> {
context: &'ctx Context, context: &'ctx Context,
@@ -125,17 +140,25 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
let func_ = self let func_ = self
.module .module
.add_function("nixjit_function", self.helpers.func_type, None); .add_function("nixjit_function", self.helpers.func_type, None);
let entry = self.context.append_basic_block(func_, "entry");
self.builder.position_at_end(entry);
let env = func_.get_nth_param(1).unwrap().into_pointer_value(); let env = func_.get_nth_param(1).unwrap().into_pointer_value();
while let Some(opcode) = iter.next() { let entry = self.context.append_basic_block(func_, "entry");
self.single_op(opcode, vm, env, &mut stack)?; self.build_expr(
} &mut iter,
vm,
env,
&mut stack,
func_,
entry,
func.opcodes.len(),
)?;
assert_eq!(stack.len(), 1); assert_eq!(stack.len(), 1);
let value = stack.pop(); let value = stack.pop();
let exit = self.context.append_basic_block(func_, "exit");
self.builder.build_unconditional_branch(exit)?;
self.builder.position_at_end(exit);
self.builder.build_return(Some(&value))?; self.builder.build_return(Some(&value))?;
if func_.verify(false) { if func_.verify(true) {
func_.print_to_stderr();
unsafe { unsafe {
let name = func_.get_name().to_str().unwrap(); let name = func_.get_name().to_str().unwrap();
let addr = self.execution_engine.get_function(name).unwrap(); let addr = self.execution_engine.get_function(name).unwrap();
@@ -146,13 +169,82 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
} }
} }
fn build_expr<const CAP: usize>(
&self,
iter: &mut impl Iterator<Item = OpCode>,
vm: &'vm VM<'_>,
env: PointerValue<'ctx>,
stack: &mut Stack<BasicValueEnum<'ctx>, CAP>,
func: FunctionValue<'ctx>,
bb: BasicBlock<'ctx>,
mut length: usize,
) -> Result<usize> {
self.builder.position_at_end(bb);
while length > 1 {
let opcode = iter.next().unwrap();
let br = self.single_op(opcode, vm, env, stack)?;
length -= 1;
if br > 0 {
let consq = self.context.append_basic_block(func, "consq");
let alter = self.context.append_basic_block(func, "alter");
let cont = self.context.append_basic_block(func, "cont");
let cond = self
.builder
.build_alloca(self.helpers.value_type, "cond_alloca")?;
let result = self
.builder
.build_alloca(self.helpers.value_type, "result_alloca")?;
self.builder.build_store(cond, stack.pop())?;
self.builder.build_conditional_branch(
self.builder
.build_load(
self.context.bool_type(),
self.builder.build_struct_gep(
self.helpers.value_type,
cond,
1,
"gep_cond",
)?,
"load_cond",
)?
.into_int_value(),
consq,
alter,
)?;
length -= br;
let br = self.build_expr(iter, vm, env, stack, func, consq, br)?;
self.builder.build_store(result, stack.pop())?;
self.builder.build_unconditional_branch(cont)?;
length -= br;
self.build_expr(iter, vm, env, stack, func, alter, br)?;
self.builder.build_store(result, stack.pop())?;
self.builder.build_unconditional_branch(cont)?;
self.builder.position_at_end(cont);
stack.push(self.builder.build_load(
self.helpers.value_type,
result,
"load_result",
)?)?;
}
}
if length > 0 {
self.single_op(iter.next().unwrap(), vm, env, stack)
} else {
Ok(0)
}
}
#[inline(always)]
fn single_op<const CAP: usize>( fn single_op<const CAP: usize>(
&self, &self,
opcode: OpCode, opcode: OpCode,
vm: &'vm VM<'_>, vm: &'vm VM<'_>,
env: PointerValue<'ctx>, env: PointerValue<'ctx>,
stack: &mut Stack<BasicValueEnum<'ctx>, CAP>, stack: &mut Stack<BasicValueEnum<'ctx>, CAP>,
) -> Result<()> { ) -> Result<usize> {
match opcode { match opcode {
OpCode::Const { idx } => { OpCode::Const { idx } => {
use Const::*; use Const::*;
@@ -186,53 +278,13 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
.try_as_basic_value() .try_as_basic_value()
.left() .left()
.unwrap(), .unwrap(),
_ => todo!(), Not => self
})?
}
OpCode::Func { idx } => {
let func = vm.get_func(idx);
let jit_func_ptr = self.compile_function(func, vm)?;
let jit_value = self
.helpers
.value_type
.const_named_struct(&[
self.helpers
.int_type
.const_int(ValueTag::Function as _, false)
.into(),
self.helpers
.ptr_int_type
.const_int(unsafe { jit_func_ptr.as_raw() } as _, false)
.into(),
])
.into();
stack.push(jit_value)?;
}
OpCode::Call { arity } => {
// Assuming arity is 1 for the test case
assert_eq!(arity, 1);
let arg = stack.pop();
let func_value = stack.pop();
let func_ptr = self
.builder .builder
.build_extract_value(func_value.into_struct_value(), 1, "func_ptr")? .build_direct_call(self.helpers.not, &[rhs.into(), env.into()], "call_neg")?
.into_pointer_value();
let result = self
.builder
.build_direct_call(
self.helpers.call,
&[
self.new_ptr(vm).into(),
env.into(),
func_ptr.into(),
arg.into(),
],
"call_func",
)?
.try_as_basic_value() .try_as_basic_value()
.left() .left()
.unwrap(); .unwrap(),
stack.push(result)?; })?
} }
OpCode::BinOp { op } => { OpCode::BinOp { op } => {
use crate::bytecode::BinOp; use crate::bytecode::BinOp;
@@ -269,9 +321,22 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
let result = self let result = self
.builder .builder
.build_direct_call( .build_direct_call(
self.helpers.add, self.helpers.eq,
&[lhs.into(), rhs.into()], &[lhs.into(), rhs.into()],
"call_add", "call_eq",
)?
.try_as_basic_value()
.left()
.unwrap();
stack.push(result)?;
}
BinOp::Or => {
let result = self
.builder
.build_direct_call(
self.helpers.or,
&[lhs.into(), rhs.into()],
"call_or",
)? )?
.try_as_basic_value() .try_as_basic_value()
.left() .left()
@@ -299,8 +364,29 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
.left() .left()
.unwrap(), .unwrap(),
)?, )?,
OpCode::Call { arity } => {
// TODO:
assert_eq!(arity, 1);
let mut args = Vec::with_capacity(arity);
for _ in 0..arity {
args.insert(0, stack.pop());
}
let func = self.builder
.build_direct_call(self.helpers.force, &[stack.pop().into(), self.new_ptr(vm).into()], "force")?
.try_as_basic_value()
.left()
.unwrap();
let ret = self.builder
.build_direct_call(self.helpers.call, &[func.into(), args[0].into(), self.new_ptr(vm).into()], "call")?
.try_as_basic_value()
.left()
.unwrap();
stack.push(ret)?;
}
OpCode::JmpIfFalse { step } => return Ok(step),
OpCode::Jmp { step } => return Ok(step),
_ => todo!("{opcode:?} not implemented"), _ => todo!("{opcode:?} not implemented"),
} }
Ok(()) Ok(0)
} }
} }

View File

@@ -45,8 +45,8 @@ impl<T, const CAP: usize> Stack<T, CAP> {
} }
pub fn pop(&mut self) -> T { pub fn pop(&mut self) -> T {
let item = self.items.get_mut(self.top - 1).unwrap();
self.top -= 1; self.top -= 1;
let item = self.items.get_mut(self.top).unwrap();
// SAFETY: `item` at `self.top` was previously written and is initialized. // SAFETY: `item` at `self.top` was previously written and is initialized.
// We replace it with `MaybeUninit::uninit()` and then `assume_init` // We replace it with `MaybeUninit::uninit()` and then `assume_init`

View File

@@ -4,24 +4,24 @@ use derive_more::Constructor;
use itertools::Itertools; use itertools::Itertools;
use crate::error::Result; use crate::error::Result;
use crate::vm::{Env, VM}; use crate::vm::{LetEnv, VM};
use super::super::public as p; use super::super::public as p;
use super::Value; use super::Value;
#[repr(C)] #[repr(C)]
#[derive(Debug, Constructor, Clone, PartialEq)] #[derive(Debug, Default, Constructor, Clone, PartialEq)]
pub struct AttrSet<'jit: 'vm, 'vm> { pub struct AttrSet<'jit: 'vm, 'vm> {
data: HashMap<usize, Value<'jit, 'vm>>, data: HashMap<usize, Value<'jit, 'vm>>,
} }
impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> { impl<'jit, 'vm> From<HashMap<usize, Value<'jit, 'vm>>> for AttrSet<'jit, 'vm> {
pub fn empty() -> Self { fn from(data: HashMap<usize, Value<'jit, 'vm>>) -> Self {
AttrSet { Self { data }
data: HashMap::new(),
}
} }
}
impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> {
pub fn with_capacity(cap: usize) -> Self { pub fn with_capacity(cap: usize) -> Self {
AttrSet { AttrSet {
data: HashMap::with_capacity(cap), data: HashMap::with_capacity(cap),
@@ -47,7 +47,7 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> {
self.data.get(&sym).is_some() self.data.get(&sym).is_some()
} }
pub fn capture(&mut self, env: &Env<'jit, 'vm>) { pub fn capture(&mut self, env: &LetEnv<'jit, 'vm>) {
self.data.iter().for_each(|(_, v)| match v.clone() { self.data.iter().for_each(|(_, v)| match v.clone() {
Value::Thunk(ref thunk) => { Value::Thunk(ref thunk) => {
thunk.capture(env.clone()); thunk.capture(env.clone());

View File

@@ -9,7 +9,7 @@ use crate::error::Result;
use crate::ir; use crate::ir;
use crate::jit::JITFunc; use crate::jit::JITFunc;
use crate::ty::internal::{Thunk, Value}; use crate::ty::internal::{Thunk, Value};
use crate::vm::{Env, VM}; use crate::vm::{LetEnv, VM};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Param { pub enum Param {
@@ -44,7 +44,7 @@ impl From<ir::Param> for Param {
#[derive(Debug, Clone, Constructor)] #[derive(Debug, Clone, Constructor)]
pub struct Func<'jit: 'vm, 'vm> { pub struct Func<'jit: 'vm, 'vm> {
pub func: &'vm BFunc, pub func: &'vm BFunc,
pub env: Env<'jit, 'vm>, pub env: LetEnv<'jit, 'vm>,
pub compiled: OnceCell<JitFunction<'jit, JITFunc<'jit, 'vm>>>, pub compiled: OnceCell<JitFunction<'jit, JITFunc<'jit, 'vm>>>,
pub count: Cell<usize>, pub count: Cell<usize>,
} }
@@ -54,7 +54,7 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> {
use Param::*; use Param::*;
let env = match self.func.param.clone() { let env = match self.func.param.clone() {
Ident(ident) => self.env.clone().enter([(ident.into(), arg)].into_iter()), Ident(ident) => self.env.clone().enter_let([(ident.into(), arg)].into_iter()),
Formals { Formals {
formals, formals,
ellipsis, ellipsis,
@@ -85,7 +85,7 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> {
if let Some(alias) = alias { if let Some(alias) = alias {
new.push((alias.clone().into(), Value::AttrSet(arg))); new.push((alias.clone().into(), Value::AttrSet(arg)));
} }
self.env.clone().enter(new.into_iter()) self.env.clone().enter_let(new.into_iter())
} }
}; };
@@ -93,7 +93,7 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> {
self.count.replace(count + 1); self.count.replace(count + 1);
if count >= 1 { if count >= 1 {
let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func)); let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func));
let ret = unsafe { compiled.call(vm as *const VM, &env as *const Env) }; let ret = unsafe { compiled.call(vm as *const VM, &env as *const LetEnv) };
return Ok(ret.into()); return Ok(ret.into());
} }
vm.eval(self.func.opcodes.iter().copied(), env) vm.eval(self.func.opcodes.iter().copied(), env)

View File

@@ -11,7 +11,7 @@ use super::public as p;
use crate::bytecode::OpCodes; use crate::bytecode::OpCodes;
use crate::error::*; use crate::error::*;
use crate::vm::{Env, VM}; use crate::vm::{LetEnv, VM};
mod attrset; mod attrset;
mod func; mod func;
@@ -157,7 +157,11 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
pub fn typename(&self) -> &'static str { pub fn typename(&self) -> &'static str {
use Value::*; use Value::*;
match self { match self {
Const(_) => unreachable!(), Const(self::Const::Int(_)) => "int",
Const(self::Const::Float(_)) => "float",
Const(self::Const::Bool(_)) => "bool",
Const(self::Const::String(_)) => "string",
Const(self::Const::Null) => "null",
Thunk(_) => "thunk", Thunk(_) => "thunk",
ThunkRef(_) => "thunk", ThunkRef(_) => "thunk",
AttrSet(_) => "set", AttrSet(_) => "set",
@@ -481,7 +485,7 @@ pub struct Thunk<'jit, 'vm> {
#[derive(Debug, IsVariant, Unwrap, Clone)] #[derive(Debug, IsVariant, Unwrap, Clone)]
pub enum _Thunk<'jit, 'vm> { pub enum _Thunk<'jit, 'vm> {
Code(&'vm OpCodes, OnceCell<Env<'jit, 'vm>>), Code(&'vm OpCodes, OnceCell<LetEnv<'jit, 'vm>>),
SuspendedFrom(*const Thunk<'jit, 'vm>), SuspendedFrom(*const Thunk<'jit, 'vm>),
Value(Value<'jit, 'vm>), Value(Value<'jit, 'vm>),
} }
@@ -493,7 +497,7 @@ impl<'jit, 'vm> Thunk<'jit, 'vm> {
} }
} }
pub fn capture(&self, env: Env<'jit, 'vm>) { pub fn capture(&self, env: LetEnv<'jit, 'vm>) {
if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() { if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() {
envcell.get_or_init(|| env); envcell.get_or_init(|| env);
} }

View File

@@ -4,14 +4,20 @@ use std::rc::Rc;
use crate::ty::internal::{AttrSet, Value}; use crate::ty::internal::{AttrSet, Value};
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
pub struct Env<'jit, 'vm> { pub struct LetEnv<'jit, 'vm> {
pub map: Rc<HashMap<usize, Value<'jit, 'vm>>>, map: Rc<HashMap<usize, Value<'jit, 'vm>>>,
pub last: Option<Rc<Env<'jit, 'vm>>>, last: Option<Rc<LetEnv<'jit, 'vm>>>,
} }
impl<'jit, 'vm> Env<'jit, 'vm> { #[derive(Debug, Default, Clone)]
pub fn empty() -> Self { pub struct WithEnv<'jit, 'vm> {
Env::default() map: Rc<AttrSet<'jit, 'vm>>,
last: Option<Rc<WithEnv<'jit, 'vm>>>,
}
impl<'jit, 'vm> LetEnv<'jit, 'vm> {
pub fn new(map: Rc<HashMap<usize, Value<'jit, 'vm>>>) -> Self {
Self { map, last: None }
} }
pub fn lookup(&self, symbol: usize) -> Option<Value<'jit, 'vm>> { pub fn lookup(&self, symbol: usize) -> Option<Value<'jit, 'vm>> {
@@ -21,25 +27,15 @@ impl<'jit, 'vm> Env<'jit, 'vm> {
self.last.as_ref().map(|env| env.lookup(symbol)).flatten() self.last.as_ref().map(|env| env.lookup(symbol)).flatten()
} }
pub fn insert(&mut self, symbol: usize, value: Value<'jit, 'vm>) { pub fn enter_let(self, new: impl Iterator<Item = (usize, Value<'jit, 'vm>)>) -> Self {
Rc::make_mut(&mut self.map).insert(symbol, value);
}
pub fn enter(self, new: impl Iterator<Item = (usize, Value<'jit, 'vm>)>) -> Self {
let map = Rc::new(new.collect()); let map = Rc::new(new.collect());
let last = Some( let last = Some(self.into());
Env { LetEnv { last, map }
last: self.last,
map: self.map,
}
.into(),
);
Env { last, map }
} }
pub fn enter_with(self, new: Rc<AttrSet<'jit, 'vm>>) -> Self { pub fn enter_with(self, new: Rc<AttrSet<'jit, 'vm>>) -> Self {
let map = Rc::new( let map = new
new.as_inner() .as_inner()
.iter() .iter()
.map(|(&k, v)| { .map(|(&k, v)| {
( (
@@ -51,16 +47,43 @@ impl<'jit, 'vm> Env<'jit, 'vm> {
}, },
) )
}) })
.collect(), .collect::<HashMap<_, _>>()
); .into();
let last = Some( let last = Some(self.into());
Env { LetEnv { last, map }
last: self.last.clone(), }
map: self.map.clone(),
} pub fn leave(self) -> Self {
.into(), self.last.unwrap().as_ref().clone()
); }
Env { last, map } }
impl<'jit, 'vm> WithEnv<'jit, 'vm> {
pub fn lookup(&self, symbol: usize) -> Option<Value<'jit, 'vm>> {
if let Some(val) = self.map.select(symbol) {
return Some(val);
}
self.last.as_ref().map(|env| env.lookup(symbol)).flatten()
}
pub fn enter_with(self, new: Rc<AttrSet<'jit, 'vm>>) -> Self {
let map = Rc::new(new
.as_inner()
.iter()
.map(|(&k, v)| {
(
k,
if let Value::Builtins(weak) = v {
Value::AttrSet(weak.upgrade().unwrap())
} else {
v.clone()
},
)
})
.collect::<HashMap<_, _>>()
.into());
let last = Some(self.into());
WithEnv { last, map }
} }
pub fn leave(self) -> Self { pub fn leave(self) -> Self {

View File

@@ -13,7 +13,7 @@ use crate::ty::public::{self as p, Symbol};
use derive_more::Constructor; use derive_more::Constructor;
use ecow::EcoString; use ecow::EcoString;
pub use env::Env; pub use env::LetEnv;
mod env; mod env;
@@ -82,7 +82,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
pub fn eval( pub fn eval(
&'vm self, &'vm self,
opcodes: impl Iterator<Item = OpCode>, opcodes: impl Iterator<Item = OpCode>,
mut env: Env<'jit, 'vm>, mut env: LetEnv<'jit, 'vm>,
) -> Result<Value<'jit, 'vm>> { ) -> Result<Value<'jit, 'vm>> {
let mut stack = Stack::<_, STACK_SIZE>::new(); let mut stack = Stack::<_, STACK_SIZE>::new();
let mut iter = opcodes.into_iter(); let mut iter = opcodes.into_iter();
@@ -103,7 +103,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
&'vm self, &'vm self,
opcode: OpCode, opcode: OpCode,
stack: &'s mut Stack<Value<'jit, 'vm>, CAP>, stack: &'s mut Stack<Value<'jit, 'vm>, CAP>,
env: &mut Env<'jit, 'vm>, env: &mut LetEnv<'jit, 'vm>,
) -> Result<usize> { ) -> Result<usize> {
match opcode { match opcode {
OpCode::Illegal => panic!("illegal opcode"), OpCode::Illegal => panic!("illegal opcode"),
@@ -121,11 +121,6 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
stack.tos_mut()?.force(self)?; stack.tos_mut()?.force(self)?;
} }
OpCode::Jmp { step } => return Ok(step), OpCode::Jmp { step } => return Ok(step),
OpCode::JmpIfTrue { step } => {
if let Value::Const(Const::Bool(true)) = stack.pop() {
return Ok(step);
}
}
OpCode::JmpIfFalse { step } => { OpCode::JmpIfFalse { step } => {
if let Value::Const(Const::Bool(false)) = stack.pop() { if let Value::Const(Const::Bool(false)) = stack.pop() {
return Ok(step); return Ok(step);
@@ -193,7 +188,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
stack.push(Value::AttrSet(AttrSet::with_capacity(cap).into()))?; stack.push(Value::AttrSet(AttrSet::with_capacity(cap).into()))?;
} }
OpCode::FinalizeRec => { OpCode::FinalizeRec => {
let env = env.clone().enter( let env = env.clone().enter_let(
stack stack
.tos()? .tos()?
.clone() .clone()