feat: JIT (unusable, segfault)

This commit is contained in:
2025-05-18 15:01:19 +08:00
parent 29e959894d
commit f98d623c13
8 changed files with 326 additions and 122 deletions

View File

@@ -4,16 +4,19 @@ use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::execution_engine::ExecutionEngine;
use inkwell::module::Module;
use inkwell::types::{FunctionType, StructType};
use inkwell::values::BasicValueEnum;
use inkwell::{AddressSpace, OptimizationLevel};
use inkwell::values::{BasicValueEnum, PointerValue};
use inkwell::OptimizationLevel;
use crate::bytecode::OpCode;
use crate::bytecode::{Func, OpCode, UnOp};
use crate::error::*;
use crate::stack::Stack;
use crate::ty::common::Const;
use crate::ty::internal::{Func, Thunk, Value};
use crate::vm::VM;
use crate::ty::internal::{Thunk, Value};
use crate::vm::{Env, VM};
mod helpers;
use helpers::Helpers;
#[cfg(test)]
mod test;
@@ -21,6 +24,7 @@ mod test;
const STACK_SIZE: usize = 8 * 1024 / size_of::<Value>();
#[repr(u64)]
#[derive(Debug, Clone, Copy)]
pub enum ValueTag {
Int,
Float,
@@ -31,6 +35,7 @@ pub enum ValueTag {
Function,
Thunk,
Path,
Null,
}
#[repr(C)]
@@ -44,9 +49,30 @@ pub union JITValueData {
int: i64,
float: f64,
boolean: bool,
ptr: *const ()
}
pub type JITFunc = fn(usize, usize, JITValue) -> JITValue;
impl<'vm> Into<Value<'vm>> for JITValue {
fn into(self) -> Value<'vm> {
use ValueTag::*;
match self.tag {
Int => Value::Const(Const::Int(unsafe { self.data.int })),
Null => Value::Const(Const::Null),
_ => todo!("not implemented for {:?}", self.tag)
}
}
}
impl From<Value<'_>> for JITValue {
fn from(value: Value<'_>) -> Self {
match value {
Value::Const(Const::Int(int)) => JITValue { tag: ValueTag::Int, data: JITValueData { int } },
_ => todo!()
}
}
}
pub type JITFunc<'vm> = fn(*const VM<'_>, *const Env<'vm>, JITValue) -> JITValue;
pub struct JITContext<'ctx> {
context: &'ctx Context,
@@ -54,127 +80,53 @@ pub struct JITContext<'ctx> {
builder: Builder<'ctx>,
execution_engine: ExecutionEngine<'ctx>,
value_type: StructType<'ctx>,
func_type: FunctionType<'ctx>,
helpers: Helpers<'ctx>
}
impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
pub fn new(context: &'ctx Context) -> Self {
let module = context.create_module("nixjit");
let execution_engine = module
.create_jit_execution_engine(OptimizationLevel::Default)
.unwrap();
let int_type = context.i64_type();
let pointer_type = context.ptr_type(AddressSpace::default());
let value_type = context.struct_type(&[int_type.into(), int_type.into()], false);
let func_type = value_type.fn_type(
&[pointer_type.into(), pointer_type.into(), value_type.into()],
false,
);
let helpers = Helpers::new(context, &module, &execution_engine);
JITContext {
execution_engine: module
.create_jit_execution_engine(OptimizationLevel::Default)
.unwrap(),
execution_engine,
builder: context.create_builder(),
context,
module,
value_type,
func_type,
helpers
}
}
fn new_int(&self, int: i64) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Int as _, false)
.into(),
self.context.i64_type().const_int(int as _, false).into(),
])
.into()
pub fn new_ptr<T>(&self, ptr: *const T) -> PointerValue<'ctx> {
self.builder.build_int_to_ptr(self.helpers.int_type.const_int(ptr as _, false), self.helpers.ptr_type, "ptrconv").unwrap()
}
fn new_float(&self, float: f64) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Float as _, false)
.into(),
self.context.f64_type().const_float(float).into(),
])
.into()
}
fn new_bool(&self, bool: bool) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Bool as _, false)
.into(),
self.context.bool_type().const_int(bool as _, false).into(),
])
.into()
}
fn new_null(&self) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Float as _, false)
.into(),
self.context.i64_type().const_zero().into(),
])
.into()
}
fn const_string(&self, string: *const u8) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Float as _, false)
.into(),
self.context
.ptr_sized_int_type(self.execution_engine.get_target_data(), None)
.const_int(string as _, false)
.into(),
])
.into()
}
fn new_thunk(&self, thunk: *const Thunk) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Thunk as _, false)
.into(),
self.context
.ptr_sized_int_type(self.execution_engine.get_target_data(), None)
.const_int(thunk as _, false)
.into(),
])
.into()
}
pub fn compile_function(&self, func: &Func, vm: &'vm VM<'_>) -> Result<JITFunc> {
pub fn compile_function(&self, func: &Func, vm: &'vm VM<'_>) -> Result<&'vm JITFunc> {
let mut stack = Stack::<_, STACK_SIZE>::new();
let mut iter = func.func.opcodes.iter().copied();
let func_ = self.module.add_function("fn", self.func_type, None);
let mut iter = func.opcodes.iter().copied();
let func_ = self.module.add_function("nixjit_function", self.helpers.func_type, None);
let entry = self.context.append_basic_block(func_, "entry");
self.builder.position_at_end(entry);
while let Some(opcode) = iter.next() {
self.single_op(opcode, vm, &mut stack)?;
}
assert_eq!(stack.len(), 1);
self.builder.build_return(Some(&stack.pop()))?;
let value = stack.pop();
self.builder.build_direct_call(self.helpers.debug, &[value.into()], "call_debug").unwrap();
self.builder.build_return(Some(&value))?;
if func_.verify(false) {
func_.print_to_stderr();
unsafe {
Ok(std::mem::transmute(self.execution_engine.get_function_address(func_.get_name().to_str().unwrap()).unwrap()))
let name = func_.get_name().to_str().unwrap();
let addr = self.execution_engine.get_function_address(name).unwrap();
Ok(std::mem::transmute(addr))
}
} else {
todo!()
@@ -191,15 +143,57 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
OpCode::Const { idx } => {
use Const::*;
match vm.get_const(idx) {
Int(int) => stack.push(self.new_int(int))?,
Float(float) => stack.push(self.new_float(float))?,
Bool(bool) => stack.push(self.new_bool(bool))?,
String(string) => stack.push(self.const_string(string.as_ptr()))?,
Null => stack.push(self.new_null())?,
Int(int) => stack.push(self.helpers.new_int(int))?,
Float(float) => stack.push(self.helpers.new_float(float))?,
Bool(bool) => stack.push(self.helpers.new_bool(bool))?,
String(string) => stack.push(self.helpers.const_string(string.as_ptr()))?,
Null => stack.push(self.helpers.new_null())?,
}
}
OpCode::LoadThunk { idx } => stack.push(self.new_thunk(Rc::new(Thunk::new(vm.get_thunk(idx))).as_ref() as _))?,
_ => todo!()
OpCode::LoadThunk { idx } => stack.push(self.helpers.new_thunk(Rc::into_raw(Rc::new(Thunk::new(vm.get_thunk(idx))))))?,
/* OpCode::CaptureEnv => {
let thunk = *stack.tos()?;
self.builder.build_direct_call(self.helpers.captured_env, &[thunk.into(), self.new_ptr(env).into()], "call_capture_env")?;
} */
OpCode::UnOp { op } => {
use UnOp::*;
let rhs = stack.pop();
stack.push(match op {
Neg => self.builder.build_direct_call(self.helpers.neg, &[rhs.into(), self.new_ptr(std::ptr::null::<Env>()).into()], "call_neg")?.try_as_basic_value().left().unwrap(),
_ => todo!()
})?
}
/* OpCode::Func { idx } => {
let func = vm.get_func(idx);
let jit_func_ptr = self.compile_function(&Func::new(func, unsafe { env.as_ref() }.unwrap().clone(), OnceCell::new(), Cell::new(1)), vm)?;
let jit_value = self.helpers.value_type.const_named_struct(&[
self.helpers.int_type.const_int(ValueTag::Function as _, false).into(),
self.helpers.ptr_int_type.const_int(jit_func_ptr as *const _ as _, false).into(),
]).into();
stack.push(jit_value)?;
} */
OpCode::Call { arity } => {
// Assuming arity is 1 for the test case
assert_eq!(arity, 1);
let arg = stack.pop();
let func_value = stack.pop();
let func_ptr = self.builder.build_extract_value(func_value.into_struct_value(), 1, "func_ptr")?.into_pointer_value();
let result = self.builder.build_direct_call(self.helpers.call, &[self.new_ptr(vm).into(), self.new_ptr(std::ptr::null::<Env>()).into(), func_ptr.into(), arg.into()], "call_func")?.try_as_basic_value().left().unwrap();
stack.push(result)?;
}
OpCode::BinOp { op } => {
use crate::bytecode::BinOp;
match op {
BinOp::Add => {
let rhs = stack.pop();
let lhs = stack.pop();
let result = self.builder.build_direct_call(self.helpers.add, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap();
stack.push(result)?;
}
_ => todo!("BinOp::{:?} not implemented in JIT", op),
}
}
_ => todo!("{opcode:?} not implemented")
}
Ok(())
}