diff --git a/src/jit/helpers.rs b/src/jit/helpers.rs index eaa93c6..3e7c196 100644 --- a/src/jit/helpers.rs +++ b/src/jit/helpers.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use inkwell::AddressSpace; use inkwell::context::Context; use inkwell::execution_engine::ExecutionEngine; @@ -21,10 +23,12 @@ pub struct Helpers<'ctx> { pub func_type: FunctionType<'ctx>, pub debug: FunctionValue<'ctx>, - pub captured_env: FunctionValue<'ctx>, + pub capture_env: FunctionValue<'ctx>, pub neg: FunctionValue<'ctx>, pub add: FunctionValue<'ctx>, + pub sub: FunctionValue<'ctx>, pub call: FunctionValue<'ctx>, + pub lookup: FunctionValue<'ctx>, } impl<'ctx> Helpers<'ctx> { @@ -40,7 +44,7 @@ impl<'ctx> Helpers<'ctx> { let ptr_type = context.ptr_type(AddressSpace::default()); let value_type = context.struct_type(&[int_type.into(), int_type.into()], false); let func_type = value_type.fn_type( - &[ptr_type.into(), ptr_type.into(), value_type.into()], + &[ptr_type.into(), ptr_type.into()], false, ); let debug = module.add_function( @@ -50,7 +54,7 @@ impl<'ctx> Helpers<'ctx> { .fn_type(&[value_type.into()], false), None ); - let captured_env = module.add_function( + let capture_env = module.add_function( "capture_env", context .void_type() @@ -67,19 +71,31 @@ impl<'ctx> Helpers<'ctx> { value_type.fn_type(&[value_type.into(), value_type.into()], false), None ); + let sub = module.add_function( + "sub", + value_type.fn_type(&[value_type.into(), value_type.into()], false), + None + ); // Assuming a single argument for now based on the test case let call = module.add_function( "call", value_type.fn_type(&[ptr_type.into(), ptr_type.into(), ptr_type.into(), value_type.into()], false), None ); + let lookup = module.add_function( + "lookup", + value_type.fn_type(&[ptr_int_type.into(), ptr_type.into()], false), + None + ); execution_engine.add_global_mapping(&debug, helper_debug as _); - execution_engine.add_global_mapping(&captured_env, helper_capture_env as _); + execution_engine.add_global_mapping(&capture_env, helper_capture_env as _); execution_engine.add_global_mapping(&neg, helper_neg as _); execution_engine.add_global_mapping(&add, helper_add as _); + execution_engine.add_global_mapping(&sub, helper_sub as _); execution_engine.add_global_mapping(&call, helper_call as _); + execution_engine.add_global_mapping(&lookup, helper_lookup as _); Helpers { @@ -92,10 +108,12 @@ impl<'ctx> Helpers<'ctx> { func_type, debug, - captured_env, + capture_env, neg, add, + sub, call, + lookup, } } @@ -194,9 +212,31 @@ extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue { } #[unsafe(no_mangle)] -extern "C" fn helper_call<'jit, 'vm>(vm: *const VM<'jit>, env: *const Env<'jit, 'vm>, func_ptr: *const (), arg: JITValue) -> JITValue { - let func: JITFunc = unsafe { std::mem::transmute(func_ptr) }; - unsafe { - func(vm, env, arg) +extern "C" fn helper_sub(lhs: JITValue, rhs: JITValue) -> JITValue { + use ValueTag::*; + match (lhs.tag, rhs.tag) { + (Int, Int) => JITValue { + tag: Int, + data: JITValueData { + int: unsafe { lhs.data.int - rhs.data.int }, + }, + }, + _ => todo!("Substruction not implemented for {:?} and {:?}", lhs.tag, rhs.tag), } } + +#[unsafe(no_mangle)] +extern "C" fn helper_call<'jit, 'vm>(vm: *const VM<'jit>, env: *const Env<'jit, 'vm>, func_ptr: *const (), arg: JITValue) -> JITValue { + let func: JITFunc = unsafe { std::mem::transmute(func_ptr) }; + todo!(); + unsafe { + func(vm, env) + } +} + +#[unsafe(no_mangle)] +extern "C" fn helper_lookup<'jit, 'vm>(sym: usize, env: *const Env<'jit, 'vm>) -> JITValue { + let env = unsafe { env.as_ref() }.unwrap(); + let val = env.lookup(sym); + val.unwrap().into() +} diff --git a/src/jit/mod.rs b/src/jit/mod.rs index ffd1c46..03f81c4 100644 --- a/src/jit/mod.rs +++ b/src/jit/mod.rs @@ -72,7 +72,7 @@ impl From> for JITValue { } } -pub type JITFunc<'jit, 'vm> = unsafe extern "C" fn(*const VM<'jit>, *const Env<'jit, 'vm>, JITValue) -> JITValue; +pub type JITFunc<'jit, 'vm> = unsafe extern "C" fn(*const VM<'jit>, *const Env<'jit, 'vm>) -> JITValue; pub struct JITContext<'ctx> { context: &'ctx Context, @@ -89,8 +89,6 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { let execution_engine = module .create_jit_execution_engine(OptimizationLevel::Default) .unwrap(); - - let helpers = Helpers::new(context, &module, &execution_engine); JITContext { @@ -99,7 +97,6 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { context, module, - helpers } } @@ -114,12 +111,12 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { let func_ = self.module.add_function("nixjit_function", self.helpers.func_type, None); let entry = self.context.append_basic_block(func_, "entry"); self.builder.position_at_end(entry); + let env = func_.get_nth_param(1).unwrap().into_pointer_value(); while let Some(opcode) = iter.next() { - self.single_op(opcode, vm, &mut stack)?; + self.single_op(opcode, vm, env, &mut stack)?; } assert_eq!(stack.len(), 1); let value = stack.pop(); - self.builder.build_direct_call(self.helpers.debug, &[value.into()], "call_debug").unwrap(); self.builder.build_return(Some(&value))?; if func_.verify(false) { func_.print_to_stderr(); @@ -137,6 +134,7 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { &self, opcode: OpCode, vm: &'vm VM<'_>, + env: PointerValue<'ctx>, stack: &mut Stack, CAP>, ) -> Result<()> { match opcode { @@ -151,48 +149,59 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { } } OpCode::LoadThunk { idx } => stack.push(self.helpers.new_thunk(Rc::into_raw(Rc::new(Thunk::new(vm.get_thunk(idx))))))?, - /* OpCode::CaptureEnv => { + OpCode::CaptureEnv => { let thunk = *stack.tos()?; - self.builder.build_direct_call(self.helpers.captured_env, &[thunk.into(), self.new_ptr(env).into()], "call_capture_env")?; - } */ + self.builder.build_direct_call(self.helpers.capture_env, &[thunk.into(), env.into()], "call_capture_env")?; + } OpCode::UnOp { op } => { use UnOp::*; let rhs = stack.pop(); stack.push(match op { - Neg => self.builder.build_direct_call(self.helpers.neg, &[rhs.into(), self.new_ptr(std::ptr::null::()).into()], "call_neg")?.try_as_basic_value().left().unwrap(), + Neg => self.builder.build_direct_call(self.helpers.neg, &[rhs.into(), env.into()], "call_neg")?.try_as_basic_value().left().unwrap(), _ => todo!() })? } - /* OpCode::Func { idx } => { + OpCode::Func { idx } => { let func = vm.get_func(idx); - let jit_func_ptr = self.compile_function(&Func::new(func, unsafe { env.as_ref() }.unwrap().clone(), OnceCell::new(), Cell::new(1)), vm)?; + let jit_func_ptr = self.compile_function(func, vm)?; let jit_value = self.helpers.value_type.const_named_struct(&[ self.helpers.int_type.const_int(ValueTag::Function as _, false).into(), - self.helpers.ptr_int_type.const_int(jit_func_ptr as *const _ as _, false).into(), + self.helpers.ptr_int_type.const_int(unsafe { jit_func_ptr.as_raw() } as _, false).into(), ]).into(); stack.push(jit_value)?; - } */ + } OpCode::Call { arity } => { // Assuming arity is 1 for the test case assert_eq!(arity, 1); let arg = stack.pop(); let func_value = stack.pop(); let func_ptr = self.builder.build_extract_value(func_value.into_struct_value(), 1, "func_ptr")?.into_pointer_value(); - let result = self.builder.build_direct_call(self.helpers.call, &[self.new_ptr(vm).into(), self.new_ptr(std::ptr::null::()).into(), func_ptr.into(), arg.into()], "call_func")?.try_as_basic_value().left().unwrap(); + let result = self.builder.build_direct_call(self.helpers.call, &[self.new_ptr(vm).into(), env.into(), func_ptr.into(), arg.into()], "call_func")?.try_as_basic_value().left().unwrap(); stack.push(result)?; } OpCode::BinOp { op } => { use crate::bytecode::BinOp; + let rhs = stack.pop(); + let lhs = stack.pop(); match op { BinOp::Add => { - let rhs = stack.pop(); - let lhs = stack.pop(); + let result = self.builder.build_direct_call(self.helpers.add, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap(); + stack.push(result)?; + } + BinOp::Sub => { + let result = self.builder.build_direct_call(self.helpers.sub, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap(); + stack.push(result)?; + } + BinOp::Eq => { let result = self.builder.build_direct_call(self.helpers.add, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap(); stack.push(result)?; } _ => todo!("BinOp::{:?} not implemented in JIT", op), } } + OpCode::LookUp { sym } => { + stack.push(self.builder.build_direct_call(self.helpers.lookup, &[self.helpers.ptr_int_type.const_int(sym as u64, false).into(), env.into()], "call_lookup").unwrap().try_as_basic_value().left().unwrap())? + } _ => todo!("{opcode:?} not implemented") } Ok(()) diff --git a/src/ty/internal/func.rs b/src/ty/internal/func.rs index 9998b5a..a3835eb 100644 --- a/src/ty/internal/func.rs +++ b/src/ty/internal/func.rs @@ -53,17 +53,8 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> { pub fn call(&self, vm: &'vm VM<'jit>, arg: Value<'jit, 'vm>) -> Result> { use Param::*; - let count = self.count.get(); - if count >= 1 { - let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func)); - let ret = unsafe { compiled.call(vm as _, &self.env as _, arg.into()) }; - return Ok(ret.into()) - } - self.count.replace(count + 1); - let mut env = self.env.clone(); - - match self.func.param.clone() { - Ident(ident) => env = env.enter([(ident.into(), arg)].into_iter()), + let env = match self.func.param.clone() { + Ident(ident) => self.env.clone().enter([(ident.into(), arg)].into_iter()), Formals { formals, ellipsis, @@ -94,10 +85,17 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> { if let Some(alias) = alias { new.push((alias.clone().into(), Value::AttrSet(arg))); } - env = env.enter(new.into_iter()); + self.env.clone().enter(new.into_iter()) } - } + }; + let count = self.count.get(); + self.count.replace(count + 1); + if count >= 1 { + let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func)); + let ret = unsafe { compiled.call(vm as *const VM, &env as *const Env) }; + return Ok(ret.into()) + } vm.eval(self.func.opcodes.iter().copied(), env) } } diff --git a/src/vm/env.rs b/src/vm/env.rs index 6f711b8..8995f3a 100644 --- a/src/vm/env.rs +++ b/src/vm/env.rs @@ -5,8 +5,8 @@ use crate::ty::internal::{AttrSet, Value}; #[derive(Debug, Default, Clone)] pub struct Env<'jit, 'vm> { - last: Option>>, - map: Rc>>, + pub map: Rc>>, + pub last: Option>>, } impl<'jit, 'vm> Env<'jit, 'vm> {