optimize: make all call single arg

to allow more aggressive optimization
This commit is contained in:
2025-05-23 09:21:40 +08:00
parent f380e5fd70
commit 53cbb37b00
8 changed files with 60 additions and 101 deletions

View File

@@ -17,8 +17,9 @@ pub fn env<'jit, 'vm>(vm: &'vm VM<'jit>) -> VmEnv<'jit, 'vm> {
first.ok()
}),
PrimOp::new("sub", 2, |_, args| {
let [mut first, second]: [Value; 2] = args.try_into().unwrap();
first.add(second.neg());
let [mut first, mut second]: [Value; 2] = args.try_into().unwrap();
second.neg();
first.add(second);
first.ok()
}),
PrimOp::new("mul", 2, |_, args| {

View File

@@ -22,9 +22,8 @@ pub enum OpCode {
/// force TOS to value
ForceValue,
/// [ .. func args @ .. ] consume (`arity` + 1) elements, call `func` with args` of length `arity`
/// Example: __add 1 2 => [ LookUp("__add") Const(1) Const(2) Call(2) ]
Call { arity: usize },
/// [ .. func arg ] consume 2 elements, call `func` with arg
Call,
/// make a function
Func { idx: usize },

View File

@@ -237,12 +237,12 @@ impl Compile for ir::BinOp {
PipeL => {
self.lhs.compile(comp);
self.rhs.compile(comp);
comp.push(OpCode::Call { arity: 1 });
comp.push(OpCode::Call);
}
PipeR => {
self.rhs.compile(comp);
self.lhs.compile(comp);
comp.push(OpCode::Call { arity: 1 });
comp.push(OpCode::Call);
}
}
}
@@ -395,12 +395,11 @@ impl Compile for ir::LoadFunc {
impl Compile for ir::Call {
fn compile(self, comp: &mut Compiler) {
let arity = self.args.len();
self.func.compile(comp);
self.args.into_iter().for_each(|arg| {
arg.compile(comp);
comp.push(OpCode::Call);
});
comp.push(OpCode::Call { arity });
}
}

View File

@@ -95,8 +95,7 @@ impl<'ctx> Helpers<'ctx> {
value_type.fn_type(
&[
value_type.into(),
ptr_type.into(),
ptr_int_type.into(),
value_type.into(),
ptr_type.into(),
],
false,
@@ -317,21 +316,16 @@ extern "C" fn helper_or(lhs: JITValue, rhs: JITValue) -> JITValue {
extern "C" fn helper_call<'jit>(
func: JITValue,
args: *mut JITValue,
arity: usize,
arg: JITValue,
vm: *const VM<'jit>,
) -> JITValue {
use ValueTag::*;
let args = unsafe { Vec::from_raw_parts(args, arity, arity) }
.into_iter()
.map(Value::from)
.collect();
match func.tag {
Function => {
let func: Value = func.into();
func.call(unsafe { vm.as_ref() }.unwrap(), args)
.unwrap()
.into()
let mut func: Value = func.into();
func.call(unsafe { vm.as_ref() }.unwrap(), arg.into())
.unwrap();
func.into()
}
_ => todo!(),
}

View File

@@ -392,23 +392,8 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
.left()
.unwrap(),
)?,
OpCode::Call { arity } => {
let args = self.builder.build_array_malloc(
self.helpers.value_type,
self.helpers.int_type.const_int(arity as u64, false),
"malloc_args",
)?;
for i in 0..arity {
let ptr = unsafe {
self.builder.build_in_bounds_gep(
self.helpers.value_type,
args,
&[self.helpers.int_type.const_int(i as u64, false)],
"gep_arg",
)?
};
self.builder.build_store(ptr, stack.pop())?;
}
OpCode::Call => {
let arg = stack.pop();
let func = self
.builder
.build_direct_call(
@@ -425,11 +410,7 @@ impl<'vm, 'ctx: 'vm> JITContext<'ctx> {
self.helpers.call,
&[
func.into(),
args.into(),
self.helpers
.ptr_int_type
.const_int(arity as u64, false)
.into(),
arg.into(),
self.new_ptr(vm).into(),
],
"call",

View File

@@ -175,38 +175,27 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
}
}
pub fn call(&self, vm: &'vm VM<'jit>, args: Vec<Self>) -> Result<Self> {
pub fn call(&mut self, vm: &'vm VM<'jit>, arg: Self) -> Result<()> {
use Value::*;
match self {
PrimOp(func) => func.call(vm, args),
PartialPrimOp(func) => func.call(vm, args),
func @ Value::Func(_) => {
let mut iter = args.into_iter();
let mut func = func.clone();
while let Some(arg) = iter.next() {
func = match func {
PrimOp(func) => {
return func.call(vm, [arg].into_iter().chain(iter).collect());
if matches!(arg, Value::Catchable(_)) {
*self = arg;
return Ok(());
}
PartialPrimOp(func) => {
return func.call(vm, [arg].into_iter().chain(iter).collect());
}
Func(func) => func.call(vm, arg)?,
*self = match self {
PrimOp(func) => func.call(vm, arg),
PartialPrimOp(func) => func.call(vm, arg),
Value::Func(func) => func.call(vm, arg),
Catchable(_) => return Ok(()),
_ => todo!(),
}
}
func.ok()
}
x @ Catchable(_) => x.clone().ok(),
_ => todo!(),
}
}?;
Ok(())
}
pub fn not(self) -> Self {
pub fn not(&mut self) {
use Const::*;
match self {
*self = match &*self {
VmConst(Bool(bool)) => VmConst(Bool(!bool)),
x @ Value::Catchable(_) => x,
Value::Catchable(_) => return,
_ => todo!(),
}
}
@@ -257,12 +246,12 @@ impl<'jit, 'vm> Value<'jit, 'vm> {
}))
}
pub fn neg(self) -> Self {
pub fn neg(&mut self) {
use Const::*;
match self {
*self = match &*self {
VmConst(Int(int)) => VmConst(Int(-int)),
VmConst(Float(float)) => VmConst(Float(-float)),
x @ Value::Catchable(_) => x,
Value::Catchable(_) => return,
_ => todo!(),
}
}

View File

@@ -21,22 +21,22 @@ impl PartialEq for PrimOp<'_, '_> {
}
impl<'jit, 'vm> PrimOp<'jit, 'vm> {
pub fn call(&self, vm: &'vm VM<'jit>, args: Vec<Value<'jit, 'vm>>) -> Result<Value<'jit, 'vm>> {
if (args.len()) < self.arity {
pub fn call(&self, vm: &'vm VM<'jit>, arg: Value<'jit, 'vm>) -> Result<Value<'jit, 'vm>> {
let mut args = Vec::with_capacity(self.arity);
args.push(arg);
if self.arity > 1 {
Value::PartialPrimOp(
PartialPrimOp {
name: self.name,
arity: self.arity - args.len(),
arity: self.arity - 1,
args,
func: self.func,
}
.into(),
)
.ok()
} else if args.len() == self.arity {
(self.func)(vm, args)
} else {
unimplemented!()
(self.func)(vm, args)
}
}
}
@@ -57,22 +57,18 @@ impl PartialEq for PartialPrimOp<'_, '_> {
impl<'jit: 'vm, 'vm> PartialPrimOp<'jit, 'vm> {
pub fn call(
self: &Rc<Self>,
self: &mut Rc<Self>,
vm: &'vm VM<'jit>,
args: Vec<Value<'jit, 'vm>>,
arg: Value<'jit, 'vm>,
) -> Result<Value<'jit, 'vm>> {
let len = args.len();
let mut self_clone = self.clone();
let self_mut = Rc::make_mut(&mut self_clone);
self_mut.args.extend(args);
self_mut.arity -= len;
let self_mut = Rc::make_mut(self);
self_mut.args.push(arg);
self_mut.arity -= 1;
if self_mut.arity > 0 {
Value::PartialPrimOp(self_clone).ok()
} else if self_mut.arity == 0 {
Value::PartialPrimOp(self.clone()).ok()
} else {
let args = std::mem::take(&mut self_mut.args);
(self.func)(vm, args)
} else {
unimplemented!()
}
}
}

View File

@@ -122,14 +122,11 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
return Ok(step);
}
}
OpCode::Call { arity } => {
let mut args = Vec::with_capacity(arity);
for _ in 0..arity {
args.insert(0, stack.pop());
}
let mut func = stack.pop();
OpCode::Call => {
let arg = stack.pop();
let func = stack.tos_mut();
func.force(self)?;
stack.push(func.call(self, args)?)?;
func.call(self, arg)?;
}
OpCode::Func { idx } => {
let func = self.get_func(idx);
@@ -139,12 +136,12 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
}
OpCode::UnOp { op } => {
use UnOp::*;
let mut value = stack.pop();
let value = stack.tos_mut();
value.force(self)?;
stack.push(match op {
match op {
Neg => value.neg(),
Not => value.not(),
})?;
}
}
OpCode::BinOp { op } => {
use BinOp::*;
@@ -154,7 +151,10 @@ impl<'vm, 'jit: 'vm> VM<'jit> {
rhs.force(self)?;
match op {
Add => lhs.add(rhs),
Sub => lhs.add(rhs.neg()),
Sub => {
rhs.neg();
lhs.add(rhs);
},
Mul => lhs.mul(rhs),
Div => lhs.div(rhs)?,
And => lhs.and(rhs),