feat: JIT (WIP)

This commit is contained in:
2025-05-17 22:38:05 +08:00
parent 95ebddf272
commit 29e959894d
7 changed files with 320 additions and 103 deletions

View File

@@ -6,7 +6,8 @@ use rustyline::{DefaultEditor, Result};
use nixjit::compile::compile;
use nixjit::error::Error;
use nixjit::ir::downgrade;
use nixjit::vm::{JITContext, run};
use nixjit::vm::run;
use nixjit::jit::JITContext;
macro_rules! unwrap {
($e:expr) => {

View File

@@ -1,99 +0,0 @@
use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::execution_engine::ExecutionEngine;
use inkwell::module::Module;
use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, StructType};
use inkwell::values::{BasicValueEnum, FunctionValue, IntValue};
use inkwell::{AddressSpace, OptimizationLevel};
use crate::bytecode::OpCode;
use crate::stack::Stack;
use crate::ty::internal::{Func, Value};
use crate::error::*;
const STACK_SIZE: usize = 8 * 1024 / size_of::<Value>();
#[repr(usize)]
pub enum ValueTag {
Int,
String,
Bool,
AttrSet,
List,
Function,
Thunk,
Path,
}
#[repr(C)]
pub struct JITValue {
tag: ValueTag,
data: JITValueData
}
#[repr(C)]
pub union JITValueData {
int: i64,
float: f64,
boolean: bool,
}
pub type JITFunc = fn(usize, usize, JITValue) -> JITValue;
pub struct JITContext<'ctx> {
context: &'ctx Context,
module: Module<'ctx>,
builder: Builder<'ctx>,
execution_engine: ExecutionEngine<'ctx>,
value_type: StructType<'ctx>,
func_type: FunctionType<'ctx>,
}
impl<'ctx> JITContext<'ctx> {
pub fn new(context: &'ctx Context) -> Self {
let module = context.create_module("nixjit");
let int_type = context.i64_type();
let pointer_type = context.ptr_type(AddressSpace::default());
let value_type = context.struct_type(&[int_type.into(), int_type.into()], false);
let func_type = value_type.fn_type(
&[pointer_type.into(), pointer_type.into(), value_type.into()],
false,
);
JITContext {
execution_engine: module
.create_jit_execution_engine(OptimizationLevel::Default)
.unwrap(),
builder: context.create_builder(),
context,
module,
value_type,
func_type,
}
}
fn new_int(&self, int: i64) -> IntValue {
self.context.i64_type().const_int(int as u64, false)
}
fn new_bool(&self, b: bool) -> IntValue {
self.context.bool_type().const_int(b as u64, false)
}
pub fn compile_function(&self, func: Func) -> Result<()> {
let mut stack = Stack::<_, STACK_SIZE>::new();
let mut iter = func.func.opcodes.iter().copied();
let func_ = self.module.add_function("fn", self.func_type, None);
while let Some(opcode) = iter.next() {
self.single_op(opcode, &func_, &mut stack)?;
}
Ok(())
}
fn single_op<const CAP: usize>(&self, opcode: OpCode, func: &FunctionValue, stack: &mut Stack<BasicValueEnum, CAP>) -> Result<()> {
todo!()
}
}

206
src/jit/mod.rs Normal file
View File

@@ -0,0 +1,206 @@
use std::rc::Rc;
use inkwell::builder::Builder;
use inkwell::context::Context;
use inkwell::execution_engine::ExecutionEngine;
use inkwell::module::Module;
use inkwell::types::{FunctionType, StructType};
use inkwell::values::BasicValueEnum;
use inkwell::{AddressSpace, OptimizationLevel};
use crate::bytecode::OpCode;
use crate::error::*;
use crate::stack::Stack;
use crate::ty::common::Const;
use crate::ty::internal::{Func, Thunk, Value};
use crate::vm::VM;
#[cfg(test)]
mod test;
const STACK_SIZE: usize = 8 * 1024 / size_of::<Value>();
#[repr(u64)]
pub enum ValueTag {
Int,
Float,
String,
Bool,
AttrSet,
List,
Function,
Thunk,
Path,
}
#[repr(C)]
pub struct JITValue {
tag: ValueTag,
data: JITValueData,
}
#[repr(C)]
pub union JITValueData {
int: i64,
float: f64,
boolean: bool,
}
pub type JITFunc = fn(usize, usize, JITValue) -> JITValue;
pub struct JITContext<'ctx> {
context: &'ctx Context,
module: Module<'ctx>,
builder: Builder<'ctx>,
execution_engine: ExecutionEngine<'ctx>,
value_type: StructType<'ctx>,
func_type: FunctionType<'ctx>,
}
impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
pub fn new(context: &'ctx Context) -> Self {
let module = context.create_module("nixjit");
let int_type = context.i64_type();
let pointer_type = context.ptr_type(AddressSpace::default());
let value_type = context.struct_type(&[int_type.into(), int_type.into()], false);
let func_type = value_type.fn_type(
&[pointer_type.into(), pointer_type.into(), value_type.into()],
false,
);
JITContext {
execution_engine: module
.create_jit_execution_engine(OptimizationLevel::Default)
.unwrap(),
builder: context.create_builder(),
context,
module,
value_type,
func_type,
}
}
fn new_int(&self, int: i64) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Int as _, false)
.into(),
self.context.i64_type().const_int(int as _, false).into(),
])
.into()
}
fn new_float(&self, float: f64) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Float as _, false)
.into(),
self.context.f64_type().const_float(float).into(),
])
.into()
}
fn new_bool(&self, bool: bool) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Bool as _, false)
.into(),
self.context.bool_type().const_int(bool as _, false).into(),
])
.into()
}
fn new_null(&self) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Float as _, false)
.into(),
self.context.i64_type().const_zero().into(),
])
.into()
}
fn const_string(&self, string: *const u8) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Float as _, false)
.into(),
self.context
.ptr_sized_int_type(self.execution_engine.get_target_data(), None)
.const_int(string as _, false)
.into(),
])
.into()
}
fn new_thunk(&self, thunk: *const Thunk) -> BasicValueEnum<'ctx> {
self.value_type
.const_named_struct(&[
self.context
.i64_type()
.const_int(ValueTag::Thunk as _, false)
.into(),
self.context
.ptr_sized_int_type(self.execution_engine.get_target_data(), None)
.const_int(thunk as _, false)
.into(),
])
.into()
}
pub fn compile_function(&self, func: &Func, vm: &'vm VM<'_>) -> Result<JITFunc> {
let mut stack = Stack::<_, STACK_SIZE>::new();
let mut iter = func.func.opcodes.iter().copied();
let func_ = self.module.add_function("fn", self.func_type, None);
let entry = self.context.append_basic_block(func_, "entry");
self.builder.position_at_end(entry);
while let Some(opcode) = iter.next() {
self.single_op(opcode, vm, &mut stack)?;
}
assert_eq!(stack.len(), 1);
self.builder.build_return(Some(&stack.pop()))?;
if func_.verify(false) {
unsafe {
Ok(std::mem::transmute(self.execution_engine.get_function_address(func_.get_name().to_str().unwrap()).unwrap()))
}
} else {
todo!()
}
}
fn single_op<const CAP: usize>(
&self,
opcode: OpCode,
vm: &'vm VM<'_>,
stack: &mut Stack<BasicValueEnum<'ctx>, CAP>,
) -> Result<()> {
match opcode {
OpCode::Const { idx } => {
use Const::*;
match vm.get_const(idx) {
Int(int) => stack.push(self.new_int(int))?,
Float(float) => stack.push(self.new_float(float))?,
Bool(bool) => stack.push(self.new_bool(bool))?,
String(string) => stack.push(self.const_string(string.as_ptr()))?,
Null => stack.push(self.new_null())?,
}
}
OpCode::LoadThunk { idx } => stack.push(self.new_thunk(Rc::new(Thunk::new(vm.get_thunk(idx))).as_ref() as _))?,
_ => todo!()
}
Ok(())
}
}

96
src/jit/test.rs Normal file
View File

@@ -0,0 +1,96 @@
extern crate test;
use hashbrown::{HashMap, HashSet};
use inkwell::context::Context;
use ecow::EcoString;
use rpds::vector_sync;
use crate::compile::compile;
use crate::ir::downgrade;
use crate::ty::public::*;
use crate::ty::common::Const;
use crate::jit::JITContext;
use crate::vm::VM;
use crate::builtins::env;
#[inline]
fn test_expr(expr: &str, expected: Value) {
let downgraded = downgrade(rnix::Root::parse(expr).tree().expr().unwrap()).unwrap();
let prog = compile(downgraded);
dbg!(&prog);
let ctx = Context::create();
let jit = JITContext::new(&ctx);
let vm = VM::new(prog.thunks, prog.funcs, prog.symbols.into(), prog.symmap.into(), prog.consts, jit);
let env = env(&vm);
let value = vm.eval(prog.top_level.into_iter(), env).unwrap().to_public(&vm, &mut HashSet::new());
assert_eq!(value, expected);
}
macro_rules! map {
($($k:expr => $v:expr),*) => {
{
#[allow(unused_mut)]
let mut m = HashMap::new();
$(
m.insert($k, $v);
)*
m
}
};
}
macro_rules! thunk {
() => {
Value::Thunk
};
}
macro_rules! int {
($e:expr) => {
Value::Const(Const::Int($e))
};
}
macro_rules! float {
($e:expr) => {
Value::Const(Const::Float($e as f64))
};
}
macro_rules! boolean {
($e:expr) => {
Value::Const(Const::Bool($e))
};
}
macro_rules! string {
($e:expr) => {
Value::Const(Const::String(EcoString::from($e)))
};
}
macro_rules! symbol {
($e:expr) => {
Symbol::from($e.to_string())
};
}
macro_rules! list {
($($x:tt)*) => (
Value::List(List::new(vector_sync![$($x)*]))
);
}
macro_rules! attrs {
($($x:tt)*) => (
Value::AttrSet(AttrSet::new(map!{$($x)*}))
)
}
#[test]
fn test_jit_const() {
test_expr("let f = _: 1; in (f 1) + (f 1)", int!(2));
}

View File

@@ -5,11 +5,11 @@ mod builtins;
mod bytecode;
mod stack;
mod ty;
mod jit;
pub mod compile;
pub mod error;
pub mod ir;
pub mod jit;
pub mod vm;
pub use ty::public::Value;

View File

@@ -52,6 +52,11 @@ impl<'vm> Func<'vm> {
pub fn call(&self, vm: &'vm VM<'_>, arg: Value<'vm>) -> Result<Value<'vm>> {
use Param::*;
let count = self.count.get();
if count >= 1 {
let compiled = self.compiled.get_or_init(|| vm.compile_func(self));
}
self.count.replace(count + 1);
let mut env = self.env.clone();
match self.func.param.clone() {

View File

@@ -8,7 +8,7 @@ use crate::ty::internal::*;
use crate::ty::common::Const;
use crate::ty::public::{self as p, Symbol};
use crate::stack::Stack;
use crate::jit::JITContext;
use crate::jit::{JITContext, JITFunc};
use derive_more::Constructor;
use ecow::EcoString;
@@ -74,6 +74,10 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
}
}
pub fn get_const(&self, idx: usize) -> Const {
self.consts[idx].clone()
}
pub fn eval(
&'vm self,
opcodes: impl Iterator<Item = OpCode>,
@@ -102,7 +106,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
) -> Result<usize> {
match opcode {
OpCode::Illegal => panic!("illegal opcode"),
OpCode::Const { idx } => stack.push(Value::Const(self.consts[idx].clone()))?,
OpCode::Const { idx } => stack.push(Value::Const(self.get_const(idx)))?,
OpCode::LoadThunk { idx } => {
stack.push(Value::Thunk(Thunk::new(self.get_thunk(idx)).into()))?
}
@@ -262,4 +266,8 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
}
Ok(0)
}
pub fn compile_func(&'vm self, func: &Func<'vm>) -> JITFunc {
self.jit.compile_function(func, self).unwrap()
}
}