diff --git a/src/builtins/mod.rs b/src/builtins/mod.rs index 4bcf81e..2fce24a 100644 --- a/src/builtins/mod.rs +++ b/src/builtins/mod.rs @@ -3,12 +3,12 @@ use std::rc::Rc; use crate::ty::common::Const; use crate::ty::internal::{AttrSet, PrimOp, Value}; -use crate::vm::{Env, VM}; +use crate::vm::{LetEnv, VM}; -pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> Env<'jit, 'vm> { - let mut env = Env::empty(); - env.insert(vm.new_sym("true"), Value::Const(Const::Bool(true))); - env.insert(vm.new_sym("false"), Value::Const(Const::Bool(false))); +pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> LetEnv<'jit, 'vm> { + let mut env_map = HashMap::new(); + env_map.insert(vm.new_sym("true"), Value::Const(Const::Bool(true))); + env_map.insert(vm.new_sym("false"), Value::Const(Const::Bool(false))); let primops = [ PrimOp::new("add", 2, |_, args| { @@ -46,18 +46,19 @@ pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> Env<'jit, 'vm> { let mut map = HashMap::new(); for primop in primops { let primop = Rc::new(primop); - env.insert( + env_map.insert( vm.new_sym(format!("__{}", primop.name)), Value::PrimOp(primop.clone()), ); map.insert(vm.new_sym(primop.name), Value::PrimOp(primop)); } + let sym = vm.new_sym("builtins"); let attrs = Rc::new_cyclic(|weak| { - map.insert(vm.new_sym("builtins"), Value::Builtins(weak.clone())); + map.insert(sym, Value::Builtins(weak.clone())); AttrSet::from_inner(map) }); let builtins = Value::AttrSet(attrs); - env.insert(vm.new_sym("builtins"), builtins); - env + env_map.insert(sym, builtins); + LetEnv::new(env_map.into()) } diff --git a/src/bytecode.rs b/src/bytecode.rs index f156e22..537e461 100644 --- a/src/bytecode.rs +++ b/src/bytecode.rs @@ -33,8 +33,6 @@ pub enum OpCode { /// jump forward Jmp { step: usize }, - /// [ .. cond ] consume 1 element, if `cond`` is true, then jump forward - JmpIfTrue { step: usize }, /// [ .. cond ] consume 1 element, if `cond` is false, then jump forward JmpIfFalse { step: usize }, diff --git a/src/jit/helpers.rs b/src/jit/helpers.rs index 4ece819..10d2bf9 100644 --- a/src/jit/helpers.rs +++ b/src/jit/helpers.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use inkwell::AddressSpace; use inkwell::context::Context; use inkwell::execution_engine::ExecutionEngine; @@ -8,8 +6,8 @@ use inkwell::types::{FloatType, FunctionType, IntType, PointerType, StructType}; use inkwell::values::{BasicValueEnum, FunctionValue}; use crate::jit::JITValueData; -use crate::ty::internal::Thunk; -use crate::vm::{Env, VM}; +use crate::ty::internal::{Thunk, Value}; +use crate::vm::{LetEnv, VM}; use super::{JITFunc, JITValue, ValueTag}; @@ -25,10 +23,14 @@ pub struct Helpers<'ctx> { pub debug: FunctionValue<'ctx>, pub capture_env: FunctionValue<'ctx>, pub neg: FunctionValue<'ctx>, + pub not: FunctionValue<'ctx>, pub add: FunctionValue<'ctx>, pub sub: FunctionValue<'ctx>, + pub eq: FunctionValue<'ctx>, + pub or: FunctionValue<'ctx>, pub call: FunctionValue<'ctx>, pub lookup: FunctionValue<'ctx>, + pub force: FunctionValue<'ctx>, } impl<'ctx> Helpers<'ctx> { @@ -61,6 +63,11 @@ impl<'ctx> Helpers<'ctx> { value_type.fn_type(&[value_type.into(), ptr_type.into()], false), None, ); + let not = module.add_function( + "not", + value_type.fn_type(&[value_type.into(), ptr_type.into()], false), + None, + ); let add = module.add_function( "add", value_type.fn_type(&[value_type.into(), value_type.into()], false), @@ -71,15 +78,23 @@ impl<'ctx> Helpers<'ctx> { value_type.fn_type(&[value_type.into(), value_type.into()], false), None, ); - // Assuming a single argument for now based on the test case + let eq = module.add_function( + "eq", + value_type.fn_type(&[value_type.into(), value_type.into()], false), + None, + ); + let or = module.add_function( + "or", + value_type.fn_type(&[value_type.into(), value_type.into()], false), + None, + ); let call = module.add_function( "call", value_type.fn_type( &[ - ptr_type.into(), - ptr_type.into(), - ptr_type.into(), value_type.into(), + value_type.into(), + ptr_type.into(), ], false, ), @@ -90,14 +105,23 @@ impl<'ctx> Helpers<'ctx> { value_type.fn_type(&[ptr_int_type.into(), ptr_type.into()], false), None, ); + let force = module.add_function( + "force", + value_type.fn_type(&[value_type.into(), ptr_type.into()], false), + None, + ); execution_engine.add_global_mapping(&debug, helper_debug as _); execution_engine.add_global_mapping(&capture_env, helper_capture_env as _); execution_engine.add_global_mapping(&neg, helper_neg as _); + execution_engine.add_global_mapping(¬, helper_not as _); execution_engine.add_global_mapping(&add, helper_add as _); execution_engine.add_global_mapping(&sub, helper_sub as _); + execution_engine.add_global_mapping(&eq, helper_eq as _); + execution_engine.add_global_mapping(&or, helper_or as _); execution_engine.add_global_mapping(&call, helper_call as _); execution_engine.add_global_mapping(&lookup, helper_lookup as _); + execution_engine.add_global_mapping(&force, helper_force as _); Helpers { int_type, @@ -111,10 +135,14 @@ impl<'ctx> Helpers<'ctx> { debug, capture_env, neg, + not, add, sub, + eq, + or, call, lookup, + force } } @@ -179,13 +207,13 @@ extern "C" fn helper_debug(value: JITValue) { } #[unsafe(no_mangle)] -extern "C" fn helper_capture_env(thunk: JITValue, env: *const Env) { +extern "C" fn helper_capture_env(thunk: JITValue, env: *const LetEnv) { let thunk: &Thunk = unsafe { std::mem::transmute(thunk.data.ptr.as_ref().unwrap()) }; thunk.capture(unsafe { env.as_ref().unwrap() }.clone()); } #[unsafe(no_mangle)] -extern "C" fn helper_neg(rhs: JITValue, _env: *const Env) -> JITValue { +extern "C" fn helper_neg(rhs: JITValue, _env: *const LetEnv) -> JITValue { use ValueTag::*; match rhs.tag { Int => JITValue { @@ -198,6 +226,20 @@ extern "C" fn helper_neg(rhs: JITValue, _env: *const Env) -> JITValue { } } +#[unsafe(no_mangle)] +extern "C" fn helper_not(rhs: JITValue, _env: *const LetEnv) -> JITValue { + use ValueTag::*; + match rhs.tag { + Bool => JITValue { + tag: Bool, + data: JITValueData { + bool: !unsafe { rhs.data.bool }, + }, + }, + _ => todo!(), + } +} + #[unsafe(no_mangle)] extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue { use ValueTag::*; @@ -235,20 +277,68 @@ extern "C" fn helper_sub(lhs: JITValue, rhs: JITValue) -> JITValue { } #[unsafe(no_mangle)] -extern "C" fn helper_call<'jit, 'vm>( - vm: *const VM<'jit>, - env: *const Env<'jit, 'vm>, - func_ptr: *const (), - arg: JITValue, -) -> JITValue { - let func: JITFunc = unsafe { std::mem::transmute(func_ptr) }; - todo!(); - unsafe { func(vm, env) } +extern "C" fn helper_eq(lhs: JITValue, rhs: JITValue) -> JITValue { + use ValueTag::*; + match (lhs.tag, rhs.tag) { + (Int, Int) => JITValue { + tag: Bool, + data: JITValueData { + bool: unsafe { lhs.data.int == rhs.data.int } + }, + }, + _ => todo!( + "Equation not implemented for {:?} and {:?}", + lhs.tag, + rhs.tag + ), + } } #[unsafe(no_mangle)] -extern "C" fn helper_lookup<'jit, 'vm>(sym: usize, env: *const Env<'jit, 'vm>) -> JITValue { +extern "C" fn helper_or(lhs: JITValue, rhs: JITValue) -> JITValue { + use ValueTag::*; + match (lhs.tag, rhs.tag) { + (Bool, Bool) => JITValue { + tag: Bool, + data: JITValueData { + bool: unsafe { lhs.data.bool || rhs.data.bool }, + }, + }, + _ => todo!( + "Substruction not implemented for {:?} and {:?}", + lhs.tag, + rhs.tag + ), + } +} + +#[unsafe(no_mangle)] +extern "C" fn helper_call<'jit, 'vm>( + func: JITValue, + arg: JITValue, + vm: *const VM<'jit>, +) -> JITValue { + use ValueTag::*; + match func.tag { + Function => { + let func: Value = func.into(); + func.call(unsafe { vm.as_ref() }.unwrap(), vec![arg.into()]).unwrap().into() + } + _ => todo!(), + } +} + +#[unsafe(no_mangle)] +extern "C" fn helper_lookup<'jit, 'vm>(sym: usize, env: *const LetEnv<'jit, 'vm>) -> JITValue { let env = unsafe { env.as_ref() }.unwrap(); let val = env.lookup(sym); + dbg!(val.as_ref().unwrap().typename()); val.unwrap().into() } + +#[unsafe(no_mangle)] +extern "C" fn helper_force<'jit, 'vm>(thunk: JITValue, vm: *const VM<'jit>) -> JITValue { + let mut val = Value::from(thunk); + val.force(unsafe { vm.as_ref() }.unwrap()).unwrap(); + val.into() +} diff --git a/src/jit/mod.rs b/src/jit/mod.rs index 8fdf63d..4a1a000 100644 --- a/src/jit/mod.rs +++ b/src/jit/mod.rs @@ -1,18 +1,19 @@ use std::rc::Rc; use inkwell::OptimizationLevel; +use inkwell::basic_block::BasicBlock; use inkwell::builder::Builder; use inkwell::context::Context; use inkwell::execution_engine::{ExecutionEngine, JitFunction}; use inkwell::module::Module; -use inkwell::values::{BasicValueEnum, PointerValue}; +use inkwell::values::{BasicValueEnum, FunctionValue, PointerValue}; use crate::bytecode::{Func, OpCode, UnOp}; use crate::error::*; use crate::stack::Stack; use crate::ty::common::Const; use crate::ty::internal::{Thunk, Value}; -use crate::vm::{Env, VM}; +use crate::vm::{LetEnv, VM}; mod helpers; @@ -48,17 +49,19 @@ pub struct JITValue { pub union JITValueData { int: i64, float: f64, - boolean: bool, + bool: bool, ptr: *const (), } -impl<'jit: 'vm, 'vm> Into> for JITValue { - fn into(self) -> Value<'jit, 'vm> { +impl<'jit: 'vm, 'vm> From for Value<'jit, 'vm> { + fn from(value: JITValue) -> Self { use ValueTag::*; - match self.tag { - Int => Value::Const(Const::Int(unsafe { self.data.int })), + match value.tag { + Int => Value::Const(Const::Int(unsafe { value.data.int })), Null => Value::Const(Const::Null), - _ => todo!("not implemented for {:?}", self.tag), + Function => Value::Func(unsafe { Rc::from_raw(value.data.ptr as *const _) }), + Thunk => Value::Thunk(unsafe { Rc::from_raw(value.data.ptr as *const _) }), + _ => todo!("not implemented for {:?}", value.tag), } } } @@ -70,13 +73,25 @@ impl From> for JITValue { tag: ValueTag::Int, data: JITValueData { int }, }, + Value::Func(func) => JITValue { + tag: ValueTag::Function, + data: JITValueData { + ptr: Rc::into_raw(func) as *const _ + } + }, + Value::Thunk(thunk) => JITValue { + tag: ValueTag::Thunk, + data: JITValueData { + ptr: Rc::into_raw(thunk) as *const _ + } + }, _ => todo!(), } } } pub type JITFunc<'jit, 'vm> = - unsafe extern "C" fn(*const VM<'jit>, *const Env<'jit, 'vm>) -> JITValue; + unsafe extern "C" fn(*const VM<'jit>, *const LetEnv<'jit, 'vm>) -> JITValue; pub struct JITContext<'ctx> { context: &'ctx Context, @@ -125,17 +140,25 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { let func_ = self .module .add_function("nixjit_function", self.helpers.func_type, None); - let entry = self.context.append_basic_block(func_, "entry"); - self.builder.position_at_end(entry); let env = func_.get_nth_param(1).unwrap().into_pointer_value(); - while let Some(opcode) = iter.next() { - self.single_op(opcode, vm, env, &mut stack)?; - } + let entry = self.context.append_basic_block(func_, "entry"); + self.build_expr( + &mut iter, + vm, + env, + &mut stack, + func_, + entry, + func.opcodes.len(), + )?; + assert_eq!(stack.len(), 1); let value = stack.pop(); + let exit = self.context.append_basic_block(func_, "exit"); + self.builder.build_unconditional_branch(exit)?; + self.builder.position_at_end(exit); self.builder.build_return(Some(&value))?; - if func_.verify(false) { - func_.print_to_stderr(); + if func_.verify(true) { unsafe { let name = func_.get_name().to_str().unwrap(); let addr = self.execution_engine.get_function(name).unwrap(); @@ -146,13 +169,82 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { } } + fn build_expr( + &self, + iter: &mut impl Iterator, + vm: &'vm VM<'_>, + env: PointerValue<'ctx>, + stack: &mut Stack, CAP>, + func: FunctionValue<'ctx>, + bb: BasicBlock<'ctx>, + mut length: usize, + ) -> Result { + self.builder.position_at_end(bb); + while length > 1 { + let opcode = iter.next().unwrap(); + let br = self.single_op(opcode, vm, env, stack)?; + length -= 1; + if br > 0 { + let consq = self.context.append_basic_block(func, "consq"); + let alter = self.context.append_basic_block(func, "alter"); + let cont = self.context.append_basic_block(func, "cont"); + let cond = self + .builder + .build_alloca(self.helpers.value_type, "cond_alloca")?; + let result = self + .builder + .build_alloca(self.helpers.value_type, "result_alloca")?; + self.builder.build_store(cond, stack.pop())?; + self.builder.build_conditional_branch( + self.builder + .build_load( + self.context.bool_type(), + self.builder.build_struct_gep( + self.helpers.value_type, + cond, + 1, + "gep_cond", + )?, + "load_cond", + )? + .into_int_value(), + consq, + alter, + )?; + + length -= br; + let br = self.build_expr(iter, vm, env, stack, func, consq, br)?; + self.builder.build_store(result, stack.pop())?; + self.builder.build_unconditional_branch(cont)?; + + length -= br; + self.build_expr(iter, vm, env, stack, func, alter, br)?; + self.builder.build_store(result, stack.pop())?; + self.builder.build_unconditional_branch(cont)?; + + self.builder.position_at_end(cont); + stack.push(self.builder.build_load( + self.helpers.value_type, + result, + "load_result", + )?)?; + } + } + if length > 0 { + self.single_op(iter.next().unwrap(), vm, env, stack) + } else { + Ok(0) + } + } + + #[inline(always)] fn single_op( &self, opcode: OpCode, vm: &'vm VM<'_>, env: PointerValue<'ctx>, stack: &mut Stack, CAP>, - ) -> Result<()> { + ) -> Result { match opcode { OpCode::Const { idx } => { use Const::*; @@ -186,54 +278,14 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { .try_as_basic_value() .left() .unwrap(), - _ => todo!(), + Not => self + .builder + .build_direct_call(self.helpers.not, &[rhs.into(), env.into()], "call_neg")? + .try_as_basic_value() + .left() + .unwrap(), })? } - OpCode::Func { idx } => { - let func = vm.get_func(idx); - let jit_func_ptr = self.compile_function(func, vm)?; - let jit_value = self - .helpers - .value_type - .const_named_struct(&[ - self.helpers - .int_type - .const_int(ValueTag::Function as _, false) - .into(), - self.helpers - .ptr_int_type - .const_int(unsafe { jit_func_ptr.as_raw() } as _, false) - .into(), - ]) - .into(); - stack.push(jit_value)?; - } - OpCode::Call { arity } => { - // Assuming arity is 1 for the test case - assert_eq!(arity, 1); - let arg = stack.pop(); - let func_value = stack.pop(); - let func_ptr = self - .builder - .build_extract_value(func_value.into_struct_value(), 1, "func_ptr")? - .into_pointer_value(); - let result = self - .builder - .build_direct_call( - self.helpers.call, - &[ - self.new_ptr(vm).into(), - env.into(), - func_ptr.into(), - arg.into(), - ], - "call_func", - )? - .try_as_basic_value() - .left() - .unwrap(); - stack.push(result)?; - } OpCode::BinOp { op } => { use crate::bytecode::BinOp; let rhs = stack.pop(); @@ -269,9 +321,22 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { let result = self .builder .build_direct_call( - self.helpers.add, + self.helpers.eq, &[lhs.into(), rhs.into()], - "call_add", + "call_eq", + )? + .try_as_basic_value() + .left() + .unwrap(); + stack.push(result)?; + } + BinOp::Or => { + let result = self + .builder + .build_direct_call( + self.helpers.or, + &[lhs.into(), rhs.into()], + "call_or", )? .try_as_basic_value() .left() @@ -299,8 +364,29 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { .left() .unwrap(), )?, + OpCode::Call { arity } => { + // TODO: + assert_eq!(arity, 1); + let mut args = Vec::with_capacity(arity); + for _ in 0..arity { + args.insert(0, stack.pop()); + } + let func = self.builder + .build_direct_call(self.helpers.force, &[stack.pop().into(), self.new_ptr(vm).into()], "force")? + .try_as_basic_value() + .left() + .unwrap(); + let ret = self.builder + .build_direct_call(self.helpers.call, &[func.into(), args[0].into(), self.new_ptr(vm).into()], "call")? + .try_as_basic_value() + .left() + .unwrap(); + stack.push(ret)?; + } + OpCode::JmpIfFalse { step } => return Ok(step), + OpCode::Jmp { step } => return Ok(step), _ => todo!("{opcode:?} not implemented"), } - Ok(()) + Ok(0) } } diff --git a/src/stack.rs b/src/stack.rs index 675321b..d738a8f 100644 --- a/src/stack.rs +++ b/src/stack.rs @@ -45,8 +45,8 @@ impl Stack { } pub fn pop(&mut self) -> T { + let item = self.items.get_mut(self.top - 1).unwrap(); self.top -= 1; - let item = self.items.get_mut(self.top).unwrap(); // SAFETY: `item` at `self.top` was previously written and is initialized. // We replace it with `MaybeUninit::uninit()` and then `assume_init` diff --git a/src/ty/internal/attrset.rs b/src/ty/internal/attrset.rs index 9925727..0687f01 100644 --- a/src/ty/internal/attrset.rs +++ b/src/ty/internal/attrset.rs @@ -4,24 +4,24 @@ use derive_more::Constructor; use itertools::Itertools; use crate::error::Result; -use crate::vm::{Env, VM}; +use crate::vm::{LetEnv, VM}; use super::super::public as p; use super::Value; #[repr(C)] -#[derive(Debug, Constructor, Clone, PartialEq)] +#[derive(Debug, Default, Constructor, Clone, PartialEq)] pub struct AttrSet<'jit: 'vm, 'vm> { data: HashMap>, } -impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> { - pub fn empty() -> Self { - AttrSet { - data: HashMap::new(), - } +impl<'jit, 'vm> From>> for AttrSet<'jit, 'vm> { + fn from(data: HashMap>) -> Self { + Self { data } } +} +impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> { pub fn with_capacity(cap: usize) -> Self { AttrSet { data: HashMap::with_capacity(cap), @@ -47,7 +47,7 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> { self.data.get(&sym).is_some() } - pub fn capture(&mut self, env: &Env<'jit, 'vm>) { + pub fn capture(&mut self, env: &LetEnv<'jit, 'vm>) { self.data.iter().for_each(|(_, v)| match v.clone() { Value::Thunk(ref thunk) => { thunk.capture(env.clone()); diff --git a/src/ty/internal/func.rs b/src/ty/internal/func.rs index ac1ec70..a651ec1 100644 --- a/src/ty/internal/func.rs +++ b/src/ty/internal/func.rs @@ -9,7 +9,7 @@ use crate::error::Result; use crate::ir; use crate::jit::JITFunc; use crate::ty::internal::{Thunk, Value}; -use crate::vm::{Env, VM}; +use crate::vm::{LetEnv, VM}; #[derive(Debug, Clone)] pub enum Param { @@ -44,7 +44,7 @@ impl From for Param { #[derive(Debug, Clone, Constructor)] pub struct Func<'jit: 'vm, 'vm> { pub func: &'vm BFunc, - pub env: Env<'jit, 'vm>, + pub env: LetEnv<'jit, 'vm>, pub compiled: OnceCell>>, pub count: Cell, } @@ -54,7 +54,7 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> { use Param::*; let env = match self.func.param.clone() { - Ident(ident) => self.env.clone().enter([(ident.into(), arg)].into_iter()), + Ident(ident) => self.env.clone().enter_let([(ident.into(), arg)].into_iter()), Formals { formals, ellipsis, @@ -85,7 +85,7 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> { if let Some(alias) = alias { new.push((alias.clone().into(), Value::AttrSet(arg))); } - self.env.clone().enter(new.into_iter()) + self.env.clone().enter_let(new.into_iter()) } }; @@ -93,7 +93,7 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> { self.count.replace(count + 1); if count >= 1 { let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func)); - let ret = unsafe { compiled.call(vm as *const VM, &env as *const Env) }; + let ret = unsafe { compiled.call(vm as *const VM, &env as *const LetEnv) }; return Ok(ret.into()); } vm.eval(self.func.opcodes.iter().copied(), env) diff --git a/src/ty/internal/mod.rs b/src/ty/internal/mod.rs index 10b3826..2c6ff42 100644 --- a/src/ty/internal/mod.rs +++ b/src/ty/internal/mod.rs @@ -11,7 +11,7 @@ use super::public as p; use crate::bytecode::OpCodes; use crate::error::*; -use crate::vm::{Env, VM}; +use crate::vm::{LetEnv, VM}; mod attrset; mod func; @@ -157,7 +157,11 @@ impl<'jit, 'vm> Value<'jit, 'vm> { pub fn typename(&self) -> &'static str { use Value::*; match self { - Const(_) => unreachable!(), + Const(self::Const::Int(_)) => "int", + Const(self::Const::Float(_)) => "float", + Const(self::Const::Bool(_)) => "bool", + Const(self::Const::String(_)) => "string", + Const(self::Const::Null) => "null", Thunk(_) => "thunk", ThunkRef(_) => "thunk", AttrSet(_) => "set", @@ -481,7 +485,7 @@ pub struct Thunk<'jit, 'vm> { #[derive(Debug, IsVariant, Unwrap, Clone)] pub enum _Thunk<'jit, 'vm> { - Code(&'vm OpCodes, OnceCell>), + Code(&'vm OpCodes, OnceCell>), SuspendedFrom(*const Thunk<'jit, 'vm>), Value(Value<'jit, 'vm>), } @@ -493,7 +497,7 @@ impl<'jit, 'vm> Thunk<'jit, 'vm> { } } - pub fn capture(&self, env: Env<'jit, 'vm>) { + pub fn capture(&self, env: LetEnv<'jit, 'vm>) { if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() { envcell.get_or_init(|| env); } diff --git a/src/vm/env.rs b/src/vm/env.rs index b7b5c65..af24084 100644 --- a/src/vm/env.rs +++ b/src/vm/env.rs @@ -4,14 +4,20 @@ use std::rc::Rc; use crate::ty::internal::{AttrSet, Value}; #[derive(Debug, Default, Clone)] -pub struct Env<'jit, 'vm> { - pub map: Rc>>, - pub last: Option>>, +pub struct LetEnv<'jit, 'vm> { + map: Rc>>, + last: Option>>, } -impl<'jit, 'vm> Env<'jit, 'vm> { - pub fn empty() -> Self { - Env::default() +#[derive(Debug, Default, Clone)] +pub struct WithEnv<'jit, 'vm> { + map: Rc>, + last: Option>>, +} + +impl<'jit, 'vm> LetEnv<'jit, 'vm> { + pub fn new(map: Rc>>) -> Self { + Self { map, last: None } } pub fn lookup(&self, symbol: usize) -> Option> { @@ -21,46 +27,63 @@ impl<'jit, 'vm> Env<'jit, 'vm> { self.last.as_ref().map(|env| env.lookup(symbol)).flatten() } - pub fn insert(&mut self, symbol: usize, value: Value<'jit, 'vm>) { - Rc::make_mut(&mut self.map).insert(symbol, value); - } - - pub fn enter(self, new: impl Iterator)>) -> Self { + pub fn enter_let(self, new: impl Iterator)>) -> Self { let map = Rc::new(new.collect()); - let last = Some( - Env { - last: self.last, - map: self.map, - } - .into(), - ); - Env { last, map } + let last = Some(self.into()); + LetEnv { last, map } } pub fn enter_with(self, new: Rc>) -> Self { - let map = Rc::new( - new.as_inner() - .iter() - .map(|(&k, v)| { - ( - k, - if let Value::Builtins(weak) = v { - Value::AttrSet(weak.upgrade().unwrap()) - } else { - v.clone() - }, - ) - }) - .collect(), - ); - let last = Some( - Env { - last: self.last.clone(), - map: self.map.clone(), - } - .into(), - ); - Env { last, map } + let map = new + .as_inner() + .iter() + .map(|(&k, v)| { + ( + k, + if let Value::Builtins(weak) = v { + Value::AttrSet(weak.upgrade().unwrap()) + } else { + v.clone() + }, + ) + }) + .collect::>() + .into(); + let last = Some(self.into()); + LetEnv { last, map } + } + + pub fn leave(self) -> Self { + self.last.unwrap().as_ref().clone() + } +} + +impl<'jit, 'vm> WithEnv<'jit, 'vm> { + pub fn lookup(&self, symbol: usize) -> Option> { + if let Some(val) = self.map.select(symbol) { + return Some(val); + } + self.last.as_ref().map(|env| env.lookup(symbol)).flatten() + } + + pub fn enter_with(self, new: Rc>) -> Self { + let map = Rc::new(new + .as_inner() + .iter() + .map(|(&k, v)| { + ( + k, + if let Value::Builtins(weak) = v { + Value::AttrSet(weak.upgrade().unwrap()) + } else { + v.clone() + }, + ) + }) + .collect::>() + .into()); + let last = Some(self.into()); + WithEnv { last, map } } pub fn leave(self) -> Self { diff --git a/src/vm/mod.rs b/src/vm/mod.rs index c1f73f9..2fab1e0 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -13,7 +13,7 @@ use crate::ty::public::{self as p, Symbol}; use derive_more::Constructor; use ecow::EcoString; -pub use env::Env; +pub use env::LetEnv; mod env; @@ -82,7 +82,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { pub fn eval( &'vm self, opcodes: impl Iterator, - mut env: Env<'jit, 'vm>, + mut env: LetEnv<'jit, 'vm>, ) -> Result> { let mut stack = Stack::<_, STACK_SIZE>::new(); let mut iter = opcodes.into_iter(); @@ -103,7 +103,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { &'vm self, opcode: OpCode, stack: &'s mut Stack, CAP>, - env: &mut Env<'jit, 'vm>, + env: &mut LetEnv<'jit, 'vm>, ) -> Result { match opcode { OpCode::Illegal => panic!("illegal opcode"), @@ -121,11 +121,6 @@ impl<'vm, 'jit: 'vm> VM<'jit> { stack.tos_mut()?.force(self)?; } OpCode::Jmp { step } => return Ok(step), - OpCode::JmpIfTrue { step } => { - if let Value::Const(Const::Bool(true)) = stack.pop() { - return Ok(step); - } - } OpCode::JmpIfFalse { step } => { if let Value::Const(Const::Bool(false)) = stack.pop() { return Ok(step); @@ -193,7 +188,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { stack.push(Value::AttrSet(AttrSet::with_capacity(cap).into()))?; } OpCode::FinalizeRec => { - let env = env.clone().enter( + let env = env.clone().enter_let( stack .tos()? .clone()