Files
nixjit/evaluator/nixjit_jit/src/lib.rs
2025-08-15 23:14:21 +08:00

796 lines
26 KiB
Rust

//! The Just-In-Time (JIT) compilation module for nixjit.
//!
//! This module provides functionality to compile Low-Level IR (LIR) expressions
//! into optimized machine code using Cranelift. The JIT compiler translates
//! Nix expressions into efficient native code for faster evaluation.
//!
//! The main components are:
//! - `JITCompiler`: The core compiler that manages the compilation process
//! - `JITContext`: A trait that provides the execution context for JIT-compiled code
//! - `Context`: An internal compilation context used during code generation
use std::marker::PhantomData;
use std::ops::Deref;
use std::rc::Rc;
use cranelift::codegen::ir::Function;
use cranelift::codegen::ir::{self, ArgumentExtension, ArgumentPurpose, StackSlot};
use cranelift::prelude::*;
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{FuncId, Linkage, Module};
use hashbrown::{HashMap, HashSet};
use nixjit_eval::{EvalContext, Value};
use nixjit_lir::Lir;
mod compile;
mod helpers;
pub use compile::JITCompile;
use helpers::*;
/// A trait that provides the execution context for JIT-compiled code.
///
/// This trait extends `EvalContext` with additional methods needed
/// for JIT compilation, such as managing `with` expression scopes directly.
pub trait JITContext: EvalContext {
/// Enters a `with` expression scope with the given namespace.
fn enter_with(&mut self, namespace: Rc<HashMap<String, Value>>);
/// Exits the current `with` expression scope.
fn exit_with(&mut self);
}
/// Type alias for a JIT-compiled function.
///
/// This represents a function pointer to JIT-compiled code that takes
/// a context pointer and a mutable value pointer as arguments.
type F<Ctx> = unsafe extern "C" fn(*const Ctx, *mut Value);
/// A JIT-compiled function.
///
/// This struct holds a function pointer to the compiled code and
/// a set of strings that were used during compilation, which need
/// to be kept alive for the function to work correctly.
pub struct JITFunc<Ctx: JITContext> {
func: F<Ctx>,
strings: HashSet<String>,
}
impl<Ctx: JITContext> Deref for JITFunc<Ctx> {
type Target = F<Ctx>;
fn deref(&self) -> &Self::Target {
&self.func
}
}
/// The internal compilation context used during code generation.
///
/// This context holds references to the compiler, the Cranelift function builder,
/// and manages resources like stack slots and string literals during compilation.
struct Context<'comp, 'ctx, Ctx: JITContext> {
/// Reference to the JIT compiler.
pub compiler: &'comp mut JITCompiler<Ctx>,
/// The Cranelift function builder used to generate IR.
pub builder: FunctionBuilder<'ctx>,
/// Stack slots available for reuse.
free_slots: Vec<StackSlot>,
/// String literals used during compilation.
strings: HashSet<String>,
}
impl<'comp, 'ctx, Ctx: JITContext> Context<'comp, 'ctx, Ctx> {
fn new(compiler: &'comp mut JITCompiler<Ctx>, builder: FunctionBuilder<'ctx>) -> Self {
Self {
compiler,
builder,
free_slots: Vec::new(),
strings: HashSet::new(),
}
}
fn alloca(&mut self) -> StackSlot {
self.free_slots.pop().map_or_else(
|| {
let slot = StackSlotData::new(StackSlotKind::ExplicitSlot, 32, 3);
self.builder.create_sized_stack_slot(slot)
},
|x| x,
)
}
fn free_slot(&mut self, slot: StackSlot) {
self.free_slots.push(slot);
}
fn alloc_array(&mut self, len: usize) -> ir::Value {
let len = self
.builder
.ins()
.iconst(self.compiler.ptr_type, len as i64);
let alloc_array = self
.compiler
.module
.declare_func_in_func(self.compiler.alloc_array, self.builder.func);
let inst = self.builder.ins().call(alloc_array, &[len]);
self.builder.inst_results(inst)[0]
}
fn create_string(&mut self, string: &str) -> StackSlot {
let string = self
.strings
.get_or_insert_with(string, |_| string.to_owned());
let ptr = self
.builder
.ins()
.iconst(self.compiler.ptr_type, string.as_ptr() as i64);
let len = self
.builder
.ins()
.iconst(self.compiler.ptr_type, string.len() as i64);
let create_string = self
.compiler
.module
.declare_func_in_func(self.compiler.create_string, self.builder.func);
let slot = self.alloca();
let ret = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
self.builder.ins().call(create_string, &[ptr, len, ret]);
slot
}
fn create_list(&mut self, ptr: ir::Value, len: usize) -> StackSlot {
let len = self
.builder
.ins()
.iconst(self.compiler.ptr_type, len as i64);
let create_list = self
.compiler
.module
.declare_func_in_func(self.compiler.create_list, self.builder.func);
let slot = self.alloca();
let ret = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
self.builder.ins().call(create_list, &[ptr, len, ret]);
slot
}
fn create_attrs(&mut self) -> StackSlot {
let create_attrs = self
.compiler
.module
.declare_func_in_func(self.compiler.create_attrs, self.builder.func);
let slot = StackSlotData::new(StackSlotKind::ExplicitSlot, 40, 3);
let slot = self.builder.create_sized_stack_slot(slot);
let ret = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
self.builder.ins().call(create_attrs, &[ret]);
slot
}
fn push_attr(&mut self, attrs: StackSlot, sym: &str, val: StackSlot) {
self.free_slot(attrs);
self.free_slot(val);
let attrs = self.builder.ins().stack_addr(types::I64, attrs, 0);
let val = self.builder.ins().stack_addr(types::I64, val, 0);
let sym = self.strings.get_or_insert_with(sym, |_| sym.to_owned());
let ptr = self
.builder
.ins()
.iconst(self.compiler.ptr_type, sym.as_ptr() as i64);
let len = self
.builder
.ins()
.iconst(self.compiler.ptr_type, sym.len() as i64);
let push_attr = self
.compiler
.module
.declare_func_in_func(self.compiler.push_attr, self.builder.func);
self.builder.ins().call(push_attr, &[attrs, ptr, len, val]);
}
fn finalize_attrs(&mut self, attrs: StackSlot) -> StackSlot {
let attrs = self.builder.ins().stack_addr(types::I64, attrs, 0);
let finalize_attrs = self
.compiler
.module
.declare_func_in_func(self.compiler.finalize_attrs, self.builder.func);
let slot = self.alloca();
let ret = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
self.builder.ins().call(finalize_attrs, &[attrs, ret]);
slot
}
fn enter_with(&mut self, env: ir::Value, namespace: StackSlot) {
let ptr = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, namespace, 0);
let enter_with = self
.compiler
.module
.declare_func_in_func(self.compiler.enter_with, self.builder.func);
self.builder.ins().call(enter_with, &[env, ptr]);
}
fn exit_with(&mut self, env: ir::Value) {
let exit_with = self
.compiler
.module
.declare_func_in_func(self.compiler.exit_with, self.builder.func);
self.builder.ins().call(exit_with, &[env]);
}
fn dbg(&mut self, slot: StackSlot) {
let ptr = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
let dbg = self
.compiler
.module
.declare_func_in_func(self.compiler.dbg, self.builder.func);
self.builder.ins().call(dbg, &[ptr]);
}
fn call(
&mut self,
func: StackSlot,
args_ptr: ir::Value,
args_len: usize,
engine: ir::Value,
env: ir::Value,
) {
let args_len = self
.builder
.ins()
.iconst(self.compiler.ptr_type, args_len as i64);
let call = self
.compiler
.module
.declare_func_in_func(self.compiler.call, self.builder.func);
let func = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, func, 0);
self.builder
.ins()
.call(call, &[func, args_ptr, args_len, engine, env]);
}
fn lookup(&mut self, env: ir::Value, sym: &str) -> StackSlot {
let sym = self.strings.get_or_insert_with(sym, |_| sym.to_owned());
let ptr = self
.builder
.ins()
.iconst(self.compiler.ptr_type, sym.as_ptr() as i64);
let len = self
.builder
.ins()
.iconst(self.compiler.ptr_type, sym.len() as i64);
let lookup = self
.compiler
.module
.declare_func_in_func(self.compiler.lookup, self.builder.func);
let slot = self.alloca();
let ret = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
self.builder.ins().call(lookup, &[env, ptr, len, ret]);
slot
}
fn lookup_stack(&mut self, env: ir::Value, idx: usize) -> StackSlot {
let slot = self.alloca();
let lookup_stack = self
.compiler
.module
.declare_func_in_func(self.compiler.lookup_stack, self.builder.func);
let idx = self
.builder
.ins()
.iconst(self.compiler.ptr_type, idx as i64);
let ptr = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
self.builder.ins().call(lookup_stack, &[env, idx, ptr]);
slot
}
fn lookup_arg(&mut self, env: ir::Value, idx: usize) -> StackSlot {
let slot = self.alloca();
let lookup_arg = self
.compiler
.module
.declare_func_in_func(self.compiler.lookup_arg, self.builder.func);
let idx = self
.builder
.ins()
.iconst(self.compiler.ptr_type, idx as i64);
let ptr = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
self.builder.ins().call(lookup_arg, &[env, idx, ptr]);
slot
}
fn select(
&mut self,
slot: StackSlot,
path_ptr: ir::Value,
path_len: usize,
engine: ir::Value,
env: ir::Value,
) {
let select = self
.compiler
.module
.declare_func_in_func(self.compiler.select, self.builder.func);
let path_len = self
.builder
.ins()
.iconst(self.compiler.ptr_type, path_len as i64);
let ptr = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
self.builder
.ins()
.call(select, &[ptr, path_ptr, path_len, engine, env]);
}
fn select_with_default(
&mut self,
slot: StackSlot,
path_ptr: ir::Value,
path_len: usize,
default: StackSlot,
engine: ir::Value,
env: ir::Value,
) {
let select_with_default = self
.compiler
.module
.declare_func_in_func(self.compiler.select_with_default, self.builder.func);
let path_len = self
.builder
.ins()
.iconst(self.compiler.ptr_type, path_len as i64);
let ptr = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, slot, 0);
let default_ptr = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, default, 0);
self.builder.ins().call(
select_with_default,
&[ptr, path_ptr, path_len, default_ptr, engine, env],
);
}
pub fn eq(&mut self, lhs: StackSlot, rhs: StackSlot) {
let lhs = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, lhs, 0);
let rhs = self
.builder
.ins()
.stack_addr(self.compiler.ptr_type, rhs, 0);
let eq = self
.compiler
.module
.declare_func_in_func(self.compiler.eq, self.builder.func);
self.builder.ins().call(eq, &[lhs, rhs]);
}
pub fn get_tag(&mut self, slot: StackSlot) -> ir::Value {
self.builder.ins().stack_load(types::I64, slot, 0)
}
pub fn get_small_value(&mut self, ty: Type, slot: StackSlot) -> ir::Value {
self.builder.ins().stack_load(ty, slot, 8)
}
}
/// The main JIT compiler that manages the compilation process.
pub struct JITCompiler<Ctx: JITContext> {
ctx: codegen::Context,
module: JITModule,
builder_ctx: Option<FunctionBuilderContext>,
_marker: PhantomData<Ctx>,
int_type: Type,
float_type: Type,
bool_type: Type,
ptr_type: Type,
value_type: Type,
func_sig: Signature,
call: FuncId,
lookup_stack: FuncId,
lookup_arg: FuncId,
lookup: FuncId,
select: FuncId,
select_with_default: FuncId,
eq: FuncId,
alloc_array: FuncId,
create_string: FuncId,
create_list: FuncId,
create_attrs: FuncId,
push_attr: FuncId,
finalize_attrs: FuncId,
enter_with: FuncId,
exit_with: FuncId,
dbg: FuncId,
}
impl<Ctx: JITContext> Default for JITCompiler<Ctx> {
fn default() -> Self {
Self::new()
}
}
impl<Ctx: JITContext> JITCompiler<Ctx> {
pub fn new() -> Self {
let mut flag_builder = settings::builder();
flag_builder.set("use_colocated_libcalls", "false").unwrap();
flag_builder.set("is_pic", "false").unwrap();
let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| {
panic!("host machine is not supported: {msg}");
});
let isa = isa_builder
.finish(settings::Flags::new(flag_builder))
.unwrap();
let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
builder.symbol("helper_call", helper_call::<Ctx> as _);
builder.symbol("helper_lookup_stack", helper_lookup_stack::<Ctx> as _);
builder.symbol("helper_lookup_arg", helper_lookup_arg::<Ctx> as _);
builder.symbol("helper_lookup", helper_lookup::<Ctx> as _);
builder.symbol("helper_select", helper_select::<Ctx> as _);
builder.symbol(
"helper_select_with_default",
helper_select_with_default::<Ctx> as _,
);
builder.symbol("helper_eq", helper_eq::<Ctx> as _);
builder.symbol("helper_alloc_array", helper_alloc_array::<Ctx> as _);
builder.symbol("helper_create_string", helper_create_string::<Ctx> as _);
builder.symbol("helper_create_list", helper_create_list::<Ctx> as _);
builder.symbol("helper_create_attrs", helper_create_attrs::<Ctx> as _);
builder.symbol("helper_push_attr", helper_push_attr::<Ctx> as _);
builder.symbol("helper_finalize_attrs", helper_finalize_attrs::<Ctx> as _);
builder.symbol("helper_enter_with", helper_enter_with::<Ctx> as _);
builder.symbol("helper_exit_with", helper_exit_with::<Ctx> as _);
builder.symbol("helper_dbg", helper_dbg::<Ctx> as _);
let mut module = JITModule::new(builder);
let ctx = module.make_context();
let int_type = types::I64;
let float_type = types::F64;
let bool_type = types::I8;
let ptr_type = module.target_config().pointer_type();
let value_type = types::I128;
// fn(*const Context, *const Env, *mut Value)
let mut func_sig = module.make_signature();
func_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 3],
);
// fn(func: &mut Value, args_ptr: *mut Value, args_len: usize, engine: &mut Context, env:
// &mut Env)
let mut call_sig = module.make_signature();
call_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 5],
);
let call = module
.declare_function("helper_call", Linkage::Import, &call_sig)
.unwrap();
let mut lookup_stack_sig = module.make_signature();
lookup_stack_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 3],
);
let lookup_stack = module
.declare_function("helper_lookup_stack", Linkage::Import, &lookup_stack_sig)
.unwrap();
// fn(env: &Env, level: usize, ret: &mut MaybeUninit<Value>)
let mut lookup_arg_sig = module.make_signature();
lookup_arg_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 3],
);
let lookup_arg = module
.declare_function("helper_lookup_arg", Linkage::Import, &lookup_arg_sig)
.unwrap();
// fn(env: &Env, sym_ptr: *const u8, sym_len: usize, ret: &mut MaybeUninit<Value>)
let mut lookup_sig = module.make_signature();
lookup_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 4],
);
let lookup = module
.declare_function("helper_lookup", Linkage::Import, &lookup_sig)
.unwrap();
// fn(val: &mut Value, path_ptr: *mut Value, path_len: usize, engine: &mut Context, env: &mut Env)
let mut select_sig = module.make_signature();
select_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 5],
);
let select = module
.declare_function("helper_select", Linkage::Import, &select_sig)
.unwrap();
// fn(val: &mut Value, path_ptr: *mut Value, path_len: usize, default: NonNull<Value>, engine: &mut Context, env: &mut Env)
let mut select_with_default_sig = module.make_signature();
select_with_default_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 6],
);
let select_with_default = module
.declare_function(
"helper_select_with_default",
Linkage::Import,
&select_with_default_sig,
)
.unwrap();
let mut eq_sig = module.make_signature();
eq_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 2],
);
let eq = module
.declare_function("helper_eq", Linkage::Import, &eq_sig)
.unwrap();
let mut alloc_array_sig = module.make_signature();
alloc_array_sig.params.push(AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
});
alloc_array_sig.returns.push(AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
});
let alloc_array = module
.declare_function("helper_alloc_array", Linkage::Import, &alloc_array_sig)
.unwrap();
let mut create_string_sig = module.make_signature();
create_string_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 3],
);
let create_string = module
.declare_function("helper_create_string", Linkage::Import, &create_string_sig)
.unwrap();
let mut create_list_sig = module.make_signature();
create_list_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 3],
);
let create_list = module
.declare_function("helper_create_list", Linkage::Import, &create_list_sig)
.unwrap();
let mut create_attrs_sig = module.make_signature();
create_attrs_sig.params.push(AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
});
let create_attrs = module
.declare_function("helper_create_attrs", Linkage::Import, &create_attrs_sig)
.unwrap();
let mut push_attr_sig = module.make_signature();
push_attr_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 4],
);
let push_attr = module
.declare_function("helper_push_attr", Linkage::Import, &push_attr_sig)
.unwrap();
let mut finalize_attrs_sig = module.make_signature();
finalize_attrs_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 2],
);
let finalize_attrs = module
.declare_function(
"helper_finalize_attrs",
Linkage::Import,
&finalize_attrs_sig,
)
.unwrap();
let mut enter_with_sig = module.make_signature();
enter_with_sig.params.extend(
[AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
}; 2],
);
let enter_with = module
.declare_function("helper_enter_with", Linkage::Import, &enter_with_sig)
.unwrap();
let mut exit_with_sig = module.make_signature();
exit_with_sig.params.push(AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
});
let exit_with = module
.declare_function("helper_exit_with", Linkage::Import, &exit_with_sig)
.unwrap();
let mut dbg_sig = module.make_signature();
dbg_sig.params.push(AbiParam {
value_type: ptr_type,
purpose: ArgumentPurpose::Normal,
extension: ArgumentExtension::None,
});
let dbg = module
.declare_function("helper_dbg", Linkage::Import, &dbg_sig)
.unwrap();
Self {
builder_ctx: None,
_marker: PhantomData,
ctx,
module,
int_type,
float_type,
bool_type,
ptr_type,
value_type,
func_sig,
call,
lookup_stack,
lookup_arg,
lookup,
select,
select_with_default,
eq,
alloc_array,
create_string,
create_list,
create_attrs,
push_attr,
finalize_attrs,
enter_with,
exit_with,
dbg,
}
}
pub fn compile(&mut self, ir: &Lir, id: usize) -> JITFunc<Ctx> {
let func_id = self
.module
.declare_function(
format!("nixjit_thunk{id}").as_str(),
Linkage::Local,
&self.func_sig,
)
.unwrap();
let mut func = Function::new();
func.signature = self.func_sig.clone();
let mut builder_ctx = self.builder_ctx.take().unwrap_or_default();
let mut ctx = Context::new(self, FunctionBuilder::new(&mut func, &mut builder_ctx));
let entry = ctx.builder.create_block();
ctx.builder.append_block_params_for_function_params(entry);
ctx.builder.switch_to_block(entry);
let params = ctx.builder.block_params(entry);
let engine = params[0];
let env = params[1];
let ret = params[2];
let res = ir.compile(&mut ctx, engine, env);
let tag = ctx.builder.ins().stack_load(types::I64, res, 0);
let val0 = ctx.builder.ins().stack_load(types::I64, res, 8);
let val1 = ctx.builder.ins().stack_load(types::I64, res, 16);
let val2 = ctx.builder.ins().stack_load(types::I64, res, 24);
ctx.builder.ins().store(MemFlags::new(), tag, ret, 0);
ctx.builder.ins().store(MemFlags::new(), val0, ret, 8);
ctx.builder.ins().store(MemFlags::new(), val1, ret, 16);
ctx.builder.ins().store(MemFlags::new(), val2, ret, 24);
ctx.builder.ins().return_(&[]);
ctx.builder.seal_all_blocks();
ctx.builder.finalize();
let strings = ctx.strings;
if cfg!(debug_assertions) {
println!("{ir:#?}");
println!("{}", func.display());
}
self.ctx.func = func;
self.module.define_function(func_id, &mut self.ctx).unwrap();
self.module.finalize_definitions().unwrap();
self.ctx.clear();
let _ = self.builder_ctx.insert(builder_ctx);
unsafe {
JITFunc {
func: std::mem::transmute::<*const u8, F<Ctx>>(
self.module.get_finalized_function(func_id),
),
strings,
}
}
}
}