diff --git a/.gitignore b/.gitignore index b8a4d91..fcb47cb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ target/ /.direnv/ + +.env diff --git a/src/compile.rs b/src/compile.rs index 203a65d..1dd0df7 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -101,16 +101,18 @@ impl Compile for ir::Attrs { cap: self.stcs.len() + self.dyns.len(), }); for stc in self.stcs { + let thunk = stc.1.is_thunk(); stc.1.compile(comp); - if !self.rec { + if thunk && !self.rec { comp.push(OpCode::CaptureEnv); } comp.push(OpCode::PushStaticAttr { name: stc.0 }); } for dynamic in self.dyns { + let thunk = dynamic.1.is_thunk(); dynamic.0.compile(comp); dynamic.1.compile(comp); - if !self.rec { + if thunk && !self.rec { comp.push(OpCode::CaptureEnv); } comp.push(OpCode::PushDynamicAttr) diff --git a/src/ir.rs b/src/ir.rs index fab1c5a..d50b18a 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -1,4 +1,5 @@ use hashbrown::HashMap; +use derive_more::IsVariant; use ecow::EcoString; use rnix::ast::{self, Expr}; @@ -38,7 +39,7 @@ macro_rules! ir { ) ,*$(,)? ) => { - #[derive(Clone, Debug)] + #[derive(Clone, Debug, IsVariant)] pub enum Ir { $( $ty($ty), diff --git a/src/jit/helpers.rs b/src/jit/helpers.rs new file mode 100644 index 0000000..9666fdb --- /dev/null +++ b/src/jit/helpers.rs @@ -0,0 +1,200 @@ +use inkwell::AddressSpace; +use inkwell::context::Context; +use inkwell::execution_engine::ExecutionEngine; +use inkwell::module::{Linkage, Module}; +use inkwell::types::{FloatType, FunctionType, IntType, PointerType, StructType}; +use inkwell::values::{BasicValueEnum, FunctionValue}; + +use crate::jit::JITValueData; +use crate::ty::internal::Thunk; +use crate::vm::{VM, Env}; + +use super::{JITValue, ValueTag, JITFunc}; + +pub struct Helpers<'ctx> { + pub int_type: IntType<'ctx>, + pub float_type: FloatType<'ctx>, + pub bool_type: IntType<'ctx>, + pub ptr_int_type: IntType<'ctx>, + pub ptr_type: PointerType<'ctx>, + pub value_type: StructType<'ctx>, + pub func_type: FunctionType<'ctx>, + + pub debug: FunctionValue<'ctx>, + pub captured_env: FunctionValue<'ctx>, + pub neg: FunctionValue<'ctx>, + pub add: FunctionValue<'ctx>, + pub call: FunctionValue<'ctx>, +} + +impl<'ctx> Helpers<'ctx> { + pub fn new( + context: &'ctx Context, + module: &Module<'ctx>, + execution_engine: &ExecutionEngine<'ctx>, + ) -> Self { + let int_type = context.i64_type(); + let float_type = context.f64_type(); + let bool_type = context.bool_type(); + let ptr_int_type = context.ptr_sized_int_type(execution_engine.get_target_data(), None); + let ptr_type = context.ptr_type(AddressSpace::default()); + let value_type = context.struct_type(&[int_type.into(), int_type.into()], false); + let func_type = value_type.fn_type( + &[ptr_type.into(), ptr_type.into(), value_type.into()], + false, + ); + let debug = module.add_function( + "debug", + context + .void_type() + .fn_type(&[value_type.into()], false), + None + ); + let captured_env = module.add_function( + "capture_env", + context + .void_type() + .fn_type(&[value_type.into(), ptr_type.into()], false), + None + ); + let neg = module.add_function( + "neg", + value_type.fn_type(&[value_type.into(), ptr_type.into()], false), + None + ); + let add = module.add_function( + "add", + value_type.fn_type(&[value_type.into(), value_type.into()], false), + None + ); + // Assuming a single argument for now based on the test case + let call = module.add_function( + "call", + value_type.fn_type(&[ptr_type.into(), ptr_type.into(), ptr_type.into(), value_type.into()], false), + None + ); + + + execution_engine.add_global_mapping(&debug, helper_debug as _); + execution_engine.add_global_mapping(&captured_env, helper_capture_env as _); + execution_engine.add_global_mapping(&neg, helper_neg as _); + execution_engine.add_global_mapping(&add, helper_add as _); + execution_engine.add_global_mapping(&call, helper_call as _); + + + Helpers { + int_type, + float_type, + bool_type, + ptr_int_type, + ptr_type, + value_type, + func_type, + + debug, + captured_env, + neg, + add, + call, + } + } + + pub fn new_int(&self, int: i64) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.int_type.const_int(ValueTag::Int as _, false).into(), + self.int_type.const_int(int as _, false).into(), + ]) + .into() + } + + pub fn new_float(&self, float: f64) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.int_type.const_int(ValueTag::Float as _, false).into(), + self.float_type.const_float(float).into(), + ]) + .into() + } + + pub fn new_bool(&self, bool: bool) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.int_type.const_int(ValueTag::Bool as _, false).into(), + self.bool_type.const_int(bool as _, false).into(), + ]) + .into() + } + + pub fn new_null(&self) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.int_type.const_int(ValueTag::Null as _, false).into(), + self.int_type.const_zero().into(), + ]) + .into() + } + + pub fn const_string(&self, string: *const u8) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.int_type.const_int(ValueTag::String as _, false).into(), + self.ptr_int_type.const_int(string as _, false).into(), + ]) + .into() + } + + pub fn new_thunk(&self, thunk: *const Thunk) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.int_type.const_int(ValueTag::Thunk as _, false).into(), + self.ptr_int_type.const_int(thunk as _, false).into(), + ]) + .into() + } +} + +#[unsafe(no_mangle)] +extern "C" fn helper_debug(value: JITValue) { + dbg!(value.tag); +} + +#[unsafe(no_mangle)] +extern "C" fn helper_capture_env(thunk: JITValue, env: *const Env) { + let thunk: &Thunk = unsafe { std::mem::transmute(thunk.data.ptr.as_ref().unwrap()) }; + thunk.capture(unsafe { env.as_ref().unwrap() }.clone()); +} + +#[unsafe(no_mangle)] +extern "C" fn helper_neg(rhs: JITValue, _env: *const Env) -> JITValue { + use ValueTag::*; + match rhs.tag { + Int => JITValue { + tag: Int, + data: JITValueData { + int: -unsafe { rhs.data.int }, + }, + }, + _ => todo!(), + } +} + +#[unsafe(no_mangle)] +extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue { + use ValueTag::*; + match (lhs.tag, rhs.tag) { + (Int, Int) => JITValue { + tag: Int, + data: JITValueData { + int: unsafe { lhs.data.int + rhs.data.int }, + }, + }, + _ => todo!("Addition not implemented for {:?} and {:?}", lhs.tag, rhs.tag), + } +} + +#[unsafe(no_mangle)] +extern "C" fn helper_call(vm: *const VM<'_>, env: *const Env<'_>, func_ptr: *const (), arg: JITValue) -> JITValue { + let func: JITFunc = unsafe { std::mem::transmute(func_ptr) }; + func(vm, env, arg) +} diff --git a/src/jit/mod.rs b/src/jit/mod.rs index cab26d9..40c7a5b 100644 --- a/src/jit/mod.rs +++ b/src/jit/mod.rs @@ -4,16 +4,19 @@ use inkwell::builder::Builder; use inkwell::context::Context; use inkwell::execution_engine::ExecutionEngine; use inkwell::module::Module; -use inkwell::types::{FunctionType, StructType}; -use inkwell::values::BasicValueEnum; -use inkwell::{AddressSpace, OptimizationLevel}; +use inkwell::values::{BasicValueEnum, PointerValue}; +use inkwell::OptimizationLevel; -use crate::bytecode::OpCode; +use crate::bytecode::{Func, OpCode, UnOp}; use crate::error::*; use crate::stack::Stack; use crate::ty::common::Const; -use crate::ty::internal::{Func, Thunk, Value}; -use crate::vm::VM; +use crate::ty::internal::{Thunk, Value}; +use crate::vm::{Env, VM}; + +mod helpers; + +use helpers::Helpers; #[cfg(test)] mod test; @@ -21,6 +24,7 @@ mod test; const STACK_SIZE: usize = 8 * 1024 / size_of::(); #[repr(u64)] +#[derive(Debug, Clone, Copy)] pub enum ValueTag { Int, Float, @@ -31,6 +35,7 @@ pub enum ValueTag { Function, Thunk, Path, + Null, } #[repr(C)] @@ -44,9 +49,30 @@ pub union JITValueData { int: i64, float: f64, boolean: bool, + ptr: *const () } -pub type JITFunc = fn(usize, usize, JITValue) -> JITValue; +impl<'vm> Into> for JITValue { + fn into(self) -> Value<'vm> { + use ValueTag::*; + match self.tag { + Int => Value::Const(Const::Int(unsafe { self.data.int })), + Null => Value::Const(Const::Null), + _ => todo!("not implemented for {:?}", self.tag) + } + } +} + +impl From> for JITValue { + fn from(value: Value<'_>) -> Self { + match value { + Value::Const(Const::Int(int)) => JITValue { tag: ValueTag::Int, data: JITValueData { int } }, + _ => todo!() + } + } +} + +pub type JITFunc<'vm> = fn(*const VM<'_>, *const Env<'vm>, JITValue) -> JITValue; pub struct JITContext<'ctx> { context: &'ctx Context, @@ -54,127 +80,53 @@ pub struct JITContext<'ctx> { builder: Builder<'ctx>, execution_engine: ExecutionEngine<'ctx>, - value_type: StructType<'ctx>, - func_type: FunctionType<'ctx>, + helpers: Helpers<'ctx> } impl<'vm, 'ctx: 'vm> JITContext<'ctx> { pub fn new(context: &'ctx Context) -> Self { let module = context.create_module("nixjit"); + let execution_engine = module + .create_jit_execution_engine(OptimizationLevel::Default) + .unwrap(); - let int_type = context.i64_type(); - let pointer_type = context.ptr_type(AddressSpace::default()); - let value_type = context.struct_type(&[int_type.into(), int_type.into()], false); - let func_type = value_type.fn_type( - &[pointer_type.into(), pointer_type.into(), value_type.into()], - false, - ); + + let helpers = Helpers::new(context, &module, &execution_engine); JITContext { - execution_engine: module - .create_jit_execution_engine(OptimizationLevel::Default) - .unwrap(), + execution_engine, builder: context.create_builder(), context, module, - value_type, - func_type, + + helpers } } - fn new_int(&self, int: i64) -> BasicValueEnum<'ctx> { - self.value_type - .const_named_struct(&[ - self.context - .i64_type() - .const_int(ValueTag::Int as _, false) - .into(), - self.context.i64_type().const_int(int as _, false).into(), - ]) - .into() + pub fn new_ptr(&self, ptr: *const T) -> PointerValue<'ctx> { + self.builder.build_int_to_ptr(self.helpers.int_type.const_int(ptr as _, false), self.helpers.ptr_type, "ptrconv").unwrap() } - fn new_float(&self, float: f64) -> BasicValueEnum<'ctx> { - self.value_type - .const_named_struct(&[ - self.context - .i64_type() - .const_int(ValueTag::Float as _, false) - .into(), - self.context.f64_type().const_float(float).into(), - ]) - .into() - } - - fn new_bool(&self, bool: bool) -> BasicValueEnum<'ctx> { - self.value_type - .const_named_struct(&[ - self.context - .i64_type() - .const_int(ValueTag::Bool as _, false) - .into(), - self.context.bool_type().const_int(bool as _, false).into(), - ]) - .into() - } - - fn new_null(&self) -> BasicValueEnum<'ctx> { - self.value_type - .const_named_struct(&[ - self.context - .i64_type() - .const_int(ValueTag::Float as _, false) - .into(), - self.context.i64_type().const_zero().into(), - ]) - .into() - } - - fn const_string(&self, string: *const u8) -> BasicValueEnum<'ctx> { - self.value_type - .const_named_struct(&[ - self.context - .i64_type() - .const_int(ValueTag::Float as _, false) - .into(), - self.context - .ptr_sized_int_type(self.execution_engine.get_target_data(), None) - .const_int(string as _, false) - .into(), - ]) - .into() - } - - fn new_thunk(&self, thunk: *const Thunk) -> BasicValueEnum<'ctx> { - self.value_type - .const_named_struct(&[ - self.context - .i64_type() - .const_int(ValueTag::Thunk as _, false) - .into(), - self.context - .ptr_sized_int_type(self.execution_engine.get_target_data(), None) - .const_int(thunk as _, false) - .into(), - ]) - .into() - } - - pub fn compile_function(&self, func: &Func, vm: &'vm VM<'_>) -> Result { + pub fn compile_function(&self, func: &Func, vm: &'vm VM<'_>) -> Result<&'vm JITFunc> { let mut stack = Stack::<_, STACK_SIZE>::new(); - let mut iter = func.func.opcodes.iter().copied(); - let func_ = self.module.add_function("fn", self.func_type, None); + let mut iter = func.opcodes.iter().copied(); + let func_ = self.module.add_function("nixjit_function", self.helpers.func_type, None); let entry = self.context.append_basic_block(func_, "entry"); self.builder.position_at_end(entry); while let Some(opcode) = iter.next() { self.single_op(opcode, vm, &mut stack)?; } assert_eq!(stack.len(), 1); - self.builder.build_return(Some(&stack.pop()))?; + let value = stack.pop(); + self.builder.build_direct_call(self.helpers.debug, &[value.into()], "call_debug").unwrap(); + self.builder.build_return(Some(&value))?; if func_.verify(false) { + func_.print_to_stderr(); unsafe { - Ok(std::mem::transmute(self.execution_engine.get_function_address(func_.get_name().to_str().unwrap()).unwrap())) + let name = func_.get_name().to_str().unwrap(); + let addr = self.execution_engine.get_function_address(name).unwrap(); + Ok(std::mem::transmute(addr)) } } else { todo!() @@ -191,15 +143,57 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> { OpCode::Const { idx } => { use Const::*; match vm.get_const(idx) { - Int(int) => stack.push(self.new_int(int))?, - Float(float) => stack.push(self.new_float(float))?, - Bool(bool) => stack.push(self.new_bool(bool))?, - String(string) => stack.push(self.const_string(string.as_ptr()))?, - Null => stack.push(self.new_null())?, + Int(int) => stack.push(self.helpers.new_int(int))?, + Float(float) => stack.push(self.helpers.new_float(float))?, + Bool(bool) => stack.push(self.helpers.new_bool(bool))?, + String(string) => stack.push(self.helpers.const_string(string.as_ptr()))?, + Null => stack.push(self.helpers.new_null())?, } } - OpCode::LoadThunk { idx } => stack.push(self.new_thunk(Rc::new(Thunk::new(vm.get_thunk(idx))).as_ref() as _))?, - _ => todo!() + OpCode::LoadThunk { idx } => stack.push(self.helpers.new_thunk(Rc::into_raw(Rc::new(Thunk::new(vm.get_thunk(idx))))))?, + /* OpCode::CaptureEnv => { + let thunk = *stack.tos()?; + self.builder.build_direct_call(self.helpers.captured_env, &[thunk.into(), self.new_ptr(env).into()], "call_capture_env")?; + } */ + OpCode::UnOp { op } => { + use UnOp::*; + let rhs = stack.pop(); + stack.push(match op { + Neg => self.builder.build_direct_call(self.helpers.neg, &[rhs.into(), self.new_ptr(std::ptr::null::()).into()], "call_neg")?.try_as_basic_value().left().unwrap(), + _ => todo!() + })? + } + /* OpCode::Func { idx } => { + let func = vm.get_func(idx); + let jit_func_ptr = self.compile_function(&Func::new(func, unsafe { env.as_ref() }.unwrap().clone(), OnceCell::new(), Cell::new(1)), vm)?; + let jit_value = self.helpers.value_type.const_named_struct(&[ + self.helpers.int_type.const_int(ValueTag::Function as _, false).into(), + self.helpers.ptr_int_type.const_int(jit_func_ptr as *const _ as _, false).into(), + ]).into(); + stack.push(jit_value)?; + } */ + OpCode::Call { arity } => { + // Assuming arity is 1 for the test case + assert_eq!(arity, 1); + let arg = stack.pop(); + let func_value = stack.pop(); + let func_ptr = self.builder.build_extract_value(func_value.into_struct_value(), 1, "func_ptr")?.into_pointer_value(); + let result = self.builder.build_direct_call(self.helpers.call, &[self.new_ptr(vm).into(), self.new_ptr(std::ptr::null::()).into(), func_ptr.into(), arg.into()], "call_func")?.try_as_basic_value().left().unwrap(); + stack.push(result)?; + } + OpCode::BinOp { op } => { + use crate::bytecode::BinOp; + match op { + BinOp::Add => { + let rhs = stack.pop(); + let lhs = stack.pop(); + let result = self.builder.build_direct_call(self.helpers.add, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap(); + stack.push(result)?; + } + _ => todo!("BinOp::{:?} not implemented in JIT", op), + } + } + _ => todo!("{opcode:?} not implemented") } Ok(()) } diff --git a/src/jit/test.rs b/src/jit/test.rs index 9c8d710..1773c1a 100644 --- a/src/jit/test.rs +++ b/src/jit/test.rs @@ -92,5 +92,11 @@ macro_rules! attrs { #[test] fn test_jit_const() { - test_expr("let f = _: 1; in (f 1) + (f 1)", int!(2)); + // test_expr("let f = _: 1; in (f 1) + (f 1)", int!(2)); + test_expr("let f = _: 1; in (f 1) == (f 1)", boolean!(true)); +} + +#[test] +fn test_arith() { + test_expr("let f = _: -(-1); in (f 1) + (f 1)", int!(2)); } diff --git a/src/ty/internal/func.rs b/src/ty/internal/func.rs index 580f32b..c996867 100644 --- a/src/ty/internal/func.rs +++ b/src/ty/internal/func.rs @@ -44,17 +44,19 @@ impl From for Param { pub struct Func<'vm> { pub func: &'vm BFunc, pub env: Env<'vm>, - pub compiled: OnceCell, + pub compiled: OnceCell<&'vm JITFunc<'vm>>, pub count: Cell } -impl<'vm> Func<'vm> { - pub fn call(&self, vm: &'vm VM<'_>, arg: Value<'vm>) -> Result> { +impl<'vm, 'jit: 'vm> Func<'vm> { + pub fn call(&self, vm: &'vm VM<'jit>, arg: Value<'vm>) -> Result> { use Param::*; let count = self.count.get(); if count >= 1 { - let compiled = self.compiled.get_or_init(|| vm.compile_func(self)); + let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func)); + let ret = compiled(vm as _, &self.env as _, arg.into()); + return Ok(ret.into()) } self.count.replace(count + 1); let mut env = self.env.clone(); diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 49d942c..b666f0d 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -110,10 +110,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { OpCode::LoadThunk { idx } => { stack.push(Value::Thunk(Thunk::new(self.get_thunk(idx)).into()))? } - OpCode::CaptureEnv => match stack.tos()? { - Value::Thunk(thunk) => thunk.capture(env.clone()), - _ => (), - }, + OpCode::CaptureEnv => stack.tos().unwrap().as_ref().unwrap_thunk().capture(env.clone()), OpCode::ForceValue => { stack.tos_mut()?.force(self)?; } @@ -267,7 +264,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { Ok(0) } - pub fn compile_func(&'vm self, func: &Func<'vm>) -> JITFunc { + pub fn compile_func(&'vm self, func: &'vm F) -> &'vm JITFunc<'vm> { self.jit.compile_function(func, self).unwrap() } }