From 9a17990d5efbfb4b06508e95bdc44c066d87033c Mon Sep 17 00:00:00 2001 From: imxyy_soope_ Date: Sat, 16 May 2026 23:18:50 +0800 Subject: [PATCH] deep equal --- fix-builtins/src/lib.rs | 3 + fix-primops/src/eq.rs | 238 ++++++++++++++++++++++++++ fix-primops/src/lib.rs | 5 + fix-vm/src/instructions/arithmetic.rs | 75 +------- flake.lock | 36 ++-- 5 files changed, 266 insertions(+), 91 deletions(-) create mode 100644 fix-primops/src/eq.rs diff --git a/fix-builtins/src/lib.rs b/fix-builtins/src/lib.rs index 11cee31..51bd757 100644 --- a/fix-builtins/src/lib.rs +++ b/fix-builtins/src/lib.rs @@ -246,6 +246,9 @@ pub enum PrimOpPhase { ForceResultShallowLoop, ForceResultDeepFinish, + EqStep, + EqForce, + // TODO: split into separate enums CallPattern, CallFunctor1, diff --git a/fix-primops/src/eq.rs b/fix-primops/src/eq.rs new file mode 100644 index 0000000..dc6de28 --- /dev/null +++ b/fix-primops/src/eq.rs @@ -0,0 +1,238 @@ +use fix_abstract_vm::{ + AttrSet, BytecodeReader, CallFrame, List, Machine, MachineExt, NixNum, Null, Path, Step, + StrictValue, Value, VmRuntimeCtx, VmRuntimeCtxExt, +}; +use fix_builtins::PrimOpPhase; +use gc_arena::{Gc, Mutation}; +use smallvec::SmallVec; + +pub fn start_eq<'gc, M: Machine<'gc>>( + m: &mut M, + ctx: &impl VmRuntimeCtx, + reader: &mut BytecodeReader<'_>, + mc: &Mutation<'gc>, + lhs: StrictValue<'gc>, + rhs: StrictValue<'gc>, + negate: bool, +) -> Step { + match shallow_eq(ctx, lhs, rhs) { + ShallowEq::True => { + m.push(Value::new_inline(!negate)); + Step::Continue(()) + } + ShallowEq::False => { + m.push(Value::new_inline(negate)); + Step::Continue(()) + } + ShallowEq::RecurseList(la, lb) => { + let lhs_init: SmallVec<[Value<'gc>; 4]> = la.inner.borrow().iter().copied().collect(); + let rhs_init: SmallVec<[Value<'gc>; 4]> = lb.inner.borrow().iter().copied().collect(); + enter_eq_machine(m, reader, mc, negate, lhs_init, rhs_init) + } + ShallowEq::RecurseAttrs(a, b) => { + let lhs_init: SmallVec<[Value<'gc>; 4]> = + a.entries.iter().map(|&(_, v)| v).collect(); + let rhs_init: SmallVec<[Value<'gc>; 4]> = + b.entries.iter().map(|&(_, v)| v).collect(); + enter_eq_machine(m, reader, mc, negate, lhs_init, rhs_init) + } + } +} + +pub fn eq_step<'gc, M: Machine<'gc>>( + m: &mut M, + reader: &mut BytecodeReader<'_>, + mc: &Mutation<'gc>, +) -> Step { + let rhs_q = m + .peek(0) + .as_gc::>() + .expect("eq state corrupted: rhs_queue"); + let lhs_q = m + .peek(1) + .as_gc::>() + .expect("eq state corrupted: lhs_queue"); + let result = m + .peek(2) + .as_inline::() + .expect("eq state corrupted: result"); + + if !result || lhs_q.inner.borrow().is_empty() { + return finalize(m, reader); + } + + let lhs = lhs_q + .unlock(mc) + .borrow_mut() + .pop() + .expect("non-empty lhs queue"); + let rhs = rhs_q + .unlock(mc) + .borrow_mut() + .pop() + .expect("non-empty rhs queue"); + m.push(lhs); + m.push(rhs); + reader.set_pc(PrimOpPhase::EqForce.ip() as usize); + Step::Continue(()) +} + +pub fn eq_force<'gc, M: Machine<'gc>>( + m: &mut M, + ctx: &mut impl VmRuntimeCtx, + reader: &mut BytecodeReader<'_>, + mc: &Mutation<'gc>, +) -> Step { + let (lhs, rhs) = m.force_and_retry::<(StrictValue, StrictValue)>(reader, mc)?; + apply_pair(m, ctx, mc, lhs, rhs); + reader.set_pc(PrimOpPhase::EqStep.ip() as usize); + Step::Continue(()) +} + +fn finalize<'gc, M: Machine<'gc>>(m: &mut M, reader: &mut BytecodeReader<'_>) -> Step { + let _ = m.pop(); + let _ = m.pop(); + let result = m + .pop() + .as_inline::() + .expect("eq state corrupted: result"); + let negate = m + .pop() + .as_inline::() + .expect("eq state corrupted: negate"); + m.return_from_primop(Value::new_inline(result ^ negate), reader) +} + +fn apply_pair<'gc, M: Machine<'gc>>( + m: &mut M, + ctx: &impl VmRuntimeCtx, + mc: &Mutation<'gc>, + lhs: StrictValue<'gc>, + rhs: StrictValue<'gc>, +) { + match shallow_eq(ctx, lhs, rhs) { + ShallowEq::True => {} + ShallowEq::False => { + m.replace(2, Value::new_inline(false)); + } + ShallowEq::RecurseList(la, lb) => { + extend_queues(m, mc, la.inner.borrow().iter().copied(), lb.inner.borrow().iter().copied()); + } + ShallowEq::RecurseAttrs(a, b) => { + extend_queues( + m, + mc, + a.entries.iter().map(|&(_, v)| v), + b.entries.iter().map(|&(_, v)| v), + ); + } + } +} + +fn extend_queues<'gc, M, L, R>(m: &mut M, mc: &Mutation<'gc>, lhs_iter: L, rhs_iter: R) +where + M: Machine<'gc>, + L: IntoIterator>, + R: IntoIterator>, +{ + let rhs_q = m + .peek(0) + .as_gc::>() + .expect("eq state corrupted: rhs_queue"); + let lhs_q = m + .peek(1) + .as_gc::>() + .expect("eq state corrupted: lhs_queue"); + let mut lq = lhs_q.unlock(mc).borrow_mut(); + let mut rq = rhs_q.unlock(mc).borrow_mut(); + for (x, y) in lhs_iter.into_iter().zip(rhs_iter) { + lq.push(x); + rq.push(y); + } +} + +fn enter_eq_machine<'gc, M: Machine<'gc>>( + m: &mut M, + reader: &mut BytecodeReader<'_>, + mc: &Mutation<'gc>, + negate: bool, + lhs_init: SmallVec<[Value<'gc>; 4]>, + rhs_init: SmallVec<[Value<'gc>; 4]>, +) -> Step { + let resume_pc = reader.pc(); + m.push_call_frame(CallFrame { + pc: resume_pc, + thunk: None, + env: m.env(), + }); + m.inc_call_depth(); + m.push(Value::new_inline(negate)); + m.push(Value::new_inline(true)); + m.push(Value::new_gc(List::new(mc, lhs_init))); + m.push(Value::new_gc(List::new(mc, rhs_init))); + reader.set_pc(PrimOpPhase::EqStep.ip() as usize); + Step::Continue(()) +} + +enum ShallowEq<'gc> { + True, + False, + RecurseList(Gc<'gc, List<'gc>>, Gc<'gc, List<'gc>>), + RecurseAttrs(Gc<'gc, AttrSet<'gc>>, Gc<'gc, AttrSet<'gc>>), +} + +fn shallow_eq<'gc>( + ctx: &impl VmRuntimeCtx, + lhs: StrictValue<'gc>, + rhs: StrictValue<'gc>, +) -> ShallowEq<'gc> { + if let (Some(a), Some(b)) = (lhs.as_num(), rhs.as_num()) { + let eq = match (a, b) { + (NixNum::Int(a), NixNum::Int(b)) => a == b, + (NixNum::Float(a), NixNum::Float(b)) => a == b, + (NixNum::Int(a), NixNum::Float(b)) => a as f64 == b, + (NixNum::Float(a), NixNum::Int(b)) => a == b as f64, + }; + return bool_outcome(eq); + } + if let (Some(a), Some(b)) = (lhs.as_inline::(), rhs.as_inline::()) { + return bool_outcome(a == b); + } + if lhs.is::() && rhs.is::() { + return ShallowEq::True; + } + if let (Some(a), Some(b)) = (lhs.as_inline::(), rhs.as_inline::()) { + return bool_outcome(a.0 == b.0); + } + if let (Some(a), Some(b)) = (ctx.get_string(lhs), ctx.get_string(rhs)) { + return bool_outcome(a == b); + } + if let (Some(a), Some(b)) = (lhs.as_gc::>(), rhs.as_gc::>()) { + if a.inner.borrow().len() != b.inner.borrow().len() { + return ShallowEq::False; + } + return ShallowEq::RecurseList(a, b); + } + if let (Some(a), Some(b)) = (lhs.as_gc::>(), rhs.as_gc::>()) { + let ae = &a.entries; + let be = &b.entries; + if ae.len() != be.len() { + return ShallowEq::False; + } + for (l, r) in ae.iter().zip(be.iter()) { + if l.0 != r.0 { + return ShallowEq::False; + } + } + return ShallowEq::RecurseAttrs(a, b); + } + ShallowEq::False +} + +fn bool_outcome<'gc>(b: bool) -> ShallowEq<'gc> { + if b { + ShallowEq::True + } else { + ShallowEq::False + } +} diff --git a/fix-primops/src/lib.rs b/fix-primops/src/lib.rs index 46395a3..408b677 100644 --- a/fix-primops/src/lib.rs +++ b/fix-primops/src/lib.rs @@ -1,11 +1,13 @@ mod control; mod conv; +mod eq; mod io; mod list; mod path; pub use control::*; pub use conv::*; +pub use eq::*; use fix_abstract_vm::{BytecodeReader, Machine, Step, VmRuntimeCtx}; use fix_builtins::PrimOpPhase; use fix_error::Error; @@ -49,6 +51,9 @@ pub fn dispatch_primop<'gc, M: Machine<'gc>>( ForceResultShallowLoop => force_result_shallow_loop(m, reader, mc), ForceResultDeepFinish => force_result_deep_finish(m, ctx, reader, mc), + EqStep => eq_step(m, reader, mc), + EqForce => eq_force(m, ctx, reader, mc), + CallPattern => call_pattern(m, ctx, reader, mc), CallFunctor1 => call_functor_1(m, reader, mc), CallFunctor2 => call_functor_2(m, reader, mc), diff --git a/fix-vm/src/instructions/arithmetic.rs b/fix-vm/src/instructions/arithmetic.rs index 8e619eb..8db1466 100644 --- a/fix-vm/src/instructions/arithmetic.rs +++ b/fix-vm/src/instructions/arithmetic.rs @@ -105,12 +105,7 @@ impl<'gc> crate::Vm<'gc> { mc: &Mutation<'gc>, ) -> Step { let (lhs, rhs) = self.force_and_retry::<(StrictValue, StrictValue)>(reader, mc)?; - let eq = match self.values_equal(ctx, lhs, rhs) { - Ok(eq) => eq, - Err(e) => return self.finish_vm_err(e), - }; - self.push(Value::new_inline(eq)); - Step::Continue(()) + fix_primops::start_eq(self, ctx, reader, mc, lhs, rhs, false) } #[inline(always)] @@ -121,12 +116,7 @@ impl<'gc> crate::Vm<'gc> { mc: &Mutation<'gc>, ) -> Step { let (lhs, rhs) = self.force_and_retry::<(StrictValue, StrictValue)>(reader, mc)?; - let eq = match self.values_equal(ctx, lhs, rhs) { - Ok(eq) => eq, - Err(e) => return self.finish_vm_err(e), - }; - self.push(Value::new_inline(!eq)); - Step::Continue(()) + fix_primops::start_eq(self, ctx, reader, mc, lhs, rhs, true) } #[inline(always)] @@ -230,67 +220,6 @@ impl<'gc> crate::Vm<'gc> { Step::Continue(()) } - pub(crate) fn values_equal( - &mut self, - ctx: &impl VmRuntimeCtx, - lhs: StrictValue<'gc>, - rhs: StrictValue<'gc>, - ) -> crate::VmResult { - if let (Some(a), Some(b)) = (get_num(lhs), get_num(rhs)) { - return Ok(match (a, b) { - (NixNum::Int(a), NixNum::Int(b)) => a == b, - (NixNum::Float(a), NixNum::Float(b)) => a == b, - (NixNum::Int(a), NixNum::Float(b)) => a as f64 == b, - (NixNum::Float(a), NixNum::Int(b)) => a == b as f64, - }); - } - if let (Some(a), Some(b)) = (lhs.as_inline::(), rhs.as_inline::()) { - return Ok(a == b); - } - if lhs.is::() && rhs.is::() { - return Ok(true); - } - // Paths only equal paths (not strings, even if their text matches). - if let (Some(a), Some(b)) = (lhs.as_inline::(), rhs.as_inline::()) { - return Ok(a.0 == b.0); - } - if let (Some(a), Some(b)) = (ctx.get_string(lhs), ctx.get_string(rhs)) { - return Ok(a == b); - } - if let (Some(a), Some(b)) = (lhs.as_gc::(), rhs.as_gc::()) { - if a.inner.borrow().len() != b.inner.borrow().len() { - return Ok(false); - } - for (x, y) in a.inner.borrow().iter().zip(b.inner.borrow().iter()) { - let lx = x.restrict().expect("forced"); - let ly = y.restrict().expect("forced"); - if !self.values_equal(ctx, lx, ly)? { - return Ok(false); - } - } - return Ok(true); - } - if let (Some(a), Some(b)) = (lhs.as_gc::(), rhs.as_gc::()) { - let a = &a.entries; - let b = &b.entries; - if a.len() != b.len() { - return Ok(false); - } - for ((k1, v1), (k2, v2)) in a.iter().zip(b.iter()) { - if k1 != k2 { - return Ok(false); - } - let lv1 = v1.restrict().expect("forced"); - let lv2 = v2.restrict().expect("forced"); - if !self.values_equal(ctx, lv1, lv2)? { - return Ok(false); - } - } - return Ok(true); - } - Ok(false) - } - fn compare_values_inner( &mut self, ctx: &impl VmRuntimeCtx, diff --git a/flake.lock b/flake.lock index cebd8f7..ac9bc57 100644 --- a/flake.lock +++ b/flake.lock @@ -45,11 +45,11 @@ ] }, "locked": { - "lastModified": 1777369708, - "narHash": "sha256-1xW7cRZNsFNPQD+cE0fwnLVStnDth0HSoASEIFeT7uI=", + "lastModified": 1778445566, + "narHash": "sha256-oQvcadh2BCkrog+SGrG6YffKJrveYpjj3TdQJWaKhaM=", "owner": "nix-community", "repo": "bun2nix", - "rev": "e659e1cc4b8e1b21d0aa85f1c481f9db61ecfa98", + "rev": "2499dedd70744dba1815875b854818a3019e9e4c", "type": "github" }, "original": { @@ -67,11 +67,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1777796307, - "narHash": "sha256-L7xLjorTwVf2aLu5b0ZZY2D0RFXwD/a/a/fFFDikB2w=", + "lastModified": 1778919578, + "narHash": "sha256-+z+jgTly48gsAiX8rOe/vs8C/2G4vdCpcEtqMJUpFqw=", "owner": "nix-community", "repo": "fenix", - "rev": "0f9881f2344c0b1c100bd9e774555759b7da6fd5", + "rev": "ecd6d4ff22cfdb1339b2915455a2ff4dc85bf52e", "type": "github" }, "original": { @@ -102,11 +102,11 @@ ] }, "locked": { - "lastModified": 1777678872, - "narHash": "sha256-EPIFsulyon7Z1vLQq5Fk64GR8L7cQsT+IPhcsukVbgk=", + "lastModified": 1778716662, + "narHash": "sha256-m1Yf0wZ8j1OHjTc2UwHwyQRSnNeSgLJOd7q5Y45hzi4=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "5250617bffd85403b14dbf43c3870e7f255d2c16", + "rev": "f7c1a2d347e4c52d5fb8d10cb4d94b5884e546fb", "type": "github" }, "original": { @@ -127,11 +127,11 @@ "treefmt-nix": "treefmt-nix" }, "locked": { - "lastModified": 1777786380, - "narHash": "sha256-GGKC1WrEoTafJwIXn+fim6cZ/w1ZWVc+DUYdk2lvPvA=", + "lastModified": 1778929997, + "narHash": "sha256-iAfbBUHBbR0N4DFqFWr4Jtmpc1YcOK7kpVM4f0MK1V8=", "owner": "numtide", "repo": "llm-agents.nix", - "rev": "961b1096bc0b2ecc7096e360646cd2f29671c55e", + "rev": "0da8f0313c9f68c155e0932f880fc1913e7be846", "type": "github" }, "original": { @@ -142,11 +142,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1777578337, - "narHash": "sha256-Ad49moKWeXtKBJNy2ebiTQUEgdLyvGmTeykAQ9xM+Z4=", + "lastModified": 1778443072, + "narHash": "sha256-zi7/fsqM/kFdNuED//4WOCUtezGtKKqRNORjMvfwjnA=", "owner": "nixos", "repo": "nixpkgs", - "rev": "15f4ee454b1dce334612fa6843b3e05cf546efab", + "rev": "da5ad661ba4e5ef59ba743f0d112cbc30e474f32", "type": "github" }, "original": { @@ -167,11 +167,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1777768857, - "narHash": "sha256-zfekJcaVctfAps1KDHwZpwkvAQn7GObRHh3Gl3xocGI=", + "lastModified": 1778854817, + "narHash": "sha256-iG+VuMy8W585geVVCUd7pR025WsY3ZkgSv5Yt5bxDmQ=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "1102c0b633599564919e36076d4362d7e68dbddc", + "rev": "1a68212c5683555ad80f0eab71db9715c6d52145", "type": "github" }, "original": {