diff --git a/dbms/src/Functions/runningAccumulate.cpp b/dbms/src/Functions/runningAccumulate.cpp index 761f3692e3d9269550a39422b5a4aa2666bf43da..53dc5e197774b7824b8a2705e8432bae568da942 100644 --- a/dbms/src/Functions/runningAccumulate.cpp +++ b/dbms/src/Functions/runningAccumulate.cpp @@ -60,6 +60,10 @@ public: DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { + if (arguments.size() < 1 || arguments.size() > 2) + throw Exception("Incorrect number of arguments of function " + getName() + ". Must be 1 or 2.", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + const DataTypeAggregateFunction * type = checkAndGetDataType(arguments[0].get()); if (!type) throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.", @@ -70,12 +74,6 @@ public: void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override { - size_t number_of_arguments = arguments.size(); - - if (number_of_arguments == 0) - throw Exception("Incorrect number of arguments of function " + getName() + ". Must be 1 or 2.", - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - const ColumnAggregateFunction * column_with_states = typeid_cast(&*block.getByPosition(arguments.at(0)).column); @@ -87,7 +85,7 @@ public: ColumnPtr column_with_groups; - if (number_of_arguments == 2) + if (arguments.size() == 2) column_with_groups = block.getByPosition(arguments[1]).column; AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction(); @@ -104,37 +102,31 @@ public: const auto & states = column_with_states->getData(); - size_t i = 0; - + bool state_created = false; SCOPE_EXIT({ - if (i > 0) + if (state_created) agg_func.destroy(place.data()); }); + size_t row_number = 0; for (const auto & state_to_add : states) { - if (i == 0 || (column_with_groups && column_with_groups->compareAt(i, i - 1, *column_with_groups, 1) != 0)) + if (row_number == 0 || (column_with_groups && column_with_groups->compareAt(row_number, row_number - 1, *column_with_groups, 1) != 0)) { - if (i > 0) - agg_func.destroy(place.data()); - - try + if (state_created) { - agg_func.create(place.data()); + agg_func.destroy(place.data()); + state_created = false; } - catch (...) - { - // prevent destroy after creation failure - i = 0; - throw; - } + agg_func.create(place.data()); + state_created = true; } agg_func.merge(place.data(), state_to_add, arena.get()); agg_func.insertResultInto(place.data(), result_column); - ++i; + ++row_number; } block.getByPosition(result).column = std::move(result_column_ptr);