From ef5d8c3b294d1cb9c129cb4e0ad0b68b8161e4a6 Mon Sep 17 00:00:00 2001 From: imxyy_soope_ Date: Thu, 22 Jan 2026 23:13:36 +0800 Subject: [PATCH] feat: builtins.functionArgs --- nix-js/runtime-ts/src/builtins/misc.ts | 19 +++++-- nix-js/runtime-ts/src/helpers.ts | 3 +- nix-js/runtime-ts/src/index.ts | 3 +- nix-js/runtime-ts/src/type-assert.ts | 7 --- nix-js/runtime-ts/src/types.ts | 38 ++++++++++++- nix-js/src/codegen.rs | 77 ++++++++++---------------- nix-js/src/context.rs | 10 ++-- nix-js/src/error.rs | 11 +++- nix-js/src/ir.rs | 10 ++-- nix-js/src/ir/downgrade.rs | 28 +++++----- nix-js/src/ir/utils.rs | 32 ++++++----- nix-js/src/value.rs | 21 +++---- nix-js/tests/builtins.rs | 41 +++++++++++++- 13 files changed, 181 insertions(+), 119 deletions(-) diff --git a/nix-js/runtime-ts/src/builtins/misc.ts b/nix-js/runtime-ts/src/builtins/misc.ts index 808913e..4ba6fac 100644 --- a/nix-js/runtime-ts/src/builtins/misc.ts +++ b/nix-js/runtime-ts/src/builtins/misc.ts @@ -4,8 +4,8 @@ import { force } from "../thunk"; import { CatchableError } from "../types"; -import type { NixBool, NixStrictValue, NixValue } from "../types"; -import { forceList, forceStringValue, forceAttrs, forceFunction } from "../type-assert"; +import type { NixAttrs, NixBool, NixStrictValue, NixValue } from "../types"; +import { forceList, forceAttrs, forceFunction, forceStringValue } from "../type-assert"; import * as context from "./context"; import { compareValues, op } from "../operators"; import { isBool, isFloat, isInt, isList, isString, typeOf } from "./type-check"; @@ -151,8 +151,19 @@ export const flakeRefToString = (attrs: NixValue): never => { throw new Error("Not implemented: flakeRefToString"); }; -export const functionArgs = (f: NixValue): never => { - throw new Error("Not implemented: functionArgs"); +export const functionArgs = (f: NixValue): NixAttrs => { + const func = forceFunction(f); + if (func.args) { + const ret: NixAttrs = {}; + for (const key of func.args!.required) { + ret[key] = false; + } + for (const key of func.args!.optional) { + ret[key] = true; + } + return ret; + } + return {}; }; const checkComparable = (value: NixStrictValue): void => { diff --git a/nix-js/runtime-ts/src/helpers.ts b/nix-js/runtime-ts/src/helpers.ts index b1768b3..b6e8807 100644 --- a/nix-js/runtime-ts/src/helpers.ts +++ b/nix-js/runtime-ts/src/helpers.ts @@ -346,6 +346,7 @@ export const call = (func: NixValue, arg: NixValue, span?: string): NixValue => function call_impl(func: NixValue, arg: NixValue): NixValue { const forcedFunc = force(func); if (typeof forcedFunc === "function") { + forcedFunc.args?.check(arg); return forcedFunc(arg); } if ( @@ -355,7 +356,7 @@ function call_impl(func: NixValue, arg: NixValue): NixValue { "__functor" in forcedFunc ) { const functor = forceFunction(forcedFunc.__functor); - return forceFunction(functor(forcedFunc))(arg); + return call(functor(forcedFunc), arg); } throw new Error(`attempt to call something which is not a function but ${typeOf(forcedFunc)}`); } diff --git a/nix-js/runtime-ts/src/index.ts b/nix-js/runtime-ts/src/index.ts index a574937..867a1c7 100644 --- a/nix-js/runtime-ts/src/index.ts +++ b/nix-js/runtime-ts/src/index.ts @@ -23,7 +23,7 @@ import { op } from "./operators"; import { builtins, PRIMOP_METADATA } from "./builtins"; import { coerceToString, StringCoercionMode } from "./builtins/conversion"; import { HAS_CONTEXT } from "./string-context"; -import { IS_PATH } from "./types"; +import { IS_PATH, mkFunction } from "./types"; import { forceBool } from "./type-assert"; export type NixRuntime = typeof Nix; @@ -52,6 +52,7 @@ export const Nix = { coerceToString, concatStringsWithContext, StringCoercionMode, + mkFunction, pushContext, popContext, diff --git a/nix-js/runtime-ts/src/type-assert.ts b/nix-js/runtime-ts/src/type-assert.ts index ba60c70..a02f9bb 100644 --- a/nix-js/runtime-ts/src/type-assert.ts +++ b/nix-js/runtime-ts/src/type-assert.ts @@ -85,13 +85,6 @@ export const forceString = (value: NixValue): NixString => { throw new TypeError(`Expected string, got ${typeOf(forced)}`); }; -/** - * Get the plain string value from any NixString - */ -export const nixStringValue = (s: NixString): string => { - return getStringValue(s); -}; - /** * Force a value and assert it's a boolean * @throws TypeError if value is not a boolean after forcing diff --git a/nix-js/runtime-ts/src/types.ts b/nix-js/runtime-ts/src/types.ts index 8e2abda..b6bf0ec 100644 --- a/nix-js/runtime-ts/src/types.ts +++ b/nix-js/runtime-ts/src/types.ts @@ -4,6 +4,8 @@ import { IS_THUNK } from "./thunk"; import { type StringWithContext, HAS_CONTEXT, isStringWithContext } from "./string-context"; +import { op } from "./operators"; +import { forceAttrs } from "./type-assert"; export { HAS_CONTEXT, isStringWithContext }; export type { StringWithContext }; @@ -30,7 +32,41 @@ export type NixNull = null; export type NixList = NixValue[]; // FIXME: reject contextful string export type NixAttrs = { [key: string]: NixValue }; -export type NixFunction = (arg: NixValue) => NixValue; +export type NixFunction = ((arg: NixValue) => NixValue) & { args?: NixArgs }; +export class NixArgs { + required: string[]; + optional: string[]; + allowed: Set; + ellipsis: boolean; + constructor(required: string[], optional: string[], ellipsis: boolean) { + this.required = required; + this.optional = optional; + this.ellipsis = ellipsis; + this.allowed = new Set(required.concat(optional)); + } + check(arg: NixValue) { + const attrs = forceAttrs(arg); + + for (const key of this.required) { + if (!Object.hasOwn(attrs, key)) { + throw new Error(`Function called without required argument '${key}'`); + } + } + + if (!this.ellipsis) { + for (const key in attrs) { + if (!this.allowed.has(key)) { + throw new Error(`Function called with unexpected argument '${key}'`); + } + } + } + } +} +export const mkFunction = (f: (arg: NixValue) => NixValue, required: string[], optional: string[], ellipsis: boolean): NixFunction => { + const func = f as NixFunction; + func.args = new NixArgs(required, optional, ellipsis); + return func +} /** * Interface for lazy thunk values diff --git a/nix-js/src/codegen.rs b/nix-js/src/codegen.rs index 62b5bd0..2ba1330 100644 --- a/nix-js/src/codegen.rs +++ b/nix-js/src/codegen.rs @@ -23,7 +23,10 @@ pub(crate) fn compile(expr: &Ir, ctx: &impl CodegenContext) -> String { let cur_dir = ctx.get_current_dir().display().to_string().escape_quote(); format!( "(()=>{{{}Nix.builtins.storeDir={};const currentDir={};return {}}})()", - debug_prefix, ctx.get_store_dir().escape_quote(), cur_dir, code + debug_prefix, + ctx.get_store_dir().escape_quote(), + cur_dir, + code ) } @@ -196,9 +199,18 @@ impl Compile for BinOp { Leq => with_ctx("<=", format!("Nix.op.lte({},{})", lhs, rhs)), Geq => with_ctx(">=", format!("Nix.op.gte({},{})", lhs, rhs)), // Short-circuit operators: use JavaScript native && and || - And => with_ctx("&&", format!("Nix.forceBool({})&&Nix.forceBool({})", lhs, rhs)), - Or => with_ctx("||", format!("Nix.forceBool({})||Nix.forceBool({})", lhs, rhs)), - Impl => with_ctx("->", format!("(!Nix.forceBool({})||Nix.forceBool({}))", lhs, rhs)), + And => with_ctx( + "&&", + format!("Nix.forceBool({})&&Nix.forceBool({})", lhs, rhs), + ), + Or => with_ctx( + "||", + format!("Nix.forceBool({})||Nix.forceBool({})", lhs, rhs), + ), + Impl => with_ctx( + "->", + format!("(!Nix.forceBool({})||Nix.forceBool({}))", lhs, rhs), + ), Con => with_ctx("++", format!("Nix.op.concat({},{})", lhs, rhs)), Upd => with_ctx("//", format!("Nix.op.update({},{})", lhs, rhs)), PipeL => format!("Nix.call({},{})", rhs, lhs), @@ -223,56 +235,23 @@ impl Compile for Func { let id = ctx.get_ir(self.arg).as_ref().unwrap_arg().inner.0; let body = ctx.get_ir(self.body).compile(ctx); - // Generate parameter validation code - let param_check = self.generate_param_check(ctx); - - if param_check.is_empty() { - // Simple function without parameter validation + if let Some(Param { + required, + optional, + ellipsis, + }) = &self.param + { + let mut required = required.iter().map(|&sym| ctx.get_sym(sym).escape_quote()); + let required = format!("[{}]", required.join(",")); + let mut optional = optional.iter().map(|&sym| ctx.get_sym(sym).escape_quote()); + let optional = format!("[{}]", optional.join(",")); + format!("Nix.mkFunction(arg{id}=>({body}),{required},{optional},{ellipsis})") + } else { format!("arg{id}=>({body})") - } else { - // Function with parameter validation (use block statement, not object literal) - format!("arg{id}=>{{{}return {}}}", param_check, body) } } } -impl Func { - fn generate_param_check(&self, ctx: &Ctx) -> String { - let has_checks = self.param.required.is_some() || self.param.allowed.is_some(); - - if !has_checks { - return String::new(); - } - - let id = ctx.get_ir(self.arg).as_ref().unwrap_arg().inner.0; - - // Build required parameter array - let required = if let Some(req) = &self.param.required { - let keys: Vec<_> = req - .iter() - .map(|&sym| ctx.get_sym(sym).escape_quote()) - .collect(); - format!("[{}]", keys.join(",")) - } else { - "null".to_string() - }; - - // Build allowed parameter array - let allowed = if let Some(allow) = &self.param.allowed { - let keys: Vec<_> = allow - .iter() - .map(|&sym| ctx.get_sym(sym).escape_quote()) - .collect(); - format!("[{}]", keys.join(",")) - } else { - "null".to_string() - }; - - // Call Nix.validateParams and store the result - format!("Nix.validateParams(arg{},{},{});", id, required, allowed) - } -} - impl Compile for Call { fn compile(&self, ctx: &Ctx) -> String { let func = ctx.get_ir(self.func).compile(ctx); diff --git a/nix-js/src/context.rs b/nix-js/src/context.rs index 7a4a6cc..c687d96 100644 --- a/nix-js/src/context.rs +++ b/nix-js/src/context.rs @@ -73,10 +73,7 @@ impl Context { let ctx = Ctx::new()?; let runtime = Runtime::new()?; - Ok(Self { - ctx, - runtime, - }) + Ok(Self { ctx, runtime }) } pub fn eval_code(&mut self, source: Source) -> Result { @@ -85,7 +82,10 @@ impl Context { tracing::debug!("Compiling code"); let code = self.compile_code(source)?; - self.runtime.op_state().borrow_mut().put(self.ctx.store.clone()); + self.runtime + .op_state() + .borrow_mut() + .put(self.ctx.store.clone()); tracing::debug!("Executing JavaScript"); self.runtime diff --git a/nix-js/src/error.rs b/nix-js/src/error.rs index 573a7fc..0060705 100644 --- a/nix-js/src/error.rs +++ b/nix-js/src/error.rs @@ -68,7 +68,10 @@ impl Source { use SourceType::*; match &self.ty { Eval(dir) | Repl(dir) => dir.as_ref(), - File(file) => file.as_path().parent().expect("source file must have a parent dir"), + File(file) => file + .as_path() + .parent() + .expect("source file must have a parent dir"), } } } @@ -233,7 +236,11 @@ pub(crate) fn parse_nix_stack(stack: &str, ctx: &impl RuntimeContext) -> Vec }, Path { pub expr: ExprId }, - Func { pub body: ExprId, pub param: Param, pub arg: ExprId }, + Func { pub body: ExprId, pub param: Option, pub arg: ExprId }, Let { pub binding_sccs: SccInfo, pub body: ExprId }, Arg(ArgId), ExprRef(ExprId), @@ -296,9 +296,7 @@ impl From for UnOpKind { /// Describes the parameters of a function. #[derive(Debug)] pub struct Param { - /// The set of required parameter names for a pattern-matching function. - pub required: Option>, - /// The set of all allowed parameter names for a non-ellipsis pattern-matching function. - /// If `None`, any attribute is allowed (ellipsis `...` is present). - pub allowed: Option>, + pub required: Vec, + pub optional: Vec, + pub ellipsis: bool, } diff --git a/nix-js/src/ir/downgrade.rs b/nix-js/src/ir/downgrade.rs index 6027e1a..10c041a 100644 --- a/nix-js/src/ir/downgrade.rs +++ b/nix-js/src/ir/downgrade.rs @@ -363,20 +363,18 @@ impl Downgrade for ast::With { /// This involves desugaring pattern-matching arguments into `let` bindings. impl Downgrade for ast::Lambda { fn downgrade(self, ctx: &mut Ctx) -> Result { - let param = self.param().unwrap(); - let arg = ctx.new_arg(param.syntax().text_range()); + let raw_param = self.param().unwrap(); + let arg = ctx.new_arg(raw_param.syntax().text_range()); - let required; - let allowed; + let param; let body; let span = self.body().unwrap().syntax().text_range(); - match param { + match raw_param { ast::Param::IdentParam(id) => { // Simple case: `x: body` let param_sym = ctx.new_sym(id.to_string()); - required = None; - allowed = None; + param = None; // Downgrade body in Param scope body = ctx @@ -387,25 +385,28 @@ impl Downgrade for ast::Lambda { .pat_bind() .map(|alias| ctx.new_sym(alias.ident().unwrap().to_string())); - let has_ellipsis = pattern.ellipsis_token().is_some(); + let ellipsis = pattern.ellipsis_token().is_some(); let pat_entries = pattern.pat_entries(); let PatternBindings { body: inner_body, scc_info, - required_params, - allowed_params, + required, + optional, } = downgrade_pattern_bindings( pat_entries, alias, arg, - has_ellipsis, + ellipsis, ctx, |ctx, _| self.body().unwrap().downgrade(ctx), )?; - required = Some(required_params); - allowed = allowed_params; + param = Some(Param { + required, + optional, + ellipsis, + }); body = ctx.new_expr( Let { @@ -418,7 +419,6 @@ impl Downgrade for ast::Lambda { } } - let param = Param { required, allowed }; let span = self.syntax().text_range(); // The function's body and parameters are now stored directly in the `Func` node. Ok(ctx.new_expr( diff --git a/nix-js/src/ir/utils.rs b/nix-js/src/ir/utils.rs index 314aeec..3ca0b68 100644 --- a/nix-js/src/ir/utils.rs +++ b/nix-js/src/ir/utils.rs @@ -3,6 +3,7 @@ use hashbrown::hash_map::Entry; use hashbrown::{HashMap, HashSet}; +use itertools::Itertools as _; use rnix::ast; use rowan::ast::AstNode; @@ -257,8 +258,8 @@ pub fn downgrade_static_attrpathvalue( pub struct PatternBindings { pub body: ExprId, pub scc_info: SccInfo, - pub required_params: Vec, - pub allowed_params: Option>, + pub required: Vec, + pub optional: Vec, } /// Helper function for Lambda pattern parameters with SCC analysis. @@ -310,17 +311,18 @@ where binding_keys.push(alias_sym); } - let required: Vec = param_syms - .iter() - .zip(param_defaults.iter()) - .filter_map(|(&sym, default)| if default.is_none() { Some(sym) } else { None }) - .collect(); - - let allowed: Option> = if has_ellipsis { - None - } else { - Some(param_syms.iter().copied().collect()) - }; + let (required, optional) = + param_syms + .iter() + .zip(param_defaults.iter()) + .partition_map(|(&sym, default)| { + use itertools::Either::*; + if default.is_none() { + Left(sym) + } else { + Right(sym) + } + }); // Get the owner from outer tracker's current_binding let owner = ctx.get_current_binding(); @@ -371,8 +373,8 @@ where Ok(PatternBindings { body, scc_info, - required_params: required, - allowed_params: allowed, + required, + optional, }) } diff --git a/nix-js/src/value.rs b/nix-js/src/value.rs index 6ee08fe..8fc7e1d 100644 --- a/nix-js/src/value.rs +++ b/nix-js/src/value.rs @@ -70,7 +70,7 @@ impl Symbol { } /// Represents a Nix attribute set, which is a map from symbols to values. -#[derive(Constructor, Clone, PartialEq)] +#[derive(Constructor, Default, Clone, PartialEq)] pub struct AttrSet { data: BTreeMap, } @@ -118,26 +118,21 @@ impl Debug for AttrSet { impl Display for AttrSet { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { use Value::*; - write!(f, "{{ ")?; - let mut first = true; + write!(f, "{{")?; for (k, v) in self.data.iter() { - if !first { - write!(f, "; ")?; - } - write!(f, "{k} = ")?; + write!(f, " {k} = ")?; match v { - AttrSet(_) => write!(f, "{{ ... }}"), - List(_) => write!(f, "[ ... ]"), - v => write!(f, "{v}"), - }?; - first = false; + List(_) => write!(f, "[ ... ];")?, + AttrSet(_) => write!(f, "{{ ... }};")?, + v => write!(f, "{v};")?, + } } write!(f, " }}") } } /// Represents a Nix list, which is a vector of values. -#[derive(Constructor, Clone, Debug, PartialEq)] +#[derive(Constructor, Default, Clone, Debug, PartialEq)] pub struct List { data: Vec, } diff --git a/nix-js/tests/builtins.rs b/nix-js/tests/builtins.rs index 49723a2..6016762 100644 --- a/nix-js/tests/builtins.rs +++ b/nix-js/tests/builtins.rs @@ -1,6 +1,8 @@ mod utils; -use nix_js::value::{List, Value}; +use std::collections::BTreeMap; + +use nix_js::value::{AttrSet, List, Value}; use utils::eval; #[test] @@ -276,3 +278,40 @@ fn builtins_generic_closure() { Value::Int(1), ); } + +#[test] +fn builtins_function_args() { + assert_eq!( + eval("builtins.functionArgs (x: 1)"), + Value::AttrSet(AttrSet::default()) + ); + assert_eq!( + eval("builtins.functionArgs ({}: 1)"), + Value::AttrSet(AttrSet::default()) + ); + assert_eq!( + eval("builtins.functionArgs ({...}: 1)"), + Value::AttrSet(AttrSet::default()) + ); + assert_eq!( + eval("builtins.functionArgs ({a}: 1)"), + Value::AttrSet(AttrSet::new(BTreeMap::from([( + "a".into(), + Value::Bool(false) + )]))) + ); + assert_eq!( + eval("builtins.functionArgs ({a, b ? 1}: 1)"), + Value::AttrSet(AttrSet::new(BTreeMap::from([ + ("a".into(), Value::Bool(false)), + ("b".into(), Value::Bool(true)) + ]))) + ); + assert_eq!( + eval("builtins.functionArgs ({a, b ? 1, ...}: 1)"), + Value::AttrSet(AttrSet::new(BTreeMap::from([ + ("a".into(), Value::Bool(false)), + ("b".into(), Value::Bool(true)) + ]))) + ); +}