提交 d7a36519 编写于 作者: A Alexey Milovidov 提交者: alexey-milovidov

Simplification of aggregate functions: development [#CLICKHOUSE-2].

上级 b3f4d439
......@@ -71,6 +71,7 @@ list (APPEND dbms_headers src/Functions/IFunction.h src/Functions/FunctionFactor
list (APPEND dbms_sources
src/AggregateFunctions/AggregateFunctionFactory.cpp
src/AggregateFunctions/FactoryHelpers.cpp
src/AggregateFunctions/AggregateFunctionState.cpp
src/AggregateFunctions/AggregateFunctionArray.cpp
src/AggregateFunctions/AggregateFunctionNull.cpp
......@@ -84,6 +85,7 @@ list (APPEND dbms_sources
list (APPEND dbms_headers
src/AggregateFunctions/IAggregateFunction.h
src/AggregateFunctions/AggregateFunctionFactory.h
src/AggregateFunctions/FactoryHelpers.h
src/AggregateFunctions/AggregateFunctionState.h
src/AggregateFunctions/AggregateFunctionArray.h
src/AggregateFunctions/AggregateFunctionNull.h
......
......@@ -42,7 +42,8 @@ AggregateFunctionPtr createAggregateFunctionForEach(AggregateFunctionPtr & neste
AggregateFunctionPtr createAggregateFunctionIf(AggregateFunctionPtr & nested, const DataTypes & argument_types);
AggregateFunctionPtr createAggregateFunctionState(AggregateFunctionPtr & nested, const DataTypes & argument_types, const Array & parameters);
AggregateFunctionPtr createAggregateFunctionMerge(const String & name, AggregateFunctionPtr & nested, const DataTypes & argument_types);
AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested, const DataTypes & argument_types);
AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested);
AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested, const DataTypes & argument_types);
AggregateFunctionPtr createAggregateFunctionCountNotNull(const String & name, const DataTypes & argument_types, const Array & parameters);
AggregateFunctionPtr createAggregateFunctionNothing();
......@@ -120,7 +121,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
}
if (argument_types.size() == 1)
return createAggregateFunctionNullUnary(nested_function, argument_types);
return createAggregateFunctionNullUnary(nested_function);
else
return createAggregateFunctionNullVariadic(nested_function, argument_types);
}
......
......@@ -36,7 +36,7 @@ namespace
M(Float64)
template <template <typename> class Data, typename Name, bool returns_float, bool returns_many>
template <template <typename> class Data, typename Name, bool have_second_arg, bool returns_float, bool returns_many>
AggregateFunctionPtr createAggregateFunctionQuantile(const std::string & name, const DataTypes & argument_types, const Array & params)
{
assertUnary(name, argument_types);
......@@ -44,58 +44,21 @@ AggregateFunctionPtr createAggregateFunctionQuantile(const std::string & name, c
#define CREATE(TYPE) \
if (typeid_cast<const DataType ## TYPE *>(argument_type.get())) \
return std::make_shared<AggregateFunctionQuantile<TYPE, void, Data<TYPE>, Name, returns_float, returns_many>>(argument_type, params);
return std::make_shared<AggregateFunctionQuantile<TYPE, Data<TYPE>, Name, have_second_arg, returns_float, returns_many>>(argument_type, params);
FOR_NUMERIC_TYPES(CREATE)
#undef CREATE
if (typeid_cast<const DataTypeDate *>(argument_type.get()))
return std::make_shared<AggregateFunctionQuantile<DataTypeDate::FieldType, void, Data<DataTypeDate::FieldType>, Name, false, returns_many>>(argument_type, params);
return std::make_shared<AggregateFunctionQuantile<
DataTypeDate::FieldType, Data<DataTypeDate::FieldType>, Name, have_second_arg, false, returns_many>>(argument_type, params);
if (typeid_cast<const DataTypeDateTime *>(argument_type.get()))
return std::make_shared<AggregateFunctionQuantile<DataTypeDateTime::FieldType, void, Data<DataTypeDateTime::FieldType>, Name, false, returns_many>>(argument_type, params);
return std::make_shared<AggregateFunctionQuantile<
DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>, Name, have_second_arg, false, returns_many>>(argument_type, params);
throw Exception("Illegal type " + argument_type->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
template <typename FirstArg, template <typename> class Data, typename Name, bool returns_float, bool returns_many>
AggregateFunctionPtr createAggregateFunctionQuantileTwoArgsForSecondArg(const std::string & name, const DataTypes & argument_types, const Array & params)
{
const DataTypePtr & second_argument_type = argument_types[0];
#define CREATE(TYPE) \
if (typeid_cast<const DataType ## TYPE *>(second_argument_type.get())) \
return std::make_shared<AggregateFunctionQuantile<FirstArg, TYPE, Data<FirstArg>, Name, returns_float, returns_many>>(argument_types[0], params);
FOR_NUMERIC_TYPES(CREATE)
#undef CREATE
throw Exception("Illegal type " + second_argument_type->getName() + " of second argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
template <template <typename> class Data, typename Name, bool returns_float, bool returns_many>
AggregateFunctionPtr createAggregateFunctionQuantileTwoArgs(const std::string & name, const DataTypes & argument_types, const Array & params)
{
assertBinary(name, argument_types);
const DataTypePtr & argument_type = argument_types[0];
#define CREATE(TYPE) \
if (typeid_cast<const DataType ## TYPE *>(argument_type.get())) \
return createAggregateFunctionQuantileTwoArgsForSecondArg<TYPE, Data, Name, returns_float, returns_many>(name, argument_types, params);
FOR_NUMERIC_TYPES(CREATE)
#undef CREATE
if (typeid_cast<const DataTypeDate *>(argument_type.get()))
return createAggregateFunctionQuantileTwoArgsForSecondArg<
DataTypeDate::FieldType, Data, Name, false, returns_many>(name, argument_types, params);
if (typeid_cast<const DataTypeDateTime *>(argument_type.get()))
return createAggregateFunctionQuantileTwoArgsForSecondArg<
DataTypeDateTime::FieldType, Data, Name, false, returns_many>(name, argument_types, params);
throw Exception("Illegal type " + argument_type->getName() + " of first argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
#undef FOR_NUMERIC_TYPES
......@@ -108,7 +71,7 @@ struct NameQuantileExact { static constexpr auto name = "quantileExact"; };
struct NameQuantileExactWeighted { static constexpr auto name = "quantileExactWeighted"; };
struct NameQuantilesExact { static constexpr auto name = "quantilesExact"; };
struct NameQuantilesExactWeighted { static constexpr auto name = "quantilesExactWeighted"; };
/*
struct NameQuantileTiming { static constexpr auto name = "quantileTiming"; };
struct NameQuantileTimingWeighted { static constexpr auto name = "quantileTimingWeighted"; };
struct NameQuantilesTiming { static constexpr auto name = "quantilesTiming"; };
......@@ -118,38 +81,76 @@ struct NameQuantileTDigest { static constexpr auto name = "quantileTDigest"; };
struct NameQuantileTDigestWeighted { static constexpr auto name = "quantileTDigestWeighted"; };
struct NameQuantilesTDigest { static constexpr auto name = "quantilesTDigest"; };
struct NameQuantilesTDigestWeighted { static constexpr auto name = "quantilesTDigestWeighted"; };
*/
}
void registerAggregateFunctionsQuantile(AggregateFunctionFactory & factory)
{
factory.registerFunction("quantile", createAggregateFunctionQuantile<QuantileReservoirSampler, NameQuantile, true, false>);
factory.registerFunction("quantiles", createAggregateFunctionQuantile<QuantileReservoirSampler, NameQuantiles, true, true>);
factory.registerFunction(NameQuantile::name,
createAggregateFunctionQuantile<QuantileReservoirSampler, NameQuantile, false, true, false>);
factory.registerFunction(NameQuantiles::name,
createAggregateFunctionQuantile<QuantileReservoirSampler, NameQuantiles, false, true, true>);
factory.registerFunction(NameQuantileDeterministic::name,
createAggregateFunctionQuantile<QuantileReservoirSamplerDeterministic, NameQuantileDeterministic, true, true, false>);
factory.registerFunction(NameQuantilesDeterministic::name,
createAggregateFunctionQuantile<QuantileReservoirSamplerDeterministic, NameQuantilesDeterministic, true, true, true>);
factory.registerFunction(NameQuantileExact::name,
createAggregateFunctionQuantile<QuantileExact, NameQuantileExact, false, false, false>);
factory.registerFunction(NameQuantilesExact::name,
createAggregateFunctionQuantile<QuantileExact, NameQuantilesExact, false, false, true>);
factory.registerFunction(NameQuantileExactWeighted::name,
createAggregateFunctionQuantile<QuantileExactWeighted, NameQuantileExactWeighted, true, false, false>);
factory.registerFunction(NameQuantilesExactWeighted::name,
createAggregateFunctionQuantile<QuantileExactWeighted, NameQuantilesExactWeighted, true, false, true>);
factory.registerFunction(NameQuantileTiming::name,
createAggregateFunctionQuantile<QuantileTiming, NameQuantileTiming, false, false, false>);
factory.registerFunction(NameQuantilesTiming::name,
createAggregateFunctionQuantile<QuantileTiming, NameQuantilesTiming, false, false, true>);
factory.registerFunction(NameQuantileTimingWeighted::name,
createAggregateFunctionQuantile<QuantileTiming, NameQuantileTimingWeighted, true, false, false>);
factory.registerFunction(NameQuantilesTimingWeighted::name,
createAggregateFunctionQuantile<QuantileTiming, NameQuantilesTimingWeighted, true, false, true>);
factory.registerFunction(NameQuantileTDigest::name,
createAggregateFunctionQuantile<QuantileTDigest, NameQuantileTDigest, false, true, false>);
factory.registerFunction(NameQuantilesTDigest::name,
createAggregateFunctionQuantile<QuantileTDigest, NameQuantilesTDigest, false, true, true>);
factory.registerFunction(NameQuantileTDigestWeighted::name,
createAggregateFunctionQuantile<QuantileTDigest, NameQuantileTDigestWeighted, true, true, false>);
factory.registerFunction(NameQuantilesTDigestWeighted::name,
createAggregateFunctionQuantile<QuantileTDigest, NameQuantilesTDigestWeighted, true, true, true>);
/// 'median' is an alias for 'quantile'
factory.registerFunction("median",
createAggregateFunctionQuantile<QuantileReservoirSampler, NameQuantile, false, true, false>);
factory.registerFunction("medianDeterministic",
createAggregateFunctionQuantile<QuantileReservoirSamplerDeterministic, NameQuantileDeterministic, true, true, false>);
factory.registerFunction("quantileDeterministic",
createAggregateFunctionQuantileTwoArgs<QuantileReservoirSamplerDeterministic, NameQuantileDeterministic, true, false>);
factory.registerFunction("quantilesDeterministic",
createAggregateFunctionQuantileTwoArgs<QuantileReservoirSamplerDeterministic, NameQuantilesDeterministic, true, true>);
factory.registerFunction("medianExact",
createAggregateFunctionQuantile<QuantileExact, NameQuantileExact, false, false, false>);
factory.registerFunction("quantileExact", createAggregateFunctionQuantile<QuantileExact, NameQuantileExact, false, false>);
factory.registerFunction("quantilesExact", createAggregateFunctionQuantile<QuantileExact, NameQuantilesExact, false, true>);
factory.registerFunction("medianExactWeighted",
createAggregateFunctionQuantile<QuantileExactWeighted, NameQuantileExactWeighted, true, false, false>);
factory.registerFunction("quantileExactWeighted", createAggregateFunctionQuantileTwoArgs<QuantileExactWeighted, NameQuantileExactWeighted, false, false>);
factory.registerFunction("quantilesExactWeighted", createAggregateFunctionQuantileTwoArgs<QuantileExactWeighted, NameQuantilesExactWeighted, false, true>);
/*
factory.registerFunction("quantileTiming", createAggregateFunctionQuantile<QuantileTiming, NameQuantileTiming, false, false>);
factory.registerFunction("quantilesTiming", createAggregateFunctionQuantile<QuantileTiming, NameQuantilesTiming, false, true>);
factory.registerFunction("medianTiming",
createAggregateFunctionQuantile<QuantileTiming, NameQuantileTiming, false, false, false>);
factory.registerFunction("quantileTimingWeighted", createAggregateFunctionQuantileTwoArgs<QuantileTiming, NameQuantileTimingWeighted, false, false>);
factory.registerFunction("quantilesTimingWeighted", createAggregateFunctionQuantileTwoArgs<QuantileTiming, NameQuantilesTimingWeighted, false, true>);
factory.registerFunction("medianTimingWeighted",
createAggregateFunctionQuantile<QuantileTiming, NameQuantileTimingWeighted, true, false, false>);
factory.registerFunction("quantileTDigest", createAggregateFunctionQuantile<QuantileTDigest, NameQuantileTDigest, true, false>);
factory.registerFunction("quantilesTDigest", createAggregateFunctionQuantile<QuantileTDigest, NameQuantilesTDigest, true, true>);
factory.registerFunction("medianTDigest",
createAggregateFunctionQuantile<QuantileTDigest, NameQuantileTDigest, false, true, false>);
factory.registerFunction("quantileTDigestWeighted", createAggregateFunctionQuantileTwoArgs<QuantileTDigest, NameQuantileTDigestWeighted, true, false>);
factory.registerFunction("quantilesTDigestWeighted", createAggregateFunctionQuantileTwoArgs<QuantileTDigest, NameQuantilesTDigestWeighted, true, true>);
*/
/// TODO Aliases
factory.registerFunction("medianTDigestWeighted",
createAggregateFunctionQuantile<QuantileTDigest, NameQuantileTDigestWeighted, true, true, false>);
}
}
......@@ -25,20 +25,20 @@ namespace ErrorCodes
/** Generic aggregate function for calculation of quantiles.
* It depends on quantile calculation data structure.
* It depends on quantile calculation data structure. Look at Quantile*.h for various implementations.
*/
template <
/// Type of first argument.
typename Value,
/// If the function accept second argument, the type of this argument
/// (in can be "weight" to calculate quantiles or "determinator" that is used instead of PRNG).
typename SecondArg,
/// Data structure and implementation of calculation. Look at QuantileExact.h for example.
typename Data,
/// Structure with static member "name", containing the name of aggregate function.
/// Structure with static member "name", containing the name of the aggregate function.
typename Name,
/// If true, the function accept second argument
/// (in can be "weight" to calculate quantiles or "determinator" that is used instead of PRNG).
/// Second argument is always obtained through 'getUInt' method.
bool have_second_arg,
/// If true, the function will return float with possibly interpolated results and NaN if there was no values.
/// Otherwise it will return Value type and default value if there was no values.
/// As an example, the function cannot return floats, if the SQL type of argument is Date or DateTime.
......@@ -48,13 +48,14 @@ template <
bool returns_many
>
class AggregateFunctionQuantile final : public IAggregateFunctionDataHelper<Data,
AggregateFunctionQuantile<Value, SecondArg, Data, Name, returns_float, returns_many>>
AggregateFunctionQuantile<Value, Data, Name, have_second_arg, returns_float, returns_many>>
{
private:
static constexpr bool have_second_arg = !std::is_same_v<SecondArg, void>;
QuantileLevels<Float64> levels;
/// Used when there are single level to get.
Float64 level = 0.5;
DataTypePtr argument_type;
public:
......@@ -87,7 +88,7 @@ public:
if constexpr (have_second_arg)
this->data(place).add(
static_cast<const ColumnVector<Value> &>(*columns[0]).getData()[row_num],
static_cast<const ColumnVector<SecondArg> &>(*columns[1]).getData()[row_num]);
columns[1]->getUInt(row_num));
else
this->data(place).add(
static_cast<const ColumnVector<Value> &>(*columns[0]).getData()[row_num]);
......@@ -100,6 +101,7 @@ public:
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
/// const_cast is required because some data structures apply finalizaton (like compactization) before serializing.
this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
}
......@@ -110,6 +112,7 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
/// const_cast is required because some data structures apply finalizaton (like sorting) for obtain a result.
auto & data = this->data(const_cast<AggregateDataPtr>(place));
if constexpr (returns_many)
......
......@@ -4,23 +4,32 @@
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (!AggregateFunctionSequenceCount::sufficientArgs(argument_types.size()))
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (params.size() != 1)
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
return std::make_shared<AggregateFunctionSequenceCount>();
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, pattern);
}
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (!AggregateFunctionSequenceMatch::sufficientArgs(argument_types.size()))
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (params.size() != 1)
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
return std::make_shared<AggregateFunctionSequenceMatch>();
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, pattern);
}
}
......
......@@ -140,7 +140,38 @@ template <typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>
{
public:
static bool sufficientArgs(const size_t arg_count) { return arg_count >= 3; }
AggregateFunctionSequenceBase(const DataTypes & arguments, const String & pattern)
: pattern(pattern)
{
arg_count = arguments.size();
if (!sufficientArgs(arg_count))
throw Exception{"Aggregate function " + derived().getName() + " requires at least 3 arguments.",
ErrorCodes::TOO_LESS_ARGUMENTS_FOR_FUNCTION};
if (arg_count - 1 > AggregateFunctionSequenceMatchData::max_events)
throw Exception{"Aggregate function " + derived().getName() + " supports up to " +
toString(AggregateFunctionSequenceMatchData::max_events) + " event arguments.",
ErrorCodes::TOO_MUCH_ARGUMENTS_FOR_FUNCTION};
const auto time_arg = arguments.front().get();
if (!typeid_cast<const DataTypeDateTime *>(time_arg))
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
+ derived().getName() + ", must be DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
for (const auto i : ext::range(1, arg_count))
{
const auto cond_arg = arguments[i].get();
if (!typeid_cast<const DataTypeUInt8 *>(cond_arg))
throw Exception{
"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) +
" of aggregate function " + derived().getName() + ", must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
parsePattern();
}
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
......@@ -158,7 +189,7 @@ public:
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(data(rhs));
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
......@@ -197,6 +228,9 @@ private:
static constexpr size_t bytes_on_stack = 64;
using PatternActions = PODArray<PatternAction, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
static bool sufficientArgs(const size_t arg_count) { return arg_count >= 3; }
Derived & derived() { return static_cast<Derived &>(*this); }
void parsePattern()
{
......@@ -210,7 +244,7 @@ private:
auto throw_exception = [&](const std::string & msg)
{
throw Exception{
msg + " '" + std::string(pos, end) + "' at position " + std::to_string(pos - begin),
msg + " '" + std::string(pos, end) + "' at position " + toString(pos - begin),
ErrorCodes::SYNTAX_ERROR};
};
......@@ -269,7 +303,7 @@ private:
if (event_number > arg_count - 1)
throw Exception{
"Event number " + std::to_string(event_number) + " is out of range",
"Event number " + toString(event_number) + " is out of range",
ErrorCodes::BAD_ARGUMENTS};
actions.emplace_back(PatternActionType::SpecificEvent, event_number - 1);
......@@ -431,6 +465,8 @@ private:
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>
{
public:
using AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceMatch"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt8>(); }
......@@ -453,6 +489,8 @@ public:
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>
{
public:
using AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceCount"; }
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeUInt64>(); }
......
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/AggregateFunctionStatistics.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionVarPop(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionVarPop(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (argument_types.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionVarPop>(*argument_types[0]));
......@@ -21,10 +28,10 @@ AggregateFunctionPtr createAggregateFunctionVarPop(const std::string & name, con
return res;
}
AggregateFunctionPtr createAggregateFunctionVarSamp(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionVarSamp(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (argument_types.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionVarSamp>(*argument_types[0]));
......@@ -34,10 +41,10 @@ AggregateFunctionPtr createAggregateFunctionVarSamp(const std::string & name, co
return res;
}
AggregateFunctionPtr createAggregateFunctionStdDevPop(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionStdDevPop(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (argument_types.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionStdDevPop>(*argument_types[0]));
......@@ -47,10 +54,10 @@ AggregateFunctionPtr createAggregateFunctionStdDevPop(const std::string & name,
return res;
}
AggregateFunctionPtr createAggregateFunctionStdDevSamp(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionStdDevSamp(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (argument_types.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionStdDevSamp>(*argument_types[0]));
......@@ -60,10 +67,10 @@ AggregateFunctionPtr createAggregateFunctionStdDevSamp(const std::string & name,
return res;
}
AggregateFunctionPtr createAggregateFunctionCovarPop(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionCovarPop(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (argument_types.size() != 2)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res(createWithTwoNumericTypes<AggregateFunctionCovarPop>(*argument_types[0], *argument_types[1]));
if (!res)
......@@ -73,10 +80,10 @@ AggregateFunctionPtr createAggregateFunctionCovarPop(const std::string & name, c
return res;
}
AggregateFunctionPtr createAggregateFunctionCovarSamp(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionCovarSamp(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (argument_types.size() != 2)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res(createWithTwoNumericTypes<AggregateFunctionCovarSamp>(*argument_types[0], *argument_types[1]));
if (!res)
......@@ -87,10 +94,10 @@ AggregateFunctionPtr createAggregateFunctionCovarSamp(const std::string & name,
}
AggregateFunctionPtr createAggregateFunctionCorr(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionCorr(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
if (argument_types.size() != 2)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res(createWithTwoNumericTypes<AggregateFunctionCorr>(*argument_types[0], *argument_types[1]));
if (!res)
......
......@@ -49,7 +49,8 @@ struct AggregateFunctionSumMapData
*/
template <typename T>
class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper<AggregateFunctionSumMapData<typename NearestFieldType<T>::Type>>
class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper<
AggregateFunctionSumMapData<typename NearestFieldType<T>::Type>, AggregateFunctionSumMap<T>>
{
private:
DataTypePtr keys_type;
......
......@@ -5,6 +5,12 @@
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
......
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionUniq.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
......@@ -12,6 +13,13 @@
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
......@@ -20,8 +28,10 @@ namespace
*/
template <typename Data, typename DataForVariadic>
AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const DataTypes & argument_types, const Array & params)
{
assertNoParameters(name, params);
if (argument_types.size() == 1)
{
const IDataType & argument_type = *argument_types[0];
......@@ -37,7 +47,7 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
else if (typeid_cast<const DataTypeString *>(&argument_type) || typeid_cast<const DataTypeFixedString *>(&argument_type))
return std::make_shared<AggregateFunctionUniq<String, Data>>();
else if (typeid_cast<const DataTypeTuple *>(&argument_type))
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, true>>();
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, true>>(argument_types);
else if (typeid_cast<const DataTypeUUID *>(&argument_type))
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data>>();
else
......@@ -52,7 +62,7 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
throw Exception("Tuple argument of function " + name + " must be the only argument",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, false>>();
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, false>>(argument_types);
}
else
throw Exception("Incorrect number of arguments for aggregate function " + name,
......@@ -60,8 +70,10 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
}
template <template <typename> class Data, typename DataForVariadic>
AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const DataTypes & argument_types, const Array & params)
{
assertNoParameters(name, params);
if (argument_types.size() == 1)
{
const IDataType & argument_type = *argument_types[0];
......@@ -77,7 +89,7 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
else if (typeid_cast<const DataTypeString *>(&argument_type) || typeid_cast<const DataTypeFixedString *>(&argument_type))
return std::make_shared<AggregateFunctionUniq<String, Data<String>>>();
else if (typeid_cast<const DataTypeTuple *>(&argument_type))
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, true>>();
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, true>>(argument_types);
else if (typeid_cast<const DataTypeUUID *>(&argument_type))
return std::make_shared<AggregateFunctionUniq<DataTypeUUID::FieldType, Data<DataTypeUUID::FieldType>>>();
else
......@@ -92,7 +104,7 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
throw Exception("Tuple argument of function " + name + " must be the only argument",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, false>>();
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, false>>(argument_types);
}
else
throw Exception("Incorrect number of arguments for aggregate function " + name,
......@@ -112,15 +124,6 @@ void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory)
factory.registerFunction("uniqExact",
createAggregateFunctionUniq<AggregateFunctionUniqExactData, AggregateFunctionUniqExactData<String>>);
factory.registerFunction("uniqCombinedRaw",
createAggregateFunctionUniq<AggregateFunctionUniqCombinedRawData, AggregateFunctionUniqCombinedRawData<UInt64>>);
factory.registerFunction("uniqCombinedLinearCounting",
createAggregateFunctionUniq<AggregateFunctionUniqCombinedLinearCountingData, AggregateFunctionUniqCombinedLinearCountingData<UInt64>>);
factory.registerFunction("uniqCombinedBiasCorrected",
createAggregateFunctionUniq<AggregateFunctionUniqCombinedBiasCorrectedData, AggregateFunctionUniqCombinedBiasCorrectedData<UInt64>>);
factory.registerFunction("uniqCombined",
createAggregateFunctionUniq<AggregateFunctionUniqCombinedData, AggregateFunctionUniqCombinedData<UInt64>>);
}
......
......@@ -378,7 +378,7 @@ public:
* But (for the possibility of effective implementation), you can not pass several arguments, among which there are tuples.
*/
template <typename Data, bool argument_is_tuple>
class AggregateFunctionUniqVariadic final : public IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data>, argument_is_tuple>
class AggregateFunctionUniqVariadic final : public IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data, argument_is_tuple>>
{
private:
static constexpr bool is_exact = std::is_same<Data, AggregateFunctionUniqExactData<String>>::value;
......@@ -386,8 +386,7 @@ private:
size_t num_args = 0;
public:
AggregateFunctionUniqVariadic(const DataTypes & arguments, UInt8 threshold)
: threshold(threshold)
AggregateFunctionUniqVariadic(const DataTypes & arguments)
{
if (argument_is_tuple)
num_args = typeid_cast<const DataTypeTuple &>(*arguments[0]).getElements().size();
......
......@@ -5,29 +5,50 @@
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, const DataTypes & argument_types, const Array & /*parameters*/)
static constexpr UInt8 uniq_upto_max_threshold = 100;
AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (params.size() != 1)
throw Exception("Aggregate function " + name + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
UInt64 threshold_param = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[0]);
if (threshold_param > uniq_upto_max_threshold)
throw Exception("Too large parameter for aggregate function " + name + ". Maximum: " + toString(uniq_upto_max_threshold),
ErrorCodes::ARGUMENT_OUT_OF_BOUND);
UInt8 threshold = threshold_param;
if (argument_types.size() == 1)
{
const IDataType & argument_type = *argument_types[0];
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniqUpTo>(*argument_types[0]));
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionUniqUpTo>(*argument_types[0], threshold));
if (res)
return res;
else if (typeid_cast<const DataTypeDate *>(&argument_type))
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDate::FieldType>>();
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDate::FieldType>>(threshold);
else if (typeid_cast<const DataTypeDateTime*>(&argument_type))
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDateTime::FieldType>>();
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeDateTime::FieldType>>(threshold);
else if (typeid_cast<const DataTypeString*>(&argument_type) || typeid_cast<const DataTypeFixedString*>(&argument_type))
return std::make_shared<AggregateFunctionUniqUpTo<String>>();
return std::make_shared<AggregateFunctionUniqUpTo<String>>(threshold);
else if (typeid_cast<const DataTypeTuple *>(&argument_type))
return std::make_shared<AggregateFunctionUniqUpToVariadic<true>>();
return std::make_shared<AggregateFunctionUniqUpToVariadic<true>>(argument_types, threshold);
else if (typeid_cast<const DataTypeUUID *>(&argument_type))
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeUUID::FieldType>>();
return std::make_shared<AggregateFunctionUniqUpTo<DataTypeUUID::FieldType>>(threshold);
else
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
......@@ -39,7 +60,7 @@ AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, c
throw Exception("Tuple argument of function " + name + " must be the only argument",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return std::make_shared<AggregateFunctionUniqUpToVariadic<false>>();
return std::make_shared<AggregateFunctionUniqUpToVariadic<false>>(argument_types, threshold);
}
else
throw Exception("Incorrect number of arguments for aggregate function " + name,
......
......@@ -124,15 +124,19 @@ struct AggregateFunctionUniqUpToData<UInt128> : AggregateFunctionUniqUpToData<UI
}
};
constexpr UInt8 uniq_upto_max_threshold = 100;
template <typename T>
class AggregateFunctionUniqUpTo final : public IAggregateFunctionDataHelper<AggregateFunctionUniqUpToData<T>, AggregateFunctionUniqUpTo<T>>
{
private:
UInt8 threshold = 5; /// Default value if the parameter is not specified.
UInt8 threshold;
public:
AggregateFunctionUniqUpTo(UInt8 threshold)
: threshold(threshold)
{
}
size_t sizeOfData() const override
{
return sizeof(AggregateFunctionUniqUpToData<T>) + sizeof(T) * threshold;
......@@ -232,6 +236,8 @@ public:
{
static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).size());
}
const char * getHeaderFilePath() const override { return __FILE__; }
};
......
......@@ -12,6 +12,7 @@ list(REMOVE_ITEM clickhouse_aggregate_functions_sources
AggregateFunctionMerge.cpp
AggregateFunctionCount.cpp
parseAggregateFunctionParameters.cpp
FactoryHelpers.cpp
)
list(REMOVE_ITEM clickhouse_aggregate_functions_headers
......@@ -26,6 +27,7 @@ list(REMOVE_ITEM clickhouse_aggregate_functions_headers
AggregateFunctionMerge.h
AggregateFunctionCount.h
parseAggregateFunctionParameters.h
FactoryHelpers.h
)
add_library(clickhouse_aggregate_functions ${clickhouse_aggregate_functions_sources})
......
#pragma once
#include <AggregateFunctions/FactoryHelpers.h>
......
......@@ -20,58 +20,58 @@ static IAggregateFunction * createAggregateFunctionSingleValue(const String & na
assertNoParameters(name, parameters);
assertUnary(name, argument_types);
const IDataType & argument_type = *argument_types[0];
const DataTypePtr & argument_type = argument_types[0];
if (typeid_cast<const DataTypeUInt8 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt8>>>(argument_type);
if (typeid_cast<const DataTypeUInt16 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt16>>>(argument_type);
if (typeid_cast<const DataTypeUInt32 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt32>>>(argument_type);
if (typeid_cast<const DataTypeUInt64 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt64>>>(argument_type);
if (typeid_cast<const DataTypeInt8 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int8>>>(argument_type);
if (typeid_cast<const DataTypeInt16 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int16>>>(argument_type);
if (typeid_cast<const DataTypeInt32 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int32>>>(argument_type);
if (typeid_cast<const DataTypeInt64 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int64>>>(argument_type);
if (typeid_cast<const DataTypeFloat32 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Float32>>>(argument_type);
if (typeid_cast<const DataTypeFloat64 *>(&argument_type)) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Float64>>>(argument_type);
if (typeid_cast<const DataTypeDate *>(&argument_type))
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDate::FieldType>>>;
if (typeid_cast<const DataTypeDateTime*>(&argument_type))
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>>;
if (typeid_cast<const DataTypeString*>(&argument_type))
return new AggregateFunctionTemplate<Data<SingleValueDataString>>;
if (typeid_cast<const DataTypeUInt8 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt8>>>(argument_type);
if (typeid_cast<const DataTypeUInt16 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt16>>>(argument_type);
if (typeid_cast<const DataTypeUInt32 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt32>>>(argument_type);
if (typeid_cast<const DataTypeUInt64 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt64>>>(argument_type);
if (typeid_cast<const DataTypeInt8 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int8>>>(argument_type);
if (typeid_cast<const DataTypeInt16 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int16>>>(argument_type);
if (typeid_cast<const DataTypeInt32 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int32>>>(argument_type);
if (typeid_cast<const DataTypeInt64 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Int64>>>(argument_type);
if (typeid_cast<const DataTypeFloat32 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Float32>>>(argument_type);
if (typeid_cast<const DataTypeFloat64 *>(argument_type.get())) return new AggregateFunctionTemplate<Data<SingleValueDataFixed<Float64>>>(argument_type);
if (typeid_cast<const DataTypeDate *>(argument_type.get()))
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDate::FieldType>>>(argument_type);
if (typeid_cast<const DataTypeDateTime *>(argument_type.get()))
return new AggregateFunctionTemplate<Data<SingleValueDataFixed<DataTypeDateTime::FieldType>>>(argument_type);
if (typeid_cast<const DataTypeString *>(argument_type.get()))
return new AggregateFunctionTemplate<Data<SingleValueDataString>>(argument_type);
return new AggregateFunctionTemplate<Data<SingleValueDataGeneric>>;
return new AggregateFunctionTemplate<Data<SingleValueDataGeneric>>(argument_type);
}
/// argMin, argMax
template <template <typename> class MinMaxData, typename ResData>
static IAggregateFunction * createAggregateFunctionArgMinMaxSecond(const IDataType * res_type, const IDataType * val_type)
static IAggregateFunction * createAggregateFunctionArgMinMaxSecond(const DataTypePtr & res_type, const DataTypePtr & val_type)
{
if (typeid_cast<const DataTypeUInt8 *>(val_type))
if (typeid_cast<const DataTypeUInt8 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<UInt8>>>>(res_type, val_type);
if (typeid_cast<const DataTypeUInt16 *>(val_type))
if (typeid_cast<const DataTypeUInt16 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<UInt16>>>>(res_type, val_type);
if (typeid_cast<const DataTypeUInt32 *>(val_type))
if (typeid_cast<const DataTypeUInt32 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<UInt32>>>>(res_type, val_type);
if (typeid_cast<const DataTypeUInt64 *>(val_type))
if (typeid_cast<const DataTypeUInt64 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<UInt64>>>>(res_type, val_type);
if (typeid_cast<const DataTypeInt8 *>(val_type))
if (typeid_cast<const DataTypeInt8 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Int8>>>>(res_type, val_type);
if (typeid_cast<const DataTypeInt16 *>(val_type))
if (typeid_cast<const DataTypeInt16 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Int16>>>>(res_type, val_type);
if (typeid_cast<const DataTypeInt32 *>(val_type))
if (typeid_cast<const DataTypeInt32 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Int32>>>>(res_type, val_type);
if (typeid_cast<const DataTypeInt64 *>(val_type))
if (typeid_cast<const DataTypeInt64 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Int64>>>>(res_type, val_type);
if (typeid_cast<const DataTypeFloat32 *>(val_type))
if (typeid_cast<const DataTypeFloat32 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Float32>>>>(res_type, val_type);
if (typeid_cast<const DataTypeFloat64 *>(val_type))
if (typeid_cast<const DataTypeFloat64 *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<Float64>>>>(res_type, val_type);
if (typeid_cast<const DataTypeDate *>(val_type))
if (typeid_cast<const DataTypeDate *>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDate::FieldType>>>>(res_type, val_type);
if (typeid_cast<const DataTypeDateTime*>(val_type))
if (typeid_cast<const DataTypeDateTime*>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataFixed<DataTypeDateTime::FieldType>>>>(res_type, val_type);
if (typeid_cast<const DataTypeString*>(val_type))
if (typeid_cast<const DataTypeString*>(val_type.get()))
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataString>>>(res_type, val_type);
return new AggregateFunctionArgMinMax<AggregateFunctionArgMinMaxData<ResData, MinMaxData<SingleValueDataGeneric>>>(res_type, val_type);
......@@ -83,34 +83,34 @@ static IAggregateFunction * createAggregateFunctionArgMinMax(const String & name
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
const IDataType * res_type = argument_types[0].get();
const IDataType * val_type = argument_types[1].get();
const DataTypePtr & res_type = argument_types[0];
const DataTypePtr & val_type = argument_types[1];
if (typeid_cast<const DataTypeUInt8 *>(&res_type))
if (typeid_cast<const DataTypeUInt8 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<UInt8>>(res_type, val_type);
if (typeid_cast<const DataTypeUInt16 *>(&res_type))
if (typeid_cast<const DataTypeUInt16 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<UInt16>>(res_type, val_type);
if (typeid_cast<const DataTypeUInt32 *>(&res_type))
if (typeid_cast<const DataTypeUInt32 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<UInt32>>(res_type, val_type);
if (typeid_cast<const DataTypeUInt64 *>(&res_type))
if (typeid_cast<const DataTypeUInt64 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<UInt64>>(res_type, val_type);
if (typeid_cast<const DataTypeInt8 *>(&res_type))
if (typeid_cast<const DataTypeInt8 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Int8>>(res_type, val_type);
if (typeid_cast<const DataTypeInt16 *>(&res_type))
if (typeid_cast<const DataTypeInt16 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Int16>>(res_type, val_type);
if (typeid_cast<const DataTypeInt32 *>(&res_type))
if (typeid_cast<const DataTypeInt32 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Int32>>(res_type, val_type);
if (typeid_cast<const DataTypeInt64 *>(&res_type))
if (typeid_cast<const DataTypeInt64 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Int64>>(res_type, val_type);
if (typeid_cast<const DataTypeFloat32 *>(&res_type))
if (typeid_cast<const DataTypeFloat32 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Float32>>(res_type, val_type);
if (typeid_cast<const DataTypeFloat64 *>(&res_type))
if (typeid_cast<const DataTypeFloat64 *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<Float64>>(res_type, val_type);
if (typeid_cast<const DataTypeDate *>(&res_type))
if (typeid_cast<const DataTypeDate *>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<DataTypeDate::FieldType>>(res_type, val_type);
if (typeid_cast<const DataTypeDateTime*>(&res_type))
if (typeid_cast<const DataTypeDateTime*>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataFixed<DataTypeDateTime::FieldType>>(res_type, val_type);
if (typeid_cast<const DataTypeString*>(&res_type))
if (typeid_cast<const DataTypeString*>(res_type.get()))
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataString>(res_type, val_type);
return createAggregateFunctionArgMinMaxSecond<MinMaxData, SingleValueDataGeneric>(res_type, val_type);
......
#pragma once
#include <AggregateFunctions/AggregateFunctionQuantile.h>
namespace DB
{
......@@ -11,7 +9,12 @@ namespace ErrorCodes
extern const int NOT_IMPLEMENTED;
}
/** */
/** Calculates quantile by collecting all values into array
* and applying n-th element (introselect) algorithm for the resulting array.
*
* It use O(N) memory and it is very inefficient in case of high amount of identical values.
* But it is very CPU efficient for not large datasets.
*/
template <typename Value>
struct QuantileExact
{
......
#pragma once
#include <AggregateFunctions/AggregateFunctionQuantile.h>
#include <Common/HashTable/HashMap.h>
namespace DB
......@@ -11,7 +11,11 @@ namespace ErrorCodes
extern const int NOT_IMPLEMENTED;
}
/** */
/** Calculates quantile by counting number of occurences for each value in a hash map.
*
* It use O(distinct(N)) memory. Can be naturally applied for values with weight.
* In case of many identical values, it can be more efficient than QuantileExact even when weight is not used.
*/
template <typename Value>
struct QuantileExactWeighted
{
......
#pragma once
#include <AggregateFunctions/AggregateFunctionQuantile.h>
#include <AggregateFunctions/ReservoirSampler.h>
......@@ -12,7 +11,14 @@ namespace ErrorCodes
extern const int NOT_IMPLEMENTED;
}
/** */
/** Quantile calculation with "reservoir sample" algorithm.
* It collects pseudorandom subset of limited size from a stream of values,
* and approximate quantile from it.
* The result is non-deterministic. Also look at QuantileReservoirSamplerDeterministic.
*
* This algorithm is quite inefficient in terms of precision for memory usage,
* but very efficient in CPU (though less efficient than QuantileTiming and than QuantileExact for small sets).
*/
template <typename Value>
struct QuantileReservoirSampler
{
......
#pragma once
#include <AggregateFunctions/AggregateFunctionQuantile.h>
#include <AggregateFunctions/ReservoirSamplerDeterministic.h>
......@@ -12,7 +11,14 @@ namespace ErrorCodes
extern const int NOT_IMPLEMENTED;
}
/** */
/** Quantile calculation with "reservoir sample" algorithm.
* It collects pseudorandom subset of limited size from a stream of values,
* and approximate quantile from it.
* The function accept second argument, named "determinator"
* and a hash function from it is calculated and used as a source for randomness
* to apply random sampling.
* The function is deterministic, but care should be taken with choose of "determinator" argument.
*/
template <typename Value>
struct QuantileReservoirSamplerDeterministic
{
......
#pragma once
#include <AggregateFunctions/AggregateFunctionQuantile.h>
#include <Common/RadixSort.h>
......@@ -31,9 +30,12 @@ namespace ErrorCodes
* does not depend on the expected number of points. Also an variant on java
* uses asin, which slows down the algorithm a bit.
*/
template <typename Value, typename Count>
template <typename T>
class QuantileTDigest
{
using Value = Float32;
using Count = Float32;
/** The centroid stores the weight of points around their mean value
*/
struct Centroid
......@@ -112,20 +114,20 @@ class QuantileTDigest
/** Adds a centroid `c` to the digest
*/
void add(const Centroid & c)
void addCentroid(const Centroid & c)
{
summary.push_back(c);
count += c.count;
++unmerged;
if (unmerged >= params.max_unmerged)
compress(params);
compress();
}
/** Performs compression of accumulated centroids
* When merging, the invariant is retained to the maximum size of each
* centroid that does not exceed `4 q (1 - q) \ delta N`.
*/
void compress(const Params & params)
void compress()
{
if (unmerged > 0)
{
......@@ -190,20 +192,20 @@ class QuantileTDigest
public:
/** Adds to the digest a change in `x` with a weight of `cnt` (default 1)
*/
void add(Value x, Count cnt = 1)
void add(T x, UInt64 cnt = 1)
{
add(params, Centroid(x, cnt));
addCentroid(Centroid(Value(x), Count(cnt)));
}
void merge(const QuantileTDigest & other)
{
for (const auto & c : other.summary)
add(params, c);
addCentroid(c);
}
void serialize(WriteBuffer & buf)
{
compress(params);
compress();
writeVarUInt(summary.size(), buf);
buf.write(reinterpret_cast<const char *>(&summary[0]), summary.size() * sizeof(summary[0]));
}
......@@ -224,24 +226,24 @@ public:
* For an empty digest returns NaN.
*/
template <typename ResultType>
ResultType getImpl(Value level)
ResultType getImpl(Float64 level)
{
if (summary.empty())
return std::is_floating_point_v<ResultType> ? NAN : 0;
compress(params);
compress();
if (summary.size() == 1)
return summary.front().mean;
Value x = level * count;
Float64 x = level * count;
Float64 prev_x = 0;
Count sum = 0;
Value prev_mean = summary.front().mean;
Value prev_x = 0;
for (const auto & c : summary)
{
Value current_x = sum + c.count * 0.5;
Float64 current_x = sum + c.count * 0.5;
if (current_x >= x)
return interpolate(x, prev_x, prev_mean, current_x, c.mean);
......@@ -260,7 +262,7 @@ public:
* result - the array where the results are added, in order of `levels`,
*/
template <typename ResultType>
void getManyImpl(const Value * levels, const size_t * levels_permutation, size_t size, ResultType * result)
void getManyImpl(const Float64 * levels, const size_t * levels_permutation, size_t size, ResultType * result)
{
if (summary.empty())
{
......@@ -269,7 +271,7 @@ public:
return;
}
compress(params);
compress();
if (summary.size() == 1)
{
......@@ -278,15 +280,15 @@ public:
return;
}
Value x = levels[levels_permutation[0]] * count;
Float64 x = levels[levels_permutation[0]] * count;
Float64 prev_x = 0;
Count sum = 0;
Value prev_mean = summary.front().mean;
Value prev_x = 0;
size_t result_num = 0;
for (const auto & c : summary)
{
Value current_x = sum + c.count * 0.5;
Float64 current_x = sum + c.count * 0.5;
while (current_x >= x)
{
......@@ -309,22 +311,22 @@ public:
result[levels_permutation[result_num]] = rest_of_results;
}
Value get(Float64 level) const
T get(Float64 level)
{
return getImpl<Value>(level);
return getImpl<T>(level);
}
float getFloat(Float64 level) const
float getFloat(Float64 level)
{
return getImpl<float>(level);
}
void getMany(const Float64 * levels, const size_t * indices, size_t size, Value * result) const
void getMany(const Float64 * levels, const size_t * indices, size_t size, T * result)
{
getManyImpl(levels, indices, size, result);
}
void getManyFloat(const Float64 * levels, const size_t * indices, size_t size, float * result) const
void getManyFloat(const Float64 * levels, const size_t * indices, size_t size, float * result)
{
getManyImpl(levels, indices, size, result);
}
......
......@@ -2,7 +2,6 @@
#include <Common/HashTable/Hash.h>
#include <Common/MemoryTracker.h>
#include <AggregateFunctions/AggregateFunctionQuantile.h>
namespace DB
......@@ -443,7 +442,8 @@ namespace detail
while (index != indices_end)
{
result[*index] = BIG_THRESHOLD;
result[*index] = std::numeric_limits<ResultType>::max() < BIG_THRESHOLD
? std::numeric_limits<ResultType>::max() : BIG_THRESHOLD;
++index;
}
}
......@@ -471,6 +471,7 @@ namespace detail
/** sizeof - 64 bytes.
* If there are not enough of them - allocates up to 20 KB of memory in addition.
*/
template <typename> /// Unused template parameter is for AggregateFunctionQuantile.
class QuantileTiming : private boost::noncopyable
{
private:
......@@ -560,7 +561,7 @@ public:
}
}
void insert(UInt64 x)
void add(UInt64 x)
{
if (tiny.count < TINY_MAX_ELEMS)
{
......@@ -586,7 +587,7 @@ public:
}
}
void insert(UInt64 x, size_t weight)
void add(UInt64 x, size_t weight)
{
/// NOTE: First condition is to avoid overflow.
if (weight < TINY_MAX_ELEMS && tiny.count + weight <= TINY_MAX_ELEMS)
......
......@@ -11,11 +11,6 @@ void registerAggregateFunctionGroupArray(AggregateFunctionFactory & factory);
void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory & factory);
void registerAggregateFunctionGroupArrayInsertAt(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantile(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileExact(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileExactWeighted(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileDeterministic(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileTiming(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileTDigest(AggregateFunctionFactory & factory);
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory);
void registerAggregateFunctionsMinMaxAny(AggregateFunctionFactory & factory);
void registerAggregateFunctionsStatistics(AggregateFunctionFactory & factory);
......@@ -37,11 +32,6 @@ void registerAggregateFunctions()
registerAggregateFunctionGroupUniqArray(factory);
registerAggregateFunctionGroupArrayInsertAt(factory);
registerAggregateFunctionsQuantile(factory);
registerAggregateFunctionsQuantileExact(factory);
registerAggregateFunctionsQuantileExactWeighted(factory);
registerAggregateFunctionsQuantileDeterministic(factory);
registerAggregateFunctionsQuantileTiming(factory);
registerAggregateFunctionsQuantileTDigest(factory);
registerAggregateFunctionsSequenceMatch(factory);
registerAggregateFunctionsMinMaxAny(factory);
registerAggregateFunctionsStatistics(factory);
......
......@@ -61,6 +61,7 @@ namespace ErrorCodes
{
extern const int NO_ELEMENTS_IN_CONFIG;
extern const int SUPPORT_IS_DISABLED;
extern const int ARGUMENT_OUT_OF_BOUND;
}
......@@ -152,9 +153,9 @@ int Server::main(const std::vector<std::string> & /*args*/)
int rc = setrlimit(RLIMIT_NOFILE, &rlim);
if (rc != 0)
LOG_WARNING(log,
std::string("Cannot set max number of file descriptors to ") + std::to_string(rlim.rlim_cur)
+ ". Try to specify max_open_files according to your system limits. error: "
+ strerror(errno));
"Cannot set max number of file descriptors to " << rlim.rlim_cur
<< ". Try to specify max_open_files according to your system limits. error: "
<< strerror(errno));
else
LOG_DEBUG(log, "Set max number of file descriptors to " << rlim.rlim_cur << " (was " << old << ").");
}
......@@ -491,7 +492,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
LOG_DEBUG(log,
"Closed all listening sockets."
<< (current_connections ? " Waiting for " + std::to_string(current_connections) + " outstanding connections." : ""));
<< (current_connections ? " Waiting for " + toString(current_connections) + " outstanding connections." : ""));
if (current_connections)
{
......@@ -511,7 +512,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
}
LOG_DEBUG(
log, "Closed connections." << (current_connections ? " But " + std::to_string(current_connections) + " remains."
log, "Closed connections." << (current_connections ? " But " + toString(current_connections) + " remains."
" Tip: To increase wait time add to config: <shutdown_wait_unfinished>60</shutdown_wait_unfinished>" : ""));
main_config_reloader.reset();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册