parse_agg.c 9.0 KB
Newer Older
1 2
/*-------------------------------------------------------------------------
 *
3
 * parse_agg.c
4 5 6 7 8 9
 *	  handle aggregates in parser
 *
 * Copyright (c) 1994, Regents of the University of California
 *
 *
 * IDENTIFICATION
10
 *	  $Header: /cvsroot/pgsql/src/backend/parser/parse_agg.c,v 1.20 1999/05/23 21:41:14 tgl Exp $
11 12 13 14 15 16 17 18 19 20
 *
 *-------------------------------------------------------------------------
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "postgres.h"
#include "access/heapam.h"
#include "catalog/pg_aggregate.h"
21
#include "catalog/pg_type.h"
22 23 24 25 26
#include "nodes/nodeFuncs.h"
#include "nodes/primnodes.h"
#include "nodes/relation.h"
#include "optimizer/clauses.h"
#include "parser/parse_agg.h"
27
#include "parser/parse_expr.h"
28 29
#include "parser/parse_node.h"
#include "parser/parse_target.h"
30
#include "parser/parse_coerce.h"
31
#include "utils/syscache.h"
32
#include "utils/lsyscache.h"
33

34
static bool contain_agg_clause(Node *clause);
35 36 37
static bool exprIsAggOrGroupCol(Node *expr, List *groupClause, List *tlist);
static bool tleIsAggOrGroupCol(TargetEntry *tle, List *groupClause, 
													List *tlist);
38

39
/*
40
 * contain_agg_clause
B
Bruce Momjian 已提交
41
 *	  Recursively find aggref nodes from a clause.
42 43 44
 *
 *	  Returns true if any aggregate found.
 */
45
static bool
46 47 48 49
contain_agg_clause(Node *clause)
{
	if (clause == NULL)
		return FALSE;
B
Bruce Momjian 已提交
50
	else if (IsA(clause, Aggref))
51 52 53 54 55
		return TRUE;
	else if (IsA(clause, Iter))
		return contain_agg_clause(((Iter *) clause)->iterexpr);
	else if (single_node(clause))
		return FALSE;
56
	else if (or_clause(clause) || and_clause(clause))
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
	{
		List	   *temp;

		foreach(temp, ((Expr *) clause)->args)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		return FALSE;
	}
	else if (is_funcclause(clause))
	{
		List	   *temp;

		foreach(temp, ((Expr *) clause)->args)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		return FALSE;
	}
	else if (IsA(clause, ArrayRef))
	{
		List	   *temp;

		foreach(temp, ((ArrayRef *) clause)->refupperindexpr)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		foreach(temp, ((ArrayRef *) clause)->reflowerindexpr)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		if (contain_agg_clause(((ArrayRef *) clause)->refexpr))
			return TRUE;
		if (contain_agg_clause(((ArrayRef *) clause)->refassgnexpr))
			return TRUE;
		return FALSE;
	}
	else if (not_clause(clause))
		return contain_agg_clause((Node *) get_notclausearg((Expr *) clause));
	else if (is_opclause(clause))
		return (contain_agg_clause((Node *) get_leftop((Expr *) clause)) ||
			  contain_agg_clause((Node *) get_rightop((Expr *) clause)));

	return FALSE;
}

/*
 * exprIsAggOrGroupCol -
 *	  returns true if the expression does not contain non-group columns.
 */
103
static bool
104
exprIsAggOrGroupCol(Node *expr, List *groupClause, List *tlist)
105 106 107 108
{
	List	   *gl;

	if (expr == NULL || IsA(expr, Const) ||
B
Bruce Momjian 已提交
109
		IsA(expr, Param) || IsA(expr, Aggref) || 
110
		IsA(expr, SubLink))		/* can't handle currently !!! */
111 112 113 114 115 116
		return TRUE;

	foreach(gl, groupClause)
	{
		GroupClause *grpcl = lfirst(gl);

117
		if (equal(expr, get_groupclause_expr(grpcl, tlist)))
118 119 120 121 122 123 124 125
			return TRUE;
	}

	if (IsA(expr, Expr))
	{
		List	   *temp;

		foreach(temp, ((Expr *) expr)->args)
126
			if (!exprIsAggOrGroupCol(lfirst(temp), groupClause, tlist))
127 128 129 130 131 132 133 134 135 136 137
			return FALSE;
		return TRUE;
	}

	return FALSE;
}

/*
 * tleIsAggOrGroupCol -
 *	  returns true if the TargetEntry is Agg or GroupCol.
 */
138
static bool
139
tleIsAggOrGroupCol(TargetEntry *tle, List *groupClause, List *tlist)
140 141 142 143 144 145 146 147 148 149 150
{
	Node	   *expr = tle->expr;
	List	   *gl;

	if (expr == NULL || IsA(expr, Const) ||IsA(expr, Param))
		return TRUE;

	foreach(gl, groupClause)
	{
		GroupClause *grpcl = lfirst(gl);

151
		if (tle->resdom->resgroupref == grpcl->tleGroupref)
152 153
		{
			if (contain_agg_clause((Node *) expr))
154
				elog(ERROR, "Aggregates not allowed in GROUP BY clause");
155 156 157 158
			return TRUE;
		}
	}

B
Bruce Momjian 已提交
159
	if (IsA(expr, Aggref))
160 161 162 163 164 165 166
		return TRUE;

	if (IsA(expr, Expr))
	{
		List	   *temp;

		foreach(temp, ((Expr *) expr)->args)
167
			if (!exprIsAggOrGroupCol(lfirst(temp), groupClause, tlist))
168 169 170 171 172 173 174 175
			return FALSE;
		return TRUE;
	}

	return FALSE;
}

/*
176 177 178 179 180 181 182
 * parseCheckAggregates
 *	Check for aggregates where they shouldn't be and improper grouping.
 *
 *	Ideally this should be done earlier, but it's difficult to distinguish
 *	aggregates from plain functions at the grammar level.  So instead we
 *	check here.  This function should be called after the target list and
 *	qualifications are finalized.
183 184 185 186 187 188
 */
void
parseCheckAggregates(ParseState *pstate, Query *qry)
{
	List	   *tl;

189 190
	/* This should only be called if we found aggregates or grouping */
	Assert(pstate->p_hasAggs || qry->groupClause);
191 192

	/*
193 194 195 196 197
	 * Aggregates must never appear in WHERE clauses.
	 * (Note this check should appear first to deliver an appropriate
	 * error message; otherwise we are likely to generate the generic
	 * "illegal use of aggregates in target list" message, which is
	 * outright misleading if the problem is in WHERE.)
198 199
	 */
	if (contain_agg_clause(qry->qual))
200
		elog(ERROR, "Aggregates not allowed in WHERE clause");
201 202

	/*
203
	 * The target list can only contain aggregates, group columns and
204 205 206 207 208 209
	 * functions thereof.
	 */
	foreach(tl, qry->targetList)
	{
		TargetEntry *tle = lfirst(tl);

210
		if (!tleIsAggOrGroupCol(tle, qry->groupClause, qry->targetList))
211
			elog(ERROR,
212
				 "Illegal use of aggregates or non-group column in target list");
213 214 215 216 217 218
	}

	/*
	 * the expression specified in the HAVING clause has the same
	 * restriction as those in the target list.
	 */
219

220
	if (!exprIsAggOrGroupCol(qry->havingQual, qry->groupClause, qry->targetList))
221
		elog(ERROR,
222
			 "Illegal use of aggregates or non-group column in HAVING clause");
223 224 225
}


B
Bruce Momjian 已提交
226
Aggref *
227
ParseAgg(ParseState *pstate, char *aggname, Oid basetype,
228
		 List *target, int precedence)
229 230 231 232 233
{
	Oid			fintype;
	Oid			vartype;
	Oid			xfn1;
	Form_pg_aggregate aggform;
B
Bruce Momjian 已提交
234
	Aggref	   *aggref;
235
	HeapTuple	theAggTuple;
236
	bool		usenulls = false;
237

238 239
	theAggTuple = SearchSysCacheTuple(AGGNAME,
									  PointerGetDatum(aggname),
240 241 242
									  ObjectIdGetDatum(basetype),
									  0, 0);
	if (!HeapTupleIsValid(theAggTuple))
243
		elog(ERROR, "Aggregate %s does not exist", aggname);
244

245
	/*
246
	 * We do a major hack for count(*) here.
247
	 *
248 249 250 251 252
	 * Count(*) poses several problems.  First, we need a field that is
	 * guaranteed to be in the range table, and unique.  Using a constant
	 * causes the optimizer to properly remove the aggragate from any
	 * elements of the query. Using just 'oid', which can not be null, in
	 * the parser fails on:
253
	 *
254 255
	 * select count(*) from tab1, tab2	   -- oid is not unique select
	 * count(*) from viewtable		-- views don't have real oids
256
	 *
257 258 259
	 * So, for an aggregate with parameter '*', we use the first valid range
	 * table entry, and pick the first column from the table. We set a
	 * flag to count nulls, because we could have nulls in that column.
B
Bruce Momjian 已提交
260
	 *
261
	 * It's an ugly job, but someone has to do it. bjm 1998/1/18
B
Bruce Momjian 已提交
262
	 */
263

264 265
	if (nodeTag(lfirst(target)) == T_Const)
	{
266 267
		Const	   *con = (Const *) lfirst(target);

268 269
		if (con->consttype == UNKNOWNOID && VARSIZE(con->constvalue) == VARHDRSZ)
		{
270 271 272
			Attr	   *attr = makeNode(Attr);
			List	   *rtable,
					   *rlist;
273 274 275 276 277 278 279 280
			RangeTblEntry *first_valid_rte;

			Assert(lnext(target) == NULL);

			if (pstate->p_is_rule)
				rtable = lnext(lnext(pstate->p_rtable));
			else
				rtable = pstate->p_rtable;
281

282 283 284 285
			first_valid_rte = NULL;
			foreach(rlist, rtable)
			{
				RangeTblEntry *rte = lfirst(rlist);
286

287 288 289 290
				/* only entries on outer(non-function?) scope */
				if (!rte->inFromCl && rte != pstate->p_target_rangetblentry)
					continue;

B
Bruce Momjian 已提交
291
				first_valid_rte = rte;
292 293 294
				break;
			}
			if (first_valid_rte == NULL)
295
				elog(ERROR, "Can't find column to do aggregate(*) on.");
296

297 298
			attr->relname = first_valid_rte->refname;
			attr->attrs = lcons(makeString(
299
						   get_attname(first_valid_rte->relid, 1)), NIL);
300 301 302 303 304

			lfirst(target) = transformExpr(pstate, (Node *) attr, precedence);
			usenulls = true;
		}
	}
305

306 307 308 309 310 311 312 313
	aggform = (Form_pg_aggregate) GETSTRUCT(theAggTuple);
	fintype = aggform->aggfinaltype;
	xfn1 = aggform->aggtransfn1;

	/* only aggregates with transfn1 need a base type */
	if (OidIsValid(xfn1))
	{
		basetype = aggform->aggbasetype;
314
		vartype = exprType(lfirst(target));
315 316
		if ((basetype != vartype)
			&& (! IS_BINARY_COMPATIBLE(basetype, vartype)))
317 318 319 320 321 322
		{
			Type		tp1,
						tp2;

			tp1 = typeidType(basetype);
			tp2 = typeidType(vartype);
323 324 325
			elog(ERROR, "Aggregate type mismatch"
						"\n\t%s() works on %s, not on %s",
						 aggname, typeTypeName(tp1), typeTypeName(tp2));
326 327 328
		}
	}

B
Bruce Momjian 已提交
329 330 331 332
	aggref = makeNode(Aggref);
	aggref->aggname = pstrdup(aggname);
	aggref->basetype = aggform->aggbasetype;
	aggref->aggtype = fintype;
333

B
Bruce Momjian 已提交
334
	aggref->target = lfirst(target);
335
	if (usenulls)
B
Bruce Momjian 已提交
336
		aggref->usenulls = true;
337

338 339
	pstate->p_hasAggs = true;

B
Bruce Momjian 已提交
340
	return aggref;
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
}

/*
 * Error message when aggregate lookup fails that gives details of the
 * basetype
 */
void
agg_error(char *caller, char *aggname, Oid basetypeID)
{

	/*
	 * basetypeID that is Invalid (zero) means aggregate over all types.
	 * (count)
	 */

	if (basetypeID == InvalidOid)
357
		elog(ERROR, "%s: aggregate '%s' for all types does not exist", caller, aggname);
358 359
	else
	{
360
		elog(ERROR, "%s: aggregate '%s' for '%s' does not exist", caller, aggname,
361 362 363
			 typeidTypeName(basetypeID));
	}
}