From ac924d91a3460fb4302a5c7214a8049b2ecf8a55 Mon Sep 17 00:00:00 2001 From: Sean Griffin Date: Tue, 26 May 2026 13:11:47 -0600 Subject: [PATCH] Receive a select statement instead of repeatedly checking we got one We have code taking a full top level ParseResult when it only operates on select statements. Nothing in the code needs the top level struct, we can just take the thing we need directly instead of repeatedly checking if that's what we got. --- .../rewrite/statement/aggregate/engine.rs | 73 ++++++------------- .../parser/rewrite/statement/aggregate/mod.rs | 10 +-- 2 files changed, 29 insertions(+), 54 deletions(-) diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs index e8e57a695..c7af4978e 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs @@ -1,6 +1,6 @@ use crate::frontend::router::parser::aggregate::{Aggregate, AggregateFunction}; use pg_query::protobuf::{ - a_const::Val, AConst, FuncCall, Integer, Node, ParseResult, ResTarget, String as PgString, + a_const::Val, AConst, FuncCall, Integer, Node, ResTarget, SelectStmt, String as PgString, }; use pg_query::NodeEnum; @@ -14,23 +14,11 @@ pub struct AggregatesRewrite; impl AggregatesRewrite { /// Rewrite a SELECT query in-place, adding helper aggregates when necessary. - pub fn rewrite_select(&self, ast: &mut ParseResult, aggregate: &Aggregate) -> RewriteOutput { + pub fn rewrite_select(&self, ast: &mut SelectStmt, aggregate: &Aggregate) -> RewriteOutput { self.rewrite_parsed(ast, aggregate) } - fn rewrite_parsed(&self, parsed: &mut ParseResult, aggregate: &Aggregate) -> RewriteOutput { - let Some(raw_stmt) = parsed.stmts.first() else { - return RewriteOutput::default(); - }; - - let Some(stmt) = raw_stmt.stmt.as_ref() else { - return RewriteOutput::default(); - }; - - let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_ref() else { - return RewriteOutput::default(); - }; - + fn rewrite_parsed(&self, select: &mut SelectStmt, aggregate: &Aggregate) -> RewriteOutput { let mut plan = AggregateRewritePlan::new(); let mut helper_nodes: Vec = Vec::new(); let mut planned_aliases: Vec = Vec::new(); @@ -133,18 +121,6 @@ impl AggregatesRewrite { return RewriteOutput::default(); } - let Some(raw_stmt) = parsed.stmts.first_mut() else { - return RewriteOutput::default(); - }; - - let Some(stmt) = raw_stmt.stmt.as_mut() else { - return RewriteOutput::default(); - }; - - let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_mut() else { - return RewriteOutput::default(); - }; - select.target_list.extend(helper_nodes); RewriteOutput::new(plan) @@ -308,39 +284,38 @@ struct HelperSpec { mod tests { use super::*; use crate::frontend::router::parser::aggregate::Aggregate; + use pg_query::protobuf::ParseResult; - fn select(ast: &ParseResult) -> &pg_query::protobuf::SelectStmt { + fn select(ast: &mut ParseResult) -> &mut pg_query::protobuf::SelectStmt { match ast .stmts - .first() - .and_then(|stmt| stmt.stmt.as_ref()) - .and_then(|stmt| stmt.node.as_ref()) + .first_mut() + .and_then(|stmt| stmt.stmt.as_mut()) + .and_then(|stmt| stmt.node.as_mut()) { - Some(NodeEnum::SelectStmt(select)) => select, + Some(NodeEnum::SelectStmt(select)) => &mut *select, _ => panic!("not a select"), } } fn rewrite(sql: &str) -> (ParseResult, RewriteOutput) { let mut parsed = pg_query::parse(sql).unwrap().protobuf; - let aggregate = { - let stmt = select(&parsed); - Aggregate::parse(stmt) - }; - let output = AggregatesRewrite.rewrite_select(&mut parsed, &aggregate); + let stmt = select(&mut parsed); + let aggregate = Aggregate::parse(stmt); + let output = AggregatesRewrite.rewrite_select(stmt, &aggregate); (parsed, output) } #[test] fn rewrite_engine_noop() { - let (ast, output) = rewrite("SELECT COUNT(price) FROM menu"); + let (mut ast, output) = rewrite("SELECT COUNT(price) FROM menu"); assert!(output.plan.is_noop()); - assert_eq!(select(&ast).target_list.len(), 1); + assert_eq!(select(&mut ast).target_list.len(), 1); } #[test] fn rewrite_engine_adds_helper() { - let (ast, output) = rewrite("SELECT AVG(price) FROM menu"); + let (mut ast, output) = rewrite("SELECT AVG(price) FROM menu"); assert!(!output.plan.is_noop()); assert_eq!(output.plan.drop_columns(), &[1]); assert_eq!(output.plan.helpers().len(), 1); @@ -351,7 +326,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&ast)); + let aggregate = Aggregate::parse(select(&mut ast)); assert_eq!(aggregate.targets().len(), 2); assert!(aggregate .targets() @@ -362,14 +337,14 @@ mod tests { #[test] fn rewrite_engine_skips_when_count_exists() { let sql = "SELECT COUNT(price), AVG(price) FROM menu"; - let (ast, output) = rewrite(sql); + let (mut ast, output) = rewrite(sql); assert!(output.plan.is_noop()); - assert_eq!(select(&ast).target_list.len(), 2); + assert_eq!(select(&mut ast).target_list.len(), 2); } #[test] fn rewrite_engine_handles_mismatched_pair() { - let (ast, output) = rewrite("SELECT COUNT(price::numeric), AVG(price) FROM menu"); + let (mut ast, output) = rewrite("SELECT COUNT(price::numeric), AVG(price) FROM menu"); assert_eq!(output.plan.drop_columns(), &[2]); assert_eq!(output.plan.helpers().len(), 1); let helper = &output.plan.helpers()[0]; @@ -378,7 +353,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&ast)); + let aggregate = Aggregate::parse(select(&mut ast)); assert_eq!(aggregate.targets().len(), 3); assert!( aggregate @@ -392,7 +367,7 @@ mod tests { #[test] fn rewrite_engine_multiple_avg_helpers() { - let (ast, output) = rewrite("SELECT AVG(price), AVG(discount) FROM menu"); + let (mut ast, output) = rewrite("SELECT AVG(price), AVG(discount) FROM menu"); assert_eq!(output.plan.drop_columns(), &[2, 3]); assert_eq!(output.plan.helpers().len(), 2); @@ -406,7 +381,7 @@ mod tests { assert_eq!(helper_discount.helper_column, 3); assert!(matches!(helper_discount.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&ast)); + let aggregate = Aggregate::parse(select(&mut ast)); assert_eq!(aggregate.targets().len(), 4); assert_eq!( aggregate @@ -420,7 +395,7 @@ mod tests { #[test] fn rewrite_engine_stddev_helpers() { - let (ast, output) = rewrite("SELECT STDDEV(price) FROM menu"); + let (mut ast, output) = rewrite("SELECT STDDEV(price) FROM menu"); assert!(!output.plan.is_noop()); assert_eq!(output.plan.drop_columns(), &[1, 2, 3]); assert_eq!(output.plan.helpers().len(), 3); @@ -440,6 +415,6 @@ mod tests { assert!(kinds.contains(&HelperKind::SumSquares)); // Expect original STDDEV plus three helpers. - assert_eq!(select(&ast).target_list.len(), 4); + assert_eq!(select(&mut ast).target_list.len(), 4); } } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs index 0676031f0..0dca2c6ee 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs @@ -15,24 +15,24 @@ impl StatementRewrite<'_> { return Ok(()); } - let Some(raw_stmt) = self.stmt.stmts.first() else { + let Some(raw_stmt) = self.stmt.stmts.first_mut() else { return Ok(()); }; - let Some(stmt) = raw_stmt.stmt.as_ref() else { + let Some(stmt) = raw_stmt.stmt.as_mut() else { return Ok(()); }; - let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_ref() else { + let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_mut() else { return Ok(()); }; - let aggregate = Aggregate::parse(select); + let aggregate = Aggregate::parse(&select); if aggregate.is_empty() { return Ok(()); } - let output = AggregatesRewrite.rewrite_select(self.stmt, &aggregate); + let output = AggregatesRewrite.rewrite_select(select, &aggregate); if output.plan.is_noop() { return Ok(()); }