//! The Just-In-Time (JIT) compilation module for nixjit. //! //! This module provides functionality to compile Low-Level IR (LIR) expressions //! into optimized machine code using Cranelift. The JIT compiler translates //! Nix expressions into efficient native code for faster evaluation. //! //! The main components are: //! - `JITCompiler`: The core compiler that manages the compilation process //! - `JITContext`: A trait that provides the execution context for JIT-compiled code //! - `Context`: An internal compilation context used during code generation use std::marker::PhantomData; use std::ops::Deref; use std::rc::Rc; use cranelift::codegen::ir::Function; use cranelift::codegen::ir::{self, ArgumentExtension, ArgumentPurpose, StackSlot}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; use hashbrown::{HashMap, HashSet}; use nixjit_eval::{EvalContext, Value}; use nixjit_lir::Lir; mod compile; mod helpers; pub use compile::JITCompile; use helpers::*; /// A trait that provides the execution context for JIT-compiled code. /// /// This trait extends `EvalContext` with additional methods needed /// for JIT compilation, such as managing `with` expression scopes directly. pub trait JITContext: EvalContext { /// Enters a `with` expression scope with the given namespace. fn enter_with(&mut self, namespace: Rc>); /// Exits the current `with` expression scope. fn exit_with(&mut self); } /// Type alias for a JIT-compiled function. /// /// This represents a function pointer to JIT-compiled code that takes /// a context pointer and a mutable value pointer as arguments. type F = unsafe extern "C" fn(*const Ctx, *mut Value); /// A JIT-compiled function. /// /// This struct holds a function pointer to the compiled code and /// a set of strings that were used during compilation, which need /// to be kept alive for the function to work correctly. pub struct JITFunc { func: F, strings: HashSet, } impl Deref for JITFunc { type Target = F; fn deref(&self) -> &Self::Target { &self.func } } /// The internal compilation context used during code generation. /// /// This context holds references to the compiler, the Cranelift function builder, /// and manages resources like stack slots and string literals during compilation. struct Context<'comp, 'ctx, Ctx: JITContext> { /// Reference to the JIT compiler. pub compiler: &'comp mut JITCompiler, /// The Cranelift function builder used to generate IR. pub builder: FunctionBuilder<'ctx>, /// Stack slots available for reuse. free_slots: Vec, /// String literals used during compilation. strings: HashSet, } impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> { fn new(compiler: &'comp mut JITCompiler, builder: FunctionBuilder<'ctx>) -> Self { Self { compiler, builder, free_slots: Vec::new(), strings: HashSet::new(), } } fn alloca(&mut self) -> StackSlot { self.free_slots.pop().map_or_else( || { let slot = StackSlotData::new(StackSlotKind::ExplicitSlot, 32, 3); self.builder.create_sized_stack_slot(slot) }, |x| x, ) } fn free_slot(&mut self, slot: StackSlot) { self.free_slots.push(slot); } fn alloc_array(&mut self, len: usize) -> ir::Value { let len = self .builder .ins() .iconst(self.compiler.ptr_type, len as i64); let alloc_array = self .compiler .module .declare_func_in_func(self.compiler.alloc_array, self.builder.func); let inst = self.builder.ins().call(alloc_array, &[len]); self.builder.inst_results(inst)[0] } fn create_string(&mut self, string: &str) -> StackSlot { let string = self .strings .get_or_insert_with(string, |_| string.to_owned()); let ptr = self .builder .ins() .iconst(self.compiler.ptr_type, string.as_ptr() as i64); let len = self .builder .ins() .iconst(self.compiler.ptr_type, string.len() as i64); let create_string = self .compiler .module .declare_func_in_func(self.compiler.create_string, self.builder.func); let slot = self.alloca(); let ret = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); self.builder.ins().call(create_string, &[ptr, len, ret]); slot } fn create_list(&mut self, ptr: ir::Value, len: usize) -> StackSlot { let len = self .builder .ins() .iconst(self.compiler.ptr_type, len as i64); let create_list = self .compiler .module .declare_func_in_func(self.compiler.create_list, self.builder.func); let slot = self.alloca(); let ret = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); self.builder.ins().call(create_list, &[ptr, len, ret]); slot } fn create_attrs(&mut self) -> StackSlot { let create_attrs = self .compiler .module .declare_func_in_func(self.compiler.create_attrs, self.builder.func); let slot = StackSlotData::new(StackSlotKind::ExplicitSlot, 40, 3); let slot = self.builder.create_sized_stack_slot(slot); let ret = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); self.builder.ins().call(create_attrs, &[ret]); slot } fn push_attr(&mut self, attrs: StackSlot, sym: &str, val: StackSlot) { self.free_slot(attrs); self.free_slot(val); let attrs = self.builder.ins().stack_addr(types::I64, attrs, 0); let val = self.builder.ins().stack_addr(types::I64, val, 0); let sym = self.strings.get_or_insert_with(sym, |_| sym.to_owned()); let ptr = self .builder .ins() .iconst(self.compiler.ptr_type, sym.as_ptr() as i64); let len = self .builder .ins() .iconst(self.compiler.ptr_type, sym.len() as i64); let push_attr = self .compiler .module .declare_func_in_func(self.compiler.push_attr, self.builder.func); self.builder.ins().call(push_attr, &[attrs, ptr, len, val]); } fn finalize_attrs(&mut self, attrs: StackSlot) -> StackSlot { let attrs = self.builder.ins().stack_addr(types::I64, attrs, 0); let finalize_attrs = self .compiler .module .declare_func_in_func(self.compiler.finalize_attrs, self.builder.func); let slot = self.alloca(); let ret = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); self.builder.ins().call(finalize_attrs, &[attrs, ret]); slot } fn enter_with(&mut self, env: ir::Value, namespace: StackSlot) { let ptr = self .builder .ins() .stack_addr(self.compiler.ptr_type, namespace, 0); let enter_with = self .compiler .module .declare_func_in_func(self.compiler.enter_with, self.builder.func); self.builder.ins().call(enter_with, &[env, ptr]); } fn exit_with(&mut self, env: ir::Value) { let exit_with = self .compiler .module .declare_func_in_func(self.compiler.exit_with, self.builder.func); self.builder.ins().call(exit_with, &[env]); } fn dbg(&mut self, slot: StackSlot) { let ptr = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); let dbg = self .compiler .module .declare_func_in_func(self.compiler.dbg, self.builder.func); self.builder.ins().call(dbg, &[ptr]); } fn call( &mut self, func: StackSlot, args_ptr: ir::Value, args_len: usize, engine: ir::Value, env: ir::Value, ) { let args_len = self .builder .ins() .iconst(self.compiler.ptr_type, args_len as i64); let call = self .compiler .module .declare_func_in_func(self.compiler.call, self.builder.func); let func = self .builder .ins() .stack_addr(self.compiler.ptr_type, func, 0); self.builder .ins() .call(call, &[func, args_ptr, args_len, engine, env]); } fn lookup(&mut self, env: ir::Value, sym: &str) -> StackSlot { let sym = self.strings.get_or_insert_with(sym, |_| sym.to_owned()); let ptr = self .builder .ins() .iconst(self.compiler.ptr_type, sym.as_ptr() as i64); let len = self .builder .ins() .iconst(self.compiler.ptr_type, sym.len() as i64); let lookup = self .compiler .module .declare_func_in_func(self.compiler.lookup, self.builder.func); let slot = self.alloca(); let ret = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); self.builder.ins().call(lookup, &[env, ptr, len, ret]); slot } fn lookup_stack(&mut self, env: ir::Value, idx: usize) -> StackSlot { let slot = self.alloca(); let lookup_stack = self .compiler .module .declare_func_in_func(self.compiler.lookup_stack, self.builder.func); let idx = self .builder .ins() .iconst(self.compiler.ptr_type, idx as i64); let ptr = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); self.builder.ins().call(lookup_stack, &[env, idx, ptr]); slot } fn lookup_arg(&mut self, env: ir::Value, idx: usize) -> StackSlot { let slot = self.alloca(); let lookup_arg = self .compiler .module .declare_func_in_func(self.compiler.lookup_arg, self.builder.func); let idx = self .builder .ins() .iconst(self.compiler.ptr_type, idx as i64); let ptr = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); self.builder.ins().call(lookup_arg, &[env, idx, ptr]); slot } fn select( &mut self, slot: StackSlot, path_ptr: ir::Value, path_len: usize, engine: ir::Value, env: ir::Value, ) { let select = self .compiler .module .declare_func_in_func(self.compiler.select, self.builder.func); let path_len = self .builder .ins() .iconst(self.compiler.ptr_type, path_len as i64); let ptr = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); self.builder .ins() .call(select, &[ptr, path_ptr, path_len, engine, env]); } fn select_with_default( &mut self, slot: StackSlot, path_ptr: ir::Value, path_len: usize, default: StackSlot, engine: ir::Value, env: ir::Value, ) { let select_with_default = self .compiler .module .declare_func_in_func(self.compiler.select_with_default, self.builder.func); let path_len = self .builder .ins() .iconst(self.compiler.ptr_type, path_len as i64); let ptr = self .builder .ins() .stack_addr(self.compiler.ptr_type, slot, 0); let default_ptr = self .builder .ins() .stack_addr(self.compiler.ptr_type, default, 0); self.builder.ins().call( select_with_default, &[ptr, path_ptr, path_len, default_ptr, engine, env], ); } pub fn eq(&mut self, lhs: StackSlot, rhs: StackSlot) { let lhs = self .builder .ins() .stack_addr(self.compiler.ptr_type, lhs, 0); let rhs = self .builder .ins() .stack_addr(self.compiler.ptr_type, rhs, 0); let eq = self .compiler .module .declare_func_in_func(self.compiler.eq, self.builder.func); self.builder.ins().call(eq, &[lhs, rhs]); } pub fn get_tag(&mut self, slot: StackSlot) -> ir::Value { self.builder.ins().stack_load(types::I64, slot, 0) } pub fn get_small_value(&mut self, ty: Type, slot: StackSlot) -> ir::Value { self.builder.ins().stack_load(ty, slot, 8) } } /// The main JIT compiler that manages the compilation process. pub struct JITCompiler { ctx: codegen::Context, module: JITModule, builder_ctx: Option, _marker: PhantomData, int_type: Type, float_type: Type, bool_type: Type, ptr_type: Type, value_type: Type, func_sig: Signature, call: FuncId, lookup_stack: FuncId, lookup_arg: FuncId, lookup: FuncId, select: FuncId, select_with_default: FuncId, eq: FuncId, alloc_array: FuncId, create_string: FuncId, create_list: FuncId, create_attrs: FuncId, push_attr: FuncId, finalize_attrs: FuncId, enter_with: FuncId, exit_with: FuncId, dbg: FuncId, } impl Default for JITCompiler { fn default() -> Self { Self::new() } } impl JITCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {msg}"); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol("helper_call", helper_call:: as _); builder.symbol("helper_lookup_stack", helper_lookup_stack:: as _); builder.symbol("helper_lookup_arg", helper_lookup_arg:: as _); builder.symbol("helper_lookup", helper_lookup:: as _); builder.symbol("helper_select", helper_select:: as _); builder.symbol( "helper_select_with_default", helper_select_with_default:: as _, ); builder.symbol("helper_eq", helper_eq:: as _); builder.symbol("helper_alloc_array", helper_alloc_array:: as _); builder.symbol("helper_create_string", helper_create_string:: as _); builder.symbol("helper_create_list", helper_create_list:: as _); builder.symbol("helper_create_attrs", helper_create_attrs:: as _); builder.symbol("helper_push_attr", helper_push_attr:: as _); builder.symbol("helper_finalize_attrs", helper_finalize_attrs:: as _); builder.symbol("helper_enter_with", helper_enter_with:: as _); builder.symbol("helper_exit_with", helper_exit_with:: as _); builder.symbol("helper_dbg", helper_dbg:: as _); let mut module = JITModule::new(builder); let ctx = module.make_context(); let int_type = types::I64; let float_type = types::F64; let bool_type = types::I8; let ptr_type = module.target_config().pointer_type(); let value_type = types::I128; // fn(*const Context, *const Env, *mut Value) let mut func_sig = module.make_signature(); func_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 3], ); // fn(func: &mut Value, args_ptr: *mut Value, args_len: usize, engine: &mut Context, env: // &mut Env) let mut call_sig = module.make_signature(); call_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 5], ); let call = module .declare_function("helper_call", Linkage::Import, &call_sig) .unwrap(); let mut lookup_stack_sig = module.make_signature(); lookup_stack_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 3], ); let lookup_stack = module .declare_function("helper_lookup_stack", Linkage::Import, &lookup_stack_sig) .unwrap(); // fn(env: &Env, level: usize, ret: &mut MaybeUninit) let mut lookup_arg_sig = module.make_signature(); lookup_arg_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 3], ); let lookup_arg = module .declare_function("helper_lookup_arg", Linkage::Import, &lookup_arg_sig) .unwrap(); // fn(env: &Env, sym_ptr: *const u8, sym_len: usize, ret: &mut MaybeUninit) let mut lookup_sig = module.make_signature(); lookup_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 4], ); let lookup = module .declare_function("helper_lookup", Linkage::Import, &lookup_sig) .unwrap(); // fn(val: &mut Value, path_ptr: *mut Value, path_len: usize, engine: &mut Context, env: &mut Env) let mut select_sig = module.make_signature(); select_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 5], ); let select = module .declare_function("helper_select", Linkage::Import, &select_sig) .unwrap(); // fn(val: &mut Value, path_ptr: *mut Value, path_len: usize, default: NonNull, engine: &mut Context, env: &mut Env) let mut select_with_default_sig = module.make_signature(); select_with_default_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 6], ); let select_with_default = module .declare_function( "helper_select_with_default", Linkage::Import, &select_with_default_sig, ) .unwrap(); let mut eq_sig = module.make_signature(); eq_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 2], ); let eq = module .declare_function("helper_eq", Linkage::Import, &eq_sig) .unwrap(); let mut alloc_array_sig = module.make_signature(); alloc_array_sig.params.push(AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }); alloc_array_sig.returns.push(AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }); let alloc_array = module .declare_function("helper_alloc_array", Linkage::Import, &alloc_array_sig) .unwrap(); let mut create_string_sig = module.make_signature(); create_string_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 3], ); let create_string = module .declare_function("helper_create_string", Linkage::Import, &create_string_sig) .unwrap(); let mut create_list_sig = module.make_signature(); create_list_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 3], ); let create_list = module .declare_function("helper_create_list", Linkage::Import, &create_list_sig) .unwrap(); let mut create_attrs_sig = module.make_signature(); create_attrs_sig.params.push(AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }); let create_attrs = module .declare_function("helper_create_attrs", Linkage::Import, &create_attrs_sig) .unwrap(); let mut push_attr_sig = module.make_signature(); push_attr_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 4], ); let push_attr = module .declare_function("helper_push_attr", Linkage::Import, &push_attr_sig) .unwrap(); let mut finalize_attrs_sig = module.make_signature(); finalize_attrs_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 2], ); let finalize_attrs = module .declare_function( "helper_finalize_attrs", Linkage::Import, &finalize_attrs_sig, ) .unwrap(); let mut enter_with_sig = module.make_signature(); enter_with_sig.params.extend( [AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }; 2], ); let enter_with = module .declare_function("helper_enter_with", Linkage::Import, &enter_with_sig) .unwrap(); let mut exit_with_sig = module.make_signature(); exit_with_sig.params.push(AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }); let exit_with = module .declare_function("helper_exit_with", Linkage::Import, &exit_with_sig) .unwrap(); let mut dbg_sig = module.make_signature(); dbg_sig.params.push(AbiParam { value_type: ptr_type, purpose: ArgumentPurpose::Normal, extension: ArgumentExtension::None, }); let dbg = module .declare_function("helper_dbg", Linkage::Import, &dbg_sig) .unwrap(); Self { builder_ctx: None, _marker: PhantomData, ctx, module, int_type, float_type, bool_type, ptr_type, value_type, func_sig, call, lookup_stack, lookup_arg, lookup, select, select_with_default, eq, alloc_array, create_string, create_list, create_attrs, push_attr, finalize_attrs, enter_with, exit_with, dbg, } } pub fn compile(&mut self, ir: &Lir, id: usize) -> JITFunc { let func_id = self .module .declare_function( format!("nixjit_thunk{id}").as_str(), Linkage::Local, &self.func_sig, ) .unwrap(); let mut func = Function::new(); func.signature = self.func_sig.clone(); let mut builder_ctx = self.builder_ctx.take().unwrap_or_default(); let mut ctx = Context::new(self, FunctionBuilder::new(&mut func, &mut builder_ctx)); let entry = ctx.builder.create_block(); ctx.builder.append_block_params_for_function_params(entry); ctx.builder.switch_to_block(entry); let params = ctx.builder.block_params(entry); let engine = params[0]; let env = params[1]; let ret = params[2]; let res = ir.compile(&mut ctx, engine, env); let tag = ctx.builder.ins().stack_load(types::I64, res, 0); let val0 = ctx.builder.ins().stack_load(types::I64, res, 8); let val1 = ctx.builder.ins().stack_load(types::I64, res, 16); let val2 = ctx.builder.ins().stack_load(types::I64, res, 24); ctx.builder.ins().store(MemFlags::new(), tag, ret, 0); ctx.builder.ins().store(MemFlags::new(), val0, ret, 8); ctx.builder.ins().store(MemFlags::new(), val1, ret, 16); ctx.builder.ins().store(MemFlags::new(), val2, ret, 24); ctx.builder.ins().return_(&[]); ctx.builder.seal_all_blocks(); ctx.builder.finalize(); let strings = ctx.strings; if cfg!(debug_assertions) { println!("{ir:#?}"); println!("{}", func.display()); } self.ctx.func = func; self.module.define_function(func_id, &mut self.ctx).unwrap(); self.module.finalize_definitions().unwrap(); self.ctx.clear(); let _ = self.builder_ctx.insert(builder_ctx); unsafe { JITFunc { func: std::mem::transmute::<*const u8, F>( self.module.get_finalized_function(func_id), ), strings, } } } }