提交 7ddede66 编写于 作者: A Alexey Milovidov

Added support for generic case for aggregate function topKWeighted; fixed errors #4245

上级 ec5fbce4
......@@ -15,6 +15,7 @@ namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ARGUMENT_OUT_OF_BOUND;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
......@@ -22,39 +23,51 @@ namespace
{
/// Substitute return type for Date and DateTime
class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType>
template <bool is_weighted>
class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>
{
using AggregateFunctionTopK<DataTypeDate::FieldType>::AggregateFunctionTopK;
using AggregateFunctionTopK<DataTypeDate::FieldType, is_weighted>::AggregateFunctionTopK;
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
};
class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType>
template <bool is_weighted>
class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>
{
using AggregateFunctionTopK<DataTypeDateTime::FieldType>::AggregateFunctionTopK;
using AggregateFunctionTopK<DataTypeDateTime::FieldType, is_weighted>::AggregateFunctionTopK;
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
};
template <bool is_weighted>
static IAggregateFunction * createWithExtraTypes(const DataTypePtr & argument_type, UInt64 threshold)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) return new AggregateFunctionTopKDate(threshold);
if (which.idx == TypeIndex::DateTime) return new AggregateFunctionTopKDateTime(threshold);
if (which.idx == TypeIndex::Date)
return new AggregateFunctionTopKDate<is_weighted>(threshold);
if (which.idx == TypeIndex::DateTime)
return new AggregateFunctionTopKDateTime<is_weighted>(threshold);
/// Check that we can use plain version of AggregateFunctionTopKGeneric
if (argument_type->isValueUnambiguouslyRepresentedInContiguousMemoryRegion())
return new AggregateFunctionTopKGeneric<true>(threshold, argument_type);
return new AggregateFunctionTopKGeneric<true, is_weighted>(threshold, argument_type);
else
return new AggregateFunctionTopKGeneric<false>(threshold, argument_type);
return new AggregateFunctionTopKGeneric<false, is_weighted>(threshold, argument_type);
}
template<bool weighed>
template <bool is_weighted>
AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (!weighed)
if (!is_weighted)
{
assertUnary(name, argument_types);
}
else
{
assertBinary(name, argument_types);
if (!isNumber(argument_types[1]))
throw Exception("The second argument for aggregate function 'topKWeighted' must have numeric type", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
UInt64 threshold = 10; /// default value
......@@ -77,15 +90,10 @@ AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const
threshold = k;
}
AggregateFunctionPtr res;
if (!weighed)
res = AggregateFunctionPtr(createWithNumericType<AggregateFunctionTopK>(*argument_types[0], threshold));
else
res = AggregateFunctionPtr(createWithNumericType<AggregateFunctionTopKWeighed>(*argument_types[0], threshold));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK, is_weighted>(*argument_types[0], threshold));
if (!res)
res = AggregateFunctionPtr(createWithExtraTypes(argument_types[0], threshold));
res = AggregateFunctionPtr(createWithExtraTypes<is_weighted>(argument_types[0], threshold));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() +
......@@ -99,7 +107,7 @@ AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const
void registerAggregateFunctionTopK(AggregateFunctionFactory & factory)
{
factory.registerFunction("topK", createAggregateFunctionTopK<false>);
factory.registerFunction("topKWeighed", createAggregateFunctionTopK<true>);
factory.registerFunction("topKWeighted", createAggregateFunctionTopK<true>);
}
}
......@@ -37,9 +37,10 @@ struct AggregateFunctionTopKData
Set value;
};
template <typename T>
template <typename T, bool is_weighted>
class AggregateFunctionTopK
: public IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T>>
: public IAggregateFunctionDataHelper<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T, is_weighted>>
{
protected:
using State = AggregateFunctionTopKData<T>;
......@@ -50,7 +51,7 @@ public:
AggregateFunctionTopK(UInt64 threshold)
: threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold) {}
String getName() const override { return "topK"; }
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
DataTypePtr getReturnType() const override
{
......@@ -62,7 +63,11 @@ public:
auto & set = this->data(place).value;
if (set.capacity() != reserved)
set.resize(reserved);
set.insert(static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
if constexpr (is_weighted)
set.insert(static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num], columns[1]->getUInt(row_num));
else
set.insert(static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
......@@ -106,25 +111,6 @@ public:
};
template <typename T>
class AggregateFunctionTopKWeighed : public AggregateFunctionTopK<T>
{
public:
AggregateFunctionTopKWeighed(UInt64 threshold)
: AggregateFunctionTopK<T>(threshold) {}
String getName() const override { return "topKWeighed"; }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
auto & set = this->data(place).value;
if (set.capacity() != AggregateFunctionTopK<T>::reserved)
set.resize(AggregateFunctionTopK<T>::reserved);
set.insert(static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<T> &>(*columns[1]).getData()[row_num]);
}
};
/// Generic implementation, it uses serialized representation as object descriptor.
struct AggregateFunctionTopKGenericData
{
......@@ -142,8 +128,8 @@ struct AggregateFunctionTopKGenericData
/** Template parameter with true value should be used for columns that store their elements in memory continuously.
* For such columns topK() can be implemented more efficiently (especially for small numeric arrays).
*/
template <bool is_plain_column = false>
class AggregateFunctionTopKGeneric : public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column>>
template <bool is_plain_column, bool is_weighted>
class AggregateFunctionTopKGeneric : public IAggregateFunctionDataHelper<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column, is_weighted>>
{
private:
using State = AggregateFunctionTopKGenericData;
......@@ -158,7 +144,7 @@ public:
AggregateFunctionTopKGeneric(UInt64 threshold, const DataTypePtr & input_data_type)
: threshold(threshold), reserved(TOP_K_LOAD_FACTOR * threshold), input_data_type(input_data_type) {}
String getName() const override { return "topK"; }
String getName() const override { return is_weighted ? "topKWeighted" : "topK"; }
DataTypePtr getReturnType() const override
{
......@@ -206,13 +192,19 @@ public:
if constexpr (is_plain_column)
{
set.insert(columns[0]->getDataAt(row_num));
if constexpr (is_weighted)
set.insert(columns[0]->getDataAt(row_num), columns[1]->getUInt(row_num));
else
set.insert(columns[0]->getDataAt(row_num));
}
else
{
const char * begin = nullptr;
StringRef str_serialized = columns[0]->serializeValueIntoArena(row_num, *arena, begin);
set.insert(str_serialized);
if constexpr (is_weighted)
set.insert(str_serialized, columns[1]->getUInt(row_num));
else
set.insert(str_serialized);
arena->rollback(str_serialized.size);
}
}
......
......@@ -33,6 +33,19 @@ static IAggregateFunction * createWithNumericType(const IDataType & argument_typ
return nullptr;
}
template <template <typename, bool> class AggregateFunctionTemplate, bool bool_param, typename... TArgs>
static IAggregateFunction * createWithNumericType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) return new AggregateFunctionTemplate<TYPE, bool_param>(std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) return new AggregateFunctionTemplate<Int8, bool_param>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Enum16) return new AggregateFunctionTemplate<Int16, bool_param>(std::forward<TArgs>(args)...);
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename... TArgs>
static IAggregateFunction * createWithNumericType(const IDataType & argument_type, TArgs && ... args)
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册