diff --git a/src/backend/optimizer/util/var.c b/src/backend/optimizer/util/var.c index 66f5b598f0c..2a792e3223a 100644 --- a/src/backend/optimizer/util/var.c +++ b/src/backend/optimizer/util/var.c @@ -776,14 +776,6 @@ pull_var_clause_walker(Node *node, pull_var_clause_context *context) * PlaceHolderVar or constructed from those, we can just add the * varnullingrels bits to the existing nullingrels field(s); otherwise * we have to add a PlaceHolderVar wrapper. - * - * NOTE: this is also used by the parser, to expand join alias Vars before - * checking GROUP BY validity. For that use-case, root will be NULL, which - * is why we have to pass the Query separately. We need the root itself only - * for making PlaceHolderVars. We can avoid making PlaceHolderVars in the - * parser's usage because it won't be dealing with arbitrary expressions: - * so long as adjust_standard_join_alias_expression can handle everything - * the parser would make as a join alias expression, we're OK. */ Node * flatten_join_alias_vars(PlannerInfo *root, Query *query, Node *node) @@ -808,6 +800,44 @@ flatten_join_alias_vars(PlannerInfo *root, Query *query, Node *node) return flatten_join_alias_vars_mutator(node, &context); } +/* + * flatten_join_alias_for_parser + * + * This variant of flatten_join_alias_vars is used by the parser, to expand + * join alias Vars before checking GROUP BY validity. In that case we lack + * a root structure. Fortunately, we'd only need the root for making + * PlaceHolderVars. We can avoid making PlaceHolderVars in the parser's + * usage because it won't be dealing with arbitrary expressions: so long as + * adjust_standard_join_alias_expression can handle everything the parser + * would make as a join alias expression, we're OK. + * + * The "node" might be part of a sub-query of the Query whose join alias + * Vars are to be expanded. "sublevels_up" indicates how far below the + * given query we are starting. + */ +Node * +flatten_join_alias_for_parser(Query *query, Node *node, int sublevels_up) +{ + flatten_join_alias_vars_context context; + + /* + * We do not expect this to be applied to the whole Query, only to + * expressions or LATERAL subqueries. Hence, if the top node is a Query, + * it's okay to immediately increment sublevels_up. + */ + Assert(node != (Node *) query); + + context.root = NULL; + context.query = query; + context.sublevels_up = sublevels_up; + /* flag whether join aliases could possibly contain SubLinks */ + context.possible_sublink = query->hasSubLinks; + /* if hasSubLinks is already true, no need to work hard */ + context.inserted_sublink = query->hasSubLinks; + + return flatten_join_alias_vars_mutator(node, &context); +} + static Node * flatten_join_alias_vars_mutator(Node *node, flatten_join_alias_vars_context *context) diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index d0187ea84a0..33fd2cccae5 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -48,9 +48,10 @@ typedef struct ParseState *pstate; Query *qry; bool hasJoinRTEs; - List *groupClauses; - List *groupClauseCommonVars; - List *gset_common; + List *groupClauses; /* list of TargetEntry */ + List *groupClauseCommonVars; /* list of Vars */ + List *groupClauseSubLevels; /* list of lists of TargetEntry */ + List *gset_common; /* integer list of sortgrouprefs */ bool have_non_var_grouping; List **func_grouped_rels; int sublevels_up; @@ -1255,8 +1256,8 @@ parseCheckAggregates(ParseState *pstate, Query *qry) * entries are RTE_JOIN kind. */ if (hasJoinRTEs) - groupClauses = (List *) flatten_join_alias_vars(NULL, qry, - (Node *) groupClauses); + groupClauses = (List *) + flatten_join_alias_for_parser(qry, (Node *) groupClauses, 0); /* * Detect whether any of the grouping expressions aren't simple Vars; if @@ -1301,7 +1302,7 @@ parseCheckAggregates(ParseState *pstate, Query *qry) groupClauses, hasJoinRTEs, have_non_var_grouping); if (hasJoinRTEs) - clause = flatten_join_alias_vars(NULL, qry, clause); + clause = flatten_join_alias_for_parser(qry, clause, 0); qry->targetList = (List *) substitute_grouped_columns(clause, pstate, qry, groupClauses, groupClauseCommonVars, @@ -1314,7 +1315,7 @@ parseCheckAggregates(ParseState *pstate, Query *qry) groupClauses, hasJoinRTEs, have_non_var_grouping); if (hasJoinRTEs) - clause = flatten_join_alias_vars(NULL, qry, clause); + clause = flatten_join_alias_for_parser(qry, clause, 0); qry->havingQual = substitute_grouped_columns(clause, pstate, qry, groupClauses, groupClauseCommonVars, @@ -1344,17 +1345,6 @@ parseCheckAggregates(ParseState *pstate, Query *qry) * * NOTE: we assume that the given clause has been transformed suitably for * parser output. This means we can use expression_tree_mutator. - * - * NOTE: we recognize grouping expressions in the main query, but only - * grouping Vars in subqueries. For example, this will be rejected, - * although it could be allowed: - * SELECT - * (SELECT x FROM bar where y = (foo.a + foo.b)) - * FROM foo - * GROUP BY a + b; - * The difficulty is the need to account for different sublevels_up. - * This appears to require a whole custom version of equal(), which is - * way more pain than the feature seems worth. */ static Node * substitute_grouped_columns(Node *node, ParseState *pstate, Query *qry, @@ -1370,6 +1360,7 @@ substitute_grouped_columns(Node *node, ParseState *pstate, Query *qry, context.hasJoinRTEs = false; /* assume caller flattened join Vars */ context.groupClauses = groupClauses; context.groupClauseCommonVars = groupClauseCommonVars; + context.groupClauseSubLevels = NIL; context.gset_common = gset_common; context.have_non_var_grouping = have_non_var_grouping; context.func_grouped_rels = func_grouped_rels; @@ -1437,14 +1428,22 @@ substitute_grouped_columns_mutator(Node *node, * If we have any GROUP BY items that are not simple Vars, check to see if * subexpression as a whole matches any GROUP BY item. We need to do this * at every recursion level so that we recognize GROUPed-BY expressions - * before reaching variables within them. But this only works at the outer - * query level, as noted above. + * before reaching variables within them. (Since this approach is pretty + * expensive, we don't do it this way if the items are all simple Vars.) */ - if (context->have_non_var_grouping && context->sublevels_up == 0) + if (context->have_non_var_grouping) { + List *groupClauses; int attnum = 0; - foreach(gl, context->groupClauses) + /* Within a subquery, we need a mutated version of the groupClauses */ + if (context->sublevels_up == 0) + groupClauses = context->groupClauses; + else + groupClauses = list_nth(context->groupClauseSubLevels, + context->sublevels_up - 1); + + foreach(gl, groupClauses) { TargetEntry *tle = (TargetEntry *) lfirst(gl); @@ -1487,7 +1486,7 @@ substitute_grouped_columns_mutator(Node *node, /* * Check for a match, if we didn't do it above. */ - if (!context->have_non_var_grouping || context->sublevels_up != 0) + if (!context->have_non_var_grouping) { int attnum = 0; @@ -1570,6 +1569,24 @@ substitute_grouped_columns_mutator(Node *node, Query *newnode; context->sublevels_up++; + + /* + * If we have non-Var grouping expressions, we'll need a copy of the + * groupClauses list that's mutated to match this sublevels_up depth. + * Build one if we've not yet visited a subquery at this depth. + */ + if (context->have_non_var_grouping && + context->sublevels_up > list_length(context->groupClauseSubLevels)) + { + List *subGroupClauses = copyObject(context->groupClauses); + + IncrementVarSublevelsUp((Node *) subGroupClauses, + context->sublevels_up, 0); + context->groupClauseSubLevels = + lappend(context->groupClauseSubLevels, subGroupClauses); + Assert(context->sublevels_up == list_length(context->groupClauseSubLevels)); + } + newnode = query_tree_mutator((Query *) node, substitute_grouped_columns_mutator, context, @@ -1604,6 +1621,7 @@ finalize_grouping_exprs(Node *node, ParseState *pstate, Query *qry, context.hasJoinRTEs = hasJoinRTEs; context.groupClauses = groupClauses; context.groupClauseCommonVars = NIL; + context.groupClauseSubLevels = NIL; context.gset_common = NIL; context.have_non_var_grouping = have_non_var_grouping; context.func_grouped_rels = NULL; @@ -1676,7 +1694,9 @@ finalize_grouping_exprs_walker(Node *node, Index ref = 0; if (context->hasJoinRTEs) - expr = flatten_join_alias_vars(NULL, context->qry, expr); + expr = flatten_join_alias_for_parser(context->qry, + expr, + context->sublevels_up); /* * Each expression must match a grouping entry at the current @@ -1706,10 +1726,21 @@ finalize_grouping_exprs_walker(Node *node, } } } - else if (context->have_non_var_grouping && - context->sublevels_up == 0) + else if (context->have_non_var_grouping) { - foreach(gl, context->groupClauses) + List *groupClauses; + + /* + * Within a subquery, we need a mutated version of the + * groupClauses + */ + if (context->sublevels_up == 0) + groupClauses = context->groupClauses; + else + groupClauses = list_nth(context->groupClauseSubLevels, + context->sublevels_up - 1); + + foreach(gl, groupClauses) { TargetEntry *tle = lfirst(gl); @@ -1744,6 +1775,24 @@ finalize_grouping_exprs_walker(Node *node, bool result; context->sublevels_up++; + + /* + * If we have non-Var grouping expressions, we'll need a copy of the + * groupClauses list that's mutated to match this sublevels_up depth. + * Build one if we've not yet visited a subquery at this depth. + */ + if (context->have_non_var_grouping && + context->sublevels_up > list_length(context->groupClauseSubLevels)) + { + List *subGroupClauses = copyObject(context->groupClauses); + + IncrementVarSublevelsUp((Node *) subGroupClauses, + context->sublevels_up, 0); + context->groupClauseSubLevels = + lappend(context->groupClauseSubLevels, subGroupClauses); + Assert(context->sublevels_up == list_length(context->groupClauseSubLevels)); + } + result = query_tree_walker((Query *) node, finalize_grouping_exprs_walker, context, diff --git a/src/include/optimizer/optimizer.h b/src/include/optimizer/optimizer.h index 3d27a019609..b562ca380a8 100644 --- a/src/include/optimizer/optimizer.h +++ b/src/include/optimizer/optimizer.h @@ -204,6 +204,8 @@ extern bool contain_vars_returning_old_or_new(Node *node); extern int locate_var_of_level(Node *node, int levelsup); extern List *pull_var_clause(Node *node, int flags); extern Node *flatten_join_alias_vars(PlannerInfo *root, Query *query, Node *node); +extern Node *flatten_join_alias_for_parser(Query *query, Node *node, + int sublevels_up); extern Node *flatten_group_exprs(PlannerInfo *root, Query *query, Node *node); #endif /* OPTIMIZER_H */ diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index bbd4554fa4f..ff80869fb33 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -862,6 +862,26 @@ select array(select sum(x+y) s {4,5,6} (3 rows) +-- Test handling of grouping-expression references within sublinks +select two + four as g, (select f1 from int4_tbl where f1 = (two + four)) +from tenk1 t1 +group by two + four order by 1; + g | f1 +---+---- + 0 | 0 + 2 | + 4 | +(3 rows) + +select q1, (select q1) as ss -- q1 is actually a COALESCE expression here +from int8_tbl i81 full outer join int8_tbl i82 using (q1) +group by q1 order by q1; + q1 | ss +------------------+------------------ + 123 | 123 + 4567890123456789 | 4567890123456789 +(2 rows) + -- -- test for bitwise integer aggregates -- diff --git a/src/test/regress/expected/groupingsets.out b/src/test/regress/expected/groupingsets.out index 921017489c0..b08083ec54c 100644 --- a/src/test/regress/expected/groupingsets.out +++ b/src/test/regress/expected/groupingsets.out @@ -348,6 +348,43 @@ select a, b, grouping(a, b), sum(t1.v), max(t2.c) | | 3 | 172 | 2 (3 rows) +select a, b, grouping(a, b), sum(t1.v), max(t2.c) + from gstest1 t1 full join gstest2 t2 using (a,b) + group by grouping sets ((a, b), ()); + a | b | grouping | sum | max +---+---+----------+-----+----- + 1 | 1 | 0 | 147 | 2 + 1 | 2 | 0 | 25 | 2 + 1 | 3 | 0 | 14 | + 2 | 2 | 0 | | 2 + 2 | 3 | 0 | 15 | + 3 | 3 | 0 | 16 | + 3 | 4 | 0 | 17 | + 4 | 1 | 0 | 37 | + | | 3 | 271 | 2 +(9 rows) + +-- references in subqueries should work too +select (select a), + (select b), + (select grouping(a, b)), + (select sum(t1.v)), + (select max(t2.c)) + from gstest1 t1 full join gstest2 t2 using (a,b) + group by grouping sets ((a, b), ()); + a | b | grouping | sum | max +---+---+----------+-----+----- + 1 | 1 | 0 | 147 | 2 + 1 | 2 | 0 | 25 | 2 + 1 | 3 | 0 | 14 | + 2 | 2 | 0 | | 2 + 2 | 3 | 0 | 15 | + 3 | 3 | 0 | 16 | + 3 | 4 | 0 | 17 | + 4 | 1 | 0 | 37 | + | | 3 | 271 | 2 +(9 rows) + -- check that functionally dependent cols are not nulled select a, d, grouping(a,b,c) from gstest3 diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql index 5992cb7ca9b..89bb83718e0 100644 --- a/src/test/regress/sql/aggregates.sql +++ b/src/test/regress/sql/aggregates.sql @@ -246,6 +246,16 @@ select array(select sum(x+y) s from generate_series(1,3) y group by y order by s) from generate_series(1,3) x; +-- Test handling of grouping-expression references within sublinks + +select two + four as g, (select f1 from int4_tbl where f1 = (two + four)) +from tenk1 t1 +group by two + four order by 1; + +select q1, (select q1) as ss -- q1 is actually a COALESCE expression here +from int8_tbl i81 full outer join int8_tbl i82 using (q1) +group by q1 order by q1; + -- -- test for bitwise integer aggregates -- diff --git a/src/test/regress/sql/groupingsets.sql b/src/test/regress/sql/groupingsets.sql index 826ac5f5dbf..a594449b697 100644 --- a/src/test/regress/sql/groupingsets.sql +++ b/src/test/regress/sql/groupingsets.sql @@ -136,6 +136,19 @@ select a, b, grouping(a, b), sum(t1.v), max(t2.c) from gstest1 t1 join gstest2 t2 using (a,b) group by grouping sets ((a, b), ()); +select a, b, grouping(a, b), sum(t1.v), max(t2.c) + from gstest1 t1 full join gstest2 t2 using (a,b) + group by grouping sets ((a, b), ()); + +-- references in subqueries should work too +select (select a), + (select b), + (select grouping(a, b)), + (select sum(t1.v)), + (select max(t2.c)) + from gstest1 t1 full join gstest2 t2 using (a,b) + group by grouping sets ((a, b), ()); + -- check that functionally dependent cols are not nulled select a, d, grouping(a,b,c) from gstest3