From d0298ce2a6013b157eb408679b3798b3ed8f2d81 Mon Sep 17 00:00:00 2001 From: imxyy_soope_ Date: Tue, 20 May 2025 09:47:30 +0800 Subject: [PATCH] optimize(env): single arg --- src/builtins/mod.rs | 2 +- src/jit/helpers.rs | 8 +++++- src/jit/test.rs | 2 +- src/ty/internal/attrset.rs | 4 ++- src/ty/internal/func.rs | 24 ++++++++++------ src/ty/internal/mod.rs | 4 +-- src/vm/env.rs | 58 +++++++++++++++++++++++++++----------- src/vm/mod.rs | 12 ++++---- 8 files changed, 76 insertions(+), 38 deletions(-) diff --git a/src/builtins/mod.rs b/src/builtins/mod.rs index 2fce24a..58c2fbb 100644 --- a/src/builtins/mod.rs +++ b/src/builtins/mod.rs @@ -60,5 +60,5 @@ pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> LetEnv<'jit, 'vm> { let builtins = Value::AttrSet(attrs); env_map.insert(sym, builtins); - LetEnv::new(env_map.into()) + LetEnv::new(AttrSet::new(env_map).into()) } diff --git a/src/jit/helpers.rs b/src/jit/helpers.rs index 531d57d..e3fbf4b 100644 --- a/src/jit/helpers.rs +++ b/src/jit/helpers.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use inkwell::AddressSpace; use inkwell::context::Context; use inkwell::execution_engine::ExecutionEngine; @@ -205,7 +207,11 @@ extern "C" fn helper_debug(value: JITValue) { #[unsafe(no_mangle)] extern "C" fn helper_capture_env(thunk: JITValue, env: *const LetEnv) { let thunk: &Thunk = unsafe { std::mem::transmute(thunk.data.ptr.as_ref().unwrap()) }; - thunk.capture(unsafe { env.as_ref().unwrap() }.clone()); + let env = unsafe { + Rc::from_raw(env) + }; + thunk.capture(env.clone()); + std::mem::forget(env); } #[unsafe(no_mangle)] diff --git a/src/jit/test.rs b/src/jit/test.rs index 34f6f2e..19ed5f0 100644 --- a/src/jit/test.rs +++ b/src/jit/test.rs @@ -32,7 +32,7 @@ fn test_expr(expr: &str, expected: Value) { ); let env = env(&vm); let value = vm - .eval(prog.top_level.into_iter(), env) + .eval(prog.top_level.into_iter(), env.into()) .unwrap() .to_public(&vm, &mut HashSet::new()); assert_eq!(value, expected); diff --git a/src/ty/internal/attrset.rs b/src/ty/internal/attrset.rs index 0687f01..5b8cdc2 100644 --- a/src/ty/internal/attrset.rs +++ b/src/ty/internal/attrset.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use hashbrown::{HashMap, HashSet}; use derive_more::Constructor; @@ -47,7 +49,7 @@ impl<'jit: 'vm, 'vm> AttrSet<'jit, 'vm> { self.data.get(&sym).is_some() } - pub fn capture(&mut self, env: &LetEnv<'jit, 'vm>) { + pub fn capture(&mut self, env: &Rc>) { self.data.iter().for_each(|(_, v)| match v.clone() { Value::Thunk(ref thunk) => { thunk.capture(env.clone()); diff --git a/src/ty/internal/func.rs b/src/ty/internal/func.rs index b4dbb96..2adf433 100644 --- a/src/ty/internal/func.rs +++ b/src/ty/internal/func.rs @@ -1,6 +1,8 @@ use std::cell::{Cell, OnceCell}; +use std::rc::Rc; use derive_more::Constructor; +use hashbrown::HashMap; use inkwell::execution_engine::JitFunction; use itertools::Itertools; @@ -8,7 +10,7 @@ use crate::bytecode::Func as BFunc; use crate::error::Result; use crate::ir; use crate::jit::JITFunc; -use crate::ty::internal::{Thunk, Value}; +use crate::ty::internal::{AttrSet, Thunk, Value}; use crate::vm::{LetEnv, VM}; #[derive(Debug, Clone)] @@ -44,7 +46,7 @@ impl From for Param { #[derive(Debug, Clone, Constructor)] pub struct Func<'jit: 'vm, 'vm> { pub func: &'vm BFunc, - pub env: LetEnv<'jit, 'vm>, + pub env: Rc>, pub compiled: OnceCell>>, pub count: Cell, } @@ -57,14 +59,14 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> { Ident(ident) => self .env .clone() - .enter_let([(ident.into(), arg)].into_iter()), + .enter_arg(ident.into(), arg), Formals { formals, ellipsis, alias, } => { let arg = arg.unwrap_attr_set(); - let mut new = Vec::with_capacity(formals.len() + alias.iter().len()); + let mut new = HashMap::with_capacity(formals.len() + alias.iter().len()); if !ellipsis && arg .as_inner() @@ -83,20 +85,24 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> { default.map(|idx| Value::Thunk(Thunk::new(vm.get_thunk(idx)).into())) }) .unwrap(); - new.push((formal, arg)); + new.insert(formal, arg); } if let Some(alias) = alias { - new.push((alias.clone().into(), Value::AttrSet(arg))); + new.insert(alias.clone().into(), Value::AttrSet(arg)); } - self.env.clone().enter_let(new.into_iter()) + self.env.clone().enter_attrs(AttrSet::new(new).into()) } - }; + }.into(); let count = self.count.get(); self.count.replace(count + 1); if count >= 1 { let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func)); - let ret = unsafe { compiled.call(vm as *const VM, &env as *const LetEnv) }; + let env = Rc::into_raw(env); + let ret = unsafe { compiled.call(vm as *const VM, env) }; + unsafe { + Rc::decrement_strong_count(env); + } return Ok(ret.into()); } vm.eval(self.func.opcodes.iter().copied(), env) diff --git a/src/ty/internal/mod.rs b/src/ty/internal/mod.rs index 2c6ff42..42f6851 100644 --- a/src/ty/internal/mod.rs +++ b/src/ty/internal/mod.rs @@ -485,7 +485,7 @@ pub struct Thunk<'jit, 'vm> { #[derive(Debug, IsVariant, Unwrap, Clone)] pub enum _Thunk<'jit, 'vm> { - Code(&'vm OpCodes, OnceCell>), + Code(&'vm OpCodes, OnceCell>>), SuspendedFrom(*const Thunk<'jit, 'vm>), Value(Value<'jit, 'vm>), } @@ -497,7 +497,7 @@ impl<'jit, 'vm> Thunk<'jit, 'vm> { } } - pub fn capture(&self, env: LetEnv<'jit, 'vm>) { + pub fn capture(&self, env: Rc>) { if let _Thunk::Code(_, envcell) = &*self.thunk.borrow() { envcell.get_or_init(|| env); } diff --git a/src/vm/env.rs b/src/vm/env.rs index 9df12c4..527b0a7 100644 --- a/src/vm/env.rs +++ b/src/vm/env.rs @@ -3,9 +3,9 @@ use std::rc::Rc; use crate::ty::internal::{AttrSet, Value}; -#[derive(Debug, Default, Clone)] +#[derive(Debug, Clone)] pub struct LetEnv<'jit, 'vm> { - map: Rc>>, + map: Env<'jit, 'vm>, last: Option>>, } @@ -15,25 +15,51 @@ pub struct WithEnv<'jit, 'vm> { last: Option>>, } +#[derive(Debug, Clone)] +enum Env<'jit, 'vm> { + Let(Rc>), + SingleArg(usize, Value<'jit, 'vm>), + MultiArg(Rc>), +} + +#[derive(Debug, Clone, Copy)] +pub enum Type { + Arg, + Let, + With +} + impl<'jit, 'vm> LetEnv<'jit, 'vm> { - pub fn new(map: Rc>>) -> Self { - Self { map, last: None } + pub fn new(map: Rc>) -> Self { + Self { map: Env::Let(map), last: None } } pub fn lookup(&self, symbol: usize) -> Option> { - if let Some(val) = self.map.get(&symbol).cloned() { - return Some(val); + use Env::*; + match &self.map { + Let(map) | MultiArg(map) => if let Some(val) = map.select(symbol) { + return Some(val) + } + SingleArg(sym, val) => if *sym == symbol { + return Some(val.clone()) + } } self.last.as_ref().map(|env| env.lookup(symbol)).flatten() } - pub fn enter_let(self, new: impl Iterator)>) -> Self { - let map = Rc::new(new.collect()); - let last = Some(self.into()); - LetEnv { last, map } + pub fn enter_arg(self: Rc, ident: usize, val: Value<'jit, 'vm>) -> Rc { + let last = Some(self); + let map = Env::SingleArg(ident, val); + LetEnv { last, map }.into() } - pub fn enter_with(self, new: Rc>) -> Self { + pub fn enter_attrs(self: Rc, map: Rc>) -> Rc { + let last = Some(self); + let map = Env::Let(map); + LetEnv { last, map }.into() + } + + pub fn enter_with(self: Rc, new: Rc>) -> Rc { let map = new .as_inner() .iter() @@ -47,14 +73,14 @@ impl<'jit, 'vm> LetEnv<'jit, 'vm> { }, ) }) - .collect::>() - .into(); + .collect::>(); + let map = Env::Let(AttrSet::new(map).into()); let last = Some(self.into()); - LetEnv { last, map } + LetEnv { last, map }.into() } - pub fn leave(self) -> Self { - self.last.unwrap().as_ref().clone() + pub fn leave(self: Rc) -> Rc { + self.last.clone().unwrap() } } diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 2fab1e0..a7fcd46 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -1,6 +1,7 @@ use hashbrown::{HashMap, HashSet}; use inkwell::execution_engine::JitFunction; use std::cell::{Cell, OnceCell, RefCell}; +use std::rc::Rc; use crate::builtins::env; use crate::bytecode::{BinOp, Func as F, OpCode, OpCodes, Program, UnOp}; @@ -34,7 +35,7 @@ pub fn run(prog: Program, jit: JITContext<'_>) -> Result { let env = env(&vm); let mut seen = HashSet::new(); let value = vm - .eval(prog.top_level.into_iter(), env)? + .eval(prog.top_level.into_iter(), env.into())? .to_public(&vm, &mut seen); Ok(value) } @@ -82,7 +83,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { pub fn eval( &'vm self, opcodes: impl Iterator, - mut env: LetEnv<'jit, 'vm>, + mut env: Rc>, ) -> Result> { let mut stack = Stack::<_, STACK_SIZE>::new(); let mut iter = opcodes.into_iter(); @@ -103,7 +104,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { &'vm self, opcode: OpCode, stack: &'s mut Stack, CAP>, - env: &mut LetEnv<'jit, 'vm>, + env: &mut Rc>, ) -> Result { match opcode { OpCode::Illegal => panic!("illegal opcode"), @@ -188,14 +189,11 @@ impl<'vm, 'jit: 'vm> VM<'jit> { stack.push(Value::AttrSet(AttrSet::with_capacity(cap).into()))?; } OpCode::FinalizeRec => { - let env = env.clone().enter_let( + let env = env.clone().enter_attrs( stack .tos()? .clone() .unwrap_attr_set() - .as_inner() - .iter() - .map(|(k, v)| (k.clone(), v.clone())), ); stack.tos_mut()?.as_mut().unwrap_attr_set().capture(&env); }