AnalyzeLambdas.cpp 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
#include <vector>
#include <DB/Analyzers/AnalyzeLambdas.h>
#include <DB/Parsers/formatAST.h>
#include <DB/Parsers/ASTSelectQuery.h>
#include <DB/Parsers/ASTTablesInSelectQuery.h>
#include <DB/Parsers/ASTIdentifier.h>
#include <DB/Parsers/ASTFunction.h>
#include <DB/IO/WriteBuffer.h>
#include <DB/IO/WriteHelpers.h>


namespace DB
{

namespace ErrorCodes
{
	extern const int BAD_LAMBDA;
	extern const int RESERVED_IDENTIFIER_NAME;
}


namespace
{

/// Parameters of lambda expressions.
using LambdaParameters = std::vector<String>;

/// Currently visible parameters in all scopes of lambda expressions.
/// Lambda expressions could be nested: arrayMap(x -> arrayMap(y -> x[y], x), [[1], [2, 3]])
using LambdaScopes = std::vector<LambdaParameters>;


LambdaParameters extractLambdaParameters(ASTPtr & ast)
{
	/// Lambda parameters could be specified in AST in two forms:
	/// - just as single parameter: x -> x + 1
	/// - parameters in tuple: (x, y) -> x + 1

#define LAMBDA_ERROR_MESSAGE " There are two valid forms of lambda expressions: x -> ... and (x, y...) -> ..."

	if (!ast->tryGetAlias().empty())
		throw Exception("Lambda parameters cannot have aliases."
			LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);

	if (const ASTIdentifier * identifier = typeid_cast<const ASTIdentifier *>(ast.get()))
	{
		return { identifier->name };
	}
	else if (const ASTFunction * function = typeid_cast<const ASTFunction *>(ast.get()))
	{
		if (function->name != "tuple")
			throw Exception("Left hand side of '->' or first argument of 'lambda' is a function, but this function is not tuple."
				LAMBDA_ERROR_MESSAGE " Found function '" + function->name + "' instead.", ErrorCodes::BAD_LAMBDA);

		if (!function->arguments || function->arguments->children.empty())
			throw Exception("Left hand side of '->' or first argument of 'lambda' is empty tuple."
				LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);

		LambdaParameters res;
		res.reserve(function->arguments->children.size());

		for (const ASTPtr & arg : function->arguments->children)
		{
			const ASTIdentifier * arg_identifier = typeid_cast<const ASTIdentifier *>(arg.get());

			if (!arg_identifier)
				throw Exception("Left hand side of '->' or first argument of 'lambda' contains something that is not just identifier."
					LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);

			if (!arg_identifier->children.empty())
				throw Exception("Left hand side of '->' or first argument of 'lambda' contains compound identifier."
					LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);

			if (!arg_identifier->alias.empty())
				throw Exception("Lambda parameters cannot have aliases."
					LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);

			res.emplace_back(arg_identifier->name);
		}

		return res;

	}
	else
		throw Exception("Unexpected left hand side of '->' or first argument of 'lambda'."
			LAMBDA_ERROR_MESSAGE, ErrorCodes::BAD_LAMBDA);

#undef LAMBDA_ERROR_MESSAGE
}


void processIdentifier(ASTPtr & ast, LambdaScopes & lambda_scopes)
{
	ASTIdentifier & identifier = static_cast<ASTIdentifier &>(*ast);

	if (identifier.children.empty())
	{
		bool found = false;

		/// From most inner scope towards outer scopes.
		for (ssize_t num_scopes = lambda_scopes.size(), scope_idx = num_scopes - 1; scope_idx >= 0; --scope_idx)
		{
			for (size_t arg_idx = 0, num_args = lambda_scopes[scope_idx].size(); arg_idx < num_args; ++arg_idx)
			{
				if (lambda_scopes[scope_idx][arg_idx] == identifier.name)
				{
					identifier.name = "_lambda" + toString(scope_idx) + "_arg" + toString(arg_idx);

					found = true;
					break;
				}
			}
			if (found)
				break;
		}

		if (!found && startsWith(identifier.name, "_lambda"))
			throw Exception("Identifier names starting with '_lambda' are reserved for parameters of lambda expressions.",
				ErrorCodes::RESERVED_IDENTIFIER_NAME);
	}
}


void processImpl(
	ASTPtr & ast,
	LambdaScopes & lambda_scopes,
	const ASTPtr & parent_function_for_this_argument,
	AnalyzeLambdas::HigherOrderFunctions & higher_order_functions)
{
	/// Don't go into subqueries and table-like expressions.
	if (typeid_cast<const ASTSelectQuery *>(ast.get())
		|| typeid_cast<const ASTTableExpression *>(ast.get()))
	{
		return;
	}
	else if (ASTFunction * func = typeid_cast<ASTFunction *>(ast.get()))
	{
		/** We must memoize parameters from left hand side (x, y) and then analyze right hand side.
		  */
		if (func->name == "lambda")
		{
			auto num_arguments = func->arguments->children.size();
			if (num_arguments != 2)
				throw Exception("Lambda expression ('->' or 'lambda' function) must have exactly two arguments."
					" Found " + toString(num_arguments) + " instead.", ErrorCodes::BAD_LAMBDA);

			lambda_scopes.emplace_back(extractLambdaParameters(func->arguments->children[0]));
			for (size_t i = 0; i < num_arguments; ++i)
				processImpl(func->arguments->children[i], lambda_scopes, nullptr, higher_order_functions);
			lambda_scopes.pop_back();

			if (!parent_function_for_this_argument)
				throw Exception("Lambda expression ('->' or 'lambda' function) must be presented as an argument of higher-order function."
					" Found standalone lambda expression instead.", ErrorCodes::BAD_LAMBDA);

			higher_order_functions.emplace_back(parent_function_for_this_argument);
		}
		else
		{
			/// When diving into function arguments, pass current ast node.
			if (func->arguments)
				for (auto & child : func->arguments->children)
					processImpl(child, lambda_scopes, ast, higher_order_functions);

			if (func->parameters)
				for (auto & child : func->parameters->children)
					processImpl(child, lambda_scopes, nullptr, higher_order_functions);
		}

		return;
	}
	else if (typeid_cast<ASTIdentifier *>(ast.get()))
	{
		processIdentifier(ast, lambda_scopes);
		return;
	}

	for (auto & child : ast->children)
		processImpl(child, lambda_scopes, nullptr, higher_order_functions);
}

}


void AnalyzeLambdas::process(ASTPtr & ast)
{
	LambdaScopes lambda_scopes;
	for (auto & child : ast->children)
		processImpl(child, lambda_scopes, nullptr, higher_order_functions);
}


void AnalyzeLambdas::dump(WriteBuffer & out) const
{
	for (const auto & ast : higher_order_functions)
	{
		writeString(ast->getColumnName(), out);
		writeChar('\n', out);
	}
}


}