提交 ec5fbce4 编写于 作者: A Alexey Milovidov

Merge branch 'add_topkweighed' of https://github.com/andrewgolman/ClickHouse...

Merge branch 'add_topkweighed' of https://github.com/andrewgolman/ClickHouse into andrewgolman-add_topkweighed
......@@ -48,16 +48,21 @@ static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_ty
return new AggregateFunctionTopKGeneric<false>(threshold, argument_type);
}
template<bool weighed>
AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const DataTypes & argument_types, const Array & params)
{
assertUnary(name, argument_types);
if (!weighed)
assertUnary(name, argument_types);
else
assertBinary(name, argument_types);
UInt64 threshold = 10; /// default value
if (!params.empty())
{
if (params.size() != 1)
throw Exception("Aggregate function " + name + " requires one parameter or less.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
throw Exception("Aggregate function " + name + " requires one parameter or less.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
UInt64 k = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
......@@ -72,7 +77,12 @@ AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const
threshold = k;
}
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK>(*argument_types[0], threshold));
AggregateFunctionPtr res;
if (!weighed)
res = AggregateFunctionPtr(createWithNumericType<AggregateFunctionTopK>(*argument_types[0], threshold));
else
res = AggregateFunctionPtr(createWithNumericType<AggregateFunctionTopKWeighed>(*argument_types[0], threshold));
if (!res)
res = AggregateFunctionPtr(createWithExtraTypes(argument_types[0], threshold));
......@@ -88,7 +98,8 @@ AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const
void registerAggregateFunctionTopK(AggregateFunctionFactory & factory)
{
factory.registerFunction("topK", createAggregateFunctionTopK);
factory.registerFunction("topK", createAggregateFunctionTopK<false>);
factory.registerFunction("topKWeighed", createAggregateFunctionTopK<true>);
}
}
......@@ -37,14 +37,12 @@ struct AggregateFunctionTopKData
Set value;
};
template <typename T>
class AggregateFunctionTopK
: public IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T>>
{
private:
protected:
using State = AggregateFunctionTopKData<T>;
UInt64 threshold;
UInt64 reserved;
......@@ -108,6 +106,25 @@ public:
};
template <typename T>
class AggregateFunctionTopKWeighed : public AggregateFunctionTopK<T>
{
public:
AggregateFunctionTopKWeighed(UInt64 threshold)
: AggregateFunctionTopK<T>(threshold) {}
String getName() const override { return "topKWeighed"; }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
auto & set = this->data(place).value;
if (set.capacity() != AggregateFunctionTopK<T>::reserved)
set.resize(AggregateFunctionTopK<T>::reserved);
set.insert(static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<T> &>(*columns[1]).getData()[row_num]);
}
};
/// Generic implementation, it uses serialized representation as object descriptor.
struct AggregateFunctionTopKGenericData
{
......@@ -226,7 +243,6 @@ public:
const char * getHeaderFilePath() const override { return __FILE__; }
};
#undef TOP_K_LOAD_FACTOR
}
......@@ -4,6 +4,8 @@
1 6 3 3
1 6 [3,2]
1 6 [3,2]
1 6 [3,2]
1 6 [3,2]
1 0.5
1 0.5
1 0.1
......
......@@ -73,6 +73,21 @@ select k, sum(c), topKMerge(2)(x) from test.summing_merge_tree_aggregate_functio
drop table test.summing_merge_tree_aggregate_function;
---- sum + topKWeighed
create table test.summing_merge_tree_aggregate_function (d materialized today(), k UInt64, c UInt64, x AggregateFunction(topKWeighed(2), UInt8, UInt8)) engine=SummingMergeTree(d, k, 8192);
insert into test.summing_merge_tree_aggregate_function select 1, 1, topKWeighedState(2)(1, 1);
insert into test.summing_merge_tree_aggregate_function select 1, 1, topKWeighedState(2)(1, 1);
insert into test.summing_merge_tree_aggregate_function select 1, 1, topKWeighedState(2)(1, 1);
insert into test.summing_merge_tree_aggregate_function select 1, 1, topKWeighedState(2)(2, 2);
insert into test.summing_merge_tree_aggregate_function select 1, 1, topKWeighedState(2)(2, 2);
insert into test.summing_merge_tree_aggregate_function select 1, 1, topKWeighedState(2)(3, 5);
select k, sum(c), topKWeighedMerge(2)(x) from test.summing_merge_tree_aggregate_function group by k;
optimize table test.summing_merge_tree_aggregate_function;
select k, sum(c), topKWeighedMerge(2)(x) from test.summing_merge_tree_aggregate_function group by k;
drop table test.summing_merge_tree_aggregate_function;
---- avg
create table test.summing_merge_tree_aggregate_function (d materialized today(), k UInt64, x AggregateFunction(avg, Float64)) engine=SummingMergeTree(d, k, 8192);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册