diff --git a/src/builtins/mod.rs b/src/builtins/mod.rs index e391317..850e57a 100644 --- a/src/builtins/mod.rs +++ b/src/builtins/mod.rs @@ -5,8 +5,8 @@ use rpds::HashTrieMap; use crate::ty::internal::{AttrSet, Const, PrimOp, Value}; use crate::vm::{Env, VM}; -pub fn env<'vm>(vm: &'vm VM) -> Rc> { - let env = Rc::new(Env::empty()); +pub fn env<'vm>(vm: &'vm VM) -> Env<'vm> { + let mut env = Env::empty(); env.insert(vm.new_sym("true"), Value::Const(Const::Bool(true))); env.insert(vm.new_sym("false"), Value::Const(Const::Bool(false))); diff --git a/src/ty/internal/attrset.rs b/src/ty/internal/attrset.rs index 121ac20..3ec082c 100644 --- a/src/ty/internal/attrset.rs +++ b/src/ty/internal/attrset.rs @@ -1,5 +1,4 @@ use std::collections::HashSet; -use std::rc::Rc; use derive_more::Constructor; use itertools::Itertools; @@ -43,7 +42,7 @@ impl<'vm> AttrSet<'vm> { self.data.get(&sym).is_some() } - pub fn capture(&mut self, env: &Rc>) { + pub fn capture(&mut self, env: &Env<'vm>) { self.data.iter().for_each(|(_, v)| match v.clone() { Value::Thunk(ref thunk) => { thunk.capture(env.clone()); diff --git a/src/ty/internal/func.rs b/src/ty/internal/func.rs index fa003f7..d219a04 100644 --- a/src/ty/internal/func.rs +++ b/src/ty/internal/func.rs @@ -1,6 +1,3 @@ -use std::collections::HashMap; -use std::rc::Rc; - use derive_more::Constructor; use itertools::Itertools; @@ -46,7 +43,7 @@ pub type JITFunc<'vm> = #[derive(Debug, Clone, Constructor)] pub struct Func<'vm> { pub func: &'vm BFunc, - pub env: Rc>, + pub env: Env<'vm>, pub compiled: Option>, } @@ -54,10 +51,10 @@ impl<'vm> Func<'vm> { pub fn call(&self, vm: &'vm VM<'_>, arg: Value<'vm>) -> Result> { use Param::*; - let env = Rc::new(self.env.as_ref().clone()); + let mut env = self.env.clone(); match self.func.param.clone() { - Ident(ident) => env.enter([(ident.into(), arg)].into_iter()), + Ident(ident) => env = env.enter([(ident.into(), arg)].into_iter()), Formals { formals, ellipsis, @@ -88,7 +85,7 @@ impl<'vm> Func<'vm> { if let Some(alias) = alias { new.push((alias.clone().into(), Value::AttrSet(arg))); } - env.enter(new.into_iter()); + env = env.enter(new.into_iter()); } } diff --git a/src/ty/internal/mod.rs b/src/ty/internal/mod.rs index 98fbe0c..784eb58 100644 --- a/src/ty/internal/mod.rs +++ b/src/ty/internal/mod.rs @@ -466,7 +466,7 @@ pub struct Thunk<'vm> { #[derive(Debug, IsVariant, Unwrap, Clone)] pub enum _Thunk<'vm> { - Code(&'vm OpCodes, OnceCell>>), + Code(&'vm OpCodes, OnceCell>), SuspendedFrom(*const Thunk<'vm>), Value(Value<'vm>), } @@ -478,7 +478,7 @@ impl<'vm> Thunk<'vm> { } } - pub fn capture(&self, env: Rc>) { + pub fn capture(&self, env: Env<'vm>) { if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() { envcell.get_or_init(|| env); } diff --git a/src/vm/env.rs b/src/vm/env.rs index 03f8857..5f241c6 100644 --- a/src/vm/env.rs +++ b/src/vm/env.rs @@ -1,28 +1,13 @@ -use std::cell::RefCell; use std::rc::Rc; use rpds::HashTrieMap; use crate::ty::internal::{AttrSet, Value}; -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct Env<'vm> { - last: RefCell>>>, - map: RefCell>>, -} - -impl Clone for Env<'_> { - fn clone(&self) -> Self { - Env { - last: RefCell::new( - self.last - .borrow() - .clone() - .map(|e| Rc::new(e.as_ref().clone())), - ), - map: RefCell::new(self.map.borrow().clone()), - } - } + last: Option>>, + map: HashTrieMap>, } impl<'vm> Env<'vm> { @@ -31,28 +16,30 @@ impl<'vm> Env<'vm> { } pub fn lookup(&self, symbol: usize) -> Option> { - self.map.borrow().get(&symbol).cloned() + self.map.get(&symbol).cloned() } - pub fn insert(&self, symbol: usize, value: Value<'vm>) { - self.map.borrow_mut().insert_mut(symbol, value); + pub fn insert(&mut self, symbol: usize, value: Value<'vm>) { + self.map.insert_mut(symbol, value); } - pub fn enter(&self, new: impl Iterator)>) { - let mut map = self.map.borrow().clone(); + pub fn enter(self, new: impl Iterator)>) -> Self { + let mut map = self.map.clone(); for (k, v) in new { map.insert_mut(k, v); } - let last = Env { - last: self.last.clone(), - map: self.map.clone(), - }; - *self.last.borrow_mut() = Some(Rc::new(last)); - *self.map.borrow_mut() = map; + let last = Some( + Env { + last: self.last, + map: self.map, + } + .into(), + ); + Env { last, map } } - pub fn enter_with(&self, new: Rc>) { - let mut map = self.map.borrow().clone(); + pub fn enter_with(self, new: Rc>) -> Self { + let mut map = self.map.clone(); for (k, v) in new.as_inner().iter() { let v = if let Value::Builtins = v { Value::AttrSet(new.clone()) @@ -61,18 +48,17 @@ impl<'vm> Env<'vm> { }; map.insert_mut(k.clone(), v); } - let last = Env { - last: self.last.clone(), - map: self.map.clone(), - }; - *self.last.borrow_mut() = Some(Rc::new(last)); - *self.map.borrow_mut() = map; + let last = Some( + Env { + last: self.last.clone(), + map: self.map.clone(), + } + .into(), + ); + Env { last, map } } - pub fn leave(&self) { - let last = (*self.last.borrow_mut()).take().unwrap(); - let _ = std::mem::replace(&mut *self.last.borrow_mut(), last.last.borrow().clone()); - let map = last.map.borrow().clone(); - *self.map.borrow_mut() = map; + pub fn leave(self) -> Self { + self.last.unwrap().as_ref().clone() } } diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 7699de2..ca718a4 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -1,7 +1,6 @@ use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::pin::Pin; -use std::rc::Rc; use crate::builtins::env; use crate::bytecode::{BinOp, Func as F, OpCode, OpCodes, Program, UnOp}; @@ -80,12 +79,12 @@ impl<'vm, 'jit: 'vm> VM<'jit> { pub fn eval( &'vm self, opcodes: impl Iterator, - env: Rc>, + mut env: Env<'vm>, ) -> Result> { let mut stack = Stack::<_, STACK_SIZE>::new(); let mut iter = opcodes.into_iter(); while let Some(opcode) = iter.next() { - let jmp = self.single_op(opcode, &mut stack, &env)?; + let jmp = self.single_op(opcode, &mut stack, &mut env)?; for _ in 0..jmp { iter.next().unwrap(); } @@ -101,7 +100,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { &'vm self, opcode: OpCode, stack: &'s mut Stack, CAP>, - env: &Rc>, + env: &mut Env<'vm>, ) -> Result { match opcode { OpCode::Illegal => panic!("illegal opcode"), @@ -187,7 +186,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { stack.push(Value::AttrSet(AttrSet::empty().into()))?; } OpCode::FinalizeRec => { - env.enter( + let env = env.clone().enter( stack .tos()? .clone() @@ -196,7 +195,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { .iter() .map(|(k, v)| (k.clone(), v.clone())), ); - stack.tos_mut()?.as_mut().unwrap_attr_set().capture(env); + stack.tos_mut()?.as_mut().unwrap_attr_set().capture(&env); } OpCode::PushStaticAttr { name } => { let val = stack.pop(); @@ -252,13 +251,11 @@ impl<'vm, 'jit: 'vm> VM<'jit> { })?)?; } OpCode::EnterEnv => match stack.pop() { - Value::AttrSet(attrs) => env.enter_with(attrs), + Value::AttrSet(attrs) => *env = env.clone().enter_with(attrs), _ => unreachable!(), }, - OpCode::LeaveEnv => { - env.leave(); - } + OpCode::LeaveEnv => *env = env.clone().leave(), OpCode::Assert => { if !stack.pop().unwrap_const().unwrap_bool() { todo!()