diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs index e8e57a695..c7af4978e 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/engine.rs @@ -1,6 +1,6 @@ use crate::frontend::router::parser::aggregate::{Aggregate, AggregateFunction}; use pg_query::protobuf::{ - a_const::Val, AConst, FuncCall, Integer, Node, ParseResult, ResTarget, String as PgString, + a_const::Val, AConst, FuncCall, Integer, Node, ResTarget, SelectStmt, String as PgString, }; use pg_query::NodeEnum; @@ -14,23 +14,11 @@ pub struct AggregatesRewrite; impl AggregatesRewrite { /// Rewrite a SELECT query in-place, adding helper aggregates when necessary. - pub fn rewrite_select(&self, ast: &mut ParseResult, aggregate: &Aggregate) -> RewriteOutput { + pub fn rewrite_select(&self, ast: &mut SelectStmt, aggregate: &Aggregate) -> RewriteOutput { self.rewrite_parsed(ast, aggregate) } - fn rewrite_parsed(&self, parsed: &mut ParseResult, aggregate: &Aggregate) -> RewriteOutput { - let Some(raw_stmt) = parsed.stmts.first() else { - return RewriteOutput::default(); - }; - - let Some(stmt) = raw_stmt.stmt.as_ref() else { - return RewriteOutput::default(); - }; - - let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_ref() else { - return RewriteOutput::default(); - }; - + fn rewrite_parsed(&self, select: &mut SelectStmt, aggregate: &Aggregate) -> RewriteOutput { let mut plan = AggregateRewritePlan::new(); let mut helper_nodes: Vec = Vec::new(); let mut planned_aliases: Vec = Vec::new(); @@ -133,18 +121,6 @@ impl AggregatesRewrite { return RewriteOutput::default(); } - let Some(raw_stmt) = parsed.stmts.first_mut() else { - return RewriteOutput::default(); - }; - - let Some(stmt) = raw_stmt.stmt.as_mut() else { - return RewriteOutput::default(); - }; - - let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_mut() else { - return RewriteOutput::default(); - }; - select.target_list.extend(helper_nodes); RewriteOutput::new(plan) @@ -308,39 +284,38 @@ struct HelperSpec { mod tests { use super::*; use crate::frontend::router::parser::aggregate::Aggregate; + use pg_query::protobuf::ParseResult; - fn select(ast: &ParseResult) -> &pg_query::protobuf::SelectStmt { + fn select(ast: &mut ParseResult) -> &mut pg_query::protobuf::SelectStmt { match ast .stmts - .first() - .and_then(|stmt| stmt.stmt.as_ref()) - .and_then(|stmt| stmt.node.as_ref()) + .first_mut() + .and_then(|stmt| stmt.stmt.as_mut()) + .and_then(|stmt| stmt.node.as_mut()) { - Some(NodeEnum::SelectStmt(select)) => select, + Some(NodeEnum::SelectStmt(select)) => &mut *select, _ => panic!("not a select"), } } fn rewrite(sql: &str) -> (ParseResult, RewriteOutput) { let mut parsed = pg_query::parse(sql).unwrap().protobuf; - let aggregate = { - let stmt = select(&parsed); - Aggregate::parse(stmt) - }; - let output = AggregatesRewrite.rewrite_select(&mut parsed, &aggregate); + let stmt = select(&mut parsed); + let aggregate = Aggregate::parse(stmt); + let output = AggregatesRewrite.rewrite_select(stmt, &aggregate); (parsed, output) } #[test] fn rewrite_engine_noop() { - let (ast, output) = rewrite("SELECT COUNT(price) FROM menu"); + let (mut ast, output) = rewrite("SELECT COUNT(price) FROM menu"); assert!(output.plan.is_noop()); - assert_eq!(select(&ast).target_list.len(), 1); + assert_eq!(select(&mut ast).target_list.len(), 1); } #[test] fn rewrite_engine_adds_helper() { - let (ast, output) = rewrite("SELECT AVG(price) FROM menu"); + let (mut ast, output) = rewrite("SELECT AVG(price) FROM menu"); assert!(!output.plan.is_noop()); assert_eq!(output.plan.drop_columns(), &[1]); assert_eq!(output.plan.helpers().len(), 1); @@ -351,7 +326,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&ast)); + let aggregate = Aggregate::parse(select(&mut ast)); assert_eq!(aggregate.targets().len(), 2); assert!(aggregate .targets() @@ -362,14 +337,14 @@ mod tests { #[test] fn rewrite_engine_skips_when_count_exists() { let sql = "SELECT COUNT(price), AVG(price) FROM menu"; - let (ast, output) = rewrite(sql); + let (mut ast, output) = rewrite(sql); assert!(output.plan.is_noop()); - assert_eq!(select(&ast).target_list.len(), 2); + assert_eq!(select(&mut ast).target_list.len(), 2); } #[test] fn rewrite_engine_handles_mismatched_pair() { - let (ast, output) = rewrite("SELECT COUNT(price::numeric), AVG(price) FROM menu"); + let (mut ast, output) = rewrite("SELECT COUNT(price::numeric), AVG(price) FROM menu"); assert_eq!(output.plan.drop_columns(), &[2]); assert_eq!(output.plan.helpers().len(), 1); let helper = &output.plan.helpers()[0]; @@ -378,7 +353,7 @@ mod tests { assert!(!helper.distinct); assert!(matches!(helper.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&ast)); + let aggregate = Aggregate::parse(select(&mut ast)); assert_eq!(aggregate.targets().len(), 3); assert!( aggregate @@ -392,7 +367,7 @@ mod tests { #[test] fn rewrite_engine_multiple_avg_helpers() { - let (ast, output) = rewrite("SELECT AVG(price), AVG(discount) FROM menu"); + let (mut ast, output) = rewrite("SELECT AVG(price), AVG(discount) FROM menu"); assert_eq!(output.plan.drop_columns(), &[2, 3]); assert_eq!(output.plan.helpers().len(), 2); @@ -406,7 +381,7 @@ mod tests { assert_eq!(helper_discount.helper_column, 3); assert!(matches!(helper_discount.kind, HelperKind::Count)); - let aggregate = Aggregate::parse(select(&ast)); + let aggregate = Aggregate::parse(select(&mut ast)); assert_eq!(aggregate.targets().len(), 4); assert_eq!( aggregate @@ -420,7 +395,7 @@ mod tests { #[test] fn rewrite_engine_stddev_helpers() { - let (ast, output) = rewrite("SELECT STDDEV(price) FROM menu"); + let (mut ast, output) = rewrite("SELECT STDDEV(price) FROM menu"); assert!(!output.plan.is_noop()); assert_eq!(output.plan.drop_columns(), &[1, 2, 3]); assert_eq!(output.plan.helpers().len(), 3); @@ -440,6 +415,6 @@ mod tests { assert!(kinds.contains(&HelperKind::SumSquares)); // Expect original STDDEV plus three helpers. - assert_eq!(select(&ast).target_list.len(), 4); + assert_eq!(select(&mut ast).target_list.len(), 4); } } diff --git a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs index 0676031f0..0dca2c6ee 100644 --- a/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs +++ b/pgdog/src/frontend/router/parser/rewrite/statement/aggregate/mod.rs @@ -15,24 +15,24 @@ impl StatementRewrite<'_> { return Ok(()); } - let Some(raw_stmt) = self.stmt.stmts.first() else { + let Some(raw_stmt) = self.stmt.stmts.first_mut() else { return Ok(()); }; - let Some(stmt) = raw_stmt.stmt.as_ref() else { + let Some(stmt) = raw_stmt.stmt.as_mut() else { return Ok(()); }; - let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_ref() else { + let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_mut() else { return Ok(()); }; - let aggregate = Aggregate::parse(select); + let aggregate = Aggregate::parse(&select); if aggregate.is_empty() { return Ok(()); } - let output = AggregatesRewrite.rewrite_select(self.stmt, &aggregate); + let output = AggregatesRewrite.rewrite_select(select, &aggregate); if output.plan.is_noop() { return Ok(()); }