From 8936f854be64581f5c404fe99d7f51ff70d09380 Mon Sep 17 00:00:00 2001 From: Alexander Tokmakov Date: Wed, 4 Sep 2019 19:54:20 +0300 Subject: [PATCH] better checking of types of literals --- dbms/src/Common/ErrorCodes.cpp | 2 - dbms/src/Functions/IFunction.h | 1 - .../Impl/ConstantExpressionTemplate.cpp | 286 ++++++++++++------ .../Formats/Impl/ConstantExpressionTemplate.h | 18 +- .../Formats/Impl/ValuesBlockInputFormat.cpp | 4 +- 5 files changed, 214 insertions(+), 97 deletions(-) diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index 800315c8d7..1001dee1db 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -449,9 +449,7 @@ namespace ErrorCodes extern const int READONLY_SETTING = 472; extern const int DEADLOCK_AVOIDED = 473; extern const int INVALID_TEMPLATE_FORMAT = 474; - extern const int CANNOT_CREATE_EXPRESSION_TEMPLATE = 475; extern const int CANNOT_PARSE_EXPRESSION_USING_TEMPLATE = 476; - extern const int CANNOT_EVALUATE_EXPRESSION_TEMPLATE = 477; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Functions/IFunction.h b/dbms/src/Functions/IFunction.h index 66975d7ead..0a180ee03e 100644 --- a/dbms/src/Functions/IFunction.h +++ b/dbms/src/Functions/IFunction.h @@ -501,7 +501,6 @@ public: bool isVariadic() const override { return function->isVariadic(); } size_t getNumberOfArguments() const override { return function->getNumberOfArguments(); } - // FIXME it's a temporary hack for ConstantExpressionTemplate ColumnNumbers getArgumentsThatAreAlwaysConstant() const { return function->getArgumentsThatAreAlwaysConstant(); } protected: diff --git a/dbms/src/Processors/Formats/Impl/ConstantExpressionTemplate.cpp b/dbms/src/Processors/Formats/Impl/ConstantExpressionTemplate.cpp index 23ff0771ab..ff23e01a64 100644 --- a/dbms/src/Processors/Formats/Impl/ConstantExpressionTemplate.cpp +++ b/dbms/src/Processors/Formats/Impl/ConstantExpressionTemplate.cpp @@ -1,4 +1,3 @@ - #include #include #include @@ -23,52 +22,53 @@ namespace DB namespace ErrorCodes { - extern const int CANNOT_CREATE_EXPRESSION_TEMPLATE; extern const int CANNOT_PARSE_EXPRESSION_USING_TEMPLATE; - extern const int CANNOT_EVALUATE_EXPRESSION_TEMPLATE; extern const int SYNTAX_ERROR; } -class ReplaceLiteralsVisitor +struct LiteralInfo { -public: - - struct LiteralInfo - { - typedef std::shared_ptr ASTLiteralPtr; - LiteralInfo(const ASTLiteralPtr & literal_, const String & column_name_) : literal(literal_), dummy_column_name(column_name_) { } - ASTLiteralPtr literal; - String dummy_column_name; - }; + typedef std::shared_ptr ASTLiteralPtr; + LiteralInfo(const ASTLiteralPtr & literal_, const String & column_name_, bool force_nullable_) + : literal(literal_), dummy_column_name(column_name_), force_nullable(force_nullable_) { } + ASTLiteralPtr literal; + String dummy_column_name; + /// Make column nullable even if expression type is not. + /// (for literals in functions like ifNull and assumeNotNul, which never return NULL even for NULL arguments) + bool force_nullable; +}; - using LiteralsInfo = std::vector; +using LiteralsInfo = std::vector; +class ReplaceLiteralsVisitor +{ +public: LiteralsInfo replaced_literals; const Context & context; explicit ReplaceLiteralsVisitor(const Context & context_) : context(context_) { } - void visit(ASTPtr & ast) + void visit(ASTPtr & ast, bool force_nullable = {}) { - if (visitIfLiteral(ast)) + if (visitIfLiteral(ast, force_nullable)) return; if (auto function = ast->as()) - visit(*function); + visit(*function, force_nullable); else if (ast->as()) throw DB::Exception("Identifier in constant expression", ErrorCodes::SYNTAX_ERROR); else - visitChildren(ast, {}); + visitChildren(ast, {}, std::vector(ast->children.size(), force_nullable)); } private: - void visitChildren(ASTPtr & ast, const ColumnNumbers & dont_visit_children) + void visitChildren(ASTPtr & ast, const ColumnNumbers & dont_visit_children, const std::vector & force_nullable) { for (size_t i = 0; i < ast->children.size(); ++i) if (std::find(dont_visit_children.begin(), dont_visit_children.end(), i) == dont_visit_children.end()) - visit(ast->children[i]); + visit(ast->children[i], force_nullable[i]); } - void visit(ASTFunction & function) + void visit(ASTFunction & function, bool force_nullable) { /// Do not replace literals which must be constant ColumnNumbers dont_visit_children; @@ -78,53 +78,70 @@ private: dont_visit_children = default_builder->getArgumentsThatAreAlwaysConstant(); else if (dynamic_cast(builder.get())) dont_visit_children.push_back(1); + /// FIXME suppose there is no other functions, which require constant arguments (it's true, until the new one is added) + + /// Allow nullable arguments if function never returns NULL + bool return_not_null = function.name == "isNull" || function.name == "isNotNull" || + function.name == "ifNull" || function.name == "assumeNotNull" || + function.name == "coalesce"; + + std::vector force_nullable_arguments(function.arguments->children.size(), force_nullable || return_not_null); - visitChildren(function.arguments, dont_visit_children); + /// coalesce may return NULL if the last argument is NULL + if (!force_nullable && function.name == "coalesce") + force_nullable_arguments.back() = false; + + visitChildren(function.arguments, dont_visit_children, force_nullable_arguments); } - bool visitIfLiteral(ASTPtr & ast) + bool visitIfLiteral(ASTPtr & ast, bool force_nullable) { auto literal = std::dynamic_pointer_cast(ast); if (!literal) return false; if (literal->begin && literal->end) { + /// Do not replace empty array + if (literal->value.getType() == Field::Types::Array && literal->value.get().empty()) + return false; String column_name = "_dummy_" + std::to_string(replaced_literals.size()); - replaced_literals.emplace_back(literal, column_name); + replaced_literals.emplace_back(literal, column_name, force_nullable); ast = std::make_shared(column_name); } return true; } }; -using LiteralInfo = ReplaceLiteralsVisitor::LiteralInfo; -using LiteralsInfo = ReplaceLiteralsVisitor::LiteralsInfo; - -ConstantExpressionTemplate::ConstantExpressionTemplate(const IDataType & result_column_type, +/// Expression template is a sequence of tokens and data types of literals. +/// E.g. template of "position('some string', 'other string') != 0" is +/// ["position", "(", DataTypeString, ",", DataTypeString, ")", "!=", DataTypeUInt64] +ConstantExpressionTemplate::ConstantExpressionTemplate(DataTypePtr result_column_type_, TokenIterator expression_begin, TokenIterator expression_end, const ASTPtr & expression_, const Context & context) + : result_column_type(result_column_type_) { + /// Extract ASTLiterals from expression and replace them with ASTIdentifiers where needed ASTPtr expression = expression_->clone(); ReplaceLiteralsVisitor visitor(context); visitor.visit(expression); - LiteralsInfo replaced_literals = visitor.replaced_literals; - - token_after_literal_idx.reserve(replaced_literals.size()); - need_special_parser.resize(replaced_literals.size(), true); + LiteralsInfo & replaced_literals = visitor.replaced_literals; std::sort(replaced_literals.begin(), replaced_literals.end(), [](const LiteralInfo & a, const LiteralInfo & b) { return a.literal->begin.value() < b.literal->begin.value(); }); - bool allow_nulls = result_column_type.isNullable(); + /// Make sequence of tokens and determine IDataType by Field::Types:Which for each literal. + token_after_literal_idx.reserve(replaced_literals.size()); + use_special_parser.resize(replaced_literals.size(), true); + TokenIterator prev_end = expression_begin; for (size_t i = 0; i < replaced_literals.size(); ++i) { const LiteralInfo & info = replaced_literals[i]; if (info.literal->begin.value() < prev_end) - throw Exception("Cannot replace literals", ErrorCodes::CANNOT_CREATE_EXPRESSION_TEMPLATE); + throw Exception("Cannot replace literals", ErrorCodes::LOGICAL_ERROR); while (prev_end < info.literal->begin.value()) { @@ -133,21 +150,8 @@ ConstantExpressionTemplate::ConstantExpressionTemplate(const IDataType & result_ } token_after_literal_idx.push_back(tokens.size()); - DataTypePtr type = applyVisitor(FieldToDataType(), info.literal->value); - - WhichDataType type_info(type); - if (type_info.isNativeInt()) - type = std::make_shared(); - else if (type_info.isNativeUInt()) - type = std::make_shared(); - else if (!type_info.isFloat()) - need_special_parser[i] = false; - - /// Allow literal to be NULL, if result column has nullable type - // TODO also allow NULL literals inside functions, which return not NULL for NULL arguments, - // even if result_column_type is not nullable - if (allow_nulls && type->canBeInsideNullable()) - type = makeNullable(type); + DataTypePtr type; + use_special_parser[i] = getDataType(info, type); literals.insert({nullptr, type, info.dummy_column_name}); @@ -161,68 +165,102 @@ ConstantExpressionTemplate::ConstantExpressionTemplate(const IDataType & result_ } columns = literals.cloneEmptyColumns(); - addNodesToCastResult(result_column_type, expression); + addNodesToCastResult(*result_column_type, expression); result_column_name = expression->getColumnName(); - // TODO convert SyntaxAnalyzer and ExpressionAnalyzer exceptions to CANNOT_CREATE_EXPRESSION_TEMPLATE auto syntax_result = SyntaxAnalyzer(context).analyze(expression, literals.getNamesAndTypesList()); - actions_on_literals = ExpressionAnalyzer(expression, syntax_result, context).getActions(false); } +bool ConstantExpressionTemplate::getDataType(const LiteralInfo & info, DataTypePtr & type) const +{ + /// Type (Field::Types:Which) of literal in AST can be: String, UInt64, Int64, Float64, Null or Array of simple literals (not of Arrays). + /// Null and empty Array literals are considered as tokens, because template with Nullable or Array is useless. + + Field::Types::Which field_type = info.literal->value.getType(); + + /// We have to use ParserNumber instead of type->deserializeAsTextQuoted() for arithmetic types + /// to check actual type of literal and avoid possible overflow and precision issues. + bool need_special_parser = true; + + /// Do not use 8, 16 and 32 bit types, so template will match all integers + if (field_type == Field::Types::UInt64) + type = std::make_shared(); + else if (field_type == Field::Types::Int64) + type = std::make_shared(); + else if (field_type == Field::Types::Float64) + type = std::make_shared(); + else if (field_type == Field::Types::String) + { + need_special_parser = false; + type = std::make_shared(); + } + else if (field_type == Field::Types::Array) + { + type = applyVisitor(FieldToDataType(), info.literal->value); + auto nested_type = dynamic_cast(*type).getNestedType(); + + /// It can be Array> + bool array_of_nullable = false; + if (auto nullable = dynamic_cast(type.get())) + { + nested_type = nullable->getNestedType(); + array_of_nullable = true; + } + + WhichDataType type_info{nested_type}; + /// Promote integers to 64 bit types + if (type_info.isNativeUInt()) + nested_type = std::make_shared(); + else if (type_info.isNativeInt()) + nested_type = std::make_shared(); + else if (type_info.isFloat64()) + ; + else if (type_info.isString()) + need_special_parser = false; + else + throw Exception("Unexpected literal type inside Array: " + nested_type->getName() + ". It's a bug", + ErrorCodes::LOGICAL_ERROR); + + if (array_of_nullable) + nested_type = std::make_shared(nested_type); + + type = std::make_shared(nested_type); + } + else + throw Exception(String("Unexpected literal type ") + info.literal->value.getTypeName() + ". It's a bug", + ErrorCodes::LOGICAL_ERROR); + + /// Allow literal to be NULL, if result column has nullable type or if function never returns NULL + bool allow_nulls = result_column_type->isNullable(); + if ((allow_nulls || info.force_nullable) && type->canBeInsideNullable()) + type = makeNullable(type); + + return need_special_parser; +} + void ConstantExpressionTemplate::parseExpression(ReadBuffer & istr, const FormatSettings & settings) { size_t cur_column = 0; try { - ParserKeyword parser_null("NULL"); - ParserNumber parser_number; size_t cur_token = 0; while (cur_column < literals.columns()) { size_t skip_tokens_until = token_after_literal_idx[cur_column]; while (cur_token < skip_tokens_until) { - // TODO skip comments + /// TODO skip comments skipWhitespaceIfAny(istr); assertString(tokens[cur_token++], istr); } skipWhitespaceIfAny(istr); const IDataType & type = *literals.getByPosition(cur_column).type; - if (need_special_parser[cur_column]) - { - WhichDataType type_info(type); - bool nullable = type_info.isNullable(); - if (nullable) - type_info = WhichDataType(dynamic_cast(type).getNestedType()); - - Tokens tokens_number(istr.position(), istr.buffer().end()); - IParser::Pos iterator(tokens_number); - Expected expected; - ASTPtr ast; - if (nullable && parser_null.parse(iterator, ast, expected)) - ast = std::make_shared(Field()); - else if (!parser_number.parse(iterator, ast, expected)) - throw DB::Exception("Cannot parse literal", ErrorCodes::CANNOT_PARSE_EXPRESSION_USING_TEMPLATE); - istr.position() = const_cast(iterator->begin); - Field & number = ast->as().value; - - // TODO also check type of Array(T), if T is arithmetic - if ((number.getType() == Field::Types::UInt64 && type_info.isUInt64()) - || (number.getType() == Field::Types::Int64 && type_info.isInt64()) - || (number.getType() == Field::Types::Float64 && type_info.isFloat64()) - || nullable) - { - columns[cur_column]->insert(number); - } - else - throw DB::Exception("Cannot parse literal", ErrorCodes::CANNOT_PARSE_EXPRESSION_USING_TEMPLATE); - } + if (use_special_parser[cur_column]) + parseLiteralAndAssertType(istr, type, cur_column); else - { type.deserializeAsTextQuoted(*columns[cur_column], istr, settings); - } ++cur_column; } @@ -245,6 +283,75 @@ void ConstantExpressionTemplate::parseExpression(ReadBuffer & istr, const Format } } +void ConstantExpressionTemplate::parseLiteralAndAssertType(ReadBuffer & istr, const IDataType & type, size_t column_idx) +{ + /// TODO faster way to check types without using Parsers + ParserKeyword parser_null("NULL"); + ParserNumber parser_number; + ParserArrayOfLiterals parser_array; + + WhichDataType type_info(type); + + bool is_array = type_info.isArray(); + if (is_array) + type_info = WhichDataType(dynamic_cast(type).getNestedType()); + + bool is_nullable = type_info.isNullable(); + if (is_nullable) + type_info = WhichDataType(dynamic_cast(type).getNestedType()); + + /// If literal does not fit entirely in the buffer, parsing error will happen. + /// However, it's possible to deduce new template after error like it was template mismatch. + /// TODO fix it + Tokens tokens_number(istr.position(), istr.buffer().end()); + IParser::Pos iterator(tokens_number); + Expected expected; + ASTPtr ast; + + if (is_array) + { + if (!parser_array.parse(iterator, ast, expected)) + throw DB::Exception("Cannot parse literal", ErrorCodes::CANNOT_PARSE_EXPRESSION_USING_TEMPLATE); + istr.position() = const_cast(iterator->begin); + + const Field & array = ast->as().value; + auto array_type = applyVisitor(FieldToDataType(), array); + auto nested_type = dynamic_cast(*array_type).getNestedType(); + if (is_nullable) + if (auto nullable = dynamic_cast(nested_type.get())) + nested_type = nullable->getNestedType(); + + WhichDataType nested_type_info(nested_type); + if ((nested_type_info.isNativeUInt() && type_info.isUInt64()) || + (nested_type_info.isNativeInt() && type_info.isInt64()) || + (nested_type_info.isFloat64() && type_info.isFloat64())) + { + columns[column_idx]->insert(array); + return; + } + } + else + { + if (is_nullable && parser_null.parse(iterator, ast, expected)) + ast = std::make_shared(Field()); + else if (!parser_number.parse(iterator, ast, expected)) + throw DB::Exception("Cannot parse literal", ErrorCodes::CANNOT_PARSE_EXPRESSION_USING_TEMPLATE); + istr.position() = const_cast(iterator->begin); + + Field & number = ast->as().value; + + if ((number.getType() == Field::Types::UInt64 && type_info.isUInt64()) || + (number.getType() == Field::Types::Int64 && type_info.isInt64()) || + (number.getType() == Field::Types::Float64 && type_info.isFloat64()) || + is_nullable) + { + columns[column_idx]->insert(number); + return; + } + } + throw DB::Exception("Cannot parse literal: type mismatch", ErrorCodes::CANNOT_PARSE_EXPRESSION_USING_TEMPLATE); +} + ColumnPtr ConstantExpressionTemplate::evaluateAll() { Block evaluated = literals.cloneWithColumns(std::move(columns)); @@ -253,13 +360,14 @@ ColumnPtr ConstantExpressionTemplate::evaluateAll() evaluated.insert({ColumnConst::create(ColumnUInt8::create(1, 0), rows_count), std::make_shared(), "_dummy"}); actions_on_literals->execute(evaluated); - if (!evaluated || evaluated.rows() == 0) - throw Exception("Logical error: empty block after evaluation of batch of constant expressions", + if (!evaluated || evaluated.rows() != rows_count) + throw Exception("Number of rows mismatch after evaluation of batch of constant expressions: got " + + std::to_string(evaluated.rows()) + " rows for " + std::to_string(rows_count) + " expressions", ErrorCodes::LOGICAL_ERROR); if (!evaluated.has(result_column_name)) throw Exception("Cannot evaluate template " + result_column_name + ", block structure:\n" + evaluated.dumpStructure(), - ErrorCodes::CANNOT_EVALUATE_EXPRESSION_TEMPLATE); + ErrorCodes::LOGICAL_ERROR); rows_count = 0; return evaluated.getByName(result_column_name).column->convertToFullColumnIfConst(); diff --git a/dbms/src/Processors/Formats/Impl/ConstantExpressionTemplate.h b/dbms/src/Processors/Formats/Impl/ConstantExpressionTemplate.h index ca1e0d61ff..aa2f21734b 100644 --- a/dbms/src/Processors/Formats/Impl/ConstantExpressionTemplate.h +++ b/dbms/src/Processors/Formats/Impl/ConstantExpressionTemplate.h @@ -8,22 +8,36 @@ namespace DB { +struct LiteralInfo; + +/// Deduces template of an expression by replacing literals with dummy columns. +/// It allows to parse and evaluate similar expressions without using heavy IParsers and ExpressionAnalyzer. +/// Using ConstantExpressionTemplate for one expression is slower then evaluateConstantExpression(...), +/// but it's significantly faster for batch of expressions class ConstantExpressionTemplate { public: - ConstantExpressionTemplate(const IDataType & result_column_type, TokenIterator expression_begin, TokenIterator expression_end, + /// Deduce template of expression of type result_column_type + ConstantExpressionTemplate(DataTypePtr result_column_type_, TokenIterator expression_begin, TokenIterator expression_end, const ASTPtr & expression, const Context & context); + /// Read expression from istr, assert it has the same structure and the same types of literals (template matches) + /// and parse literals into temporary columns void parseExpression(ReadBuffer & istr, const FormatSettings & settings); + /// Evaluate batch of expressions were parsed using template ColumnPtr evaluateAll(); size_t rowsCount() const { return rows_count; } private: static void addNodesToCastResult(const IDataType & result_column_type, ASTPtr & expr); + bool getDataType(const LiteralInfo & info, DataTypePtr & type) const; + void parseLiteralAndAssertType(ReadBuffer & istr, const IDataType & type, size_t column_idx); private: + DataTypePtr result_column_type; + std::vector tokens; std::vector token_after_literal_idx; @@ -32,7 +46,7 @@ private: Block literals; MutableColumns columns; - std::vector need_special_parser; + std::vector use_special_parser; /// For expressions without literals (e.g. "now()") size_t rows_count = 0; diff --git a/dbms/src/Processors/Formats/Impl/ValuesBlockInputFormat.cpp b/dbms/src/Processors/Formats/Impl/ValuesBlockInputFormat.cpp index 63d557379e..5773947980 100644 --- a/dbms/src/Processors/Formats/Impl/ValuesBlockInputFormat.cpp +++ b/dbms/src/Processors/Formats/Impl/ValuesBlockInputFormat.cpp @@ -25,9 +25,7 @@ namespace ErrorCodes extern const int CANNOT_PARSE_DATE; extern const int SYNTAX_ERROR; extern const int VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE; - extern const int CANNOT_CREATE_EXPRESSION_TEMPLATE; extern const int CANNOT_PARSE_EXPRESSION_USING_TEMPLATE; - extern const int CANNOT_EVALUATE_EXPRESSION_TEMPLATE; } @@ -233,7 +231,7 @@ ValuesBlockInputFormat::parseExpression(IColumn & column, size_t column_idx, boo ErrorCodes::LOGICAL_ERROR); try { - templates[column_idx] = ConstantExpressionTemplate(type, TokenIterator(tokens), token_iterator, ast, *context); + templates[column_idx] = ConstantExpressionTemplate(header.getByPosition(column_idx).type, TokenIterator(tokens), token_iterator, ast, *context); buf.rollbackToCheckpoint(); templates[column_idx].value().parseExpression(buf, format_settings); assertDelimiterAfterValue(column_idx); -- GitLab