From 95ebddf27202ca95348a6d3565c6ba28a52cf452 Mon Sep 17 00:00:00 2001 From: imxyy_soope_ Date: Sat, 17 May 2025 20:54:36 +0800 Subject: [PATCH] feat: JIT (WIP) --- src/builtins/mod.rs | 3 +- src/bytecode.rs | 3 +- src/error.rs | 2 + src/ir.rs | 10 +- src/{vm => }/jit.rs | 47 +++++++--- src/lib.rs | 1 + src/stack.rs | 16 +++- src/ty/common.rs | 93 ++++++++++++++++++- src/ty/internal/cnst.rs | 79 ---------------- src/ty/internal/func.rs | 11 ++- src/ty/internal/mod.rs | 10 +- src/ty/{public/mod.rs => public.rs} | 5 +- src/ty/public/cnst.rs | 137 ---------------------------- src/vm/mod.rs | 30 +++--- src/vm/test.rs | 2 +- 15 files changed, 175 insertions(+), 274 deletions(-) rename src/{vm => }/jit.rs (60%) delete mode 100644 src/ty/internal/cnst.rs rename src/ty/{public/mod.rs => public.rs} (98%) delete mode 100644 src/ty/public/cnst.rs diff --git a/src/builtins/mod.rs b/src/builtins/mod.rs index d94a46c..fe16268 100644 --- a/src/builtins/mod.rs +++ b/src/builtins/mod.rs @@ -1,7 +1,8 @@ use hashbrown::HashMap; use std::rc::Rc; -use crate::ty::internal::{AttrSet, Const, PrimOp, Value}; +use crate::ty::internal::{AttrSet, PrimOp, Value}; +use crate::ty::common::Const; use crate::vm::{Env, VM}; pub fn env<'vm>(vm: &'vm VM) -> Env<'vm> { diff --git a/src/bytecode.rs b/src/bytecode.rs index 90b45db..c259492 100644 --- a/src/bytecode.rs +++ b/src/bytecode.rs @@ -2,7 +2,8 @@ use hashbrown::HashMap; use ecow::EcoString; -use crate::ty::internal::{Const, Param}; +use crate::ty::internal::Param; +use crate::ty::common::Const; type Slice = Box<[T]>; diff --git a/src/error.rs b/src/error.rs index 7e7b898..3970873 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,6 +10,8 @@ pub enum Error { DowngradeError(String), #[error("error occurred during evaluation stage: {0}")] EvalError(String), + #[error("error occurred during JIT compile stage: {0}")] + CompileError(#[from] inkwell::builder::BuilderError), #[error("unknown error")] Unknown, } diff --git a/src/ir.rs b/src/ir.rs index 66b9995..fab1c5a 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -5,7 +5,7 @@ use rnix::ast::{self, Expr}; use crate::compile::*; use crate::error::*; -use crate::ty::internal as i; +use crate::ty::common as c; pub fn downgrade(expr: Expr) -> Result { let mut ctx = DowngradeContext::new(); @@ -126,15 +126,15 @@ pub struct DynamicAttrPair(pub Ir, pub Ir); pub struct DowngradeContext { thunks: Vec, funcs: Vec, - consts: Vec, - constmap: HashMap, + consts: Vec, + constmap: HashMap, symbols: Vec, symmap: HashMap, } pub struct Downgraded { pub top_level: Ir, - pub consts: Box<[i::Const]>, + pub consts: Box<[c::Const]>, pub symbols: Vec, pub symmap: HashMap, pub thunks: Box<[Ir]>, @@ -158,7 +158,7 @@ impl DowngradeContext { LoadFunc { idx } } - fn new_const(&mut self, cnst: i::Const) -> Const { + fn new_const(&mut self, cnst: c::Const) -> Const { if let Some(&idx) = self.constmap.get(&cnst) { Const { idx } } else { diff --git a/src/vm/jit.rs b/src/jit.rs similarity index 60% rename from src/vm/jit.rs rename to src/jit.rs index 4db9212..8e6c077 100644 --- a/src/vm/jit.rs +++ b/src/jit.rs @@ -1,5 +1,3 @@ -use std::pin::Pin; - use inkwell::builder::Builder; use inkwell::context::Context; use inkwell::execution_engine::ExecutionEngine; @@ -8,9 +6,12 @@ use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, StructType}; use inkwell::values::{BasicValueEnum, FunctionValue, IntValue}; use inkwell::{AddressSpace, OptimizationLevel}; +use crate::bytecode::OpCode; use crate::stack::Stack; +use crate::ty::internal::{Func, Value}; +use crate::error::*; -use super::STACK_SIZE; +const STACK_SIZE: usize = 8 * 1024 / size_of::(); #[repr(usize)] pub enum ValueTag { @@ -27,25 +28,31 @@ pub enum ValueTag { #[repr(C)] pub struct JITValue { tag: ValueTag, - data: u64, + data: JITValueData } +#[repr(C)] +pub union JITValueData { + int: i64, + float: f64, + boolean: bool, +} + +pub type JITFunc = fn(usize, usize, JITValue) -> JITValue; + pub struct JITContext<'ctx> { context: &'ctx Context, module: Module<'ctx>, builder: Builder<'ctx>, execution_engine: ExecutionEngine<'ctx>, - stack: Stack, STACK_SIZE>, - cur_func: Option>, value_type: StructType<'ctx>, func_type: FunctionType<'ctx>, } impl<'ctx> JITContext<'ctx> { - pub fn new(context: &'ctx Context) -> Pin> { + pub fn new(context: &'ctx Context) -> Self { let module = context.create_module("nixjit"); - let stack = Stack::new(); let int_type = context.i64_type(); let pointer_type = context.ptr_type(AddressSpace::default()); @@ -55,24 +62,38 @@ impl<'ctx> JITContext<'ctx> { false, ); - Pin::new(Box::new(JITContext { + JITContext { execution_engine: module .create_jit_execution_engine(OptimizationLevel::Default) .unwrap(), builder: context.create_builder(), context, module, - stack, - cur_func: None, value_type, func_type, - })) + } } fn new_int(&self, int: i64) -> IntValue { self.context.i64_type().const_int(int as u64, false) } - pub fn start_trace(&mut self) {} + fn new_bool(&self, b: bool) -> IntValue { + self.context.bool_type().const_int(b as u64, false) + } + + pub fn compile_function(&self, func: Func) -> Result<()> { + let mut stack = Stack::<_, STACK_SIZE>::new(); + let mut iter = func.func.opcodes.iter().copied(); + let func_ = self.module.add_function("fn", self.func_type, None); + while let Some(opcode) = iter.next() { + self.single_op(opcode, &func_, &mut stack)?; + } + Ok(()) + } + + fn single_op(&self, opcode: OpCode, func: &FunctionValue, stack: &mut Stack) -> Result<()> { + todo!() + } } diff --git a/src/lib.rs b/src/lib.rs index e03e267..f8ea608 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ mod builtins; mod bytecode; mod stack; mod ty; +mod jit; pub mod compile; pub mod error; diff --git a/src/stack.rs b/src/stack.rs index 34907e5..675321b 100644 --- a/src/stack.rs +++ b/src/stack.rs @@ -3,11 +3,6 @@ use std::ops::Deref; use crate::error::*; -pub struct Stack { - items: [MaybeUninit; CAP], - top: usize, -} - macro_rules! into { ($e:expr) => { // SAFETY: This macro is used to transmute `MaybeUninit>` to `Value<'vm>` @@ -18,6 +13,17 @@ macro_rules! into { }; } +pub struct Stack { + items: [MaybeUninit; CAP], + top: usize, +} + +impl Default for Stack { + fn default() -> Self { + Self::new() + } +} + impl Stack { pub fn new() -> Self { Stack { diff --git a/src/ty/common.rs b/src/ty/common.rs index c958fe7..ced4f3d 100644 --- a/src/ty/common.rs +++ b/src/ty/common.rs @@ -1,6 +1,8 @@ +use std::hash::Hash; use std::fmt::{Display, Formatter, Result as FmtResult}; -use derive_more::Constructor; +use derive_more::{Constructor, IsVariant, Unwrap}; +use ecow::EcoString; #[derive(Clone, Debug, PartialEq, Constructor, Hash)] pub struct Catchable { @@ -12,3 +14,92 @@ impl Display for Catchable { write!(f, "", self.msg) } } + +#[derive(Debug, Clone, IsVariant, Unwrap)] +pub enum Const { + Bool(bool), + Int(i64), + Float(f64), + String(EcoString), + Null, +} + +impl Hash for Const { + fn hash(&self, state: &mut H) { + use Const::*; + std::mem::discriminant(self).hash(state); + match self { + Int(x) => x.hash(state), + Float(x) => x.to_bits().hash(state), + Bool(x) => x.hash(state), + String(x) => x.hash(state), + Null => (), + } + } +} + +impl Display for Const { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + use Const::*; + match self { + Int(x) => write!(f, "{x}"), + Float(x) => write!(f, "{x}"), + Bool(x) => write!(f, "{x}"), + String(x) => write!(f, "{x:?}"), + Null => write!(f, "null"), + } + } +} + +impl From for Const { + fn from(value: bool) -> Self { + Const::Bool(value) + } +} + +impl From for Const { + fn from(value: i64) -> Self { + Const::Int(value) + } +} + +impl From for Const { + fn from(value: f64) -> Self { + Const::Float(value) + } +} + +impl From for Const { + fn from(value: EcoString) -> Self { + Const::String(value) + } +} + +impl From for Const { + fn from(value: String) -> Self { + Const::String(value.into()) + } +} + +impl From<&str> for Const { + fn from(value: &str) -> Self { + Const::String(value.into()) + } +} + +impl PartialEq for Const { + fn eq(&self, other: &Self) -> bool { + use Const::*; + match (self, other) { + (Bool(a), Bool(b)) => a == b, + (Int(a), Int(b)) => a == b, + (Float(a), Float(b)) => a == b, + (Int(a), Float(b)) => *a as f64 == *b, + (Float(a), Int(b)) => *b as f64 == *a, + (String(a), String(b)) => a == b, + _ => false, + } + } +} + +impl Eq for Const {} diff --git a/src/ty/internal/cnst.rs b/src/ty/internal/cnst.rs deleted file mode 100644 index e2ff5ac..0000000 --- a/src/ty/internal/cnst.rs +++ /dev/null @@ -1,79 +0,0 @@ -use std::hash::Hash; - -use derive_more::{IsVariant, Unwrap}; -use ecow::EcoString; - -#[derive(Debug, Clone, IsVariant, Unwrap)] -pub enum Const { - Bool(bool), - Int(i64), - Float(f64), - String(EcoString), - Null, -} - -impl Hash for Const { - fn hash(&self, state: &mut H) { - use Const::*; - match self { - Int(x) => x.hash(state), - Float(x) => x.to_bits().hash(state), - Bool(x) => x.hash(state), - String(x) => x.hash(state), - x @ Null => x.hash(state), - } - } -} - -impl From for Const { - fn from(value: bool) -> Self { - Const::Bool(value) - } -} - -impl From for Const { - fn from(value: i64) -> Self { - Const::Int(value) - } -} - -impl From for Const { - fn from(value: f64) -> Self { - Const::Float(value) - } -} - -impl From for Const { - fn from(value: EcoString) -> Self { - Const::String(value) - } -} - -impl From for Const { - fn from(value: String) -> Self { - Const::String(value.into()) - } -} - -impl From<&str> for Const { - fn from(value: &str) -> Self { - Const::String(value.into()) - } -} - -impl PartialEq for Const { - fn eq(&self, other: &Self) -> bool { - use Const::*; - match (self, other) { - (Bool(a), Bool(b)) => a == b, - (Int(a), Int(b)) => a == b, - (Float(a), Float(b)) => a == b, - (Int(a), Float(b)) => *a as f64 == *b, - (Float(a), Int(b)) => *b as f64 == *a, - (String(a), String(b)) => a == b, - _ => false, - } - } -} - -impl Eq for Const {} diff --git a/src/ty/internal/func.rs b/src/ty/internal/func.rs index d219a04..6ddfcf9 100644 --- a/src/ty/internal/func.rs +++ b/src/ty/internal/func.rs @@ -1,9 +1,12 @@ -use derive_more::Constructor; +use std::cell::{Cell, OnceCell}; + use itertools::Itertools; +use derive_more::Constructor; use crate::bytecode::Func as BFunc; use crate::error::Result; use crate::ir; +use crate::jit::JITFunc; use crate::ty::internal::{Thunk, Value}; use crate::vm::{Env, VM}; @@ -37,14 +40,12 @@ impl From for Param { } } -pub type JITFunc<'vm> = - unsafe extern "C" fn(vm: *mut VM<'_>, *mut Env<'vm>, *mut Value<'vm>) -> Value<'vm>; - #[derive(Debug, Clone, Constructor)] pub struct Func<'vm> { pub func: &'vm BFunc, pub env: Env<'vm>, - pub compiled: Option>, + pub compiled: OnceCell, + pub count: Cell } impl<'vm> Func<'vm> { diff --git a/src/ty/internal/mod.rs b/src/ty/internal/mod.rs index ada48ca..1744d34 100644 --- a/src/ty/internal/mod.rs +++ b/src/ty/internal/mod.rs @@ -6,7 +6,7 @@ use std::rc::Rc; use derive_more::{IsVariant, Unwrap}; -use super::common as c; +use super::common::*; use super::public as p; use crate::bytecode::OpCodes; @@ -14,14 +14,12 @@ use crate::error::*; use crate::vm::{Env, VM}; mod attrset; -mod cnst; mod func; mod list; mod primop; mod string; pub use attrset::*; -pub use cnst::Const; pub use func::*; pub use list::List; pub use primop::*; @@ -33,7 +31,7 @@ pub enum Value<'vm> { ThunkRef(&'vm Thunk<'vm>), AttrSet(Rc>), List(Rc>), - Catchable(c::Catchable), + Catchable(Catchable), PrimOp(Rc>), PartialPrimOp(Rc>), Func(Rc>), @@ -81,7 +79,7 @@ pub enum ValueAsRef<'v, 'vm: 'v> { Thunk(&'v Thunk<'vm>), AttrSet(&'v AttrSet<'vm>), List(&'v List<'vm>), - Catchable(&'v c::Catchable), + Catchable(&'v Catchable), PrimOp(&'v PrimOp<'vm>), PartialPrimOp(&'v PartialPrimOp<'vm>), Func(&'v Func<'vm>), @@ -93,7 +91,7 @@ pub enum ValueAsMut<'v, 'vm: 'v> { Thunk(&'v Thunk<'vm>), AttrSet(&'v mut AttrSet<'vm>), List(&'v mut List<'vm>), - Catchable(&'v mut c::Catchable), + Catchable(&'v mut Catchable), PrimOp(&'v mut PrimOp<'vm>), PartialPrimOp(&'v mut PartialPrimOp<'vm>), Func(&'v Func<'vm>), diff --git a/src/ty/public/mod.rs b/src/ty/public.rs similarity index 98% rename from src/ty/public/mod.rs rename to src/ty/public.rs index 5f31f08..35846ff 100644 --- a/src/ty/public/mod.rs +++ b/src/ty/public.rs @@ -10,10 +10,6 @@ use rpds::VectorSync; use super::common::*; -mod cnst; - -pub use cnst::Const; - #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Constructor)] pub struct Symbol(EcoString); @@ -132,3 +128,4 @@ impl Display for Value { } } } + diff --git a/src/ty/public/cnst.rs b/src/ty/public/cnst.rs deleted file mode 100644 index d5ceca6..0000000 --- a/src/ty/public/cnst.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::fmt::{Display, Formatter, Result as FmtResult}; - -use derive_more::{IsVariant, Unwrap}; -use ecow::EcoString; - -use crate::error::Error; - -use super::super::internal as i; - -#[derive(Debug, Clone, IsVariant, Unwrap)] -pub enum Const { - Bool(bool), - Int(i64), - Float(f64), - String(EcoString), - Null, -} - -impl Display for Const { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - use Const::*; - match self { - Bool(b) => write!(f, "{b}"), - Int(i) => write!(f, "{i}"), - Float(float) => write!(f, "{float}"), - String(s) => write!(f, "{s}"), - Null => write!(f, "null"), - } - } -} - -impl From for Const { - fn from(value: i::Const) -> Self { - use i::Const::*; - match value { - Bool(bool) => Const::Bool(bool), - Int(int) => Const::Int(int), - Float(float) => Const::Float(float), - String(string) => Const::String(string), - Null => Const::Null, - } - } -} - -impl From for Const { - fn from(value: bool) -> Self { - Const::Bool(value) - } -} - -impl From for Const { - fn from(value: i64) -> Self { - Const::Int(value) - } -} - -impl From for Const { - fn from(value: f64) -> Self { - Const::Float(value) - } -} - -impl From for Const { - fn from(value: EcoString) -> Self { - Const::String(value) - } -} - -impl From for Const { - fn from(value: String) -> Self { - Const::String(value.into()) - } -} - -impl From<&str> for Const { - fn from(value: &str) -> Self { - Const::String(value.into()) - } -} - -impl<'a> TryFrom<&'a Const> for &'a bool { - type Error = Error; - - fn try_from(value: &'a Const) -> Result { - match value { - Const::Bool(b) => Ok(b), - _ => panic!(), - } - } -} -impl<'a> TryFrom<&'a Const> for &'a i64 { - type Error = Error; - - fn try_from(value: &'a Const) -> Result { - match value { - Const::Int(int) => Ok(int), - _ => panic!(), - } - } -} - -impl<'a> TryFrom<&'a Const> for &'a f64 { - type Error = Error; - - fn try_from(value: &'a Const) -> Result { - match value { - Const::Float(float) => Ok(float), - _ => panic!(), - } - } -} - -impl<'a> TryFrom<&'a Const> for &'a str { - type Error = Error; - - fn try_from(value: &'a Const) -> Result { - match value { - Const::String(string) => Ok(string), - _ => panic!(), - } - } -} - -impl PartialEq for Const { - fn eq(&self, other: &Self) -> bool { - use Const::*; - match (self, other) { - (Bool(a), Bool(b)) => a == b, - (Int(a), Int(b)) => a == b, - (Float(a), Float(b)) => a == b, - (String(a), String(b)) => a == b, - _ => false, - } - } -} - -impl Eq for Const {} diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 7b18ca4..9d7416b 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -1,37 +1,35 @@ use hashbrown::{HashMap, HashSet}; -use std::cell::RefCell; -use std::pin::Pin; +use std::cell::{Cell, OnceCell, RefCell}; use crate::builtins::env; use crate::bytecode::{BinOp, Func as F, OpCode, OpCodes, Program, UnOp}; use crate::error::*; use crate::ty::internal::*; +use crate::ty::common::Const; use crate::ty::public::{self as p, Symbol}; - use crate::stack::Stack; +use crate::jit::JITContext; use derive_more::Constructor; use ecow::EcoString; pub use env::Env; -pub use jit::JITContext; mod env; -mod jit; #[cfg(test)] mod test; -pub const STACK_SIZE: usize = 8 * 1024 / size_of::(); +const STACK_SIZE: usize = 8 * 1024 / size_of::(); -pub fn run(prog: Program, jit: Pin>>) -> Result { - let vm = VM::new( - prog.thunks, - prog.funcs, - RefCell::new(prog.symbols), - RefCell::new(prog.symmap), - prog.consts, +pub fn run(prog: Program, jit: JITContext<'_>) -> Result { + let vm = VM { + thunks: prog.thunks, + funcs: prog.funcs, + symbols: RefCell::new(prog.symbols), + symmap: RefCell::new(prog.symmap), + consts: prog.consts, jit, - ); + }; let env = env(&vm); let mut seen = HashSet::new(); let value = vm @@ -47,7 +45,7 @@ pub struct VM<'jit> { symbols: RefCell>, symmap: RefCell>, consts: Box<[Const]>, - jit: Pin>>, + jit: JITContext<'jit>, } impl<'vm, 'jit: 'vm> VM<'jit> { @@ -137,7 +135,7 @@ impl<'vm, 'jit: 'vm> VM<'jit> { } OpCode::Func { idx } => { let func = self.get_func(idx); - stack.push(Value::Func(Func::new(func, env.clone(), None).into()))?; + stack.push(Value::Func(Func::new(func, env.clone(), OnceCell::new(), Cell::new(0)).into()))?; } OpCode::UnOp { op } => { use UnOp::*; diff --git a/src/vm/test.rs b/src/vm/test.rs index 847714a..17631af 100644 --- a/src/vm/test.rs +++ b/src/vm/test.rs @@ -10,8 +10,8 @@ use rpds::vector_sync; use crate::compile::compile; use crate::ir::downgrade; -use crate::ty::public::Symbol; use crate::ty::public::*; +use crate::ty::common::Const; use crate::vm::JITContext; use super::run;