feat(jit): lookup
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
use inkwell::AddressSpace;
|
use inkwell::AddressSpace;
|
||||||
use inkwell::context::Context;
|
use inkwell::context::Context;
|
||||||
use inkwell::execution_engine::ExecutionEngine;
|
use inkwell::execution_engine::ExecutionEngine;
|
||||||
@@ -21,10 +23,12 @@ pub struct Helpers<'ctx> {
|
|||||||
pub func_type: FunctionType<'ctx>,
|
pub func_type: FunctionType<'ctx>,
|
||||||
|
|
||||||
pub debug: FunctionValue<'ctx>,
|
pub debug: FunctionValue<'ctx>,
|
||||||
pub captured_env: FunctionValue<'ctx>,
|
pub capture_env: FunctionValue<'ctx>,
|
||||||
pub neg: FunctionValue<'ctx>,
|
pub neg: FunctionValue<'ctx>,
|
||||||
pub add: FunctionValue<'ctx>,
|
pub add: FunctionValue<'ctx>,
|
||||||
|
pub sub: FunctionValue<'ctx>,
|
||||||
pub call: FunctionValue<'ctx>,
|
pub call: FunctionValue<'ctx>,
|
||||||
|
pub lookup: FunctionValue<'ctx>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> Helpers<'ctx> {
|
impl<'ctx> Helpers<'ctx> {
|
||||||
@@ -40,7 +44,7 @@ impl<'ctx> Helpers<'ctx> {
|
|||||||
let ptr_type = context.ptr_type(AddressSpace::default());
|
let ptr_type = context.ptr_type(AddressSpace::default());
|
||||||
let value_type = context.struct_type(&[int_type.into(), int_type.into()], false);
|
let value_type = context.struct_type(&[int_type.into(), int_type.into()], false);
|
||||||
let func_type = value_type.fn_type(
|
let func_type = value_type.fn_type(
|
||||||
&[ptr_type.into(), ptr_type.into(), value_type.into()],
|
&[ptr_type.into(), ptr_type.into()],
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
let debug = module.add_function(
|
let debug = module.add_function(
|
||||||
@@ -50,7 +54,7 @@ impl<'ctx> Helpers<'ctx> {
|
|||||||
.fn_type(&[value_type.into()], false),
|
.fn_type(&[value_type.into()], false),
|
||||||
None
|
None
|
||||||
);
|
);
|
||||||
let captured_env = module.add_function(
|
let capture_env = module.add_function(
|
||||||
"capture_env",
|
"capture_env",
|
||||||
context
|
context
|
||||||
.void_type()
|
.void_type()
|
||||||
@@ -67,19 +71,31 @@ impl<'ctx> Helpers<'ctx> {
|
|||||||
value_type.fn_type(&[value_type.into(), value_type.into()], false),
|
value_type.fn_type(&[value_type.into(), value_type.into()], false),
|
||||||
None
|
None
|
||||||
);
|
);
|
||||||
|
let sub = module.add_function(
|
||||||
|
"sub",
|
||||||
|
value_type.fn_type(&[value_type.into(), value_type.into()], false),
|
||||||
|
None
|
||||||
|
);
|
||||||
// Assuming a single argument for now based on the test case
|
// Assuming a single argument for now based on the test case
|
||||||
let call = module.add_function(
|
let call = module.add_function(
|
||||||
"call",
|
"call",
|
||||||
value_type.fn_type(&[ptr_type.into(), ptr_type.into(), ptr_type.into(), value_type.into()], false),
|
value_type.fn_type(&[ptr_type.into(), ptr_type.into(), ptr_type.into(), value_type.into()], false),
|
||||||
None
|
None
|
||||||
);
|
);
|
||||||
|
let lookup = module.add_function(
|
||||||
|
"lookup",
|
||||||
|
value_type.fn_type(&[ptr_int_type.into(), ptr_type.into()], false),
|
||||||
|
None
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
execution_engine.add_global_mapping(&debug, helper_debug as _);
|
execution_engine.add_global_mapping(&debug, helper_debug as _);
|
||||||
execution_engine.add_global_mapping(&captured_env, helper_capture_env as _);
|
execution_engine.add_global_mapping(&capture_env, helper_capture_env as _);
|
||||||
execution_engine.add_global_mapping(&neg, helper_neg as _);
|
execution_engine.add_global_mapping(&neg, helper_neg as _);
|
||||||
execution_engine.add_global_mapping(&add, helper_add as _);
|
execution_engine.add_global_mapping(&add, helper_add as _);
|
||||||
|
execution_engine.add_global_mapping(&sub, helper_sub as _);
|
||||||
execution_engine.add_global_mapping(&call, helper_call as _);
|
execution_engine.add_global_mapping(&call, helper_call as _);
|
||||||
|
execution_engine.add_global_mapping(&lookup, helper_lookup as _);
|
||||||
|
|
||||||
|
|
||||||
Helpers {
|
Helpers {
|
||||||
@@ -92,10 +108,12 @@ impl<'ctx> Helpers<'ctx> {
|
|||||||
func_type,
|
func_type,
|
||||||
|
|
||||||
debug,
|
debug,
|
||||||
captured_env,
|
capture_env,
|
||||||
neg,
|
neg,
|
||||||
add,
|
add,
|
||||||
|
sub,
|
||||||
call,
|
call,
|
||||||
|
lookup,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,9 +212,31 @@ extern "C" fn helper_add(lhs: JITValue, rhs: JITValue) -> JITValue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[unsafe(no_mangle)]
|
#[unsafe(no_mangle)]
|
||||||
extern "C" fn helper_call<'jit, 'vm>(vm: *const VM<'jit>, env: *const Env<'jit, 'vm>, func_ptr: *const (), arg: JITValue) -> JITValue {
|
extern "C" fn helper_sub(lhs: JITValue, rhs: JITValue) -> JITValue {
|
||||||
let func: JITFunc = unsafe { std::mem::transmute(func_ptr) };
|
use ValueTag::*;
|
||||||
unsafe {
|
match (lhs.tag, rhs.tag) {
|
||||||
func(vm, env, arg)
|
(Int, Int) => JITValue {
|
||||||
|
tag: Int,
|
||||||
|
data: JITValueData {
|
||||||
|
int: unsafe { lhs.data.int - rhs.data.int },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
_ => todo!("Substruction not implemented for {:?} and {:?}", lhs.tag, rhs.tag),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[unsafe(no_mangle)]
|
||||||
|
extern "C" fn helper_call<'jit, 'vm>(vm: *const VM<'jit>, env: *const Env<'jit, 'vm>, func_ptr: *const (), arg: JITValue) -> JITValue {
|
||||||
|
let func: JITFunc = unsafe { std::mem::transmute(func_ptr) };
|
||||||
|
todo!();
|
||||||
|
unsafe {
|
||||||
|
func(vm, env)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[unsafe(no_mangle)]
|
||||||
|
extern "C" fn helper_lookup<'jit, 'vm>(sym: usize, env: *const Env<'jit, 'vm>) -> JITValue {
|
||||||
|
let env = unsafe { env.as_ref() }.unwrap();
|
||||||
|
let val = env.lookup(sym);
|
||||||
|
val.unwrap().into()
|
||||||
|
}
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ impl From<Value<'_, '_>> for JITValue {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type JITFunc<'jit, 'vm> = unsafe extern "C" fn(*const VM<'jit>, *const Env<'jit, 'vm>, JITValue) -> JITValue;
|
pub type JITFunc<'jit, 'vm> = unsafe extern "C" fn(*const VM<'jit>, *const Env<'jit, 'vm>) -> JITValue;
|
||||||
|
|
||||||
pub struct JITContext<'ctx> {
|
pub struct JITContext<'ctx> {
|
||||||
context: &'ctx Context,
|
context: &'ctx Context,
|
||||||
@@ -89,8 +89,6 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
|
|||||||
let execution_engine = module
|
let execution_engine = module
|
||||||
.create_jit_execution_engine(OptimizationLevel::Default)
|
.create_jit_execution_engine(OptimizationLevel::Default)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
||||||
let helpers = Helpers::new(context, &module, &execution_engine);
|
let helpers = Helpers::new(context, &module, &execution_engine);
|
||||||
|
|
||||||
JITContext {
|
JITContext {
|
||||||
@@ -99,7 +97,6 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
|
|||||||
context,
|
context,
|
||||||
module,
|
module,
|
||||||
|
|
||||||
|
|
||||||
helpers
|
helpers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -114,12 +111,12 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
|
|||||||
let func_ = self.module.add_function("nixjit_function", self.helpers.func_type, None);
|
let func_ = self.module.add_function("nixjit_function", self.helpers.func_type, None);
|
||||||
let entry = self.context.append_basic_block(func_, "entry");
|
let entry = self.context.append_basic_block(func_, "entry");
|
||||||
self.builder.position_at_end(entry);
|
self.builder.position_at_end(entry);
|
||||||
|
let env = func_.get_nth_param(1).unwrap().into_pointer_value();
|
||||||
while let Some(opcode) = iter.next() {
|
while let Some(opcode) = iter.next() {
|
||||||
self.single_op(opcode, vm, &mut stack)?;
|
self.single_op(opcode, vm, env, &mut stack)?;
|
||||||
}
|
}
|
||||||
assert_eq!(stack.len(), 1);
|
assert_eq!(stack.len(), 1);
|
||||||
let value = stack.pop();
|
let value = stack.pop();
|
||||||
self.builder.build_direct_call(self.helpers.debug, &[value.into()], "call_debug").unwrap();
|
|
||||||
self.builder.build_return(Some(&value))?;
|
self.builder.build_return(Some(&value))?;
|
||||||
if func_.verify(false) {
|
if func_.verify(false) {
|
||||||
func_.print_to_stderr();
|
func_.print_to_stderr();
|
||||||
@@ -137,6 +134,7 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
|
|||||||
&self,
|
&self,
|
||||||
opcode: OpCode,
|
opcode: OpCode,
|
||||||
vm: &'vm VM<'_>,
|
vm: &'vm VM<'_>,
|
||||||
|
env: PointerValue<'ctx>,
|
||||||
stack: &mut Stack<BasicValueEnum<'ctx>, CAP>,
|
stack: &mut Stack<BasicValueEnum<'ctx>, CAP>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
match opcode {
|
match opcode {
|
||||||
@@ -151,48 +149,59 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
OpCode::LoadThunk { idx } => stack.push(self.helpers.new_thunk(Rc::into_raw(Rc::new(Thunk::new(vm.get_thunk(idx))))))?,
|
OpCode::LoadThunk { idx } => stack.push(self.helpers.new_thunk(Rc::into_raw(Rc::new(Thunk::new(vm.get_thunk(idx))))))?,
|
||||||
/* OpCode::CaptureEnv => {
|
OpCode::CaptureEnv => {
|
||||||
let thunk = *stack.tos()?;
|
let thunk = *stack.tos()?;
|
||||||
self.builder.build_direct_call(self.helpers.captured_env, &[thunk.into(), self.new_ptr(env).into()], "call_capture_env")?;
|
self.builder.build_direct_call(self.helpers.capture_env, &[thunk.into(), env.into()], "call_capture_env")?;
|
||||||
} */
|
}
|
||||||
OpCode::UnOp { op } => {
|
OpCode::UnOp { op } => {
|
||||||
use UnOp::*;
|
use UnOp::*;
|
||||||
let rhs = stack.pop();
|
let rhs = stack.pop();
|
||||||
stack.push(match op {
|
stack.push(match op {
|
||||||
Neg => self.builder.build_direct_call(self.helpers.neg, &[rhs.into(), self.new_ptr(std::ptr::null::<Env>()).into()], "call_neg")?.try_as_basic_value().left().unwrap(),
|
Neg => self.builder.build_direct_call(self.helpers.neg, &[rhs.into(), env.into()], "call_neg")?.try_as_basic_value().left().unwrap(),
|
||||||
_ => todo!()
|
_ => todo!()
|
||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
/* OpCode::Func { idx } => {
|
OpCode::Func { idx } => {
|
||||||
let func = vm.get_func(idx);
|
let func = vm.get_func(idx);
|
||||||
let jit_func_ptr = self.compile_function(&Func::new(func, unsafe { env.as_ref() }.unwrap().clone(), OnceCell::new(), Cell::new(1)), vm)?;
|
let jit_func_ptr = self.compile_function(func, vm)?;
|
||||||
let jit_value = self.helpers.value_type.const_named_struct(&[
|
let jit_value = self.helpers.value_type.const_named_struct(&[
|
||||||
self.helpers.int_type.const_int(ValueTag::Function as _, false).into(),
|
self.helpers.int_type.const_int(ValueTag::Function as _, false).into(),
|
||||||
self.helpers.ptr_int_type.const_int(jit_func_ptr as *const _ as _, false).into(),
|
self.helpers.ptr_int_type.const_int(unsafe { jit_func_ptr.as_raw() } as _, false).into(),
|
||||||
]).into();
|
]).into();
|
||||||
stack.push(jit_value)?;
|
stack.push(jit_value)?;
|
||||||
} */
|
}
|
||||||
OpCode::Call { arity } => {
|
OpCode::Call { arity } => {
|
||||||
// Assuming arity is 1 for the test case
|
// Assuming arity is 1 for the test case
|
||||||
assert_eq!(arity, 1);
|
assert_eq!(arity, 1);
|
||||||
let arg = stack.pop();
|
let arg = stack.pop();
|
||||||
let func_value = stack.pop();
|
let func_value = stack.pop();
|
||||||
let func_ptr = self.builder.build_extract_value(func_value.into_struct_value(), 1, "func_ptr")?.into_pointer_value();
|
let func_ptr = self.builder.build_extract_value(func_value.into_struct_value(), 1, "func_ptr")?.into_pointer_value();
|
||||||
let result = self.builder.build_direct_call(self.helpers.call, &[self.new_ptr(vm).into(), self.new_ptr(std::ptr::null::<Env>()).into(), func_ptr.into(), arg.into()], "call_func")?.try_as_basic_value().left().unwrap();
|
let result = self.builder.build_direct_call(self.helpers.call, &[self.new_ptr(vm).into(), env.into(), func_ptr.into(), arg.into()], "call_func")?.try_as_basic_value().left().unwrap();
|
||||||
stack.push(result)?;
|
stack.push(result)?;
|
||||||
}
|
}
|
||||||
OpCode::BinOp { op } => {
|
OpCode::BinOp { op } => {
|
||||||
use crate::bytecode::BinOp;
|
use crate::bytecode::BinOp;
|
||||||
match op {
|
|
||||||
BinOp::Add => {
|
|
||||||
let rhs = stack.pop();
|
let rhs = stack.pop();
|
||||||
let lhs = stack.pop();
|
let lhs = stack.pop();
|
||||||
|
match op {
|
||||||
|
BinOp::Add => {
|
||||||
|
let result = self.builder.build_direct_call(self.helpers.add, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap();
|
||||||
|
stack.push(result)?;
|
||||||
|
}
|
||||||
|
BinOp::Sub => {
|
||||||
|
let result = self.builder.build_direct_call(self.helpers.sub, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap();
|
||||||
|
stack.push(result)?;
|
||||||
|
}
|
||||||
|
BinOp::Eq => {
|
||||||
let result = self.builder.build_direct_call(self.helpers.add, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap();
|
let result = self.builder.build_direct_call(self.helpers.add, &[lhs.into(), rhs.into()], "call_add")?.try_as_basic_value().left().unwrap();
|
||||||
stack.push(result)?;
|
stack.push(result)?;
|
||||||
}
|
}
|
||||||
_ => todo!("BinOp::{:?} not implemented in JIT", op),
|
_ => todo!("BinOp::{:?} not implemented in JIT", op),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
OpCode::LookUp { sym } => {
|
||||||
|
stack.push(self.builder.build_direct_call(self.helpers.lookup, &[self.helpers.ptr_int_type.const_int(sym as u64, false).into(), env.into()], "call_lookup").unwrap().try_as_basic_value().left().unwrap())?
|
||||||
|
}
|
||||||
_ => todo!("{opcode:?} not implemented")
|
_ => todo!("{opcode:?} not implemented")
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -53,17 +53,8 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> {
|
|||||||
pub fn call(&self, vm: &'vm VM<'jit>, arg: Value<'jit, 'vm>) -> Result<Value<'jit, 'vm>> {
|
pub fn call(&self, vm: &'vm VM<'jit>, arg: Value<'jit, 'vm>) -> Result<Value<'jit, 'vm>> {
|
||||||
use Param::*;
|
use Param::*;
|
||||||
|
|
||||||
let count = self.count.get();
|
let env = match self.func.param.clone() {
|
||||||
if count >= 1 {
|
Ident(ident) => self.env.clone().enter([(ident.into(), arg)].into_iter()),
|
||||||
let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func));
|
|
||||||
let ret = unsafe { compiled.call(vm as _, &self.env as _, arg.into()) };
|
|
||||||
return Ok(ret.into())
|
|
||||||
}
|
|
||||||
self.count.replace(count + 1);
|
|
||||||
let mut env = self.env.clone();
|
|
||||||
|
|
||||||
match self.func.param.clone() {
|
|
||||||
Ident(ident) => env = env.enter([(ident.into(), arg)].into_iter()),
|
|
||||||
Formals {
|
Formals {
|
||||||
formals,
|
formals,
|
||||||
ellipsis,
|
ellipsis,
|
||||||
@@ -94,10 +85,17 @@ impl<'vm, 'jit: 'vm> Func<'jit, 'vm> {
|
|||||||
if let Some(alias) = alias {
|
if let Some(alias) = alias {
|
||||||
new.push((alias.clone().into(), Value::AttrSet(arg)));
|
new.push((alias.clone().into(), Value::AttrSet(arg)));
|
||||||
}
|
}
|
||||||
env = env.enter(new.into_iter());
|
self.env.clone().enter(new.into_iter())
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let count = self.count.get();
|
||||||
|
self.count.replace(count + 1);
|
||||||
|
if count >= 1 {
|
||||||
|
let compiled = self.compiled.get_or_init(|| vm.compile_func(self.func));
|
||||||
|
let ret = unsafe { compiled.call(vm as *const VM, &env as *const Env) };
|
||||||
|
return Ok(ret.into())
|
||||||
|
}
|
||||||
vm.eval(self.func.opcodes.iter().copied(), env)
|
vm.eval(self.func.opcodes.iter().copied(), env)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ use crate::ty::internal::{AttrSet, Value};
|
|||||||
|
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Debug, Default, Clone)]
|
||||||
pub struct Env<'jit, 'vm> {
|
pub struct Env<'jit, 'vm> {
|
||||||
last: Option<Rc<Env<'jit, 'vm>>>,
|
pub map: Rc<HashMap<usize, Value<'jit, 'vm>>>,
|
||||||
map: Rc<HashMap<usize, Value<'jit, 'vm>>>,
|
pub last: Option<Rc<Env<'jit, 'vm>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'jit, 'vm> Env<'jit, 'vm> {
|
impl<'jit, 'vm> Env<'jit, 'vm> {
|
||||||
|
|||||||
Reference in New Issue
Block a user