提交 a6432a4d 编写于 作者: A Adam Lee

Serialize the aggstates while spilling hash table

AggStates are now pointers allocated in aggcontext with type INTERNAL,
just spilling the pointers don't decrease the memory usage and have
possible memory leak if combining states without free.

This commit serialize the aggstates, write the real data into file and
free the memory.

(cherry picked from commit 6dadce04)
上级 53d428c6
......@@ -190,14 +190,20 @@ calc_hash_value(AggState* aggstate, TupleTableSlot *inputslot)
*/
static inline void
adjustInputGroup(AggState *aggstate,
void *input_group)
void *input_group,
bool temporary)
{
int32 tuple_size;
void *datum;
AggStatePerGroup pergroup;
AggStatePerAgg peragg = aggstate->peragg;
int aggno;
Size datum_size;
int16 byteaTranstypeLen = 0;
bool byteaTranstypeByVal = 0;
/* INTERNAL aggtype is always set to BYTEA in cdbgrouping and upstream partial-aggregation */
get_typlenbyval(BYTEAOID, &byteaTranstypeLen, &byteaTranstypeByVal);
tuple_size = memtuple_get_size((MemTuple)input_group);
pergroup = (AggStatePerGroup) ((char *)input_group +
MAXALIGN(tuple_size));
......@@ -209,11 +215,48 @@ adjustInputGroup(AggState *aggstate,
{
AggStatePerAgg peraggstate = &peragg[aggno];
AggStatePerGroup pergroupstate = &pergroup[aggno];
if (!peraggstate->transtypeByVal &&
!pergroupstate->transValueIsNull)
/* Skip null transValue */
if (pergroupstate->transValueIsNull)
continue;
/* Deserialize the aggregate states loaded from the spill file */
if (OidIsValid(peraggstate->deserialfn_oid))
{
FunctionCallInfoData _dsinfo;
FunctionCallInfo dsinfo = &_dsinfo;
MemoryContext oldContext;
InitFunctionCallInfoData(_dsinfo,
&peraggstate->deserialfn,
2,
InvalidOid,
(void *) aggstate, NULL);
dsinfo->arg[0] = PointerGetDatum(datum);
dsinfo->argnull[0] = pergroupstate->transValueIsNull;
/* Dummy second argument for type-safety reasons */
dsinfo->arg[1] = PointerGetDatum(NULL);
dsinfo->argnull[1] = false;
/*
* We run the deserialization functions in per-input-tuple
* memory context if it's safe to be dropped after.
*/
if (temporary)
oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
pergroupstate->transValue = FunctionCallInvoke(dsinfo);
if (temporary)
MemoryContextSwitchTo(oldContext);
datum_size = datumGetSize(PointerGetDatum(datum), byteaTranstypeByVal, byteaTranstypeLen);
Assert(MAXALIGN(datum_size) - datum_size <= MAXIMUM_ALIGNOF);
datum = (char *)datum + MAXALIGN(datum_size);
}
else if (!peraggstate->transtypeByVal)
{
Size datum_size;
pergroupstate->transValue = PointerGetDatum(datum);
datum_size = datumGetSize(pergroupstate->transValue,
peraggstate->transtypeByVal,
......@@ -271,7 +314,7 @@ makeHashAggEntryForInput(AggState *aggstate, TupleTableSlot *inputslot, uint32 h
tup_len = 0;
aggs_len = aggstate->numaggs * sizeof(AggStatePerGroupData);
oldcxt = MemoryContextSwitchTo(hashtable->entry_cxt);
entry = getEmptyHashAggEntry(aggstate);
......@@ -349,7 +392,7 @@ makeHashAggEntryForGroup(AggState *aggstate, void *tuple_and_aggs,
memcpy(copy_tuple_and_aggs, tuple_and_aggs, input_size);
oldcxt = MemoryContextSwitchTo(hashtable->entry_cxt);
entry = getEmptyHashAggEntry(aggstate);
entry->hashvalue = hashvalue;
entry->is_primodial = !(hashtable->is_spilling);
......@@ -357,7 +400,7 @@ makeHashAggEntryForGroup(AggState *aggstate, void *tuple_and_aggs,
entry->next = NULL;
/* Initialize per group data */
adjustInputGroup(aggstate, entry->tuple_and_aggs);
adjustInputGroup(aggstate, entry->tuple_and_aggs, false);
MemoryContextSwitchTo(oldcxt);
......@@ -1484,6 +1527,13 @@ writeHashEntry(AggState *aggstate, BatchFileInfo *file_info,
AggStatePerGroup pergroup;
int aggno;
AggStatePerAgg peragg = aggstate->peragg;
int32 aggstateSize = 0;
Size datum_size;
Datum serializedVal;
int16 byteaTranstypeLen = 0;
bool byteaTranstypeByVal = 0;
/* INTERNAL aggtype is always set to BYTEA in cdbgrouping and upstream partial-aggregation */
get_typlenbyval(BYTEAOID, &byteaTranstypeLen, &byteaTranstypeByVal);
Assert(file_info != NULL);
Assert(file_info->wfile != NULL);
......@@ -1496,21 +1546,6 @@ writeHashEntry(AggState *aggstate, BatchFileInfo *file_info,
aggstate->numaggs * sizeof(AggStatePerGroupData);
total_size = MAXALIGN(tuple_agg_size);
for (aggno = 0; aggno < aggstate->numaggs; aggno++)
{
AggStatePerAgg peraggstate = &peragg[aggno];
AggStatePerGroup pergroupstate = &pergroup[aggno];
if (!peraggstate->transtypeByVal &&
!pergroupstate->transValueIsNull)
{
Size datum_size = datumGetSize(pergroupstate->transValue,
peraggstate->transtypeByVal,
peraggstate->transtypeLen);
total_size += MAXALIGN(datum_size);
}
}
BufFileWriteOrError(file_info->wfile, (char *) &total_size, sizeof(total_size));
BufFileWriteOrError(file_info->wfile, entry->tuple_and_aggs, tuple_agg_size);
Assert(MAXALIGN(tuple_agg_size) - tuple_agg_size <= MAXIMUM_ALIGNOF);
......@@ -1519,26 +1554,82 @@ writeHashEntry(AggState *aggstate, BatchFileInfo *file_info,
BufFileWriteOrError(file_info->wfile, padding_dummy, MAXALIGN(tuple_agg_size) - tuple_agg_size);
}
/* Write the transition aggstates */
for (aggno = 0; aggno < aggstate->numaggs; aggno++)
{
AggStatePerAgg peraggstate = &peragg[aggno];
AggStatePerGroup pergroupstate = &pergroup[aggno];
if (!peraggstate->transtypeByVal &&
!pergroupstate->transValueIsNull)
/* Skip null transValue */
if (pergroupstate->transValueIsNull)
continue;
/*
* If it has a serialization function, serialize it without checking
* transtypeByVal since it's INTERNALOID, a pointer but set to byVal.
*/
if (OidIsValid(peraggstate->serialfn_oid))
{
Size datum_size = datumGetSize(pergroupstate->transValue,
peraggstate->transtypeByVal,
peraggstate->transtypeLen);
FunctionCallInfoData fcinfo;
InitFunctionCallInfoData(fcinfo,
&peraggstate->serialfn,
1,
InvalidOid,
(void *) aggstate, NULL);
fcinfo.arg[0] = pergroupstate->transValue;
fcinfo.argnull[0] = pergroupstate->transValueIsNull;
serializedVal = FunctionCallInvoke(&fcinfo);
datum_size = datumGetSize(serializedVal, byteaTranstypeByVal, byteaTranstypeLen);
pfree(DatumGetPointer(pergroupstate->transValue));
BufFileWriteOrError(file_info->wfile,
DatumGetPointer(serializedVal), datum_size);
pfree(DatumGetPointer(serializedVal));
}
/* If it's a ByRef, write the data to the file */
else if (!peraggstate->transtypeByVal)
{
datum_size = datumGetSize(pergroupstate->transValue,
peraggstate->transtypeByVal,
peraggstate->transtypeLen);
BufFileWriteOrError(file_info->wfile,
DatumGetPointer(pergroupstate->transValue), datum_size);
Assert(MAXALIGN(datum_size) - datum_size <= MAXIMUM_ALIGNOF);
if (MAXALIGN(datum_size) - datum_size > 0)
{
BufFileWriteOrError(file_info->wfile,
padding_dummy, MAXALIGN(datum_size) - datum_size);
}
}
/* Otherwise it's a real ByVal, do nothing */
else
{
continue;
}
Assert(MAXALIGN(datum_size) - datum_size <= MAXIMUM_ALIGNOF);
if (MAXALIGN(datum_size) - datum_size > 0)
{
BufFileWriteOrError(file_info->wfile,
padding_dummy, MAXALIGN(datum_size) - datum_size);
}
aggstateSize += MAXALIGN(datum_size);
}
if (aggstateSize)
{
total_size += aggstateSize;
/* Rewind to write the correct total_size */
if (BufFileSeek(file_info->wfile, 0, -(aggstateSize + MAXALIGN(tuple_agg_size) + sizeof(total_size)), SEEK_CUR) != 0)
ereport(ERROR,
(errcode_for_file_access(),
errmsg("could not seek in hash agg temporary file: %m")));
BufFileWriteOrError(file_info->wfile, (char *) &total_size, sizeof(total_size));
/* Go back to the last offset */
if (BufFileSeek(file_info->wfile, 0, aggstateSize + MAXALIGN(tuple_agg_size), SEEK_CUR) != 0)
ereport(ERROR,
(errcode_for_file_access(),
errmsg("could not seek in hash agg temporary file: %m")));
}
return (total_size + sizeof(total_size) + sizeof(entry->hashvalue));
......@@ -1808,6 +1899,7 @@ agg_hash_reload(AggState *aggstate)
hashkey, &isNew);
}
/* Combine it to the group state if it's not a new entry */
if (!isNew)
{
int aggno;
......@@ -1816,7 +1908,12 @@ agg_hash_reload(AggState *aggstate)
setGroupAggs(hashtable, entry);
adjustInputGroup(aggstate, input);
/*
* Adjust the input in the per tuple memory context, since the
* value will be combined to the group state, we don't need the
* keep the memory storing the transValue.
*/
adjustInputGroup(aggstate, input, true);
/* Advance the aggregates for the group by applying combine function. */
for (aggno = 0; aggno < aggstate->numaggs; aggno++)
......@@ -1829,6 +1926,7 @@ agg_hash_reload(AggState *aggstate)
fcinfo.arg[1] = input_pergroupstate[aggno].transValue;
fcinfo.argnull[1] = input_pergroupstate[aggno].transValueIsNull;
/* Combine to the transition aggstate */
pergroupstate->transValue =
invoke_agg_trans_func(aggstate,
peraggstate,
......
......@@ -560,8 +560,6 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
/* DISTINCT and/or ORDER BY case */
Assert(slot->PRIVATE_tts_nvalid == peraggstate->numInputs);
Assert(peraggstate->deserialfn_oid == InvalidOid);
/*
* If the transfn is strict, we want to check for nullity before
* storing the row in the sorter, to save space if there are a lot
......@@ -613,7 +611,9 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
* deserialfn_oid will be set if we must deserialize the input state
* before calling the combine function
*/
if (OidIsValid(peraggstate->deserialfn_oid))
if (OidIsValid(peraggstate->deserialfn_oid) &&
(peraggstate->aggref->aggstage == AGGSTAGE_INTERMEDIATE ||
peraggstate->aggref->aggstage == AGGSTAGE_FINAL))
{
Datum serialized = fcinfo->arg[1];
bool serializednull = fcinfo->argnull[1];
......@@ -887,8 +887,6 @@ finalize_aggregate(AggState *aggstate,
{
int numFinalArgs = peraggstate->numFinalArgs;
Assert(peraggstate->serialfn_oid == InvalidOid);
/* set up aggstate->curperagg for AggGetAggref() */
aggstate->curperagg = peraggstate;
......@@ -927,7 +925,9 @@ finalize_aggregate(AggState *aggstate,
* serialfn_oid will be set if we must serialize the transvalue before
* returning it
*/
else if (OidIsValid(peraggstate->serialfn_oid))
else if (OidIsValid(peraggstate->serialfn_oid) &&
(peraggstate->aggref->aggstage == AGGSTAGE_INTERMEDIATE ||
peraggstate->aggref->aggstage == AGGSTAGE_PARTIAL))
{
/* Don't call a strict serialization function with NULL input. */
if (peraggstate->serialfn.fn_strict && pergroupstate->transValueIsNull)
......@@ -2249,17 +2249,21 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
{
if (!OidIsValid(aggform->aggserialfn))
elog(ERROR, "serialfunc not provided for serialization aggregation");
peraggstate->serialfn_oid = aggform->aggserialfn;
}
if (OidIsValid(aggform->aggserialfn))
peraggstate->serialfn_oid = aggform->aggserialfn;
/* Likewise for deserialization functions */
if (aggref->aggstage == AGGSTAGE_INTERMEDIATE ||
aggref->aggstage == AGGSTAGE_FINAL)
{
if (!OidIsValid(aggform->aggdeserialfn))
elog(ERROR, "deserialfunc not provided for deserialization aggregation");
peraggstate->deserialfn_oid = aggform->aggdeserialfn;
}
if (OidIsValid(aggform->aggdeserialfn))
peraggstate->deserialfn_oid = aggform->aggdeserialfn;
}
if (OidIsValid(peraggstate->serialfn_oid))
......
......@@ -89,6 +89,13 @@ create table aggspill (i int, j int, t text) distributed by (i);
insert into aggspill select i, i*2, i::text from generate_series(1, 10000) i;
insert into aggspill select i, i*2, i::text from generate_series(1, 100000) i;
insert into aggspill select i, i*2, i::text from generate_series(1, 1000000) i;
-- Test the spilling with serial/deserial functions involved
-- The transition type of numeric is internal, and hence it uses the serial/deserial functions when spilling
drop table if exists aggspill_numeric_avg;
create table aggspill_numeric_avg (a int, b int, c numeric) distributed by (a);
insert into aggspill_numeric_avg (select i, i + 1, i * 1.1111 from generate_series(1, 500000) as i);
insert into aggspill_numeric_avg (select i, i + 1, i * 1.1111 from generate_series(1, 500000) as i);
analyze aggspill_numeric_avg;
-- No spill with large statement memory
set statement_mem = '125MB';
select count(*) from (select i, count(*) from aggspill group by i,j having count(*) = 1) g;
......@@ -112,6 +119,12 @@ select count(*) from (select i, count(*) from aggspill group by i,j having count
90000
(1 row)
select count(*) from (select a, avg(b), avg(c) from aggspill_numeric_avg group by a) g;
count
--------
500000
(1 row)
-- Reduce the statement memory, nbatches and entrysize even further to cause multiple overflows
set gp_hashagg_default_nbatches = 4;
set statement_mem = '5MB';
......@@ -128,6 +141,12 @@ select count(*) from (select i, count(*) from aggspill group by i,j,t having cou
10000
(1 row)
select count(*) from (select a, avg(b), avg(c) from aggspill_numeric_avg group by a) g;
count
--------
500000
(1 row)
drop schema hashagg_spill cascade;
NOTICE: drop cascades to 3 other objects
DETAIL: drop cascades to function hashagg_spill.is_workfile_created(text)
......
......@@ -74,6 +74,14 @@ insert into aggspill select i, i*2, i::text from generate_series(1, 10000) i;
insert into aggspill select i, i*2, i::text from generate_series(1, 100000) i;
insert into aggspill select i, i*2, i::text from generate_series(1, 1000000) i;
-- Test the spilling with serial/deserial functions involved
-- The transition type of numeric is internal, and hence it uses the serial/deserial functions when spilling
drop table if exists aggspill_numeric_avg;
create table aggspill_numeric_avg (a int, b int, c numeric) distributed by (a);
insert into aggspill_numeric_avg (select i, i + 1, i * 1.1111 from generate_series(1, 500000) as i);
insert into aggspill_numeric_avg (select i, i + 1, i * 1.1111 from generate_series(1, 500000) as i);
analyze aggspill_numeric_avg;
-- No spill with large statement memory
set statement_mem = '125MB';
select count(*) from (select i, count(*) from aggspill group by i,j having count(*) = 1) g;
......@@ -83,6 +91,7 @@ set statement_mem = '10MB';
select overflows >= 1 from hashagg_spill.num_hashagg_overflows('explain analyze
select count(*) from (select i, count(*) from aggspill group by i,j having count(*) = 2) g') overflows;
select count(*) from (select i, count(*) from aggspill group by i,j having count(*) = 2) g;
select count(*) from (select a, avg(b), avg(c) from aggspill_numeric_avg group by a) g;
-- Reduce the statement memory, nbatches and entrysize even further to cause multiple overflows
set gp_hashagg_default_nbatches = 4;
......@@ -92,5 +101,6 @@ select overflows > 1 from hashagg_spill.num_hashagg_overflows('explain analyze
select count(*) from (select i, count(*) from aggspill group by i,j,t having count(*) = 3) g') overflows;
select count(*) from (select i, count(*) from aggspill group by i,j,t having count(*) = 3) g;
select count(*) from (select a, avg(b), avg(c) from aggspill_numeric_avg group by a) g;
drop schema hashagg_spill cascade;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册