提交 f4f2ed7a 编写于 作者: A Alexey Milovidov

Fixed errors: checking the number of arguments; managing of state #8326

上级 4f052954
......@@ -60,6 +60,10 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (arguments.size() < 1 || arguments.size() > 2)
throw Exception("Incorrect number of arguments of function " + getName() + ". Must be 1 or 2.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const DataTypeAggregateFunction * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
if (!type)
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
......@@ -70,12 +74,6 @@ public:
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
{
size_t number_of_arguments = arguments.size();
if (number_of_arguments == 0)
throw Exception("Incorrect number of arguments of function " + getName() + ". Must be 1 or 2.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const ColumnAggregateFunction * column_with_states
= typeid_cast<const ColumnAggregateFunction *>(&*block.getByPosition(arguments.at(0)).column);
......@@ -87,7 +85,7 @@ public:
ColumnPtr column_with_groups;
if (number_of_arguments == 2)
if (arguments.size() == 2)
column_with_groups = block.getByPosition(arguments[1]).column;
AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction();
......@@ -104,37 +102,31 @@ public:
const auto & states = column_with_states->getData();
size_t i = 0;
bool state_created = false;
SCOPE_EXIT({
if (i > 0)
if (state_created)
agg_func.destroy(place.data());
});
size_t row_number = 0;
for (const auto & state_to_add : states)
{
if (i == 0 || (column_with_groups && column_with_groups->compareAt(i, i - 1, *column_with_groups, 1) != 0))
if (row_number == 0 || (column_with_groups && column_with_groups->compareAt(row_number, row_number - 1, *column_with_groups, 1) != 0))
{
if (i > 0)
agg_func.destroy(place.data());
try
if (state_created)
{
agg_func.create(place.data());
agg_func.destroy(place.data());
state_created = false;
}
catch (...)
{
// prevent destroy after creation failure
i = 0;
throw;
}
agg_func.create(place.data());
state_created = true;
}
agg_func.merge(place.data(), state_to_add, arena.get());
agg_func.insertResultInto(place.data(), result_column);
++i;
++row_number;
}
block.getByPosition(result).column = std::move(result_column_ptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册