parse_agg.c 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*-------------------------------------------------------------------------
 *
 * parse_agg.c--
 *	  handle aggregates in parser
 *
 * Copyright (c) 1994, Regents of the University of California
 *
 *
 * IDENTIFICATION
B
Bruce Momjian 已提交
10
 *	  $Header: /cvsroot/pgsql/src/backend/parser/parse_agg.c,v 1.16 1999/01/24 00:28:29 momjian Exp $
11 12 13 14 15 16 17 18 19 20
 *
 *-------------------------------------------------------------------------
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "postgres.h"
#include "access/heapam.h"
#include "catalog/pg_aggregate.h"
21
#include "catalog/pg_type.h"
22 23 24 25 26
#include "nodes/nodeFuncs.h"
#include "nodes/primnodes.h"
#include "nodes/relation.h"
#include "optimizer/clauses.h"
#include "parser/parse_agg.h"
27
#include "parser/parse_expr.h"
28 29
#include "parser/parse_node.h"
#include "parser/parse_target.h"
30
#include "parser/parse_coerce.h"
31
#include "utils/syscache.h"
32
#include "utils/lsyscache.h"
33

34 35 36 37
static bool contain_agg_clause(Node *clause);
static bool exprIsAggOrGroupCol(Node *expr, List *groupClause);
static bool tleIsAggOrGroupCol(TargetEntry *tle, List *groupClause);

38 39
/*
 * contain_agg_clause--
B
Bruce Momjian 已提交
40
 *	  Recursively find aggref nodes from a clause.
41 42 43
 *
 *	  Returns true if any aggregate found.
 */
44
static bool
45 46 47 48
contain_agg_clause(Node *clause)
{
	if (clause == NULL)
		return FALSE;
B
Bruce Momjian 已提交
49
	else if (IsA(clause, Aggref))
50 51 52 53 54
		return TRUE;
	else if (IsA(clause, Iter))
		return contain_agg_clause(((Iter *) clause)->iterexpr);
	else if (single_node(clause))
		return FALSE;
55
	else if (or_clause(clause) || and_clause(clause))
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
	{
		List	   *temp;

		foreach(temp, ((Expr *) clause)->args)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		return FALSE;
	}
	else if (is_funcclause(clause))
	{
		List	   *temp;

		foreach(temp, ((Expr *) clause)->args)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		return FALSE;
	}
	else if (IsA(clause, ArrayRef))
	{
		List	   *temp;

		foreach(temp, ((ArrayRef *) clause)->refupperindexpr)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		foreach(temp, ((ArrayRef *) clause)->reflowerindexpr)
			if (contain_agg_clause(lfirst(temp)))
			return TRUE;
		if (contain_agg_clause(((ArrayRef *) clause)->refexpr))
			return TRUE;
		if (contain_agg_clause(((ArrayRef *) clause)->refassgnexpr))
			return TRUE;
		return FALSE;
	}
	else if (not_clause(clause))
		return contain_agg_clause((Node *) get_notclausearg((Expr *) clause));
	else if (is_opclause(clause))
		return (contain_agg_clause((Node *) get_leftop((Expr *) clause)) ||
			  contain_agg_clause((Node *) get_rightop((Expr *) clause)));

	return FALSE;
}

/*
 * exprIsAggOrGroupCol -
 *	  returns true if the expression does not contain non-group columns.
 */
102
static bool
103 104 105 106 107
exprIsAggOrGroupCol(Node *expr, List *groupClause)
{
	List	   *gl;

	if (expr == NULL || IsA(expr, Const) ||
B
Bruce Momjian 已提交
108
		IsA(expr, Param) || IsA(expr, Aggref) || 
109
		IsA(expr, SubLink))		/* can't handle currently !!! */
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
		return TRUE;

	foreach(gl, groupClause)
	{
		GroupClause *grpcl = lfirst(gl);

		if (equal(expr, grpcl->entry->expr))
			return TRUE;
	}

	if (IsA(expr, Expr))
	{
		List	   *temp;

		foreach(temp, ((Expr *) expr)->args)
			if (!exprIsAggOrGroupCol(lfirst(temp), groupClause))
			return FALSE;
		return TRUE;
	}

	return FALSE;
}

/*
 * tleIsAggOrGroupCol -
 *	  returns true if the TargetEntry is Agg or GroupCol.
 */
137
static bool
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
tleIsAggOrGroupCol(TargetEntry *tle, List *groupClause)
{
	Node	   *expr = tle->expr;
	List	   *gl;

	if (expr == NULL || IsA(expr, Const) ||IsA(expr, Param))
		return TRUE;

	foreach(gl, groupClause)
	{
		GroupClause *grpcl = lfirst(gl);

		if (tle->resdom->resno == grpcl->entry->resdom->resno)
		{
			if (contain_agg_clause((Node *) expr))
153
				elog(ERROR, "Aggregates not allowed in GROUP BY clause");
154 155 156 157
			return TRUE;
		}
	}

B
Bruce Momjian 已提交
158
	if (IsA(expr, Aggref))
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
		return TRUE;

	if (IsA(expr, Expr))
	{
		List	   *temp;

		foreach(temp, ((Expr *) expr)->args)
			if (!exprIsAggOrGroupCol(lfirst(temp), groupClause))
			return FALSE;
		return TRUE;
	}

	return FALSE;
}

/*
 * parseCheckAggregates -
 *	  this should really be done earlier but the current grammar
 *	  cannot differentiate functions from aggregates. So we have do check
 *	  here when the target list and the qualifications are finalized.
 */
void
parseCheckAggregates(ParseState *pstate, Query *qry)
{
	List	   *tl;

185
	Assert(pstate->p_hasAggs);
186 187 188 189 190 191 192

	/*
	 * aggregates never appear in WHERE clauses. (we have to check where
	 * clause first because if there is an aggregate, the check for
	 * non-group column in target list may fail.)
	 */
	if (contain_agg_clause(qry->qual))
193
		elog(ERROR, "Aggregates not allowed in WHERE clause");
194 195 196 197 198 199 200 201 202 203

	/*
	 * the target list can only contain aggregates, group columns and
	 * functions thereof.
	 */
	foreach(tl, qry->targetList)
	{
		TargetEntry *tle = lfirst(tl);

		if (!tleIsAggOrGroupCol(tle, qry->groupClause))
204
			elog(ERROR,
205
				 "Illegal use of aggregates or non-group column in target list");
206 207 208 209 210 211
	}

	/*
	 * the expression specified in the HAVING clause has the same
	 * restriction as those in the target list.
	 */
212

213
	if (!exprIsAggOrGroupCol(qry->havingQual, qry->groupClause))
214
		elog(ERROR,
215
			 "Illegal use of aggregates or non-group column in HAVING clause");
216 217 218 219
	return;
}


B
Bruce Momjian 已提交
220
Aggref *
221
ParseAgg(ParseState *pstate, char *aggname, Oid basetype,
222
		 List *target, int precedence)
223 224 225 226 227
{
	Oid			fintype;
	Oid			vartype;
	Oid			xfn1;
	Form_pg_aggregate aggform;
B
Bruce Momjian 已提交
228
	Aggref	   *aggref;
229
	HeapTuple	theAggTuple;
230
	bool		usenulls = false;
231

232 233
	theAggTuple = SearchSysCacheTuple(AGGNAME,
									  PointerGetDatum(aggname),
234 235 236
									  ObjectIdGetDatum(basetype),
									  0, 0);
	if (!HeapTupleIsValid(theAggTuple))
237
		elog(ERROR, "Aggregate %s does not exist", aggname);
238

239
	/*
240
	 * We do a major hack for count(*) here.
241
	 *
242 243 244 245 246
	 * Count(*) poses several problems.  First, we need a field that is
	 * guaranteed to be in the range table, and unique.  Using a constant
	 * causes the optimizer to properly remove the aggragate from any
	 * elements of the query. Using just 'oid', which can not be null, in
	 * the parser fails on:
247
	 *
248 249
	 * select count(*) from tab1, tab2	   -- oid is not unique select
	 * count(*) from viewtable		-- views don't have real oids
250
	 *
251 252 253
	 * So, for an aggregate with parameter '*', we use the first valid range
	 * table entry, and pick the first column from the table. We set a
	 * flag to count nulls, because we could have nulls in that column.
B
Bruce Momjian 已提交
254
	 *
255
	 * It's an ugly job, but someone has to do it. bjm 1998/1/18
B
Bruce Momjian 已提交
256
	 */
257

258 259
	if (nodeTag(lfirst(target)) == T_Const)
	{
260 261
		Const	   *con = (Const *) lfirst(target);

262 263
		if (con->consttype == UNKNOWNOID && VARSIZE(con->constvalue) == VARHDRSZ)
		{
264 265 266
			Attr	   *attr = makeNode(Attr);
			List	   *rtable,
					   *rlist;
267 268 269 270 271 272 273 274
			RangeTblEntry *first_valid_rte;

			Assert(lnext(target) == NULL);

			if (pstate->p_is_rule)
				rtable = lnext(lnext(pstate->p_rtable));
			else
				rtable = pstate->p_rtable;
275

276 277 278 279
			first_valid_rte = NULL;
			foreach(rlist, rtable)
			{
				RangeTblEntry *rte = lfirst(rlist);
280

281 282 283 284
				/* only entries on outer(non-function?) scope */
				if (!rte->inFromCl && rte != pstate->p_target_rangetblentry)
					continue;

B
Bruce Momjian 已提交
285
				first_valid_rte = rte;
286 287 288
				break;
			}
			if (first_valid_rte == NULL)
289
				elog(ERROR, "Can't find column to do aggregate(*) on.");
290

291 292
			attr->relname = first_valid_rte->refname;
			attr->attrs = lcons(makeString(
293
						   get_attname(first_valid_rte->relid, 1)), NIL);
294 295 296 297 298

			lfirst(target) = transformExpr(pstate, (Node *) attr, precedence);
			usenulls = true;
		}
	}
299

300 301 302 303 304 305 306 307
	aggform = (Form_pg_aggregate) GETSTRUCT(theAggTuple);
	fintype = aggform->aggfinaltype;
	xfn1 = aggform->aggtransfn1;

	/* only aggregates with transfn1 need a base type */
	if (OidIsValid(xfn1))
	{
		basetype = aggform->aggbasetype;
308 309
		if (nodeTag(lfirst(target)) == T_Var)
			vartype = ((Var *) lfirst(target))->vartype;
310
		else
311
			vartype = ((Expr *) lfirst(target))->typeOid;
312

313 314
		if ((basetype != vartype)
			&& (! IS_BINARY_COMPATIBLE(basetype, vartype)))
315 316 317 318 319 320
		{
			Type		tp1,
						tp2;

			tp1 = typeidType(basetype);
			tp2 = typeidType(vartype);
321 322 323
			elog(ERROR, "Aggregate type mismatch"
						"\n\t%s() works on %s, not on %s",
						 aggname, typeTypeName(tp1), typeTypeName(tp2));
324 325 326
		}
	}

B
Bruce Momjian 已提交
327 328 329 330
	aggref = makeNode(Aggref);
	aggref->aggname = pstrdup(aggname);
	aggref->basetype = aggform->aggbasetype;
	aggref->aggtype = fintype;
331

B
Bruce Momjian 已提交
332
	aggref->target = lfirst(target);
333
	if (usenulls)
B
Bruce Momjian 已提交
334
		aggref->usenulls = true;
335

336 337
	pstate->p_hasAggs = true;

B
Bruce Momjian 已提交
338
	return aggref;
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
}

/*
 * Error message when aggregate lookup fails that gives details of the
 * basetype
 */
void
agg_error(char *caller, char *aggname, Oid basetypeID)
{

	/*
	 * basetypeID that is Invalid (zero) means aggregate over all types.
	 * (count)
	 */

	if (basetypeID == InvalidOid)
355
		elog(ERROR, "%s: aggregate '%s' for all types does not exist", caller, aggname);
356 357
	else
	{
358
		elog(ERROR, "%s: aggregate '%s' for '%s' does not exist", caller, aggname,
359 360 361
			 typeidTypeName(basetypeID));
	}
}