feat: builtins.functionArgs

This commit is contained in:
2026-01-22 23:13:36 +08:00
parent 56a8ba9475
commit ef5d8c3b29
13 changed files with 181 additions and 119 deletions

View File

@@ -4,8 +4,8 @@
import { force } from "../thunk"; import { force } from "../thunk";
import { CatchableError } from "../types"; import { CatchableError } from "../types";
import type { NixBool, NixStrictValue, NixValue } from "../types"; import type { NixAttrs, NixBool, NixStrictValue, NixValue } from "../types";
import { forceList, forceStringValue, forceAttrs, forceFunction } from "../type-assert"; import { forceList, forceAttrs, forceFunction, forceStringValue } from "../type-assert";
import * as context from "./context"; import * as context from "./context";
import { compareValues, op } from "../operators"; import { compareValues, op } from "../operators";
import { isBool, isFloat, isInt, isList, isString, typeOf } from "./type-check"; import { isBool, isFloat, isInt, isList, isString, typeOf } from "./type-check";
@@ -151,8 +151,19 @@ export const flakeRefToString = (attrs: NixValue): never => {
throw new Error("Not implemented: flakeRefToString"); throw new Error("Not implemented: flakeRefToString");
}; };
export const functionArgs = (f: NixValue): never => { export const functionArgs = (f: NixValue): NixAttrs => {
throw new Error("Not implemented: functionArgs"); const func = forceFunction(f);
if (func.args) {
const ret: NixAttrs = {};
for (const key of func.args!.required) {
ret[key] = false;
}
for (const key of func.args!.optional) {
ret[key] = true;
}
return ret;
}
return {};
}; };
const checkComparable = (value: NixStrictValue): void => { const checkComparable = (value: NixStrictValue): void => {

View File

@@ -346,6 +346,7 @@ export const call = (func: NixValue, arg: NixValue, span?: string): NixValue =>
function call_impl(func: NixValue, arg: NixValue): NixValue { function call_impl(func: NixValue, arg: NixValue): NixValue {
const forcedFunc = force(func); const forcedFunc = force(func);
if (typeof forcedFunc === "function") { if (typeof forcedFunc === "function") {
forcedFunc.args?.check(arg);
return forcedFunc(arg); return forcedFunc(arg);
} }
if ( if (
@@ -355,7 +356,7 @@ function call_impl(func: NixValue, arg: NixValue): NixValue {
"__functor" in forcedFunc "__functor" in forcedFunc
) { ) {
const functor = forceFunction(forcedFunc.__functor); const functor = forceFunction(forcedFunc.__functor);
return forceFunction(functor(forcedFunc))(arg); return call(functor(forcedFunc), arg);
} }
throw new Error(`attempt to call something which is not a function but ${typeOf(forcedFunc)}`); throw new Error(`attempt to call something which is not a function but ${typeOf(forcedFunc)}`);
} }

View File

@@ -23,7 +23,7 @@ import { op } from "./operators";
import { builtins, PRIMOP_METADATA } from "./builtins"; import { builtins, PRIMOP_METADATA } from "./builtins";
import { coerceToString, StringCoercionMode } from "./builtins/conversion"; import { coerceToString, StringCoercionMode } from "./builtins/conversion";
import { HAS_CONTEXT } from "./string-context"; import { HAS_CONTEXT } from "./string-context";
import { IS_PATH } from "./types"; import { IS_PATH, mkFunction } from "./types";
import { forceBool } from "./type-assert"; import { forceBool } from "./type-assert";
export type NixRuntime = typeof Nix; export type NixRuntime = typeof Nix;
@@ -52,6 +52,7 @@ export const Nix = {
coerceToString, coerceToString,
concatStringsWithContext, concatStringsWithContext,
StringCoercionMode, StringCoercionMode,
mkFunction,
pushContext, pushContext,
popContext, popContext,

View File

@@ -85,13 +85,6 @@ export const forceString = (value: NixValue): NixString => {
throw new TypeError(`Expected string, got ${typeOf(forced)}`); throw new TypeError(`Expected string, got ${typeOf(forced)}`);
}; };
/**
* Get the plain string value from any NixString
*/
export const nixStringValue = (s: NixString): string => {
return getStringValue(s);
};
/** /**
* Force a value and assert it's a boolean * Force a value and assert it's a boolean
* @throws TypeError if value is not a boolean after forcing * @throws TypeError if value is not a boolean after forcing

View File

@@ -4,6 +4,8 @@
import { IS_THUNK } from "./thunk"; import { IS_THUNK } from "./thunk";
import { type StringWithContext, HAS_CONTEXT, isStringWithContext } from "./string-context"; import { type StringWithContext, HAS_CONTEXT, isStringWithContext } from "./string-context";
import { op } from "./operators";
import { forceAttrs } from "./type-assert";
export { HAS_CONTEXT, isStringWithContext }; export { HAS_CONTEXT, isStringWithContext };
export type { StringWithContext }; export type { StringWithContext };
@@ -30,7 +32,41 @@ export type NixNull = null;
export type NixList = NixValue[]; export type NixList = NixValue[];
// FIXME: reject contextful string // FIXME: reject contextful string
export type NixAttrs = { [key: string]: NixValue }; export type NixAttrs = { [key: string]: NixValue };
export type NixFunction = (arg: NixValue) => NixValue; export type NixFunction = ((arg: NixValue) => NixValue) & { args?: NixArgs };
export class NixArgs {
required: string[];
optional: string[];
allowed: Set<string>;
ellipsis: boolean;
constructor(required: string[], optional: string[], ellipsis: boolean) {
this.required = required;
this.optional = optional;
this.ellipsis = ellipsis;
this.allowed = new Set(required.concat(optional));
}
check(arg: NixValue) {
const attrs = forceAttrs(arg);
for (const key of this.required) {
if (!Object.hasOwn(attrs, key)) {
throw new Error(`Function called without required argument '${key}'`);
}
}
if (!this.ellipsis) {
for (const key in attrs) {
if (!this.allowed.has(key)) {
throw new Error(`Function called with unexpected argument '${key}'`);
}
}
}
}
}
export const mkFunction = (f: (arg: NixValue) => NixValue, required: string[], optional: string[], ellipsis: boolean): NixFunction => {
const func = f as NixFunction;
func.args = new NixArgs(required, optional, ellipsis);
return func
}
/** /**
* Interface for lazy thunk values * Interface for lazy thunk values

View File

@@ -23,7 +23,10 @@ pub(crate) fn compile(expr: &Ir, ctx: &impl CodegenContext) -> String {
let cur_dir = ctx.get_current_dir().display().to_string().escape_quote(); let cur_dir = ctx.get_current_dir().display().to_string().escape_quote();
format!( format!(
"(()=>{{{}Nix.builtins.storeDir={};const currentDir={};return {}}})()", "(()=>{{{}Nix.builtins.storeDir={};const currentDir={};return {}}})()",
debug_prefix, ctx.get_store_dir().escape_quote(), cur_dir, code debug_prefix,
ctx.get_store_dir().escape_quote(),
cur_dir,
code
) )
} }
@@ -196,9 +199,18 @@ impl<Ctx: CodegenContext> Compile<Ctx> for BinOp {
Leq => with_ctx("<=", format!("Nix.op.lte({},{})", lhs, rhs)), Leq => with_ctx("<=", format!("Nix.op.lte({},{})", lhs, rhs)),
Geq => with_ctx(">=", format!("Nix.op.gte({},{})", lhs, rhs)), Geq => with_ctx(">=", format!("Nix.op.gte({},{})", lhs, rhs)),
// Short-circuit operators: use JavaScript native && and || // Short-circuit operators: use JavaScript native && and ||
And => with_ctx("&&", format!("Nix.forceBool({})&&Nix.forceBool({})", lhs, rhs)), And => with_ctx(
Or => with_ctx("||", format!("Nix.forceBool({})||Nix.forceBool({})", lhs, rhs)), "&&",
Impl => with_ctx("->", format!("(!Nix.forceBool({})||Nix.forceBool({}))", lhs, rhs)), format!("Nix.forceBool({})&&Nix.forceBool({})", lhs, rhs),
),
Or => with_ctx(
"||",
format!("Nix.forceBool({})||Nix.forceBool({})", lhs, rhs),
),
Impl => with_ctx(
"->",
format!("(!Nix.forceBool({})||Nix.forceBool({}))", lhs, rhs),
),
Con => with_ctx("++", format!("Nix.op.concat({},{})", lhs, rhs)), Con => with_ctx("++", format!("Nix.op.concat({},{})", lhs, rhs)),
Upd => with_ctx("//", format!("Nix.op.update({},{})", lhs, rhs)), Upd => with_ctx("//", format!("Nix.op.update({},{})", lhs, rhs)),
PipeL => format!("Nix.call({},{})", rhs, lhs), PipeL => format!("Nix.call({},{})", rhs, lhs),
@@ -223,56 +235,23 @@ impl<Ctx: CodegenContext> Compile<Ctx> for Func {
let id = ctx.get_ir(self.arg).as_ref().unwrap_arg().inner.0; let id = ctx.get_ir(self.arg).as_ref().unwrap_arg().inner.0;
let body = ctx.get_ir(self.body).compile(ctx); let body = ctx.get_ir(self.body).compile(ctx);
// Generate parameter validation code if let Some(Param {
let param_check = self.generate_param_check(ctx); required,
optional,
if param_check.is_empty() { ellipsis,
// Simple function without parameter validation }) = &self.param
{
let mut required = required.iter().map(|&sym| ctx.get_sym(sym).escape_quote());
let required = format!("[{}]", required.join(","));
let mut optional = optional.iter().map(|&sym| ctx.get_sym(sym).escape_quote());
let optional = format!("[{}]", optional.join(","));
format!("Nix.mkFunction(arg{id}=>({body}),{required},{optional},{ellipsis})")
} else {
format!("arg{id}=>({body})") format!("arg{id}=>({body})")
} else {
// Function with parameter validation (use block statement, not object literal)
format!("arg{id}=>{{{}return {}}}", param_check, body)
} }
} }
} }
impl Func {
fn generate_param_check<Ctx: CodegenContext>(&self, ctx: &Ctx) -> String {
let has_checks = self.param.required.is_some() || self.param.allowed.is_some();
if !has_checks {
return String::new();
}
let id = ctx.get_ir(self.arg).as_ref().unwrap_arg().inner.0;
// Build required parameter array
let required = if let Some(req) = &self.param.required {
let keys: Vec<_> = req
.iter()
.map(|&sym| ctx.get_sym(sym).escape_quote())
.collect();
format!("[{}]", keys.join(","))
} else {
"null".to_string()
};
// Build allowed parameter array
let allowed = if let Some(allow) = &self.param.allowed {
let keys: Vec<_> = allow
.iter()
.map(|&sym| ctx.get_sym(sym).escape_quote())
.collect();
format!("[{}]", keys.join(","))
} else {
"null".to_string()
};
// Call Nix.validateParams and store the result
format!("Nix.validateParams(arg{},{},{});", id, required, allowed)
}
}
impl<Ctx: CodegenContext> Compile<Ctx> for Call { impl<Ctx: CodegenContext> Compile<Ctx> for Call {
fn compile(&self, ctx: &Ctx) -> String { fn compile(&self, ctx: &Ctx) -> String {
let func = ctx.get_ir(self.func).compile(ctx); let func = ctx.get_ir(self.func).compile(ctx);

View File

@@ -73,10 +73,7 @@ impl Context {
let ctx = Ctx::new()?; let ctx = Ctx::new()?;
let runtime = Runtime::new()?; let runtime = Runtime::new()?;
Ok(Self { Ok(Self { ctx, runtime })
ctx,
runtime,
})
} }
pub fn eval_code(&mut self, source: Source) -> Result<Value> { pub fn eval_code(&mut self, source: Source) -> Result<Value> {
@@ -85,7 +82,10 @@ impl Context {
tracing::debug!("Compiling code"); tracing::debug!("Compiling code");
let code = self.compile_code(source)?; let code = self.compile_code(source)?;
self.runtime.op_state().borrow_mut().put(self.ctx.store.clone()); self.runtime
.op_state()
.borrow_mut()
.put(self.ctx.store.clone());
tracing::debug!("Executing JavaScript"); tracing::debug!("Executing JavaScript");
self.runtime self.runtime

View File

@@ -68,7 +68,10 @@ impl Source {
use SourceType::*; use SourceType::*;
match &self.ty { match &self.ty {
Eval(dir) | Repl(dir) => dir.as_ref(), Eval(dir) | Repl(dir) => dir.as_ref(),
File(file) => file.as_path().parent().expect("source file must have a parent dir"), File(file) => file
.as_path()
.parent()
.expect("source file must have a parent dir"),
} }
} }
} }
@@ -233,7 +236,11 @@ pub(crate) fn parse_nix_stack(stack: &str, ctx: &impl RuntimeContext) -> Vec<Nix
} }
}; };
frames.push(NixStackFrame { span, message, source }); frames.push(NixStackFrame {
span,
message,
source,
});
} }
// Deduplicate consecutive identical frames // Deduplicate consecutive identical frames

View File

@@ -69,7 +69,7 @@ ir! {
Assert { pub assertion: ExprId, pub expr: ExprId, pub assertion_raw: String }, Assert { pub assertion: ExprId, pub expr: ExprId, pub assertion_raw: String },
ConcatStrings { pub parts: Vec<ExprId> }, ConcatStrings { pub parts: Vec<ExprId> },
Path { pub expr: ExprId }, Path { pub expr: ExprId },
Func { pub body: ExprId, pub param: Param, pub arg: ExprId }, Func { pub body: ExprId, pub param: Option<Param>, pub arg: ExprId },
Let { pub binding_sccs: SccInfo, pub body: ExprId }, Let { pub binding_sccs: SccInfo, pub body: ExprId },
Arg(ArgId), Arg(ArgId),
ExprRef(ExprId), ExprRef(ExprId),
@@ -296,9 +296,7 @@ impl From<ast::UnaryOpKind> for UnOpKind {
/// Describes the parameters of a function. /// Describes the parameters of a function.
#[derive(Debug)] #[derive(Debug)]
pub struct Param { pub struct Param {
/// The set of required parameter names for a pattern-matching function. pub required: Vec<SymId>,
pub required: Option<Vec<SymId>>, pub optional: Vec<SymId>,
/// The set of all allowed parameter names for a non-ellipsis pattern-matching function. pub ellipsis: bool,
/// If `None`, any attribute is allowed (ellipsis `...` is present).
pub allowed: Option<HashSet<SymId>>,
} }

View File

@@ -363,20 +363,18 @@ impl<Ctx: DowngradeContext> Downgrade<Ctx> for ast::With {
/// This involves desugaring pattern-matching arguments into `let` bindings. /// This involves desugaring pattern-matching arguments into `let` bindings.
impl<Ctx: DowngradeContext> Downgrade<Ctx> for ast::Lambda { impl<Ctx: DowngradeContext> Downgrade<Ctx> for ast::Lambda {
fn downgrade(self, ctx: &mut Ctx) -> Result<ExprId> { fn downgrade(self, ctx: &mut Ctx) -> Result<ExprId> {
let param = self.param().unwrap(); let raw_param = self.param().unwrap();
let arg = ctx.new_arg(param.syntax().text_range()); let arg = ctx.new_arg(raw_param.syntax().text_range());
let required; let param;
let allowed;
let body; let body;
let span = self.body().unwrap().syntax().text_range(); let span = self.body().unwrap().syntax().text_range();
match param { match raw_param {
ast::Param::IdentParam(id) => { ast::Param::IdentParam(id) => {
// Simple case: `x: body` // Simple case: `x: body`
let param_sym = ctx.new_sym(id.to_string()); let param_sym = ctx.new_sym(id.to_string());
required = None; param = None;
allowed = None;
// Downgrade body in Param scope // Downgrade body in Param scope
body = ctx body = ctx
@@ -387,25 +385,28 @@ impl<Ctx: DowngradeContext> Downgrade<Ctx> for ast::Lambda {
.pat_bind() .pat_bind()
.map(|alias| ctx.new_sym(alias.ident().unwrap().to_string())); .map(|alias| ctx.new_sym(alias.ident().unwrap().to_string()));
let has_ellipsis = pattern.ellipsis_token().is_some(); let ellipsis = pattern.ellipsis_token().is_some();
let pat_entries = pattern.pat_entries(); let pat_entries = pattern.pat_entries();
let PatternBindings { let PatternBindings {
body: inner_body, body: inner_body,
scc_info, scc_info,
required_params, required,
allowed_params, optional,
} = downgrade_pattern_bindings( } = downgrade_pattern_bindings(
pat_entries, pat_entries,
alias, alias,
arg, arg,
has_ellipsis, ellipsis,
ctx, ctx,
|ctx, _| self.body().unwrap().downgrade(ctx), |ctx, _| self.body().unwrap().downgrade(ctx),
)?; )?;
required = Some(required_params); param = Some(Param {
allowed = allowed_params; required,
optional,
ellipsis,
});
body = ctx.new_expr( body = ctx.new_expr(
Let { Let {
@@ -418,7 +419,6 @@ impl<Ctx: DowngradeContext> Downgrade<Ctx> for ast::Lambda {
} }
} }
let param = Param { required, allowed };
let span = self.syntax().text_range(); let span = self.syntax().text_range();
// The function's body and parameters are now stored directly in the `Func` node. // The function's body and parameters are now stored directly in the `Func` node.
Ok(ctx.new_expr( Ok(ctx.new_expr(

View File

@@ -3,6 +3,7 @@
use hashbrown::hash_map::Entry; use hashbrown::hash_map::Entry;
use hashbrown::{HashMap, HashSet}; use hashbrown::{HashMap, HashSet};
use itertools::Itertools as _;
use rnix::ast; use rnix::ast;
use rowan::ast::AstNode; use rowan::ast::AstNode;
@@ -257,8 +258,8 @@ pub fn downgrade_static_attrpathvalue(
pub struct PatternBindings { pub struct PatternBindings {
pub body: ExprId, pub body: ExprId,
pub scc_info: SccInfo, pub scc_info: SccInfo,
pub required_params: Vec<SymId>, pub required: Vec<SymId>,
pub allowed_params: Option<HashSet<SymId>>, pub optional: Vec<SymId>,
} }
/// Helper function for Lambda pattern parameters with SCC analysis. /// Helper function for Lambda pattern parameters with SCC analysis.
@@ -310,17 +311,18 @@ where
binding_keys.push(alias_sym); binding_keys.push(alias_sym);
} }
let required: Vec<SymId> = param_syms let (required, optional) =
param_syms
.iter() .iter()
.zip(param_defaults.iter()) .zip(param_defaults.iter())
.filter_map(|(&sym, default)| if default.is_none() { Some(sym) } else { None }) .partition_map(|(&sym, default)| {
.collect(); use itertools::Either::*;
if default.is_none() {
let allowed: Option<HashSet<SymId>> = if has_ellipsis { Left(sym)
None
} else { } else {
Some(param_syms.iter().copied().collect()) Right(sym)
}; }
});
// Get the owner from outer tracker's current_binding // Get the owner from outer tracker's current_binding
let owner = ctx.get_current_binding(); let owner = ctx.get_current_binding();
@@ -371,8 +373,8 @@ where
Ok(PatternBindings { Ok(PatternBindings {
body, body,
scc_info, scc_info,
required_params: required, required,
allowed_params: allowed, optional,
}) })
} }

View File

@@ -70,7 +70,7 @@ impl Symbol {
} }
/// Represents a Nix attribute set, which is a map from symbols to values. /// Represents a Nix attribute set, which is a map from symbols to values.
#[derive(Constructor, Clone, PartialEq)] #[derive(Constructor, Default, Clone, PartialEq)]
pub struct AttrSet { pub struct AttrSet {
data: BTreeMap<Symbol, Value>, data: BTreeMap<Symbol, Value>,
} }
@@ -119,25 +119,20 @@ impl Display for AttrSet {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
use Value::*; use Value::*;
write!(f, "{{")?; write!(f, "{{")?;
let mut first = true;
for (k, v) in self.data.iter() { for (k, v) in self.data.iter() {
if !first {
write!(f, "; ")?;
}
write!(f, " {k} = ")?; write!(f, " {k} = ")?;
match v { match v {
AttrSet(_) => write!(f, "{{ ... }}"), List(_) => write!(f, "[ ... ];")?,
List(_) => write!(f, "[ ... ]"), AttrSet(_) => write!(f, "{{ ... }};")?,
v => write!(f, "{v}"), v => write!(f, "{v};")?,
}?; }
first = false;
} }
write!(f, " }}") write!(f, " }}")
} }
} }
/// Represents a Nix list, which is a vector of values. /// Represents a Nix list, which is a vector of values.
#[derive(Constructor, Clone, Debug, PartialEq)] #[derive(Constructor, Default, Clone, Debug, PartialEq)]
pub struct List { pub struct List {
data: Vec<Value>, data: Vec<Value>,
} }

View File

@@ -1,6 +1,8 @@
mod utils; mod utils;
use nix_js::value::{List, Value}; use std::collections::BTreeMap;
use nix_js::value::{AttrSet, List, Value};
use utils::eval; use utils::eval;
#[test] #[test]
@@ -276,3 +278,40 @@ fn builtins_generic_closure() {
Value::Int(1), Value::Int(1),
); );
} }
#[test]
fn builtins_function_args() {
assert_eq!(
eval("builtins.functionArgs (x: 1)"),
Value::AttrSet(AttrSet::default())
);
assert_eq!(
eval("builtins.functionArgs ({}: 1)"),
Value::AttrSet(AttrSet::default())
);
assert_eq!(
eval("builtins.functionArgs ({...}: 1)"),
Value::AttrSet(AttrSet::default())
);
assert_eq!(
eval("builtins.functionArgs ({a}: 1)"),
Value::AttrSet(AttrSet::new(BTreeMap::from([(
"a".into(),
Value::Bool(false)
)])))
);
assert_eq!(
eval("builtins.functionArgs ({a, b ? 1}: 1)"),
Value::AttrSet(AttrSet::new(BTreeMap::from([
("a".into(), Value::Bool(false)),
("b".into(), Value::Bool(true))
])))
);
assert_eq!(
eval("builtins.functionArgs ({a, b ? 1, ...}: 1)"),
Value::AttrSet(AttrSet::new(BTreeMap::from([
("a".into(), Value::Bool(false)),
("b".into(), Value::Bool(true))
])))
);
}