未验证 提交 85396fc8 编写于 作者: M Maksim Kita 提交者: GitHub

Merge pull request #20057 from rf/rf/add-deltasum

Add `deltaSum` aggregate function, docs & test
---
toc_priority: 141
---
# deltaSum {#agg_functions-deltasum}
Syntax: `deltaSum(value)`
Adds the differences between consecutive rows. If the difference is negative, it is ignored.
`value` must be some integer or floating point type.
Example:
```sql
select deltaSum(arrayJoin([1, 2, 3])); -- => 2
select deltaSum(arrayJoin([1, 2, 3, 0, 3, 4, 2, 3])); -- => 7
select deltaSum(arrayJoin([2.25, 3, 4.5])); -- => 2.25
```
#include <AggregateFunctions/AggregateFunctionDeltaSum.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionDeltaSum(
const String & name,
const DataTypes & arguments,
const Array & params)
{
assertNoParameters(name, params);
if (arguments.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function " + name,
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
DataTypePtr data_type = arguments[0];
if (isInteger(data_type) || isFloat(data_type))
return AggregateFunctionPtr(createWithNumericType<AggregationFunctionDeltaSum>(
*data_type, arguments, params));
else
throw Exception("Illegal type " + arguments[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
}
void registerAggregateFunctionDeltaSum(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties properties = { .returns_default_when_only_null = true, .is_order_dependent = true };
factory.registerFunction("deltaSum", { createAggregateFunctionDeltaSum, properties });
}
}
#pragma once
#include <type_traits>
#include <experimental/type_traits>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/IAggregateFunction.h>
namespace DB
{
template <typename T>
struct AggregationFunctionDeltaSumData
{
T sum = 0;
T last = 0;
T first = 0;
bool seen_last = false;
bool seen_first = false;
};
template <typename T>
class AggregationFunctionDeltaSum final
: public IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>
{
public:
AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params}
{}
AggregationFunctionDeltaSum()
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{}
{}
String getName() const override { return "deltaSum"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<T>>(); }
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
auto value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
if ((this->data(place).last < value) && this->data(place).seen_last)
{
this->data(place).sum += (value - this->data(place).last);
}
this->data(place).last = value;
this->data(place).seen_last = true;
if (!this->data(place).seen_first)
{
this->data(place).first = value;
this->data(place).seen_first = true;
}
}
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
auto place_data = &this->data(place);
auto rhs_data = &this->data(rhs);
if ((place_data->last < rhs_data->first) && place_data->seen_last && rhs_data->seen_first)
{
// If the lhs last number seen is less than the first number the rhs saw, the lhs is before
// the rhs, for example [0, 2] [4, 7]. So we want to add the deltasums, but also add the
// difference between lhs last number and rhs first number (the 2 and 4). Then we want to
// take last value from the rhs, so first and last become 0 and 7.
place_data->sum += rhs_data->sum + (rhs_data->first - place_data->last);
place_data->last = rhs_data->last;
}
else if ((rhs_data->last < place_data->first && rhs_data->seen_last && place_data->seen_first))
{
// In the opposite scenario, the lhs comes after the rhs, e.g. [4, 6] [1, 2]. Since we
// assume the input interval states are sorted by time, we assume this is a counter
// reset, and therefore do *not* add the difference between our first value and the
// rhs last value.
place_data->sum += rhs_data->sum;
place_data->first = rhs_data->first;
}
else if (rhs_data->seen_first)
{
// If we're here then the lhs is an empty state and the rhs does have some state, so
// we'll just take that state.
place_data->first = rhs_data->first;
place_data->seen_first = rhs_data->seen_first;
place_data->last = rhs_data->last;
place_data->seen_last = rhs_data->seen_last;
place_data->sum = rhs_data->sum;
}
// Otherwise lhs either has data or is uninitialized, so we don't need to modify its values.
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
writeIntBinary(this->data(place).sum, buf);
writeIntBinary(this->data(place).first, buf);
writeIntBinary(this->data(place).last, buf);
writePODBinary<bool>(this->data(place).seen_first, buf);
writePODBinary<bool>(this->data(place).seen_last, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
readIntBinary(this->data(place).sum, buf);
readIntBinary(this->data(place).first, buf);
readIntBinary(this->data(place).last, buf);
readPODBinary<bool>(this->data(place).seen_first, buf);
readPODBinary<bool>(this->data(place).seen_last, buf);
}
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
{
assert_cast<ColumnVector<T> &>(to).getData().push_back(this->data(place).sum);
}
};
}
......@@ -11,6 +11,7 @@ class AggregateFunctionFactory;
void registerAggregateFunctionAvg(AggregateFunctionFactory &);
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory &);
void registerAggregateFunctionCount(AggregateFunctionFactory &);
void registerAggregateFunctionDeltaSum(AggregateFunctionFactory &);
void registerAggregateFunctionGroupArray(AggregateFunctionFactory &);
void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory &);
void registerAggregateFunctionGroupArrayInsertAt(AggregateFunctionFactory &);
......@@ -66,6 +67,7 @@ void registerAggregateFunctions()
registerAggregateFunctionAvg(factory);
registerAggregateFunctionAvgWeighted(factory);
registerAggregateFunctionCount(factory);
registerAggregateFunctionDeltaSum(factory);
registerAggregateFunctionGroupArray(factory);
registerAggregateFunctionGroupUniqArray(factory);
registerAggregateFunctionGroupArrayInsertAt(factory);
......
......@@ -19,6 +19,7 @@ SRCS(
AggregateFunctionCategoricalInformationValue.cpp
AggregateFunctionCombinatorFactory.cpp
AggregateFunctionCount.cpp
AggregateFunctionDeltaSum.cpp
AggregateFunctionDistinct.cpp
AggregateFunctionEntropy.cpp
AggregateFunctionFactory.cpp
......
select deltaSum(arrayJoin([1, 2, 3]));
select deltaSum(arrayJoin([1, 2, 3, 0, 3, 4]));
select deltaSum(arrayJoin([1, 2, 3, 0, 3, 4, 2, 3]));
select deltaSum(arrayJoin([1, 2, 3, 0, 3, 3, 3, 3, 3, 4, 2, 3]));
select deltaSum(arrayJoin([1, 2, 3, 0, 0, 0, 0, 3, 3, 3, 3, 3, 4, 2, 3]));
select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([0, 1])) as rows union all select deltaSumState(arrayJoin([4, 5])) as rows);
select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([4, 5])) as rows union all select deltaSumState(arrayJoin([0, 1])) as rows);
select deltaSum(arrayJoin([2.25, 3, 4.5]));
select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([0.1, 0.3, 0.5])) as rows union all select deltaSumState(arrayJoin([4.1, 5.1, 6.6])) as rows);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册