From 29e959894d083f39449c227a28771d3498d00103 Mon Sep 17 00:00:00 2001 From: imxyy_soope_ Date: Sat, 17 May 2025 22:38:05 +0800 Subject: [PATCH] feat: JIT (WIP) --- src/bin/repl.rs | 3 +- src/jit.rs | 99 ------------------- src/jit/mod.rs | 206 ++++++++++++++++++++++++++++++++++++++++ src/jit/test.rs | 96 +++++++++++++++++++ src/lib.rs | 2 +- src/ty/internal/func.rs | 5 + src/vm/mod.rs | 12 ++- 7 files changed, 320 insertions(+), 103 deletions(-) delete mode 100644 src/jit.rs create mode 100644 src/jit/mod.rs create mode 100644 src/jit/test.rs diff --git a/src/bin/repl.rs b/src/bin/repl.rs index 6e3cf57..4f5edee 100644 --- a/src/bin/repl.rs +++ b/src/bin/repl.rs @@ -6,7 +6,8 @@ use rustyline::{DefaultEditor, Result}; use nixjit::compile::compile; use nixjit::error::Error; use nixjit::ir::downgrade; -use nixjit::vm::{JITContext, run}; +use nixjit::vm::run; +use nixjit::jit::JITContext; macro_rules! unwrap { ($e:expr) => { diff --git a/src/jit.rs b/src/jit.rs deleted file mode 100644 index 8e6c077..0000000 --- a/src/jit.rs +++ /dev/null @@ -1,99 +0,0 @@ -use inkwell::builder::Builder; -use inkwell::context::Context; -use inkwell::execution_engine::ExecutionEngine; -use inkwell::module::Module; -use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, StructType}; -use inkwell::values::{BasicValueEnum, FunctionValue, IntValue}; -use inkwell::{AddressSpace, OptimizationLevel}; - -use crate::bytecode::OpCode; -use crate::stack::Stack; -use crate::ty::internal::{Func, Value}; -use crate::error::*; - -const STACK_SIZE: usize = 8 * 1024 / size_of::(); - -#[repr(usize)] -pub enum ValueTag { - Int, - String, - Bool, - AttrSet, - List, - Function, - Thunk, - Path, -} - -#[repr(C)] -pub struct JITValue { - tag: ValueTag, - data: JITValueData -} - -#[repr(C)] -pub union JITValueData { - int: i64, - float: f64, - boolean: bool, -} - -pub type JITFunc = fn(usize, usize, JITValue) -> JITValue; - -pub struct JITContext<'ctx> { - context: &'ctx Context, - module: Module<'ctx>, - builder: Builder<'ctx>, - execution_engine: ExecutionEngine<'ctx>, - - value_type: StructType<'ctx>, - func_type: FunctionType<'ctx>, -} - -impl<'ctx> JITContext<'ctx> { - pub fn new(context: &'ctx Context) -> Self { - let module = context.create_module("nixjit"); - - let int_type = context.i64_type(); - let pointer_type = context.ptr_type(AddressSpace::default()); - let value_type = context.struct_type(&[int_type.into(), int_type.into()], false); - let func_type = value_type.fn_type( - &[pointer_type.into(), pointer_type.into(), value_type.into()], - false, - ); - - JITContext { - execution_engine: module - .create_jit_execution_engine(OptimizationLevel::Default) - .unwrap(), - builder: context.create_builder(), - context, - module, - - value_type, - func_type, - } - } - - fn new_int(&self, int: i64) -> IntValue { - self.context.i64_type().const_int(int as u64, false) - } - - fn new_bool(&self, b: bool) -> IntValue { - self.context.bool_type().const_int(b as u64, false) - } - - pub fn compile_function(&self, func: Func) -> Result<()> { - let mut stack = Stack::<_, STACK_SIZE>::new(); - let mut iter = func.func.opcodes.iter().copied(); - let func_ = self.module.add_function("fn", self.func_type, None); - while let Some(opcode) = iter.next() { - self.single_op(opcode, &func_, &mut stack)?; - } - Ok(()) - } - - fn single_op(&self, opcode: OpCode, func: &FunctionValue, stack: &mut Stack) -> Result<()> { - todo!() - } -} diff --git a/src/jit/mod.rs b/src/jit/mod.rs new file mode 100644 index 0000000..cab26d9 --- /dev/null +++ b/src/jit/mod.rs @@ -0,0 +1,206 @@ +use std::rc::Rc; + +use inkwell::builder::Builder; +use inkwell::context::Context; +use inkwell::execution_engine::ExecutionEngine; +use inkwell::module::Module; +use inkwell::types::{FunctionType, StructType}; +use inkwell::values::BasicValueEnum; +use inkwell::{AddressSpace, OptimizationLevel}; + +use crate::bytecode::OpCode; +use crate::error::*; +use crate::stack::Stack; +use crate::ty::common::Const; +use crate::ty::internal::{Func, Thunk, Value}; +use crate::vm::VM; + +#[cfg(test)] +mod test; + +const STACK_SIZE: usize = 8 * 1024 / size_of::(); + +#[repr(u64)] +pub enum ValueTag { + Int, + Float, + String, + Bool, + AttrSet, + List, + Function, + Thunk, + Path, +} + +#[repr(C)] +pub struct JITValue { + tag: ValueTag, + data: JITValueData, +} + +#[repr(C)] +pub union JITValueData { + int: i64, + float: f64, + boolean: bool, +} + +pub type JITFunc = fn(usize, usize, JITValue) -> JITValue; + +pub struct JITContext<'ctx> { + context: &'ctx Context, + module: Module<'ctx>, + builder: Builder<'ctx>, + execution_engine: ExecutionEngine<'ctx>, + + value_type: StructType<'ctx>, + func_type: FunctionType<'ctx>, +} + +impl<'vm, 'ctx: 'vm> JITContext<'ctx> { + pub fn new(context: &'ctx Context) -> Self { + let module = context.create_module("nixjit"); + + let int_type = context.i64_type(); + let pointer_type = context.ptr_type(AddressSpace::default()); + let value_type = context.struct_type(&[int_type.into(), int_type.into()], false); + let func_type = value_type.fn_type( + &[pointer_type.into(), pointer_type.into(), value_type.into()], + false, + ); + + JITContext { + execution_engine: module + .create_jit_execution_engine(OptimizationLevel::Default) + .unwrap(), + builder: context.create_builder(), + context, + module, + + value_type, + func_type, + } + } + + fn new_int(&self, int: i64) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.context + .i64_type() + .const_int(ValueTag::Int as _, false) + .into(), + self.context.i64_type().const_int(int as _, false).into(), + ]) + .into() + } + + fn new_float(&self, float: f64) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.context + .i64_type() + .const_int(ValueTag::Float as _, false) + .into(), + self.context.f64_type().const_float(float).into(), + ]) + .into() + } + + fn new_bool(&self, bool: bool) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.context + .i64_type() + .const_int(ValueTag::Bool as _, false) + .into(), + self.context.bool_type().const_int(bool as _, false).into(), + ]) + .into() + } + + fn new_null(&self) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.context + .i64_type() + .const_int(ValueTag::Float as _, false) + .into(), + self.context.i64_type().const_zero().into(), + ]) + .into() + } + + fn const_string(&self, string: *const u8) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.context + .i64_type() + .const_int(ValueTag::Float as _, false) + .into(), + self.context + .ptr_sized_int_type(self.execution_engine.get_target_data(), None) + .const_int(string as _, false) + .into(), + ]) + .into() + } + + fn new_thunk(&self, thunk: *const Thunk) -> BasicValueEnum<'ctx> { + self.value_type + .const_named_struct(&[ + self.context + .i64_type() + .const_int(ValueTag::Thunk as _, false) + .into(), + self.context + .ptr_sized_int_type(self.execution_engine.get_target_data(), None) + .const_int(thunk as _, false) + .into(), + ]) + .into() + } + + pub fn compile_function(&self, func: &Func, vm: &'vm VM<'_>) -> Result { + let mut stack = Stack::<_, STACK_SIZE>::new(); + let mut iter = func.func.opcodes.iter().copied(); + let func_ = self.module.add_function("fn", self.func_type, None); + let entry = self.context.append_basic_block(func_, "entry"); + self.builder.position_at_end(entry); + while let Some(opcode) = iter.next() { + self.single_op(opcode, vm, &mut stack)?; + } + assert_eq!(stack.len(), 1); + self.builder.build_return(Some(&stack.pop()))?; + if func_.verify(false) { + unsafe { + Ok(std::mem::transmute(self.execution_engine.get_function_address(func_.get_name().to_str().unwrap()).unwrap())) + } + } else { + todo!() + } + } + + fn single_op( + &self, + opcode: OpCode, + vm: &'vm VM<'_>, + stack: &mut Stack, CAP>, + ) -> Result<()> { + match opcode { + OpCode::Const { idx } => { + use Const::*; + match vm.get_const(idx) { + Int(int) => stack.push(self.new_int(int))?, + Float(float) => stack.push(self.new_float(float))?, + Bool(bool) => stack.push(self.new_bool(bool))?, + String(string) => stack.push(self.const_string(string.as_ptr()))?, + Null => stack.push(self.new_null())?, + } + } + OpCode::LoadThunk { idx } => stack.push(self.new_thunk(Rc::new(Thunk::new(vm.get_thunk(idx))).as_ref() as _))?, + _ => todo!() + } + Ok(()) + } +} diff --git a/src/jit/test.rs b/src/jit/test.rs new file mode 100644 index 0000000..9c8d710 --- /dev/null +++ b/src/jit/test.rs @@ -0,0 +1,96 @@ +extern crate test; + +use hashbrown::{HashMap, HashSet}; + +use inkwell::context::Context; + +use ecow::EcoString; +use rpds::vector_sync; + +use crate::compile::compile; +use crate::ir::downgrade; +use crate::ty::public::*; +use crate::ty::common::Const; +use crate::jit::JITContext; +use crate::vm::VM; +use crate::builtins::env; + + +#[inline] +fn test_expr(expr: &str, expected: Value) { + let downgraded = downgrade(rnix::Root::parse(expr).tree().expr().unwrap()).unwrap(); + let prog = compile(downgraded); + dbg!(&prog); + let ctx = Context::create(); + let jit = JITContext::new(&ctx); + let vm = VM::new(prog.thunks, prog.funcs, prog.symbols.into(), prog.symmap.into(), prog.consts, jit); + let env = env(&vm); + let value = vm.eval(prog.top_level.into_iter(), env).unwrap().to_public(&vm, &mut HashSet::new()); + assert_eq!(value, expected); +} + +macro_rules! map { + ($($k:expr => $v:expr),*) => { + { + #[allow(unused_mut)] + let mut m = HashMap::new(); + $( + m.insert($k, $v); + )* + m + } + }; +} + +macro_rules! thunk { + () => { + Value::Thunk + }; +} + +macro_rules! int { + ($e:expr) => { + Value::Const(Const::Int($e)) + }; +} + +macro_rules! float { + ($e:expr) => { + Value::Const(Const::Float($e as f64)) + }; +} + +macro_rules! boolean { + ($e:expr) => { + Value::Const(Const::Bool($e)) + }; +} + +macro_rules! string { + ($e:expr) => { + Value::Const(Const::String(EcoString::from($e))) + }; +} + +macro_rules! symbol { + ($e:expr) => { + Symbol::from($e.to_string()) + }; +} + +macro_rules! list { + ($($x:tt)*) => ( + Value::List(List::new(vector_sync![$($x)*])) + ); +} + +macro_rules! attrs { + ($($x:tt)*) => ( + Value::AttrSet(AttrSet::new(map!{$($x)*})) + ) +} + +#[test] +fn test_jit_const() { + test_expr("let f = _: 1; in (f 1) + (f 1)", int!(2)); +} diff --git a/src/lib.rs b/src/lib.rs index f8ea608..1b27b13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,11 +5,11 @@ mod builtins; mod bytecode; mod stack; mod ty; -mod jit; pub mod compile; pub mod error; pub mod ir; +pub mod jit; pub mod vm; pub use ty::public::Value; diff --git a/src/ty/internal/func.rs b/src/ty/internal/func.rs index 6ddfcf9..580f32b 100644 --- a/src/ty/internal/func.rs +++ b/src/ty/internal/func.rs @@ -52,6 +52,11 @@ impl<'vm> Func<'vm> { pub fn call(&self, vm: &'vm VM<'_>, arg: Value<'vm>) -> Result> { use Param::*; + let count = self.count.get(); + if count >= 1 { + let compiled = self.compiled.get_or_init(|| vm.compile_func(self)); + } + self.count.replace(count + 1); let mut env = self.env.clone(); match self.func.param.clone() { diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 9d7416b..49d942c 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -8,7 +8,7 @@ use crate::ty::internal::*; use crate::ty::common::Const; use crate::ty::public::{self as p, Symbol}; use crate::stack::Stack; -use crate::jit::JITContext; +use crate::jit::{JITContext, JITFunc}; use derive_more::Constructor; use ecow::EcoString; @@ -74,6 +74,10 @@ impl<'vm, 'jit: 'vm> VM<'jit> { } } + pub fn get_const(&self, idx: usize) -> Const { + self.consts[idx].clone() + } + pub fn eval( &'vm self, opcodes: impl Iterator, @@ -102,7 +106,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { ) -> Result { match opcode { OpCode::Illegal => panic!("illegal opcode"), - OpCode::Const { idx } => stack.push(Value::Const(self.consts[idx].clone()))?, + OpCode::Const { idx } => stack.push(Value::Const(self.get_const(idx)))?, OpCode::LoadThunk { idx } => { stack.push(Value::Thunk(Thunk::new(self.get_thunk(idx)).into()))? } @@ -262,4 +266,8 @@ impl<'vm, 'jit: 'vm> VM<'jit> { } Ok(0) } + + pub fn compile_func(&'vm self, func: &Func<'vm>) -> JITFunc { + self.jit.compile_function(func, self).unwrap() + } }