use std::rc::Rc; use inkwell::builder::Builder; use inkwell::context::Context; use inkwell::execution_engine::{ExecutionEngine, JitFunction}; use inkwell::module::Module; use inkwell::values::{BasicValueEnum, PointerValue}; use inkwell::OptimizationLevel; use crate::bytecode::{Func, OpCode, UnOp}; use crate::error::*; use crate::stack::Stack; use crate::ty::common::Const; use crate::ty::internal::{Thunk, Value}; use crate::vm::{Env, VM}; mod helpers; use helpers::Helpers; #[cfg(test)] mod test; const STACK_SIZE: usize = 8 * 1024 / size_of::(); #[repr(u64)] #[derive(Debug, Clone, Copy)] pub enum ValueTag { Int, Float, String, Bool, AttrSet, List, Function, Thunk, Path, Null, } #[repr(C)] pub struct JITValue { tag: ValueTag, data: JITValueData, } #[repr(C)] pub union JITValueData { int: i64, float: f64, boolean: bool, ptr: *const () } impl<'jit: 'vm, 'vm> Into> for JITValue { fn into(self) -> Value<'jit, 'vm> { use ValueTag::*; match self.tag { Int => Value::Const(Const::Int(unsafe { self.data.int })), Null => Value::Const(Const::Null), _ => todo!("not implemented for {:?}", self.tag) } } } impl From> for JITValue { fn from(value: Value<'_, '_>) -> Self { match value { Value::Const(Const::Int(int)) => JITValue { tag: ValueTag::Int, data: JITValueData { int } }, _ => todo!() } } } pub type JITFunc<'jit, 'vm> = unsafe extern "C" fn(*const VM<'jit>, *const Env<'jit, 'vm>, JITValue) -> JITValue; pub struct JITContext<'ctx> { context: &'ctx Context, module: Module<'ctx>, builder: Builder<'ctx>, execution_engine: ExecutionEngine<'ctx>, helpers: Helpers<'ctx> } impl<'vm, 'ctx: 'vm> JITContext<'ctx> { pub fn new(context: &'ctx Context) -> Self { let module = context.create_module("nixjit"); let execution_engine = module .create_jit_execution_engine(OptimizationLevel::Default) .unwrap(); let helpers = Helpers::new(context, &module, &execution_engine); JITContext { execution_engine, builder: context.create_builder(), context, module, helpers } } pub fn new_ptr(&self, ptr: *const T) -> PointerValue<'ctx> { self.builder.build_int_to_ptr(self.helpers.int_type.const_int(ptr as _, false), self.helpers.ptr_type, "ptrconv").unwrap() } pub fn compile_function(&self, func: &Func, vm: &'vm VM<'_>) -> Result>> { let mut stack = Stack::<_, STACK_SIZE>::new(); let mut iter = func.opcodes.iter().copied(); let func_ = self.module.add_function("nixjit_function", self.helpers.func_type, None); let entry = self.context.append_basic_block(func_, "entry"); self.builder.position_at_end(entry); while let Some(opcode) = iter.next() { self.single_op(opcode, vm, &mut stack)?; } assert_eq!(stack.len(), 1); let value = stack.pop(); self.builder.build_direct_call(self.helpers.debug, &[value.into()], "call_debug").unwrap(); self.builder.build_return(Some(&value))?; if func_.verify(false) { func_.print_to_stderr(); unsafe { let name = func_.get_name().to_str().unwrap(); let addr = self.execution_engine.get_function(name).unwrap(); Ok(addr) } } else { todo!() } } fn single_op( &self, opcode: OpCode, vm: &'vm VM<'_>, stack: &mut Stack, CAP>, ) -> Result<()> { match opcode { OpCode::Const { idx } => { use Const::*; match vm.get_const(idx) { Int(int) => stack.push(self.helpers.new_int(int))?, Float(float) => stack.push(self.helpers.new_float(float))?, Bool(bool) => stack.push(self.helpers.new_bool(bool))?, String(string) => stack.push(self.helpers.const_string(string.as_ptr()))?, Null => stack.push(self.helpers.new_null())?, } } OpCode::LoadThunk { idx } => stack.push(self.helpers.new_thunk(Rc::into_raw(Rc::new(Thunk::new(vm.get_thunk(idx))))))?, /* OpCode::CaptureEnv => { let thunk = *stack.tos()?; self.builder.build_direct_call(self.helpers.captured_env, &[thunk.into(), self.new_ptr(env).into()], "call_capture_env")?; } */ OpCode::UnOp { op } => { use UnOp::*; let rhs = stack.pop(); stack.push(match op { Neg => self.builder.build_direct_call(self.helpers.neg, &[rhs.into(), self.new_ptr(std::ptr::null::()).into()], "call_neg")?.try_as_basic_value().left().unwrap(), _ => todo!() })? } /* OpCode::Func { idx } => { let func = vm.get_func(idx); let jit_func_ptr = self.compile_function(&Func::new(func, unsafe { env.as_ref() }.unwrap().clone(), OnceCell::new(), Cell::new(1)), vm)?; let jit_value = self.helpers.value_type.const_named_struct(&[ self.helpers.int_type.const_int(ValueTag::Function as _, false).into(), self.helpers.ptr_int_type.const_int(jit_func_ptr as *const _ as _, false).into(), ]).into(); stack.push(jit_value)?; } */ OpCode::Call { arity } => { // Assuming arity is 1 for the test case assert_eq!(arity, 1); let arg = stack.pop(); let func_value = stack.pop(); let func_ptr = self.builder.build_extract_value(func_value.into_struct_value(), 1, "func_ptr")?.into_pointer_value(); let result = self.builder.build_direct_call(self.helpers.call, &[self.new_ptr(vm).into(), self.new_ptr(std::ptr::null::()).into(), func_ptr.into(), arg.into()], "call_func")?.try_as_basic_value().left().unwrap(); stack.push(result)?; } OpCode::BinOp { op } => { use crate::bytecode::BinOp; match op { BinOp::Add => { let rhs = stack.pop(); let lhs = stack.pop(); let result = self.builder.build_direct_call(self.helpers.add, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap(); stack.push(result)?; } _ => todo!("BinOp::{:?} not implemented in JIT", op), } } _ => todo!("{opcode:?} not implemented") } Ok(()) } }