parse_agg.c 8.7 KB
Newer Older
1 2
/*-------------------------------------------------------------------------
 *
3
 * parse_agg.c
4 5 6 7 8 9
 *	  handle aggregates in parser
 *
 * Copyright (c) 1994, Regents of the University of California
 *
 *
 * IDENTIFICATION
10
 *	  $Header: /cvsroot/pgsql/src/backend/parser/parse_agg.c,v 1.19 1999/05/12 15:01:48 wieck Exp $
11 12 13 14 15 16 17 18 19 20
 *
 *-------------------------------------------------------------------------
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "postgres.h"
#include "access/heapam.h"
#include "catalog/pg_aggregate.h"
21
#include "catalog/pg_type.h"
22 23 24 25 26
#include "nodes/nodeFuncs.h"
#include "nodes/primnodes.h"
#include "nodes/relation.h"
#include "optimizer/clauses.h"
#include "parser/parse_agg.h"
27
#include "parser/parse_expr.h"
28 29
#include "parser/parse_node.h"
#include "parser/parse_target.h"
30
#include "parser/parse_coerce.h"
31
#include "utils/syscache.h"
32
#include "utils/lsyscache.h"
33

34
static bool contain_agg_clause(Node *clause);
35 36 37
static bool exprIsAggOrGroupCol(Node *expr, List *groupClause, List *tlist);
static bool tleIsAggOrGroupCol(TargetEntry *tle, List *groupClause, 
													List *tlist);
38

39
/*
40
 * contain_agg_clause
B
Bruce Momjian 已提交
41
 *	  Recursively find aggref nodes from a clause.
42 43 44
 *
 *	  Returns true if any aggregate found.
 */
45
static bool
46 47 48 49
contain_agg_clause(Node *clause)
{
	if (clause == NULL)
		return FALSE;
B
Bruce Momjian 已提交
50
	else if (IsA(clause, Aggref))
51 52 53 54 55
		return TRUE;
	else if (IsA(clause, Iter))
		return contain_agg_clause(((Iter *) clause)->iterexpr);
	else if (single_node(clause))
		return FALSE;
56
	else if (or_clause(clause) || and_clause(clause))
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
	{
		List	   *temp;

		foreach(temp, ((Expr *) clause)->args)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		return FALSE;
	}
	else if (is_funcclause(clause))
	{
		List	   *temp;

		foreach(temp, ((Expr *) clause)->args)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		return FALSE;
	}
	else if (IsA(clause, ArrayRef))
	{
		List	   *temp;

		foreach(temp, ((ArrayRef *) clause)->refupperindexpr)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		foreach(temp, ((ArrayRef *) clause)->reflowerindexpr)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		if (contain_agg_clause(((ArrayRef *) clause)->refexpr))
			return TRUE;
		if (contain_agg_clause(((ArrayRef *) clause)->refassgnexpr))
			return TRUE;
		return FALSE;
	}
	else if (not_clause(clause))
		return contain_agg_clause((Node *) get_notclausearg((Expr *) clause));
	else if (is_opclause(clause))
		return (contain_agg_clause((Node *) get_leftop((Expr *) clause)) ||
			  contain_agg_clause((Node *) get_rightop((Expr *) clause)));

	return FALSE;
}

/*
 * exprIsAggOrGroupCol -
 *	  returns true if the expression does not contain non-group columns.
 */
103
static bool
104
exprIsAggOrGroupCol(Node *expr, List *groupClause, List *tlist)
105 106 107 108
{
	List	   *gl;

	if (expr == NULL || IsA(expr, Const) ||
B
Bruce Momjian 已提交
109
		IsA(expr, Param) || IsA(expr, Aggref) || 
110
		IsA(expr, SubLink))		/* can't handle currently !!! */
111 112 113 114 115 116
		return TRUE;

	foreach(gl, groupClause)
	{
		GroupClause *grpcl = lfirst(gl);

117
		if (equal(expr, get_groupclause_expr(grpcl, tlist)))
118 119 120 121 122 123 124 125
			return TRUE;
	}

	if (IsA(expr, Expr))
	{
		List	   *temp;

		foreach(temp, ((Expr *) expr)->args)
126
			if (!exprIsAggOrGroupCol(lfirst(temp), groupClause, tlist))
127 128 129 130 131 132 133 134 135 136 137
			return FALSE;
		return TRUE;
	}

	return FALSE;
}

/*
 * tleIsAggOrGroupCol -
 *	  returns true if the TargetEntry is Agg or GroupCol.
 */
138
static bool
139
tleIsAggOrGroupCol(TargetEntry *tle, List *groupClause, List *tlist)
140 141 142 143 144 145 146 147 148 149 150
{
	Node	   *expr = tle->expr;
	List	   *gl;

	if (expr == NULL || IsA(expr, Const) ||IsA(expr, Param))
		return TRUE;

	foreach(gl, groupClause)
	{
		GroupClause *grpcl = lfirst(gl);

151
		if (tle->resdom->resgroupref == grpcl->tleGroupref)
152 153
		{
			if (contain_agg_clause((Node *) expr))
154
				elog(ERROR, "Aggregates not allowed in GROUP BY clause");
155 156 157 158
			return TRUE;
		}
	}

B
Bruce Momjian 已提交
159
	if (IsA(expr, Aggref))
160 161 162 163 164 165 166
		return TRUE;

	if (IsA(expr, Expr))
	{
		List	   *temp;

		foreach(temp, ((Expr *) expr)->args)
167
			if (!exprIsAggOrGroupCol(lfirst(temp), groupClause, tlist))
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
			return FALSE;
		return TRUE;
	}

	return FALSE;
}

/*
 * parseCheckAggregates -
 *	  this should really be done earlier but the current grammar
 *	  cannot differentiate functions from aggregates. So we have do check
 *	  here when the target list and the qualifications are finalized.
 */
void
parseCheckAggregates(ParseState *pstate, Query *qry)
{
	List	   *tl;

186
	Assert(pstate->p_hasAggs);
187 188 189 190 191 192 193

	/*
	 * aggregates never appear in WHERE clauses. (we have to check where
	 * clause first because if there is an aggregate, the check for
	 * non-group column in target list may fail.)
	 */
	if (contain_agg_clause(qry->qual))
194
		elog(ERROR, "Aggregates not allowed in WHERE clause");
195 196 197 198 199 200 201 202 203

	/*
	 * the target list can only contain aggregates, group columns and
	 * functions thereof.
	 */
	foreach(tl, qry->targetList)
	{
		TargetEntry *tle = lfirst(tl);

204
		if (!tleIsAggOrGroupCol(tle, qry->groupClause, qry->targetList))
205
			elog(ERROR,
206
				 "Illegal use of aggregates or non-group column in target list");
207 208 209 210 211 212
	}

	/*
	 * the expression specified in the HAVING clause has the same
	 * restriction as those in the target list.
	 */
213

214
	if (!exprIsAggOrGroupCol(qry->havingQual, qry->groupClause, qry->targetList))
215
		elog(ERROR,
216
			 "Illegal use of aggregates or non-group column in HAVING clause");
217 218 219 220
	return;
}


B
Bruce Momjian 已提交
221
Aggref *
222
ParseAgg(ParseState *pstate, char *aggname, Oid basetype,
223
		 List *target, int precedence)
224 225 226 227 228
{
	Oid			fintype;
	Oid			vartype;
	Oid			xfn1;
	Form_pg_aggregate aggform;
B
Bruce Momjian 已提交
229
	Aggref	   *aggref;
230
	HeapTuple	theAggTuple;
231
	bool		usenulls = false;
232

233 234
	theAggTuple = SearchSysCacheTuple(AGGNAME,
									  PointerGetDatum(aggname),
235 236 237
									  ObjectIdGetDatum(basetype),
									  0, 0);
	if (!HeapTupleIsValid(theAggTuple))
238
		elog(ERROR, "Aggregate %s does not exist", aggname);
239

240
	/*
241
	 * We do a major hack for count(*) here.
242
	 *
243 244 245 246 247
	 * Count(*) poses several problems.  First, we need a field that is
	 * guaranteed to be in the range table, and unique.  Using a constant
	 * causes the optimizer to properly remove the aggragate from any
	 * elements of the query. Using just 'oid', which can not be null, in
	 * the parser fails on:
248
	 *
249 250
	 * select count(*) from tab1, tab2	   -- oid is not unique select
	 * count(*) from viewtable		-- views don't have real oids
251
	 *
252 253 254
	 * So, for an aggregate with parameter '*', we use the first valid range
	 * table entry, and pick the first column from the table. We set a
	 * flag to count nulls, because we could have nulls in that column.
B
Bruce Momjian 已提交
255
	 *
256
	 * It's an ugly job, but someone has to do it. bjm 1998/1/18
B
Bruce Momjian 已提交
257
	 */
258

259 260
	if (nodeTag(lfirst(target)) == T_Const)
	{
261 262
		Const	   *con = (Const *) lfirst(target);

263 264
		if (con->consttype == UNKNOWNOID && VARSIZE(con->constvalue) == VARHDRSZ)
		{
265 266 267
			Attr	   *attr = makeNode(Attr);
			List	   *rtable,
					   *rlist;
268 269 270 271 272 273 274 275
			RangeTblEntry *first_valid_rte;

			Assert(lnext(target) == NULL);

			if (pstate->p_is_rule)
				rtable = lnext(lnext(pstate->p_rtable));
			else
				rtable = pstate->p_rtable;
276

277 278 279 280
			first_valid_rte = NULL;
			foreach(rlist, rtable)
			{
				RangeTblEntry *rte = lfirst(rlist);
281

282 283 284 285
				/* only entries on outer(non-function?) scope */
				if (!rte->inFromCl && rte != pstate->p_target_rangetblentry)
					continue;

B
Bruce Momjian 已提交
286
				first_valid_rte = rte;
287 288 289
				break;
			}
			if (first_valid_rte == NULL)
290
				elog(ERROR, "Can't find column to do aggregate(*) on.");
291

292 293
			attr->relname = first_valid_rte->refname;
			attr->attrs = lcons(makeString(
294
						   get_attname(first_valid_rte->relid, 1)), NIL);
295 296 297 298 299

			lfirst(target) = transformExpr(pstate, (Node *) attr, precedence);
			usenulls = true;
		}
	}
300

301 302 303 304 305 306 307 308
	aggform = (Form_pg_aggregate) GETSTRUCT(theAggTuple);
	fintype = aggform->aggfinaltype;
	xfn1 = aggform->aggtransfn1;

	/* only aggregates with transfn1 need a base type */
	if (OidIsValid(xfn1))
	{
		basetype = aggform->aggbasetype;
309
		vartype = exprType(lfirst(target));
310 311
		if ((basetype != vartype)
			&& (! IS_BINARY_COMPATIBLE(basetype, vartype)))
312 313 314 315 316 317
		{
			Type		tp1,
						tp2;

			tp1 = typeidType(basetype);
			tp2 = typeidType(vartype);
318 319 320
			elog(ERROR, "Aggregate type mismatch"
						"\n\t%s() works on %s, not on %s",
						 aggname, typeTypeName(tp1), typeTypeName(tp2));
321 322 323
		}
	}

B
Bruce Momjian 已提交
324 325 326 327
	aggref = makeNode(Aggref);
	aggref->aggname = pstrdup(aggname);
	aggref->basetype = aggform->aggbasetype;
	aggref->aggtype = fintype;
328

B
Bruce Momjian 已提交
329
	aggref->target = lfirst(target);
330
	if (usenulls)
B
Bruce Momjian 已提交
331
		aggref->usenulls = true;
332

333 334
	pstate->p_hasAggs = true;

B
Bruce Momjian 已提交
335
	return aggref;
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
}

/*
 * Error message when aggregate lookup fails that gives details of the
 * basetype
 */
void
agg_error(char *caller, char *aggname, Oid basetypeID)
{

	/*
	 * basetypeID that is Invalid (zero) means aggregate over all types.
	 * (count)
	 */

	if (basetypeID == InvalidOid)
352
		elog(ERROR, "%s: aggregate '%s' for all types does not exist", caller, aggname);
353 354
	else
	{
355
		elog(ERROR, "%s: aggregate '%s' for '%s' does not exist", caller, aggname,
356 357 358
			 typeidTypeName(basetypeID));
	}
}