Skip to content

Commit f18cdcf

Browse files
authored
Merge pull request #21848 from asgerf/asgerf/swift-yeast
Unified: Add schema checking and corpus-style tests
2 parents 491c373 + 554bdf1 commit f18cdcf

36 files changed

Lines changed: 5445 additions & 4708 deletions

shared/tree-sitter-extractor/src/extractor/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ pub fn extract(
330330

331331
if let Some(yeast_runner) = yeast_runner {
332332
let ast = yeast_runner
333-
.run_from_tree(&tree)
333+
.run_from_tree(&tree, source)
334334
.unwrap_or_else(|e| panic!("Desugaring failed for {path_str}: {e}"));
335335
traverse_yeast(&ast, &mut visitor);
336336
} else {

shared/tree-sitter-extractor/src/generator/mod.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,19 @@ pub fn generate(
115115
&node_parent_table_name,
116116
)),
117117
ql::TopLevel::Class(ql_gen::create_token_class(&token_name, &tokeninfo_name)),
118-
ql::TopLevel::Class(ql_gen::create_reserved_word_class(&reserved_word_name)),
119118
];
119+
// Only emit the ReservedWord class when there are actually unnamed token
120+
// types in the schema (i.e., @{prefix}_reserved_word exists in the dbscheme).
121+
// When converting from a YEAST YAML schema that has no unnamed tokens, this
122+
// type is absent and referencing it would cause a QL compilation error.
123+
let has_reserved_words = nodes
124+
.values()
125+
.any(|n| n.dbscheme_name == reserved_word_name);
126+
if has_reserved_words {
127+
body.push(ql::TopLevel::Class(ql_gen::create_reserved_word_class(
128+
&reserved_word_name,
129+
)));
130+
}
120131

121132
// Overlay discard predicates
122133
body.push(ql::TopLevel::Predicate(

shared/yeast-macros/src/parse.rs

Lines changed: 104 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,65 @@ fn parse_query_node_inner(tokens: &mut Tokens) -> Result<TokenStream> {
113113
/// appear in any order; bare patterns are accumulated and emitted as a
114114
/// single `("child", ...)` entry.
115115
fn parse_query_fields(tokens: &mut Tokens) -> Result<Vec<TokenStream>> {
116-
let mut fields = Vec::new();
116+
// Accumulate per-field elems in declaration order; multiple uses of the
117+
// same field name extend the same list (so e.g. `cond: (foo) cond: (bar)`
118+
// matches a `cond` field whose first child is `foo` and second is `bar`).
119+
let mut field_order: Vec<String> = Vec::new();
120+
let mut field_elems: std::collections::HashMap<String, Vec<TokenStream>> =
121+
std::collections::HashMap::new();
117122
let mut bare_children: Vec<TokenStream> = Vec::new();
123+
let push_field_elem = |order: &mut Vec<String>,
124+
map: &mut std::collections::HashMap<String, Vec<TokenStream>>,
125+
name: String,
126+
elem: TokenStream| {
127+
if !map.contains_key(&name) {
128+
order.push(name.clone());
129+
map.insert(name, vec![elem]);
130+
} else {
131+
map.get_mut(&name).unwrap().push(elem);
132+
}
133+
};
118134
while tokens.peek().is_some() {
119135
if peek_is_field(tokens) {
120136
let field_name = expect_ident(tokens, "expected field name")?;
121137
let field_str = field_name.to_string();
122138

123139
expect_punct(tokens, ':', "expected `:` after field name")?;
124140

125-
let child = parse_query_node(tokens)?;
126-
fields.push(quote! {
127-
(#field_str, vec![yeast::query::QueryListElem::SingleNode(#child)])
128-
});
141+
// Parse the field's pattern. To support repetition like
142+
// `field: (kind)* @cap`, parse the atom first, then check for
143+
// a quantifier, and lastly handle a trailing `@capture`.
144+
let atom = parse_query_atom(tokens)?;
145+
if peek_is_repetition(tokens) {
146+
let rep = expect_repetition(tokens)?;
147+
let elem = quote! {
148+
yeast::query::QueryListElem::Repeated {
149+
children: vec![yeast::query::QueryListElem::SingleNode(#atom)],
150+
rep: #rep,
151+
}
152+
};
153+
let elem = maybe_wrap_list_capture(tokens, elem)?;
154+
push_field_elem(&mut field_order, &mut field_elems, field_str, elem);
155+
} else {
156+
let child = if peek_is_at(tokens) {
157+
tokens.next();
158+
let capture_name =
159+
expect_ident(tokens, "expected capture name after @")?;
160+
let name_str = capture_name.to_string();
161+
quote! {
162+
yeast::query::QueryNode::Capture {
163+
capture: #name_str,
164+
node: Box::new(#atom),
165+
}
166+
}
167+
} else {
168+
atom
169+
};
170+
let elem = quote! {
171+
yeast::query::QueryListElem::SingleNode(#child)
172+
};
173+
push_field_elem(&mut field_order, &mut field_elems, field_str, elem);
174+
}
129175
} else {
130176
// Bare patterns — accumulate into the implicit `child` field.
131177
// We don't break here, so we can interleave with named fields.
@@ -137,6 +183,13 @@ fn parse_query_fields(tokens: &mut Tokens) -> Result<Vec<TokenStream>> {
137183
bare_children.extend(elems);
138184
}
139185
}
186+
let mut fields: Vec<TokenStream> = Vec::new();
187+
for name in field_order {
188+
let elems = field_elems.remove(&name).unwrap();
189+
fields.push(quote! {
190+
(#name, vec![#(#elems),*])
191+
});
192+
}
140193
if !bare_children.is_empty() {
141194
fields.push(quote! {
142195
("child", vec![#(#bare_children),*])
@@ -299,7 +352,7 @@ fn parse_direct_node(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStream> {
299352
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Brace => {
300353
let group = expect_group(tokens, Delimiter::Brace)?;
301354
let expr = group.stream();
302-
Ok(quote! { #expr })
355+
Ok(quote! { ::std::convert::Into::<usize>::into(#expr) })
303356
}
304357
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => {
305358
let group = expect_group(tokens, Delimiter::Parenthesis)?;
@@ -329,12 +382,17 @@ fn parse_direct_node_inner(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStre
329382
return Ok(quote! { #ctx.literal(#kind_str, #lit) });
330383
}
331384

332-
// Check for (kind #{expr}) — computed literal, expr converted via .to_string()
385+
// Check for (kind #{expr}) — computed literal, expr converted via YeastDisplay
333386
if peek_is_hash(tokens) {
334387
tokens.next(); // consume #
335388
let group = expect_group(tokens, Delimiter::Brace)?;
336389
let expr = group.stream();
337-
return Ok(quote! { #ctx.literal(#kind_str, &(#expr).to_string()) });
390+
return Ok(quote! {
391+
{
392+
let __value = yeast::YeastDisplay::yeast_to_string(&(#expr), &*#ctx.ast);
393+
#ctx.literal(#kind_str, &__value)
394+
}
395+
});
338396
}
339397

340398
// Check for (kind $fresh)
@@ -374,15 +432,19 @@ fn parse_direct_node_inner(tokens: &mut Tokens, ctx: &Ident) -> Result<TokenStre
374432
inner.next(); // consume first .
375433
inner.next(); // consume second .
376434
let expr: proc_macro2::TokenStream = inner.collect();
377-
stmts.push(quote! { let #temp: Vec<usize> = #expr; });
435+
stmts.push(quote! {
436+
let #temp: Vec<usize> = (#expr).into_iter()
437+
.map(::std::convert::Into::<usize>::into)
438+
.collect();
439+
});
378440
field_args.push(quote! { (#field_str, #temp) });
379441
continue;
380442
}
381443
}
382444
}
383445

384446
let value = parse_direct_node(tokens, ctx)?;
385-
stmts.push(quote! { let #temp = #value; });
447+
stmts.push(quote! { let #temp: usize = #value; });
386448
field_args.push(quote! { (#field_str, vec![#temp]) });
387449
}
388450

@@ -427,10 +489,16 @@ fn parse_direct_list(tokens: &mut Tokens, ctx: &Ident) -> Result<Vec<TokenStream
427489
inner.next(); // consume first .
428490
inner.next(); // consume second .
429491
let expr: TokenStream = inner.collect();
430-
items.push(quote! { __nodes.extend(#expr); });
492+
items.push(quote! {
493+
__nodes.extend(
494+
(#expr).into_iter().map(::std::convert::Into::<usize>::into)
495+
);
496+
});
431497
} else {
432498
let expr = group.stream();
433-
items.push(quote! { __nodes.push(#expr); });
499+
items.push(quote! {
500+
__nodes.push(::std::convert::Into::<usize>::into(#expr));
501+
});
434502
}
435503
continue;
436504
}
@@ -580,13 +648,24 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
580648
let name_str = &cap.name;
581649
match cap.multiplicity {
582650
CaptureMultiplicity::Repeated => {
583-
quote! { let #name: Vec<usize> = __captures.get_all(#name_str); }
651+
quote! {
652+
let #name: Vec<yeast::NodeRef> = __captures.get_all(#name_str)
653+
.into_iter()
654+
.map(yeast::NodeRef)
655+
.collect();
656+
}
584657
}
585658
CaptureMultiplicity::Optional => {
586-
quote! { let #name: Option<usize> = __captures.get_opt(#name_str); }
659+
quote! {
660+
let #name: Option<yeast::NodeRef> =
661+
__captures.get_opt(#name_str).map(yeast::NodeRef);
662+
}
587663
}
588664
CaptureMultiplicity::Single => {
589-
quote! { let #name: usize = __captures.get_var(#name_str).unwrap(); }
665+
quote! {
666+
let #name: yeast::NodeRef =
667+
yeast::NodeRef(__captures.get_var(#name_str).unwrap());
668+
}
590669
}
591670
}
592671
})
@@ -613,19 +692,26 @@ pub fn parse_rule_top(input: TokenStream) -> Result<TokenStream> {
613692
CaptureMultiplicity::Repeated => quote! {
614693
let __field_id = #ctx_ident.ast.field_id_for_name(#name_str)
615694
.unwrap_or_else(|| panic!("field '{}' not found", #name_str));
616-
__fields.insert(__field_id, #name);
695+
__fields.insert(
696+
__field_id,
697+
#name.into_iter()
698+
.map(::std::convert::Into::<usize>::into)
699+
.collect(),
700+
);
617701
},
618702
CaptureMultiplicity::Optional => quote! {
619703
let __field_id = #ctx_ident.ast.field_id_for_name(#name_str)
620704
.unwrap_or_else(|| panic!("field '{}' not found", #name_str));
621705
if let Some(__id) = #name {
622-
__fields.entry(__field_id).or_insert_with(Vec::new).push(__id);
706+
__fields.entry(__field_id).or_insert_with(Vec::new)
707+
.push(::std::convert::Into::<usize>::into(__id));
623708
}
624709
},
625710
CaptureMultiplicity::Single => quote! {
626711
let __field_id = #ctx_ident.ast.field_id_for_name(#name_str)
627712
.unwrap_or_else(|| panic!("field '{}' not found", #name_str));
628-
__fields.entry(__field_id).or_insert_with(Vec::new).push(#name);
713+
__fields.entry(__field_id).or_insert_with(Vec::new)
714+
.push(::std::convert::Into::<usize>::into(#name));
629715
},
630716
}
631717
})

shared/yeast/doc/yeast.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,8 @@ to enable rewriting:
349349

350350
```rust
351351
let desugar = yeast::DesugaringConfig::new()
352-
.add_phase("cleanup", cleanup_rules())
353-
.add_phase("desugar", desugar_rules())
352+
.add_phase("cleanup", yeast::PhaseKind::Repeating, cleanup_rules())
353+
.add_phase("translate", yeast::PhaseKind::OneShot, translate_rules())
354354
.with_output_node_types_yaml(include_str!("output-node-types.yml"));
355355

356356
let lang = simple::LanguageSpec {
@@ -365,6 +365,15 @@ let lang = simple::LanguageSpec {
365365
A single-phase config is just `.add_phase(...)` called once. Phase names
366366
appear in error messages so you can tell which phase failed.
367367

368+
There are two kinds of phases:
369+
- **Repeating**:
370+
Each node is re-processed until none of the rules in the phase matches.
371+
When a node no longer matches any rules, its children are recursively processed. In practice this is used to desugar or simplify an AST, while staying mostly within the same schema.
372+
- **One-shot**:
373+
Each node is processed by the first matching rule, and the engine panics if no rule matches.
374+
Rules are then recursively applied to every captured node.
375+
In practice this is used when translating from one AST schema to another, where an exhaustive match is required.
376+
368377
The same YAML node-types is used for both the runtime yeast `Schema` (so
369378
rules can refer to output-only kinds and fields) and TRAP validation (it
370379
is converted to JSON internally).

shared/yeast/src/captures.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ impl Captures {
6161
}
6262
}
6363
}
64+
65+
/// Apply a fallible function to every captured id (across all keys),
66+
/// replacing each id with the result. Stops and returns the error on
67+
/// the first failure.
68+
pub fn try_map_all_captures<E>(
69+
&mut self,
70+
mut f: impl FnMut(Id) -> Result<Id, E>,
71+
) -> Result<(), E> {
72+
for ids in self.captures.values_mut() {
73+
for id in ids {
74+
*id = f(*id)?;
75+
}
76+
}
77+
Ok(())
78+
}
6479
pub fn map_captures_to(&mut self, from: &str, to: &'static str, f: &mut impl FnMut(Id) -> Id) {
6580
if let Some(from_ids) = self.captures.get(from) {
6681
let new_values = from_ids.iter().copied().map(f).collect();

0 commit comments

Comments
 (0)