optimize: make all call single arg

to allow more aggressive optimization
This commit is contained in:
2025-05-23 09:21:40 +08:00
parent f380e5fd70
commit 53cbb37b00
8 changed files with 60 additions and 101 deletions

View File

@@ -17,8 +17,9 @@ pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> VmEnv<'jit, 'vm> {
first.ok() first.ok()
}), }),
PrimOp::new("sub", 2, |_, args| { PrimOp::new("sub", 2, |_, args| {
let [mut first, second]: [Value; 2] = args.try_into().unwrap(); let [mut first, mut second]: [Value; 2] = args.try_into().unwrap();
first.add(second.neg()); second.neg();
first.add(second);
first.ok() first.ok()
}), }),
PrimOp::new("mul", 2, |_, args| { PrimOp::new("mul", 2, |_, args| {

View File

@@ -22,9 +22,8 @@ pub enum OpCode {
/// force TOS to value /// force TOS to value
ForceValue, ForceValue,
/// [ .. func args @ .. ] consume (`arity` + 1) elements, call `func` with args` of length `arity` /// [ .. func arg ] consume 2 elements, call `func` with arg
/// Example: __add 1 2 => [ LookUp("__add") Const(1) Const(2) Call(2) ] Call,
Call { arity: usize },
/// make a function /// make a function
Func { idx: usize }, Func { idx: usize },

View File

@@ -237,12 +237,12 @@ impl Compile for ir::BinOp {
PipeL => { PipeL => {
self.lhs.compile(comp); self.lhs.compile(comp);
self.rhs.compile(comp); self.rhs.compile(comp);
comp.push(OpCode::Call { arity: 1 }); comp.push(OpCode::Call);
} }
PipeR => { PipeR => {
self.rhs.compile(comp); self.rhs.compile(comp);
self.lhs.compile(comp); self.lhs.compile(comp);
comp.push(OpCode::Call { arity: 1 }); comp.push(OpCode::Call);
} }
} }
} }
@@ -395,12 +395,11 @@ impl Compile for ir::LoadFunc {
impl Compile for ir::Call { impl Compile for ir::Call {
fn compile(self, comp: &mut Compiler) { fn compile(self, comp: &mut Compiler) {
let arity = self.args.len();
self.func.compile(comp); self.func.compile(comp);
self.args.into_iter().for_each(|arg| { self.args.into_iter().for_each(|arg| {
arg.compile(comp); arg.compile(comp);
comp.push(OpCode::Call);
}); });
comp.push(OpCode::Call { arity });
} }
} }

View File

@@ -95,8 +95,7 @@ impl<'ctx> Helpers<'ctx> {
value_type.fn_type( value_type.fn_type(
&[ &[
value_type.into(), value_type.into(),
ptr_type.into(), value_type.into(),
ptr_int_type.into(),
ptr_type.into(), ptr_type.into(),
], ],
false, false,
@@ -317,21 +316,16 @@ extern "C" fn helper_or(lhs: JITValue, rhs: JITValue) -> JITValue {
extern "C" fn helper_call<'jit>( extern "C" fn helper_call<'jit>(
func: JITValue, func: JITValue,
args: *mut JITValue, arg: JITValue,
arity: usize,
vm: *const VM<'jit>, vm: *const VM<'jit>,
) -> JITValue { ) -> JITValue {
use ValueTag::*; use ValueTag::*;
let args = unsafe { Vec::from_raw_parts(args, arity, arity) }
.into_iter()
.map(Value::from)
.collect();
match func.tag { match func.tag {
Function => { Function => {
let func: Value = func.into(); let mut func: Value = func.into();
func.call(unsafe { vm.as_ref() }.unwrap(), args) func.call(unsafe { vm.as_ref() }.unwrap(), arg.into())
.unwrap() .unwrap();
.into() func.into()
} }
_ => todo!(), _ => todo!(),
} }

View File

@@ -392,23 +392,8 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
.left() .left()
.unwrap(), .unwrap(),
)?, )?,
OpCode::Call { arity } => { OpCode::Call => {
let args = self.builder.build_array_malloc( let arg = stack.pop();
self.helpers.value_type,
self.helpers.int_type.const_int(arity as u64, false),
"malloc_args",
)?;
for i in 0..arity {
let ptr = unsafe {
self.builder.build_in_bounds_gep(
self.helpers.value_type,
args,
&[self.helpers.int_type.const_int(i as u64, false)],
"gep_arg",
)?
};
self.builder.build_store(ptr, stack.pop())?;
}
let func = self let func = self
.builder .builder
.build_direct_call( .build_direct_call(
@@ -425,11 +410,7 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
self.helpers.call, self.helpers.call,
&[ &[
func.into(), func.into(),
args.into(), arg.into(),
self.helpers
.ptr_int_type
.const_int(arity as u64, false)
.into(),
self.new_ptr(vm).into(), self.new_ptr(vm).into(),
], ],
"call", "call",

View File

@@ -175,38 +175,27 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
} }
} }
pub fn call(&self, vm: &'vm VM<'jit>, args: Vec<Self>) -> Result<Self> { pub fn call(&mut self, vm: &'vm VM<'jit>, arg: Self) -> Result<()> {
use Value::*; use Value::*;
match self { if matches!(arg, Value::Catchable(_)) {
PrimOp(func) => func.call(vm, args), *self = arg;
PartialPrimOp(func) => func.call(vm, args), return Ok(());
func @ Value::Func(_) => {
let mut iter = args.into_iter();
let mut func = func.clone();
while let Some(arg) = iter.next() {
func = match func {
PrimOp(func) => {
return func.call(vm, [arg].into_iter().chain(iter).collect());
}
PartialPrimOp(func) => {
return func.call(vm, [arg].into_iter().chain(iter).collect());
}
Func(func) => func.call(vm, arg)?,
_ => todo!(),
}
}
func.ok()
}
x @ Catchable(_) => x.clone().ok(),
_ => todo!(),
} }
*self = match self {
PrimOp(func) => func.call(vm, arg),
PartialPrimOp(func) => func.call(vm, arg),
Value::Func(func) => func.call(vm, arg),
Catchable(_) => return Ok(()),
_ => todo!(),
}?;
Ok(())
} }
pub fn not(self) -> Self { pub fn not(&mut self) {
use Const::*; use Const::*;
match self { *self = match &*self {
VmConst(Bool(bool)) => VmConst(Bool(!bool)), VmConst(Bool(bool)) => VmConst(Bool(!bool)),
x @ Value::Catchable(_) => x, Value::Catchable(_) => return,
_ => todo!(), _ => todo!(),
} }
} }
@@ -257,12 +246,12 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
})) }))
} }
pub fn neg(self) -> Self { pub fn neg(&mut self) {
use Const::*; use Const::*;
match self { *self = match &*self {
VmConst(Int(int)) => VmConst(Int(-int)), VmConst(Int(int)) => VmConst(Int(-int)),
VmConst(Float(float)) => VmConst(Float(-float)), VmConst(Float(float)) => VmConst(Float(-float)),
x @ Value::Catchable(_) => x, Value::Catchable(_) => return,
_ => todo!(), _ => todo!(),
} }
} }

View File

@@ -21,22 +21,22 @@ impl PartialEq for PrimOp<'_, '_> {
} }
impl<'jit, 'vm> PrimOp<'jit, 'vm> { impl<'jit, 'vm> PrimOp<'jit, 'vm> {
pub fn call(&self, vm: &'vm VM<'jit>, args: Vec<Value<'jit, 'vm>>) -> Result<Value<'jit, 'vm>> { pub fn call(&self, vm: &'vm VM<'jit>, arg: Value<'jit, 'vm>) -> Result<Value<'jit, 'vm>> {
if (args.len()) < self.arity { let mut args = Vec::with_capacity(self.arity);
args.push(arg);
if self.arity > 1 {
Value::PartialPrimOp( Value::PartialPrimOp(
PartialPrimOp { PartialPrimOp {
name: self.name, name: self.name,
arity: self.arity - args.len(), arity: self.arity - 1,
args, args,
func: self.func, func: self.func,
} }
.into(), .into(),
) )
.ok() .ok()
} else if args.len() == self.arity {
(self.func)(vm, args)
} else { } else {
unimplemented!() (self.func)(vm, args)
} }
} }
} }
@@ -57,22 +57,18 @@ impl PartialEq for PartialPrimOp<'_, '_> {
impl<'jit: 'vm, 'vm> PartialPrimOp<'jit, 'vm> { impl<'jit: 'vm, 'vm> PartialPrimOp<'jit, 'vm> {
pub fn call( pub fn call(
self: &Rc<Self>, self: &mut Rc<Self>,
vm: &'vm VM<'jit>, vm: &'vm VM<'jit>,
args: Vec<Value<'jit, 'vm>>, arg: Value<'jit, 'vm>,
) -> Result<Value<'jit, 'vm>> { ) -> Result<Value<'jit, 'vm>> {
let len = args.len(); let self_mut = Rc::make_mut(self);
let mut self_clone = self.clone(); self_mut.args.push(arg);
let self_mut = Rc::make_mut(&mut self_clone); self_mut.arity -= 1;
self_mut.args.extend(args);
self_mut.arity -= len;
if self_mut.arity > 0 { if self_mut.arity > 0 {
Value::PartialPrimOp(self_clone).ok() Value::PartialPrimOp(self.clone()).ok()
} else if self_mut.arity == 0 { } else {
let args = std::mem::take(&mut self_mut.args); let args = std::mem::take(&mut self_mut.args);
(self.func)(vm, args) (self.func)(vm, args)
} else { }
unimplemented!()
}
} }
} }

View File

@@ -122,14 +122,11 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
return Ok(step); return Ok(step);
} }
} }
OpCode::Call { arity } => { OpCode::Call => {
let mut args = Vec::with_capacity(arity); let arg = stack.pop();
for _ in 0..arity { let func = stack.tos_mut();
args.insert(0, stack.pop());
}
let mut func = stack.pop();
func.force(self)?; func.force(self)?;
stack.push(func.call(self, args)?)?; func.call(self, arg)?;
} }
OpCode::Func { idx } => { OpCode::Func { idx } => {
let func = self.get_func(idx); let func = self.get_func(idx);
@@ -139,12 +136,12 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
} }
OpCode::UnOp { op } => { OpCode::UnOp { op } => {
use UnOp::*; use UnOp::*;
let mut value = stack.pop(); let value = stack.tos_mut();
value.force(self)?; value.force(self)?;
stack.push(match op { match op {
Neg => value.neg(), Neg => value.neg(),
Not => value.not(), Not => value.not(),
})?; }
} }
OpCode::BinOp { op } => { OpCode::BinOp { op } => {
use BinOp::*; use BinOp::*;
@@ -154,7 +151,10 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
rhs.force(self)?; rhs.force(self)?;
match op { match op {
Add => lhs.add(rhs), Add => lhs.add(rhs),
Sub => lhs.add(rhs.neg()), Sub => {
rhs.neg();
lhs.add(rhs);
},
Mul => lhs.mul(rhs), Mul => lhs.mul(rhs),
Div => lhs.div(rhs)?, Div => lhs.div(rhs)?,
And => lhs.and(rhs), And => lhs.and(rhs),