use std::ops::Deref; use gc_arena::{Collect, Gc, Mutation}; use inkwell::OptimizationLevel; use inkwell::builder::Builder; use inkwell::context::Context; use inkwell::execution_engine::{ExecutionEngine, JitFunction}; use inkwell::module::Module; use inkwell::values::{BasicValueEnum, FunctionValue, PointerValue}; use crate::bytecode::{Func, OpCode, UnOp}; use crate::env::VmEnv; use crate::error::*; use crate::stack::Stack; use crate::ty::common::Const; use crate::ty::internal::{Thunk, Value}; use crate::vm::VM; mod helpers; use helpers::Helpers; #[cfg(test)] mod test; const STACK_SIZE: usize = 8 * 1024 / size_of::(); #[repr(u64)] #[derive(Debug, Clone, Copy)] pub enum ValueTag { Int, Float, String, Bool, AttrSet, List, Function, Thunk, Path, Null, } #[repr(C)] #[derive(Clone, Copy)] pub struct JITValue { tag: ValueTag, data: JITValueData, } #[repr(C)] #[derive(Clone, Copy)] pub union JITValueData { int: i64, float: f64, bool: bool, ptr: *const (), } impl<'gc> From for Value<'gc> { fn from(value: JITValue) -> Self { use ValueTag::*; match value.tag { Int => Value::Int(unsafe { value.data.int }), Null => Value::Null, Function => Value::Func(unsafe { Gc::from_ptr(value.data.ptr as *const _) }), Thunk => Value::Thunk(self::Thunk { thunk: unsafe { Gc::from_ptr(value.data.ptr as *const _) }, }), _ => todo!("not implemented for {:?}", value.tag), } } } impl From<&Value<'_>> for JITValue { fn from(value: &Value<'_>) -> Self { match *value { Value::Int(int) => JITValue { tag: ValueTag::Int, data: JITValueData { int }, }, Value::Func(func) => JITValue { tag: ValueTag::Function, data: JITValueData { ptr: Gc::as_ptr(func) as *const _, }, }, Value::Thunk(ref thunk) => JITValue { tag: ValueTag::Thunk, data: JITValueData { ptr: Gc::as_ptr(thunk.thunk) as *const _, }, }, _ => todo!(), } } } impl From> for JITValue { fn from(value: Value) -> Self { match value { Value::Int(int) => JITValue { tag: ValueTag::Int, data: JITValueData { int }, }, Value::Func(func) => JITValue { tag: ValueTag::Function, data: JITValueData { ptr: Gc::as_ptr(func) as *const _, }, }, Value::Thunk(thunk) => JITValue { tag: ValueTag::Thunk, data: JITValueData { ptr: Gc::as_ptr(thunk.thunk) as *const _, }, }, _ => todo!(), } } } pub struct JITFunc<'gc>( JitFunction<'gc, unsafe extern "C" fn(*const VmEnv<'gc>, *const Mutation<'gc>) -> JITValue>, ); unsafe impl<'gc> Collect<'gc> for JITFunc<'gc> { fn trace>(&self, _: &mut T) {} const NEEDS_TRACE: bool = false; } impl<'gc> From< JitFunction<'gc, unsafe extern "C" fn(*const VmEnv<'gc>, *const Mutation<'gc>) -> JITValue>, > for JITFunc<'gc> { fn from( value: JitFunction< 'gc, unsafe extern "C" fn(*const VmEnv<'gc>, *const Mutation<'gc>) -> JITValue, >, ) -> Self { Self(value) } } impl<'gc> Deref for JITFunc<'gc> { type Target = JitFunction<'gc, unsafe extern "C" fn(*const VmEnv<'gc>, *const Mutation<'gc>) -> JITValue>; fn deref(&self) -> &Self::Target { &self.0 } } pub struct JITContext<'gc> { context: &'gc Context, module: Module<'gc>, builder: Builder<'gc>, execution_engine: ExecutionEngine<'gc>, helpers: Helpers<'gc>, } unsafe impl<'gc> Collect<'gc> for JITContext<'gc> { fn trace>(&self, _: &mut T) {} const NEEDS_TRACE: bool = false; } impl<'gc> JITContext<'gc> { pub fn new(context: &'gc Context) -> Self { // force linker to link JIT engine unsafe { inkwell::llvm_sys::execution_engine::LLVMLinkInMCJIT(); } let module = context.create_module("nixjit"); let execution_engine = module .create_jit_execution_engine(OptimizationLevel::Aggressive) .unwrap(); let helpers = Helpers::new(context, &module, &execution_engine); JITContext { execution_engine, builder: context.create_builder(), context, module, helpers, } } pub fn new_ptr(&self, ptr: *const T) -> PointerValue<'gc> { self.builder .build_int_to_ptr( self.helpers.int_type.const_int(ptr as _, false), self.helpers.ptr_type, "ptrconv", ) .unwrap() } pub fn compile_function(&self, func: &'gc Func, vm: &'gc VM<'gc>) -> Result> { let mut stack = Stack::<_, STACK_SIZE>::new(); let mut iter = func.opcodes.iter().copied(); let func_ = self .module .add_function("nixjit_function", self.helpers.func_type, None); let env = func_.get_nth_param(0).unwrap().into_pointer_value(); let mc = func_.get_nth_param(1).unwrap().into_pointer_value(); let entry = self.context.append_basic_block(func_, "entry"); self.builder.position_at_end(entry); self.build_expr( &mut iter, vm, env, mc, &mut stack, func_, func.opcodes.len(), )?; assert_eq!(stack.len(), 1); let value = stack.pop(); let exit = self.context.append_basic_block(func_, "exit"); self.builder.build_unconditional_branch(exit)?; self.builder.position_at_end(exit); self.builder.build_return(Some(&value))?; if func_.verify(true) { unsafe { let name = func_.get_name().to_str().unwrap(); let func = self.execution_engine.get_function(name).unwrap(); Ok(func.into()) } } else { todo!() } } fn build_expr( &self, iter: &mut impl Iterator, vm: &'gc VM<'gc>, env: PointerValue<'gc>, mc: PointerValue<'gc>, stack: &mut Stack, CAP>, func: FunctionValue<'gc>, mut length: usize, ) -> Result { while length > 1 { let opcode = iter.next().unwrap(); let br = self.single_op(opcode, vm, env, mc, stack)?; length -= 1; if br > 0 { let consq = self.context.append_basic_block(func, "consq"); let alter = self.context.append_basic_block(func, "alter"); let cont = self.context.append_basic_block(func, "cont"); let cond = self .builder .build_alloca(self.helpers.value_type, "cond_alloca")?; let result = self .builder .build_alloca(self.helpers.value_type, "result_alloca")?; self.builder.build_store(cond, stack.pop())?; self.builder.build_conditional_branch( self.builder .build_load( self.context.bool_type(), self.builder.build_struct_gep( self.helpers.value_type, cond, 1, "gep_cond", )?, "load_cond", )? .into_int_value(), consq, alter, )?; length -= br; self.builder.position_at_end(consq); let br = self.build_expr(iter, vm, env, mc, stack, func, br)?; self.builder.build_store(result, stack.pop())?; self.builder.build_unconditional_branch(cont)?; length -= br; self.builder.position_at_end(alter); self.build_expr(iter, vm, env, mc, stack, func, br)?; self.builder.build_store(result, stack.pop())?; self.builder.build_unconditional_branch(cont)?; self.builder.position_at_end(cont); stack.push(self.builder.build_load( self.helpers.value_type, result, "load_result", )?)?; } } if length > 0 { self.single_op(iter.next().unwrap(), vm, env, mc, stack) } else { Ok(0) } } #[inline(always)] fn single_op( &self, opcode: OpCode, vm: &'gc VM<'_>, env: PointerValue<'gc>, mc: PointerValue<'gc>, stack: &mut Stack, CAP>, ) -> Result { match opcode { OpCode::Const { idx } => { use Const::*; match vm.get_const(idx) { Int(int) => stack.push(self.helpers.new_int(int))?, Float(float) => stack.push(self.helpers.new_float(float))?, Bool(bool) => stack.push(self.helpers.new_bool(bool))?, String(string) => stack.push(self.helpers.const_string(string.as_ptr()))?, Null => stack.push(self.helpers.new_null())?, } } OpCode::LoadThunk { idx } => stack.push( self.builder .build_direct_call( self.helpers.new_thunk, &[ self.new_ptr(vm.get_thunk(idx) as *const _).into(), mc.into(), ], "call_capture_env", )? .try_as_basic_value() .unwrap_left() .into(), )?, OpCode::CaptureEnv => { let thunk = *stack.tos(); self.builder.build_direct_call( self.helpers.capture_env, &[thunk.into(), env.into()], "call_capture_env", )?; } OpCode::UnOp { op } => { use UnOp::*; let rhs = stack.pop(); stack.push(match op { Neg => self .builder .build_direct_call(self.helpers.neg, &[rhs.into(), env.into()], "call_neg")? .try_as_basic_value() .left() .unwrap(), Not => self .builder .build_direct_call(self.helpers.not, &[rhs.into(), env.into()], "call_neg")? .try_as_basic_value() .left() .unwrap(), })? } OpCode::BinOp { op } => { use crate::bytecode::BinOp; let rhs = stack.pop(); let lhs = stack.pop(); match op { BinOp::Add => { let result = self .builder .build_direct_call( self.helpers.add, &[lhs.into(), rhs.into()], "call_add", )? .try_as_basic_value() .left() .unwrap(); stack.push(result)?; } BinOp::Sub => { let result = self .builder .build_direct_call( self.helpers.sub, &[lhs.into(), rhs.into()], "call_add", )? .try_as_basic_value() .left() .unwrap(); stack.push(result)?; } BinOp::Eq => { let result = self .builder .build_direct_call( self.helpers.eq, &[lhs.into(), rhs.into()], "call_eq", )? .try_as_basic_value() .left() .unwrap(); stack.push(result)?; } BinOp::Or => { let result = self .builder .build_direct_call( self.helpers.or, &[lhs.into(), rhs.into()], "call_or", )? .try_as_basic_value() .left() .unwrap(); stack.push(result)?; } _ => todo!("BinOp::{:?} not implemented in JIT", op), } } OpCode::LookUp { sym } => stack.push( self.builder .build_direct_call( self.helpers.lookup, &[ self.helpers .ptr_int_type .const_int(sym as u64, false) .into(), env.into(), ], "call_lookup", ) .unwrap() .try_as_basic_value() .left() .unwrap(), )?, OpCode::Call => { let arg = stack.pop(); let func = self .builder .build_direct_call( self.helpers.force, &[stack.pop().into(), self.new_ptr(vm).into()], "force", )? .try_as_basic_value() .left() .unwrap(); let ret = self .builder .build_direct_call( self.helpers.call, &[func.into(), arg.into(), self.new_ptr(vm).into()], "call", )? .try_as_basic_value() .left() .unwrap(); stack.push(ret)?; } OpCode::JmpIfFalse { step } => return Ok(step), OpCode::Jmp { step } => return Ok(step), _ => todo!("{opcode:?} not implemented"), } Ok(0) } }