未验证 提交 5f3d7cac 编写于 作者: N Nikolai Kochetov 提交者: GitHub

Merge pull request #7898 from apbodrov/avg_weighted

avgWeighted
......@@ -14,7 +14,7 @@ template <typename T>
struct Avg
{
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>;
using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType>>;
using Function = AggregateFunctionAvg<T, AggregateFunctionAvgData<FieldType, UInt64>>;
};
template <typename T>
......
#pragma once
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/IAggregateFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
template <typename T>
template <typename T, typename Denominator>
struct AggregateFunctionAvgData
{
T sum = 0;
UInt64 count = 0;
T numerator = 0;
Denominator denominator = 0;
template <typename ResultT>
ResultT NO_SANITIZE_UNDEFINED result() const
{
if constexpr (std::is_floating_point_v<ResultT>)
if constexpr (std::numeric_limits<ResultT>::is_iec559)
return static_cast<ResultT>(sum) / count; /// allow division by zero
return static_cast<ResultT>(numerator) / denominator; /// allow division by zero
if (count == 0)
if (denominator == 0)
return static_cast<ResultT>(0);
return static_cast<ResultT>(sum / count);
return static_cast<ResultT>(numerator / denominator);
}
};
/// Calculates arithmetic mean of numbers.
template <typename T, typename Data>
class AggregateFunctionAvg final : public IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>
template <typename T, typename Data, typename Derived>
class AggregateFunctionAvgBase : public IAggregateFunctionDataHelper<Data, Derived>
{
public:
using ResultType = std::conditional_t<IsDecimalNumber<T>, T, Float64>;
......@@ -49,18 +47,13 @@ public:
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<Float64>>;
/// ctor for native types
AggregateFunctionAvg(const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_, {})
, scale(0)
{}
AggregateFunctionAvgBase(const DataTypes & argument_types_) : IAggregateFunctionDataHelper<Data, Derived>(argument_types_, {}), scale(0) {}
/// ctor for Decimals
AggregateFunctionAvg(const IDataType & data_type, const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, AggregateFunctionAvg<T, Data>>(argument_types_, {})
, scale(getDecimalScale(data_type))
{}
String getName() const override { return "avg"; }
AggregateFunctionAvgBase(const IDataType & data_type, const DataTypes & argument_types_)
: IAggregateFunctionDataHelper<Data, Derived>(argument_types_, {}), scale(getDecimalScale(data_type))
{
}
DataTypePtr getReturnType() const override
{
......@@ -70,29 +63,22 @@ public:
return std::make_shared<ResultDataType>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & column = static_cast<const ColVecType &>(*columns[0]);
this->data(place).sum += column.getData()[row_num];
++this->data(place).count;
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).sum += this->data(rhs).sum;
this->data(place).count += this->data(rhs).count;
this->data(place).numerator += this->data(rhs).numerator;
this->data(place).denominator += this->data(rhs).denominator;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
writeBinary(this->data(place).sum, buf);
writeVarUInt(this->data(place).count, buf);
writeBinary(this->data(place).numerator, buf);
writeVarUInt(this->data(place).denominator, buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
readBinary(this->data(place).sum, buf);
readVarUInt(this->data(place).count, buf);
readBinary(this->data(place).numerator, buf);
readVarUInt(this->data(place).denominator, buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
......@@ -103,9 +89,25 @@ public:
const char * getHeaderFilePath() const override { return __FILE__; }
private:
protected:
UInt32 scale;
};
template <typename T, typename Data>
class AggregateFunctionAvg final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvg<T, Data>>
{
public:
using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvg<T, Data>>::AggregateFunctionAvgBase;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & column = static_cast<const ColVecType &>(*columns[0]);
this->data(place).numerator += column.getData()[row_num];
this->data(place).denominator += 1;
}
String getName() const override { return "avg"; }
};
}
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionAvgWeighted.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include "registerAggregateFunctions.h"
namespace DB
{
namespace
{
template <typename T>
struct AvgWeighted
{
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128, NearestFieldType<T>>;
using Function = AggregateFunctionAvgWeighted<T, AggregateFunctionAvgData<FieldType, FieldType>>;
};
template <typename T>
using AggregateFuncAvgWeighted = typename AvgWeighted<T>::Function;
AggregateFunctionPtr createAggregateFunctionAvgWeighted(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
assertBinary(name, argument_types);
AggregateFunctionPtr res;
const auto data_type = static_cast<const DataTypePtr>(argument_types[0]);
const auto data_type_weight = static_cast<const DataTypePtr>(argument_types[1]);
if (!data_type->equals(*data_type_weight))
throw Exception("Different types " + data_type->getName() + " and " + data_type_weight->getName() + " of arguments for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (isDecimal(data_type))
res.reset(createWithDecimalType<AggregateFuncAvgWeighted>(*data_type, *data_type, argument_types));
else
res.reset(createWithNumericType<AggregateFuncAvgWeighted>(*data_type, argument_types));
if (!res)
throw Exception("Illegal type " + data_type->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
}
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory & factory)
{
factory.registerFunction("avgWeighted", createAggregateFunctionAvgWeighted, AggregateFunctionFactory::CaseSensitive);
}
}
#pragma once
#include <AggregateFunctions/AggregateFunctionAvg.h>
namespace DB
{
template <typename T, typename Data>
class AggregateFunctionAvgWeighted final : public AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>
{
public:
using AggregateFunctionAvgBase<T, Data, AggregateFunctionAvgWeighted<T, Data>>::AggregateFunctionAvgBase;
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & values = static_cast<const ColVecType &>(*columns[0]);
const auto & weights = static_cast<const ColVecType &>(*columns[1]);
this->data(place).numerator += values.getData()[row_num] * weights.getData()[row_num];
this->data(place).denominator += weights.getData()[row_num];
}
String getName() const override { return "avgWeighted"; }
};
}
......@@ -13,6 +13,7 @@ void registerAggregateFunctions()
auto & factory = AggregateFunctionFactory::instance();
registerAggregateFunctionAvg(factory);
registerAggregateFunctionAvgWeighted(factory);
registerAggregateFunctionCount(factory);
registerAggregateFunctionGroupArray(factory);
registerAggregateFunctionGroupUniqArray(factory);
......
......@@ -5,6 +5,7 @@ namespace DB
class AggregateFunctionFactory;
void registerAggregateFunctionAvg(AggregateFunctionFactory &);
void registerAggregateFunctionAvgWeighted(AggregateFunctionFactory &);
void registerAggregateFunctionCount(AggregateFunctionFactory &);
void registerAggregateFunctionGroupArray(AggregateFunctionFactory &);
void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory &);
......
#!/usr/bin/env bash
CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
. $CUR_DIR/../shell_config.sh
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 5), (2, 4), (3, 3), (4, 2), (5, 1)]) AS t));"
${CLICKHOUSE_CLIENT} --query="SELECT avgWeighted(x, weight) FROM (SELECT t.1 AS x, t.2 AS weight FROM (SELECT arrayJoin([(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]) AS t));"
echo `${CLICKHOUSE_CLIENT} --server_logs_file=/dev/null --query="SELECT avgWeighted(toDecimal64(0, 0), toFloat64(0))" 2>&1` \
| grep -c 'Code: 43. DB::Exception: .* DB::Exception:.* Different types .* of arguments for aggregate function avgWeighted'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册