diff --git a/dbms/include/DB/Columns/ColumnArray.h b/dbms/include/DB/Columns/ColumnArray.h index 889210ce22153550aa0429e087654df910b57f61..03f2b907bae9a41fe0986ab02f101bce63666792 100644 --- a/dbms/include/DB/Columns/ColumnArray.h +++ b/dbms/include/DB/Columns/ColumnArray.h @@ -27,9 +27,18 @@ public: typedef ColumnVector ColumnOffsets_t; /** Создать пустой столбец массивов, с типом значений, как в столбце nested_column */ - ColumnArray(ColumnPtr nested_column) - : data(nested_column), offsets(new ColumnOffsets_t) + explicit ColumnArray(ColumnPtr nested_column, ColumnPtr offsets_column = NULL) + : data(nested_column), offsets(offsets_column) { + if (!offsets_column) + { + offsets = new ColumnOffsets_t; + } + else + { + if (!dynamic_cast(&*offsets_column)) + throw Exception("offsets_column must be a ColumnVector", ErrorCodes::ILLEGAL_COLUMN); + } } std::string getName() const { return "ColumnArray(" + data->getName() + ")"; } diff --git a/dbms/include/DB/Columns/ColumnExpression.h b/dbms/include/DB/Columns/ColumnExpression.h index aef535d6d515e2211ef821e1906d5884364474cc..57bd2f69ce8a9549818f9c023314f22a4f3ef41e 100644 --- a/dbms/include/DB/Columns/ColumnExpression.h +++ b/dbms/include/DB/Columns/ColumnExpression.h @@ -13,8 +13,8 @@ namespace DB class ColumnExpression : public IColumnDummy { public: - ColumnExpression(size_t s_, ExpressionPtr expression_, const NamesAndTypes & arguments_, DataTypePtr return_type_) - : IColumnDummy(s_), expression(expression_), arguments(arguments_), return_type(return_type_) {} + ColumnExpression(size_t s_, ExpressionPtr expression_, const NamesAndTypes & arguments_, DataTypePtr return_type_, std::string return_name_) + : IColumnDummy(s_), expression(expression_), arguments(arguments_), return_type(return_type_), return_name(return_name_) {} std::string getName() const { return "ColumnExpression"; } ColumnPtr cloneDummy(size_t s_) const { return new ColumnExpression(s_, expression, arguments, return_type); } @@ -22,11 +22,13 @@ public: ExpressionPtr & getExpression() { return expression; } const NamesAndTypes & getArguments() const { return arguments; } const DataTypePtr & getReturnType() const { return return_type; } + const std::string & getReturnName() const { return return_name; } private: ExpressionPtr expression; NamesAndTypes arguments; DataTypePtr return_type; + std::string return_name; }; } diff --git a/dbms/include/DB/Columns/IColumnDummy.h b/dbms/include/DB/Columns/IColumnDummy.h index 5ae12dae1dd0fde00cc799bbcb3a0293d69878a7..ee06221ddfcffc33927f7364fa86b3c7de082e78 100644 --- a/dbms/include/DB/Columns/IColumnDummy.h +++ b/dbms/include/DB/Columns/IColumnDummy.h @@ -18,13 +18,15 @@ public: virtual ColumnPtr cloneDummy(size_t s_) const = 0; ColumnPtr cloneEmpty() const { return cloneDummy(0); } + bool isConst() { return true; } size_t size() const { return s; } - Field operator[](size_t n) const { throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - void get(size_t n, Field & res) const { throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED); }; - void insert(const Field & x) { throw Exception("Cannot insert element into " + getName(), ErrorCodes::NOT_IMPLEMENTED); } void insertDefault() { ++s; } size_t byteSize() const { return 0; } int compareAt(size_t n, size_t m, const IColumn & rhs_) const { return 0; } + + Field operator[](size_t n) const { throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED); } + void get(size_t n, Field & res) const { throw Exception("Cannot get value from " + getName(), ErrorCodes::NOT_IMPLEMENTED); }; + void insert(const Field & x) { throw Exception("Cannot insert element into " + getName(), ErrorCodes::NOT_IMPLEMENTED); } StringRef getDataAt(size_t n) const { throw Exception("Method getDataAt is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } void insertData(const char * pos, size_t length) { throw Exception("Method insertData is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } diff --git a/dbms/include/DB/Functions/FunctionsHigherOrder.h b/dbms/include/DB/Functions/FunctionsHigherOrder.h new file mode 100644 index 0000000000000000000000000000000000000000..0a1080afd5e6e7ad0fb9b29112f4f4ded6fc6f3e --- /dev/null +++ b/dbms/include/DB/Functions/FunctionsHigherOrder.h @@ -0,0 +1,264 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include + + +namespace DB +{ + +/** Функции высшего порядка для массивов: + * + * arrayMap(x -> expression, array) - применить выражение к каждому элементу массива. + * arrayFilter(x -> predicate, array) - оставить в массиве только элементы, для которых выражение истинно. + * arrayCount(x -> predicate, array) - для скольки элементов массива выражение истинно. + * arrayExists(x -> predicate, array) - истинно ли выражение для хотя бы одного элемента массива. + */ + +struct ArrayMapImpl +{ + static bool needBooleanExpression() { return false; } + + static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & array_element) + { + return new DataTypeArray(expression_return); + } + + static ColumnPtr execute(ColumnArray * array, ColumnPtr mapped) + { + return new ColumnArray(mapped, array->getOffsetsColumn()); + } +}; + +struct ArrayFilterImpl +{ + static bool needBooleanExpression() { return true; } + + static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & array_element) + { + if (!dynamic_cast(&*expression_return)) + throw Exception("Filter expression must be UInt8", ErrorCodes::TYPE_MISMATCH); + return new DataTypeArray(array_element); + } + + static ColumnPtr execute(ColumnArray * array, ColumnPtr mapped) + { + ColumnVector * column_filter = dynamic_cast *>(&*mapped); + if (!column_filter) + throw Exception("Unexpected type of filter column", ErrorCodes::ILLEGAL_COLUMN); + + const IColumn::Filter & filter = column_filter->getData(); + ColumnPtr filtered = array->getData().filter(filter); + + const IColumn::Offsets_t & in_offsets = array->getOffsets(); + ColumnArray::ColumnOffsets_t * column_offsets = new ColumnOffsets_t(in_offsets.size()); + ColumnPtr column_offsets_ptr = column_offsets; + IColumn::Offsets_t & out_offsets = column_offsets->getData(); + + size_t in_pos = 0; + size_t out_pos = 0; + for (size_t i = 0; i < in_offsets.size(); ++i) + { + for (; in_pos < in_offsets[i]; ++in_pos) + { + if (filter[in_pos]) + ++out_pos; + } + out_offsets[i] = out_pos; + } + + return new ColumnArray(filtered, column_offsets_ptr); + } +}; + +struct ArrayCountImpl +{ + static bool needBooleanExpression() { return true; } + + static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & array_element) + { + if (!dynamic_cast(&*expression_return)) + throw Exception("Filter expression must be UInt8", ErrorCodes::TYPE_MISMATCH); + return new DataUInt32; + } + + static ColumnPtr execute(ColumnArray * array, ColumnPtr mapped) + { + ColumnVector * column_filter = dynamic_cast *>(&*mapped); + if (!column_filter) + throw Exception("Unexpected type of filter column", ErrorCodes::ILLEGAL_COLUMN); + + const IColumn::Filter & filter = column_filter->getData(); + const IColumn::Offsets_t & offsets = array->getOffsets(); + ColumnVector & out_column = new ColumnVector(offsets.size()); + ColumnPtr out_column_ptr = out_column; + ColumnVector::Container_t & out_counts = out_column->getData(); + + size_t pos = 0; + for (size_t i = 0; i < offsets.size(); ++i) + { + size_t count = 0; + for (; pos < offsets[i]; ++pos) + { + if (filter[pos]) + ++count; + } + out_counts[i] = count; + } + + return out_column_ptr; + } +}; + +struct ArrayExistsImpl +{ + static bool needBooleanExpression() { return true; } + + static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypePtr & array_element) + { + if (!dynamic_cast(&*expression_return)) + throw Exception("Filter expression must be UInt8", ErrorCodes::TYPE_MISMATCH); + return new DataUInt8; + } + + static ColumnPtr execute(ColumnArray * array, ColumnPtr mapped) + { + ColumnVector * column_filter = dynamic_cast *>(&*mapped); + if (!column_filter) + throw Exception("Unexpected type of filter column", ErrorCodes::ILLEGAL_COLUMN); + + const IColumn::Filter & filter = column_filter->getData(); + const IColumn::Offsets_t & offsets = array->getOffsets(); + ColumnVector & out_column = new ColumnVector(offsets.size()); + ColumnPtr out_column_ptr = out_column; + ColumnVector::Container_t & out_exists = out_column->getData(); + + size_t pos = 0; + for (size_t i = 0; i < offsets.size(); ++i) + { + UInt8 exists = 0; + for (; pos < offsets[i]; ++pos) + { + if (filter[pos]) + { + exists = true; + break; + } + } + out_exists[i] = exists; + } + + return out_column_ptr; + } +}; + +template +class FunctionArrayMapped : public IFunction +{ +public: + /// Получить имя функции. + String getName() const + { + return Name::get(); + } + + void checkTypes(const DataTypes & arguments, const DataTypeExpression *& expression_type, const DataTypeArray *& array_type) { + if (arguments.size() != 2) + throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " + + Poco::NumberFormatter::format(arguments.size()) + ", should be 2.", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + array_type = dynamic_cast(&*arguments[1]); + if (!array_type) + throw Exception("Second argument for function " + getName() + " must be array. Found " + arguments[1]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + expression_type = dynamic_cast(&*arguments[0]); + if (!expression_type || expression_type.getArgumentTypes().size() != 1) + throw Exception("First argument for function " + getName() + " must be an expression with one argument. Found " + arguments[0]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + + /// Вызывается, если хоть один агрумент функции - лямбда-выражение. + /// Для аргументов-лямбда-выражений определяет типы аргументов этих выражений. + void getLambdaArgumentTypes(DataTypes & arguments) const + { + const DataTypeArray * array_type; + const DataTypeExpression * expression_type; + checkTypes(arguments, expression_type, array_type); + arguments[0] = new DataTypeExpression(DataTypes(1, array_type->getNestedType())); + } + + /// Получить типы результата по типам аргументов. Если функция неприменима для данных аргументов - кинуть исключение. + DataTypePtr getReturnType(const DataTypes & arguments) const + { + const DataTypeArray * array_type; + const DataTypeExpression * expression_type; + checkTypes(arguments, expression_type, array_type); + + if (Impl::needBooleanExpression() && !dynamic_cast(&*expression_type->getReturnType())) + throw Exception("Expression for function " + getName() + " must return UInt8, found " + expression_type->getReturnType()->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return Impl::getReturnType(expression_type->getReturnType(), array_type->getNestedType()); + } + + /// Выполнить функцию над блоком. + void execute(Block & block, const ColumnNumbers & arguments, size_t result) + { + const ColumnArray * column_array = dynamic_cast(&*block.getByPosition(arguments[1]).column); + ColumnExpression * column_expression = dynamic_cast(&*block.getByPosition(arguments[1]).column); + + Block temp_block; + Expression & expression = column_expression->getExpression(); + Name argument_name = column_expression->getArguments()[0].first; + DataTypePtr element_type = column_expression->getArguments()[0].second; + /// Положим в блок аргумент выражения. + temp_block.insert(ColumnWithNameAndType(column_array->getDataPtr(), argument_name, element_type)); + + Names columns = expression.getRequiredColumns(); + + /// Положим в блок все нужные столбцы, размноженные по размерам массивов. + for (size_t i = 0; i < columns.size(); ++i) + { + const String & name = columns[i]; + if (name == argument_name) + continue; + String replicated_name = "replicate(" + name + "," + block.getByPosition(arguments[1]).name + ")"; + ColumnWithNameAndType replicated_column; + if (block.has(replicated_name)) + { + replicated_column = block.getByName(replicated_name); + } + else + { + ColumnWithNameAndType needed_column = block.getByName(name); + replicated_column = + ColumnWithNameAndType(needed_column->replicate(column_array->getOffsets()), needed_column.type, replicated_name); + block.insert(replicated_column); + } + replicated_column.name = name; + temp_block.insert(replicated_column); + } + + expression.execute(temp_block); + + block.getByPosition(result).column = Impl::execute(temp_block.getByName(column_expression->getReturnName()).column); + } +}; + + +struct NameArrayMap { static const char * get() { return "arrayMap"; } }; +struct NameArrayFilter { static const char * get() { return "arrayFilter"; } }; +struct NameArrayCount { static const char * get() { return "arrayCount"; } }; +struct NameArrayExists { static const char * get() { return "arrayExists"; } }; + +typedef FunctionArrayMapped FunctionArrayMap; +typedef FunctionArrayMapped FunctionArrayFilter; +typedef FunctionArrayMapped FunctionArrayCount; +typedef FunctionArrayMapped FunctionArrayExists; + +} diff --git a/dbms/include/DB/Functions/IFunction.h b/dbms/include/DB/Functions/IFunction.h index d64b1a3b7567ccb329666944206604cd09f8974f..a608e4d13e93cdadf487812e5faa11959002dd44 100644 --- a/dbms/include/DB/Functions/IFunction.h +++ b/dbms/include/DB/Functions/IFunction.h @@ -34,7 +34,7 @@ public: virtual DataTypePtr getReturnType(const DataTypes & arguments) const = 0; /// Вызывается, если хоть один агрумент функции - лямбда-выражение. - /// Для аргументов-лямбда-выражений определяет типы аргументов этих выражений. + /// Для аргументов-лямбда-выражений определяет типы аргументов этих выражений и кладет результат в arguments. virtual void getLambdaArgumentTypes(DataTypes & arguments) const { throw Exception("Function " + getName() + " can't have lambda-expressions as arguments", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/include/DB/Interpreters/Expression.h b/dbms/include/DB/Interpreters/Expression.h index c90a0b23a6a35877ab422409a06fe3055dd3b694..d4e331410d2965b51927de9af27645d05e4d2af2 100644 --- a/dbms/include/DB/Interpreters/Expression.h +++ b/dbms/include/DB/Interpreters/Expression.h @@ -131,7 +131,6 @@ private: typedef std::set SetOfASTs; typedef std::map MapOfASTs; - void init() { createAliasesDict(ast);