提交 c83ae92d 编写于 作者: A Alexey Milovidov

Analyzers: added AnalyzeLambdas step [#CLICKHOUSE-11].

上级 d99a642b
#pragma once
#include <DB/Parsers/IAST.h>
#include <vector>
namespace DB
{
class WriteBuffer;
/** For every lambda expression, rename its parameters to '_lambda0_arg0' form.
* Check correctness of lambda expressions.
* Find functions, that have lambda expressions as arguments (they are called "higher order" functions).
*
* This should be done before CollectAliases.
*/
struct AnalyzeLambdas
{
void process(ASTPtr & ast);
using HigherOrderFunctions = std::vector<ASTPtr>;
HigherOrderFunctions higher_order_functions;
/// Debug output
void dump(WriteBuffer & out) const;
};
}
......@@ -24,7 +24,6 @@ namespace ErrorCodes
extern const int AMBIGUOUS_COLUMN_NAME;
extern const int UNKNOWN_TABLE;
extern const int THERE_IS_NO_COLUMN;
extern const int BAD_LAMBDA;
}
......@@ -222,17 +221,8 @@ ASTs expandQualifiedAsterisk(
}
/// Parameters of lambda expressions.
using LambdaParameters = std::vector<String>;
/// Currently visible parameters in all scopes of lambda expressions.
/// Lambda expressions could be nested: arrayMap(x -> arrayMap(y -> x[y], x), [[1], [2, 3]])
using LambdaScopes = std::vector<LambdaParameters>;
void processIdentifier(
const ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables,
const LambdaScopes & lambda_scopes)
const ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables)
{
const ASTIdentifier & identifier = static_cast<const ASTIdentifier &>(*ast);
......@@ -248,13 +238,10 @@ void processIdentifier(
if (identifier.children.empty())
{
/** Lambda parameters are not columns from table. Just skip them.
* If identifier name are known as lambda parameter in any currently visible scope of lambda expressions.
* This step requires AnalyzeLambdas to be done on AST.
*/
if (lambda_scopes.end() != std::find_if(lambda_scopes.begin(), lambda_scopes.end(),
[&identifier] (const LambdaParameters & names) { return names.end() != std::find(names.begin(), names.end(), identifier.name); }))
{
if (startsWith(identifier.name, "_lambda"))
return;
}
table = findTableWithUnqualifiedName(tables, identifier.name);
if (table)
......@@ -338,67 +325,7 @@ void processIdentifier(
}
LambdaParameters extractLambdaParameters(ASTPtr & ast)
{
/// Lambda parameters could be specified in AST in two forms:
/// - just as single parameter: x -> x + 1
/// - parameters in tuple: (x, y) -> x + 1
#define LAMBDA_ERROR_MESSAGE " There are two valid forms of lambda expressions: x -> ... and (x, y...) -> ..."
if (!ast->tryGetAlias().empty())
throw Exception("Lambda parameters cannot have aliases."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (const ASTIdentifier * identifier = typeid_cast<const ASTIdentifier *>(ast.get()))
{
return { identifier->name };
}
else if (const ASTFunction * function = typeid_cast<const ASTFunction *>(ast.get()))
{
if (function->name != "tuple")
throw Exception("Left hand side of '->' or first argument of 'lambda' is a function, but this function is not tuple."
LAMBDA_ERROR_MESSAGE " Found function '" + function->name + "' instead.", ErrorCodes::BAD_LAMBDA);
if (!function->arguments || function->arguments->children.empty())
throw Exception("Left hand side of '->' or first argument of 'lambda' is empty tuple."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
LambdaParameters res;
res.reserve(function->arguments->children.size());
for (const ASTPtr & arg : function->arguments->children)
{
const ASTIdentifier * arg_identifier = typeid_cast<const ASTIdentifier *>(arg.get());
if (!arg_identifier)
throw Exception("Left hand side of '->' or first argument of 'lambda' contains something that is not just identifier."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (!arg_identifier->children.empty())
throw Exception("Left hand side of '->' or first argument of 'lambda' contains compound identifier."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (!arg_identifier->alias.empty())
throw Exception("Lambda parameters cannot have aliases."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
res.emplace_back(arg_identifier->name);
}
return res;
}
else
throw Exception("Unexpected left hand side of '->' or first argument of 'lambda'."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
#undef LAMBDA_ERROR_MESSAGE
}
void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables,
LambdaScopes & lambda_scopes)
void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectAliases & aliases, const CollectTables & tables)
{
/// Don't go into subqueries and table-like expressions.
if (typeid_cast<const ASTSelectQuery *>(ast.get())
......@@ -416,25 +343,6 @@ void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectA
{
func->arguments->children.clear();
}
/** Special case for lambda functions, like (x, y) -> x + y + column.
* We must memoize parameters from left hand side (x, y)
* and then analyze right hand side, skipping that parameters.
* In example, from right hand side "x + y + column", only "column" should be searched in tables,
* because x and y are just lambda parameters.
*/
if (func->name == "lambda")
{
auto num_arguments = func->arguments->children.size();
if (num_arguments != 2)
throw Exception("Lambda expression ('->' or 'lambda' function) must have exactly two arguments."
" Found " + toString(num_arguments) + " instead.", ErrorCodes::BAD_LAMBDA);
lambda_scopes.emplace_back(extractLambdaParameters(func->arguments->children[0]));
processImpl(func->arguments->children[1], columns, aliases, tables, lambda_scopes);
lambda_scopes.pop_back();
return;
}
}
else if (typeid_cast<ASTExpressionList *>(ast.get()))
{
......@@ -458,12 +366,12 @@ void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectA
}
else if (typeid_cast<const ASTIdentifier *>(ast.get()))
{
processIdentifier(ast, columns, aliases, tables, lambda_scopes);
processIdentifier(ast, columns, aliases, tables);
return;
}
for (auto & child : ast->children)
processImpl(child, columns, aliases, tables, lambda_scopes);
processImpl(child, columns, aliases, tables);
}
}
......@@ -471,9 +379,8 @@ void processImpl(ASTPtr & ast, AnalyzeColumns::Columns & columns, const CollectA
void AnalyzeColumns::process(ASTPtr & ast, const CollectAliases & aliases, const CollectTables & tables)
{
LambdaScopes lambda_scopes;
for (auto & child : ast->children)
processImpl(child, columns, aliases, tables, lambda_scopes);
processImpl(child, columns, aliases, tables);
}
......
#include <vector>
#include <DB/Analyzers/AnalyzeLambdas.h>
#include <DB/Parsers/formatAST.h>
#include <DB/Parsers/ASTSelectQuery.h>
#include <DB/Parsers/ASTTablesInSelectQuery.h>
#include <DB/Parsers/ASTIdentifier.h>
#include <DB/Parsers/ASTFunction.h>
#include <DB/IO/WriteBuffer.h>
#include <DB/IO/WriteHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_LAMBDA;
extern const int RESERVED_IDENTIFIER_NAME;
}
namespace
{
/// Parameters of lambda expressions.
using LambdaParameters = std::vector<String>;
/// Currently visible parameters in all scopes of lambda expressions.
/// Lambda expressions could be nested: arrayMap(x -> arrayMap(y -> x[y], x), [[1], [2, 3]])
using LambdaScopes = std::vector<LambdaParameters>;
LambdaParameters extractLambdaParameters(ASTPtr & ast)
{
/// Lambda parameters could be specified in AST in two forms:
/// - just as single parameter: x -> x + 1
/// - parameters in tuple: (x, y) -> x + 1
#define LAMBDA_ERROR_MESSAGE " There are two valid forms of lambda expressions: x -> ... and (x, y...) -> ..."
if (!ast->tryGetAlias().empty())
throw Exception("Lambda parameters cannot have aliases."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (const ASTIdentifier * identifier = typeid_cast<const ASTIdentifier *>(ast.get()))
{
return { identifier->name };
}
else if (const ASTFunction * function = typeid_cast<const ASTFunction *>(ast.get()))
{
if (function->name != "tuple")
throw Exception("Left hand side of '->' or first argument of 'lambda' is a function, but this function is not tuple."
LAMBDA_ERROR_MESSAGE " Found function '" + function->name + "' instead.", ErrorCodes::BAD_LAMBDA);
if (!function->arguments || function->arguments->children.empty())
throw Exception("Left hand side of '->' or first argument of 'lambda' is empty tuple."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
LambdaParameters res;
res.reserve(function->arguments->children.size());
for (const ASTPtr & arg : function->arguments->children)
{
const ASTIdentifier * arg_identifier = typeid_cast<const ASTIdentifier *>(arg.get());
if (!arg_identifier)
throw Exception("Left hand side of '->' or first argument of 'lambda' contains something that is not just identifier."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (!arg_identifier->children.empty())
throw Exception("Left hand side of '->' or first argument of 'lambda' contains compound identifier."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
if (!arg_identifier->alias.empty())
throw Exception("Lambda parameters cannot have aliases."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
res.emplace_back(arg_identifier->name);
}
return res;
}
else
throw Exception("Unexpected left hand side of '->' or first argument of 'lambda'."
LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);
#undef LAMBDA_ERROR_MESSAGE
}
void processIdentifier(ASTPtr & ast, LambdaScopes & lambda_scopes)
{
ASTIdentifier & identifier = static_cast<ASTIdentifier &>(*ast);
if (identifier.children.empty())
{
bool found = false;
/// From most inner scope towards outer scopes.
for (ssize_t num_scopes = lambda_scopes.size(), scope_idx = num_scopes - 1; scope_idx >= 0; --scope_idx)
{
for (size_t arg_idx = 0, num_args = lambda_scopes[scope_idx].size(); arg_idx < num_args; ++arg_idx)
{
if (lambda_scopes[scope_idx][arg_idx] == identifier.name)
{
identifier.name = "_lambda" + toString(scope_idx) + "_arg" + toString(arg_idx);
found = true;
break;
}
}
if (found)
break;
}
if (!found && startsWith(identifier.name, "_lambda"))
throw Exception("Identifier names starting with '_lambda' are reserved for parameters of lambda expressions.",
ErrorCodes::RESERVED_IDENTIFIER_NAME);
}
}
void processImpl(
ASTPtr & ast,
LambdaScopes & lambda_scopes,
const ASTPtr & parent_function_for_this_argument,
AnalyzeLambdas::HigherOrderFunctions & higher_order_functions)
{
/// Don't go into subqueries and table-like expressions.
if (typeid_cast<const ASTSelectQuery *>(ast.get())
|| typeid_cast<const ASTTableExpression *>(ast.get()))
{
return;
}
else if (ASTFunction * func = typeid_cast<ASTFunction *>(ast.get()))
{
/** We must memoize parameters from left hand side (x, y) and then analyze right hand side.
*/
if (func->name == "lambda")
{
auto num_arguments = func->arguments->children.size();
if (num_arguments != 2)
throw Exception("Lambda expression ('->' or 'lambda' function) must have exactly two arguments."
" Found " + toString(num_arguments) + " instead.", ErrorCodes::BAD_LAMBDA);
lambda_scopes.emplace_back(extractLambdaParameters(func->arguments->children[0]));
for (size_t i = 0; i < num_arguments; ++i)
processImpl(func->arguments->children[i], lambda_scopes, nullptr, higher_order_functions);
lambda_scopes.pop_back();
if (!parent_function_for_this_argument)
throw Exception("Lambda expression ('->' or 'lambda' function) must be presented as an argument of higher-order function."
" Found standalone lambda expression instead.", ErrorCodes::BAD_LAMBDA);
higher_order_functions.emplace_back(parent_function_for_this_argument);
}
else
{
/// When diving into function arguments, pass current ast node.
if (func->arguments)
for (auto & child : func->arguments->children)
processImpl(child, lambda_scopes, ast, higher_order_functions);
if (func->parameters)
for (auto & child : func->parameters->children)
processImpl(child, lambda_scopes, nullptr, higher_order_functions);
}
return;
}
else if (typeid_cast<ASTIdentifier *>(ast.get()))
{
processIdentifier(ast, lambda_scopes);
return;
}
for (auto & child : ast->children)
processImpl(child, lambda_scopes, nullptr, higher_order_functions);
}
}
void AnalyzeLambdas::process(ASTPtr & ast)
{
LambdaScopes lambda_scopes;
for (auto & child : ast->children)
processImpl(child, lambda_scopes, nullptr, higher_order_functions);
}
void AnalyzeLambdas::dump(WriteBuffer & out) const
{
for (const auto & ast : higher_order_functions)
{
writeString(ast->getColumnName(), out);
writeChar('\n', out);
}
}
}
#include <DB/Analyzers/AnalyzeResultOfQuery.h>
#include <DB/Analyzers/CollectAliases.h>
#include <DB/Analyzers/CollectTables.h>
#include <DB/Analyzers/AnalyzeLambdas.h>
#include <DB/Analyzers/AnalyzeColumns.h>
#include <DB/Analyzers/TypeAndConstantInference.h>
#include <DB/Interpreters/Context.h>
......@@ -27,6 +28,9 @@ void AnalyzeResultOfQuery::process(ASTPtr & ast, Context & context)
if (!select->select_expression_list)
throw Exception("SELECT query doesn't have select_expression_list", ErrorCodes::UNEXPECTED_AST_STRUCTURE);
AnalyzeLambdas analyze_lambdas;
analyze_lambdas.process(ast);
CollectAliases collect_aliases;
collect_aliases.process(ast);
......
......@@ -19,6 +19,7 @@
#include <DB/DataTypes/DataTypeTuple.h>
#include <DB/AggregateFunctions/AggregateFunctionFactory.h>
#include <DB/Functions/FunctionFactory.h>
#include <ext/range.hpp>
......@@ -29,6 +30,7 @@ namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int FUNCTION_CANNOT_HAVE_PARAMETERS;
extern const int BAD_LAMBDA;
}
......@@ -173,7 +175,7 @@ void processFunction(const String & column_name, ASTPtr & ast, TypeAndConstantIn
aggregate_function_ptr->setArguments(argument_types);
/// Replace function name to canonical one. Because same function could be referenced by different names.
/// (?) Replace function name to canonical one. Because same function could be referenced by different names.
// function->name = aggregate_function_ptr->getName();
TypeAndConstantInference::ExpressionInfo expression_info;
......@@ -323,7 +325,8 @@ void processImpl(
TypeAndConstantInference::Info & info,
LambdaScopes & lambda_scopes)
{
/// Depth-first
/// Top down part - collection of lambda scopes
/// Don't go into components of compound identifiers.
if (!typeid_cast<const ASTIdentifier *>(ast.get()))
......@@ -341,15 +344,16 @@ void processImpl(
}
}
/// Bottom up part
const ASTLiteral * literal = nullptr;
const ASTIdentifier * identifier = nullptr;
const ASTFunction * function = nullptr;
const ASTSubquery * subquery = nullptr;
false
function
|| (literal = typeid_cast<const ASTLiteral *>(ast.get()))
|| (identifier = typeid_cast<const ASTIdentifier *>(ast.get()))
|| (function = typeid_cast<const ASTFunction *>(ast.get()))
|| (subquery = typeid_cast<const ASTSubquery *>(ast.get()));
if (!literal && !identifier && !function && !subquery)
......
......@@ -18,3 +18,6 @@ target_link_libraries(translate_positional_arguments dbms)
add_executable(optimize_group_order_limit_by optimize_group_order_limit_by.cpp)
target_link_libraries(optimize_group_order_limit_by dbms)
add_executable(analyze_lambdas analyze_lambdas.cpp)
target_link_libraries(analyze_lambdas dbms)
#include <DB/Analyzers/CollectAliases.h>
#include <DB/Analyzers/CollectTables.h>
#include <DB/Analyzers/AnalyzeColumns.h>
#include <DB/Analyzers/AnalyzeLambdas.h>
#include <DB/Parsers/parseQuery.h>
#include <DB/Parsers/ParserSelectQuery.h>
#include <DB/Parsers/formatAST.h>
......@@ -38,6 +39,9 @@ try
system_database->attachTable("numbers", StorageSystemNumbers::create("numbers"));
context.setCurrentDatabase("system");
AnalyzeLambdas analyze_lambdas;
analyze_lambdas.process(ast);
CollectAliases collect_aliases;
collect_aliases.process(ast);
......
......@@ -10,9 +10,9 @@ SELECT dummy, number, one.dummy, numbers.number, system.one.dummy, system.number
c -> c UInt8. Database name: (none). Table name: (none). Alias: (none). Storage: (none). AST: c
SELECT arrayMap((x, y) -> arrayMap((y, z) -> x[y], x, c), [[1], [2, 3]]) FROM (SELECT 1 AS c, 2 AS d)
SELECT arrayMap((_lambda0_arg0, _lambda0_arg1) -> arrayMap((_lambda1_arg0, _lambda1_arg1) -> _lambda0_arg0[_lambda1_arg0], _lambda0_arg0, c), [[1], [2, 3]]) FROM (SELECT 1 AS c, 2 AS d)
c -> c UInt8. Database name: (none). Table name: (none). Alias: (none). Storage: (none). AST: c
x -> x UInt8. Database name: (none). Table name: (none). Alias: (none). Storage: (none). AST: x
SELECT x, arrayMap((x, y) -> (x + y), x, c) FROM (SELECT 1 AS x, 2 AS c)
SELECT x, arrayMap((_lambda0_arg0, _lambda0_arg1) -> (_lambda0_arg0 + _lambda0_arg1), x, c) FROM (SELECT 1 AS x, 2 AS c)
arrayMap(lambda(tuple(_lambda1_arg0), arrayElement(_lambda1_arg0, _lambda0_arg1)), arr3)
arrayMap(lambda(tuple(_lambda0_arg0, _lambda0_arg1), plus(_lambda0_arg0, arrayMap(lambda(tuple(_lambda1_arg0), arrayElement(_lambda1_arg0, _lambda0_arg1)), arr3))), arr1, arr2)
SELECT arrayMap((_lambda0_arg0, _lambda0_arg1) -> (_lambda0_arg0 + arrayMap(_lambda1_arg0 -> _lambda1_arg0[_lambda0_arg1], arr3)), arr1, arr2)
#!/bin/sh
echo "SELECT arrayMap((x, y) -> x + arrayMap(x -> x[y], arr3), arr1, arr2)" | ./analyze_lambdas
#include <DB/Analyzers/CollectAliases.h>
#include <DB/Analyzers/CollectTables.h>
#include <DB/Analyzers/AnalyzeColumns.h>
#include <DB/Analyzers/AnalyzeLambdas.h>
#include <DB/Analyzers/TypeAndConstantInference.h>
#include <DB/Analyzers/TranslatePositionalArguments.h>
#include <DB/Analyzers/OptimizeGroupOrderLimitBy.h>
......@@ -41,6 +42,9 @@ try
system_database->attachTable("numbers", StorageSystemNumbers::create("numbers"));
context.setCurrentDatabase("system");
AnalyzeLambdas analyze_lambdas;
analyze_lambdas.process(ast);
CollectAliases collect_aliases;
collect_aliases.process(ast);
......
#include <DB/Analyzers/CollectAliases.h>
#include <DB/Analyzers/CollectTables.h>
#include <DB/Analyzers/AnalyzeColumns.h>
#include <DB/Analyzers/AnalyzeLambdas.h>
#include <DB/Analyzers/TypeAndConstantInference.h>
#include <DB/Parsers/parseQuery.h>
#include <DB/Parsers/ParserSelectQuery.h>
......@@ -39,6 +40,9 @@ try
system_database->attachTable("numbers", StorageSystemNumbers::create("numbers"));
context.setCurrentDatabase("system");
AnalyzeLambdas analyze_lambdas;
analyze_lambdas.process(ast);
CollectAliases collect_aliases;
collect_aliases.process(ast);
......
......@@ -359,6 +359,7 @@ namespace ErrorCodes
extern const int ZLIB_INFLATE_FAILED = 354;
extern const int ZLIB_DEFLATE_FAILED = 355;
extern const int BAD_LAMBDA = 356;
extern const int RESERVED_IDENTIFIER_NAME = 357;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册