未验证 提交 3792d688 编写于 作者: N Nikolai Kochetov 提交者: GitHub

Merge pull request #14741 from ClickHouse/expression-dag

ActionsDAG
......@@ -53,6 +53,7 @@ namespace ErrorCodes
extern const int TYPE_MISMATCH;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int INCORRECT_ELEMENT_OF_SET;
extern const int BAD_ARGUMENTS;
}
static NamesAndTypesList::iterator findColumn(const String & name, NamesAndTypesList & cols)
......@@ -328,7 +329,7 @@ Block createBlockForSet(
}
SetPtr makeExplicitSet(
const ASTFunction * node, const Block & sample_block, bool create_ordered_set,
const ASTFunction * node, const ActionsDAG & actions, bool create_ordered_set,
const Context & context, const SizeLimits & size_limits, PreparedSets & prepared_sets)
{
const IAST & args = *node->arguments;
......@@ -339,7 +340,11 @@ SetPtr makeExplicitSet(
const ASTPtr & left_arg = args.children.at(0);
const ASTPtr & right_arg = args.children.at(1);
const DataTypePtr & left_arg_type = sample_block.getByName(left_arg->getColumnName()).type;
const auto & index = actions.getIndex();
auto it = index.find(left_arg->getColumnName());
if (it == index.end())
throw Exception("Unknown identifier: '" + left_arg->getColumnName() + "'", ErrorCodes::UNKNOWN_IDENTIFIER);
const DataTypePtr & left_arg_type = it->second->result_type;
DataTypes set_element_types = {left_arg_type};
const auto * left_tuple_type = typeid_cast<const DataTypeTuple *>(left_arg_type.get());
......@@ -370,95 +375,145 @@ SetPtr makeExplicitSet(
return set;
}
ScopeStack::ScopeStack(const ExpressionActionsPtr & actions, const Context & context_)
ActionsMatcher::Data::Data(
const Context & context_, SizeLimits set_size_limit_, size_t subquery_depth_,
const NamesAndTypesList & source_columns_, ActionsDAGPtr actions,
PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_,
bool no_subqueries_, bool no_makeset_, bool only_consts_, bool no_storage_or_local_)
: context(context_)
, set_size_limit(set_size_limit_)
, subquery_depth(subquery_depth_)
, source_columns(source_columns_)
, prepared_sets(prepared_sets_)
, subqueries_for_sets(subqueries_for_sets_)
, no_subqueries(no_subqueries_)
, no_makeset(no_makeset_)
, only_consts(only_consts_)
, no_storage_or_local(no_storage_or_local_)
, visit_depth(0)
, actions_stack(std::move(actions), context)
, next_unique_suffix(actions_stack.getLastActions().getIndex().size() + 1)
{
stack.emplace_back();
stack.back().actions = actions;
}
const Block & sample_block = actions->getSampleBlock();
for (size_t i = 0, size = sample_block.columns(); i < size; ++i)
stack.back().new_columns.insert(sample_block.getByPosition(i).name);
bool ActionsMatcher::Data::hasColumn(const String & column_name) const
{
return actions_stack.getLastActions().getIndex().count(column_name) != 0;
}
void ScopeStack::pushLevel(const NamesAndTypesList & input_columns)
ScopeStack::ScopeStack(ActionsDAGPtr actions, const Context & context_)
: context(context_)
{
stack.emplace_back();
Level & prev = stack[stack.size() - 2];
auto & level = stack.emplace_back();
level.actions = std::move(actions);
ColumnsWithTypeAndName all_columns;
NameSet new_names;
for (const auto & [name, node] : level.actions->getIndex())
if (node->type == ActionsDAG::Type::INPUT)
level.inputs.emplace(name);
}
void ScopeStack::pushLevel(const NamesAndTypesList & input_columns)
{
auto & level = stack.emplace_back();
level.actions = std::make_shared<ActionsDAG>();
const auto & prev = stack[stack.size() - 2];
for (const auto & input_column : input_columns)
{
all_columns.emplace_back(nullptr, input_column.type, input_column.name);
new_names.insert(input_column.name);
stack.back().new_columns.insert(input_column.name);
level.actions->addInput(input_column.name, input_column.type);
level.inputs.emplace(input_column.name);
}
const Block & prev_sample_block = prev.actions->getSampleBlock();
for (size_t i = 0, size = prev_sample_block.columns(); i < size; ++i)
const auto & index = level.actions->getIndex();
for (const auto & [name, node] : prev.actions->getIndex())
{
const ColumnWithTypeAndName & col = prev_sample_block.getByPosition(i);
if (!new_names.count(col.name))
all_columns.push_back(col);
if (index.count(name) == 0)
level.actions->addInput({node->column, node->result_type, node->result_name});
}
stack.back().actions = std::make_shared<ExpressionActions>(all_columns, context);
}
size_t ScopeStack::getColumnLevel(const std::string & name)
{
for (int i = static_cast<int>(stack.size()) - 1; i >= 0; --i)
if (stack[i].new_columns.count(name))
for (size_t i = stack.size(); i > 0;)
{
--i;
if (stack[i].inputs.count(name))
return i;
const auto & index = stack[i].actions->getIndex();
auto it = index.find(name);
if (it != index.end() && it->second->type != ActionsDAG::Type::INPUT)
return i;
}
throw Exception("Unknown identifier: " + name, ErrorCodes::UNKNOWN_IDENTIFIER);
}
void ScopeStack::addAction(const ExpressionAction & action)
void ScopeStack::addColumn(ColumnWithTypeAndName column)
{
size_t level = 0;
Names required = action.getNeededColumns();
for (const auto & elem : required)
level = std::max(level, getColumnLevel(elem));
const auto & node = stack[0].actions->addColumn(std::move(column));
Names added;
stack[level].actions->add(action, added);
for (size_t j = 1; j < stack.size(); ++j)
stack[j].actions->addInput({node.column, node.result_type, node.result_name});
}
stack[level].new_columns.insert(added.begin(), added.end());
void ScopeStack::addAlias(const std::string & name, std::string alias)
{
auto level = getColumnLevel(name);
const auto & node = stack[level].actions->addAlias(name, std::move(alias));
for (const auto & elem : added)
{
const ColumnWithTypeAndName & col = stack[level].actions->getSampleBlock().getByName(elem);
for (size_t j = level + 1; j < stack.size(); ++j)
stack[j].actions->addInput(col);
}
for (size_t j = level + 1; j < stack.size(); ++j)
stack[j].actions->addInput({node.column, node.result_type, node.result_name});
}
void ScopeStack::addArrayJoin(const std::string & source_name, std::string result_name, std::string unique_column_name)
{
getColumnLevel(source_name);
if (stack.front().actions->getIndex().count(source_name) == 0)
throw Exception("Expression with arrayJoin cannot depend on lambda argument: " + source_name,
ErrorCodes::BAD_ARGUMENTS);
const auto & node = stack.front().actions->addArrayJoin(source_name, std::move(result_name), std::move(unique_column_name));
for (size_t j = 1; j < stack.size(); ++j)
stack[j].actions->addInput({node.column, node.result_type, node.result_name});
}
void ScopeStack::addActionNoInput(const ExpressionAction & action)
void ScopeStack::addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name,
bool compile_expressions)
{
size_t level = 0;
Names required = action.getNeededColumns();
for (const auto & elem : required)
level = std::max(level, getColumnLevel(elem));
for (const auto & argument : argument_names)
level = std::max(level, getColumnLevel(argument));
Names added;
stack[level].actions->add(action, added);
const auto & node = stack[level].actions->addFunction(function, argument_names, std::move(result_name), compile_expressions);
stack[level].new_columns.insert(added.begin(), added.end());
for (size_t j = level + 1; j < stack.size(); ++j)
stack[j].actions->addInput({node.column, node.result_type, node.result_name});
}
ExpressionActionsPtr ScopeStack::popLevel()
ActionsDAGPtr ScopeStack::popLevel()
{
ExpressionActionsPtr res = stack.back().actions;
auto res = std::move(stack.back());
stack.pop_back();
return res;
return res.actions;
}
const Block & ScopeStack::getSampleBlock() const
std::string ScopeStack::dumpNames() const
{
return stack.back().actions->getSampleBlock();
return stack.back().actions->dumpNames();
}
const ActionsDAG & ScopeStack::getLastActions() const
{
return *stack.back().actions;
}
struct CachedColumnName
......@@ -521,7 +576,7 @@ void ActionsMatcher::visit(const ASTIdentifier & identifier, const ASTPtr & ast,
/// Special check for WITH statement alias. Add alias action to be able to use this alias.
if (identifier.prefer_alias_to_column_name && !identifier.alias.empty())
data.addAction(ExpressionAction::addAliases({{identifier.name, identifier.alias}}));
data.addAlias(identifier.name, identifier.alias);
}
}
......@@ -545,14 +600,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
if (!data.only_consts)
{
String result_name = column_name.get(ast);
/// Here we copy argument because arrayJoin removes source column.
/// It makes possible to remove source column before arrayJoin if it won't be needed anymore.
/// It could have been possible to implement arrayJoin which keeps source column,
/// but in this case it will always be replicated (as many arrays), which is expensive.
String tmp_name = data.getUniqueName("_array_join_" + arg->getColumnName());
data.addActionNoInput(ExpressionAction::copyColumn(arg->getColumnName(), tmp_name));
data.addAction(ExpressionAction::arrayJoin(tmp_name, result_name));
data.addArrayJoin(arg->getColumnName(), result_name);
}
return;
......@@ -577,10 +625,10 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
auto argument_name = node.arguments->children.at(0)->getColumnName();
data.addAction(ExpressionAction::applyFunction(
data.addFunction(
FunctionFactory::instance().get(node.name + "IgnoreSet", data.context),
{ argument_name, argument_name },
column_name.get(ast)));
column_name.get(ast));
}
return;
}
......@@ -652,7 +700,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
column.column = ColumnConst::create(std::move(column_set), 1);
else
column.column = std::move(column_set);
data.addAction(ExpressionAction::addColumn(column));
data.addColumn(column);
}
argument_types.push_back(column.type);
......@@ -668,7 +716,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
ColumnConst::create(std::move(column_string), 1),
std::make_shared<DataTypeString>(),
data.getUniqueName("__" + node.name));
data.addAction(ExpressionAction::addColumn(column));
data.addColumn(column);
argument_types.push_back(column.type);
argument_names.push_back(column.name);
}
......@@ -688,9 +736,11 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
child_column_name = as_literal->unique_column_name;
}
if (data.hasColumn(child_column_name))
const auto & index = data.actions_stack.getLastActions().getIndex();
auto it = index.find(child_column_name);
if (it != index.end())
{
argument_types.push_back(data.getSampleBlock().getByName(child_column_name).type);
argument_types.push_back(it->second->result_type);
argument_names.push_back(child_column_name);
}
else
......@@ -698,7 +748,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
if (data.only_consts)
arguments_present = false;
else
throw Exception("Unknown identifier: " + child_column_name + " there are columns: " + data.getSampleBlock().dumpNames(),
throw Exception("Unknown identifier: " + child_column_name + " there are columns: " + data.actions_stack.dumpNames(),
ErrorCodes::UNKNOWN_IDENTIFIER);
}
}
......@@ -735,7 +785,8 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
data.actions_stack.pushLevel(lambda_arguments);
visit(lambda->arguments->children.at(1), data);
ExpressionActionsPtr lambda_actions = data.actions_stack.popLevel();
auto lambda_dag = data.actions_stack.popLevel();
auto lambda_actions = lambda_dag->buildExpressions(data.context);
String result_name = lambda->arguments->children.at(1)->getColumnName();
lambda_actions->finalize(Names(1, result_name));
......@@ -754,7 +805,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
auto function_capture = std::make_unique<FunctionCaptureOverloadResolver>(
lambda_actions, captured, lambda_arguments, result_type, result_name);
auto function_capture_adapter = std::make_shared<FunctionOverloadResolverAdaptor>(std::move(function_capture));
data.addAction(ExpressionAction::applyFunction(function_capture_adapter, captured, lambda_name));
data.addFunction(function_capture_adapter, captured, lambda_name);
argument_types[i] = std::make_shared<DataTypeFunction>(lambda_type->getArgumentTypes(), result_type);
argument_names[i] = lambda_name;
......@@ -776,7 +827,7 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
if (arguments_present)
{
data.addAction(ExpressionAction::applyFunction(function_builder, argument_names, column_name.get(ast)));
data.addFunction(function_builder, argument_names, column_name.get(ast));
}
}
......@@ -791,8 +842,12 @@ void ActionsMatcher::visit(const ASTLiteral & literal, const ASTPtr & /* ast */,
if (literal.unique_column_name.empty())
{
const auto default_name = literal.getColumnName();
const auto & block = data.getSampleBlock();
const auto * existing_column = block.findByName(default_name);
const auto & index = data.actions_stack.getLastActions().getIndex();
const ActionsDAG::Node * existing_column = nullptr;
auto it = index.find(default_name);
if (it != index.end())
existing_column = it->second;
/*
* To approximate CSE, bind all identical literals to a single temporary
......@@ -828,7 +883,7 @@ void ActionsMatcher::visit(const ASTLiteral & literal, const ASTPtr & /* ast */,
column.column = type->createColumnConst(1, value);
column.type = type;
data.addAction(ExpressionAction::addColumn(column));
data.addColumn(std::move(column));
}
SetPtr ActionsMatcher::makeSet(const ASTFunction & node, Data & data, bool no_subqueries)
......@@ -840,7 +895,6 @@ SetPtr ActionsMatcher::makeSet(const ASTFunction & node, Data & data, bool no_su
const IAST & args = *node.arguments;
const ASTPtr & left_in_operand = args.children.at(0);
const ASTPtr & right_in_operand = args.children.at(1);
const Block & sample_block = data.getSampleBlock();
/// If the subquery or table name for SELECT.
const auto * identifier = right_in_operand->as<ASTIdentifier>();
......@@ -902,9 +956,11 @@ SetPtr ActionsMatcher::makeSet(const ASTFunction & node, Data & data, bool no_su
}
else
{
if (sample_block.has(left_in_operand->getColumnName()))
const auto & last_actions = data.actions_stack.getLastActions();
const auto & index = last_actions.getIndex();
if (index.count(left_in_operand->getColumnName()) != 0)
/// An explicit enumeration of values in parentheses.
return makeExplicitSet(&node, sample_block, false, data.context, data.set_size_limit, data.prepared_sets);
return makeExplicitSet(&node, last_actions, false, data.context, data.set_size_limit, data.prepared_sets);
else
return {};
}
......
......@@ -16,9 +16,15 @@ struct ExpressionAction;
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
class ActionsDAG;
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
/// The case of an explicit enumeration of values.
SetPtr makeExplicitSet(
const ASTFunction * node, const Block & sample_block, bool create_ordered_set,
const ASTFunction * node, const ActionsDAG & actions, bool create_ordered_set,
const Context & context, const SizeLimits & limits, PreparedSets & prepared_sets);
/** Create a block for set from expression.
......@@ -59,8 +65,8 @@ struct ScopeStack
{
struct Level
{
ExpressionActionsPtr actions;
NameSet new_columns;
ActionsDAGPtr actions;
NameSet inputs;
};
using Levels = std::vector<Level>;
......@@ -69,19 +75,25 @@ struct ScopeStack
const Context & context;
ScopeStack(const ExpressionActionsPtr & actions, const Context & context_);
ScopeStack(ActionsDAGPtr actions, const Context & context_);
void pushLevel(const NamesAndTypesList & input_columns);
size_t getColumnLevel(const std::string & name);
void addAction(const ExpressionAction & action);
/// For arrayJoin() to avoid double columns in the input.
void addActionNoInput(const ExpressionAction & action);
void addColumn(ColumnWithTypeAndName column);
void addAlias(const std::string & name, std::string alias);
void addArrayJoin(const std::string & source_name, std::string result_name, std::string unique_column_name);
void addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name,
bool compile_expressions);
ExpressionActionsPtr popLevel();
ActionsDAGPtr popLevel();
const Block & getSampleBlock() const;
const ActionsDAG & getLastActions() const;
std::string dumpNames() const;
};
class ASTIdentifier;
......@@ -117,47 +129,38 @@ public:
int next_unique_suffix;
Data(const Context & context_, SizeLimits set_size_limit_, size_t subquery_depth_,
const NamesAndTypesList & source_columns_, const ExpressionActionsPtr & actions,
const NamesAndTypesList & source_columns_, ActionsDAGPtr actions,
PreparedSets & prepared_sets_, SubqueriesForSets & subqueries_for_sets_,
bool no_subqueries_, bool no_makeset_, bool only_consts_, bool no_storage_or_local_)
: context(context_),
set_size_limit(set_size_limit_),
subquery_depth(subquery_depth_),
source_columns(source_columns_),
prepared_sets(prepared_sets_),
subqueries_for_sets(subqueries_for_sets_),
no_subqueries(no_subqueries_),
no_makeset(no_makeset_),
only_consts(only_consts_),
no_storage_or_local(no_storage_or_local_),
visit_depth(0),
actions_stack(actions, context),
next_unique_suffix(actions_stack.getSampleBlock().columns() + 1)
{}
void updateActions(ExpressionActionsPtr & actions)
bool no_subqueries_, bool no_makeset_, bool only_consts_, bool no_storage_or_local_);
/// Does result of the calculation already exists in the block.
bool hasColumn(const String & column_name) const;
void addColumn(ColumnWithTypeAndName column)
{
actions = actions_stack.popLevel();
actions_stack.addColumn(std::move(column));
}
void addAction(const ExpressionAction & action)
void addAlias(const std::string & name, std::string alias)
{
actions_stack.addAction(action);
actions_stack.addAlias(name, std::move(alias));
}
void addActionNoInput(const ExpressionAction & action)
void addArrayJoin(const std::string & source_name, std::string result_name)
{
actions_stack.addActionNoInput(action);
actions_stack.addArrayJoin(source_name, std::move(result_name), getUniqueName("_array_join_" + source_name));
}
const Block & getSampleBlock() const
void addFunction(const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name)
{
return actions_stack.getSampleBlock();
actions_stack.addFunction(function, argument_names, std::move(result_name),
context.getSettingsRef().compile_expressions);
}
/// Does result of the calculation already exists in the block.
bool hasColumn(const String & columnName) const
ActionsDAGPtr getActions()
{
return actions_stack.getSampleBlock().has(columnName);
return actions_stack.popLevel();
}
/*
......@@ -166,12 +169,11 @@ public:
*/
String getUniqueName(const String & prefix)
{
const auto & block = getSampleBlock();
auto result = prefix;
// First, try the name without any suffix, because it is currently
// used both as a display name and a column id.
while (block.has(result))
while (hasColumn(result))
{
result = prefix + "_" + toString(next_unique_suffix);
++next_unique_suffix;
......
......@@ -35,11 +35,13 @@ ArrayJoinAction::ArrayJoinAction(const NameSet & array_joined_columns_, bool arr
}
void ArrayJoinAction::prepare(Block & sample_block)
void ArrayJoinAction::prepare(ColumnsWithTypeAndName & sample) const
{
for (const auto & name : columns)
for (auto & current : sample)
{
ColumnWithTypeAndName & current = sample_block.getByName(name);
if (columns.count(current.name) == 0)
continue;
const DataTypeArray * array_type = typeid_cast<const DataTypeArray *>(&*current.type);
if (!array_type)
throw Exception("ARRAY JOIN requires array argument", ErrorCodes::TYPE_MISMATCH);
......
......@@ -28,7 +28,7 @@ public:
FunctionOverloadResolverPtr function_builder;
ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, const Context & context);
void prepare(Block & sample_block);
void prepare(ColumnsWithTypeAndName & sample) const;
void execute(Block & block);
};
......
......@@ -13,8 +13,10 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/IFunction.h>
#include <IO/Operators.h>
#include <optional>
#include <Columns/ColumnSet.h>
#include <queue>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
......@@ -186,7 +188,8 @@ void ExpressionAction::prepare(Block & sample_block, const Settings & settings,
size_t result_position = sample_block.columns();
sample_block.insert({nullptr, result_type, result_name});
function = function_base->prepare(sample_block, arguments, result_position);
if (!function)
function = function_base->prepare(sample_block, arguments, result_position);
function->createLowCardinalityResultCache(settings.max_threads);
bool compile_expressions = false;
......@@ -198,7 +201,10 @@ void ExpressionAction::prepare(Block & sample_block, const Settings & settings,
/// so we don't want to unfold non deterministic functions
if (all_const && function_base->isSuitableForConstantFolding() && (!compile_expressions || function_base->isDeterministic()))
{
function->execute(sample_block, arguments, result_position, sample_block.rows(), true);
if (added_column)
sample_block.getByPosition(result_position).column = added_column;
else
function->execute(sample_block, arguments, result_position, sample_block.rows(), true);
/// If the result is not a constant, just in case, we will consider the result as unknown.
ColumnWithTypeAndName & col = sample_block.safeGetByPosition(result_position);
......@@ -586,8 +592,11 @@ void ExpressionActions::addImpl(ExpressionAction action, Names & new_names)
arguments[i] = sample_block.getByName(action.argument_names[i]);
}
action.function_base = action.function_builder->build(arguments);
action.result_type = action.function_base->getReturnType();
if (!action.function_base)
{
action.function_base = action.function_builder->build(arguments);
action.result_type = action.function_base->getReturnType();
}
}
if (action.type == ExpressionAction::ADD_ALIASES)
......@@ -1256,8 +1265,14 @@ void ExpressionActionsChain::addStep()
if (steps.empty())
throw Exception("Cannot add action to empty ExpressionActionsChain", ErrorCodes::LOGICAL_ERROR);
if (auto * step = typeid_cast<ExpressionActionsStep *>(steps.back().get()))
{
if (!step->actions)
step->actions = step->actions_dag->buildExpressions(context);
}
ColumnsWithTypeAndName columns = steps.back()->getResultColumns();
steps.push_back(std::make_unique<ExpressionActionsStep>(std::make_shared<ExpressionActions>(columns, context)));
steps.push_back(std::make_unique<ExpressionActionsStep>(std::make_shared<ActionsDAG>(columns)));
}
void ExpressionActionsChain::finalize()
......@@ -1404,14 +1419,383 @@ void ExpressionActionsChain::JoinStep::finalize(const Names & required_output_)
std::swap(result_columns, new_result_columns);
}
ExpressionActionsPtr & ExpressionActionsChain::Step::actions()
ActionsDAGPtr & ExpressionActionsChain::Step::actions()
{
return typeid_cast<ExpressionActionsStep *>(this)->actions;
return typeid_cast<ExpressionActionsStep *>(this)->actions_dag;
}
const ExpressionActionsPtr & ExpressionActionsChain::Step::actions() const
const ActionsDAGPtr & ExpressionActionsChain::Step::actions() const
{
return typeid_cast<const ExpressionActionsStep *>(this)->actions_dag;
}
ExpressionActionsPtr ExpressionActionsChain::Step::getExpression() const
{
return typeid_cast<const ExpressionActionsStep *>(this)->actions;
}
ActionsDAG::ActionsDAG(const NamesAndTypesList & inputs)
{
for (const auto & input : inputs)
addInput(input.name, input.type);
}
ActionsDAG::ActionsDAG(const ColumnsWithTypeAndName & inputs)
{
for (const auto & input : inputs)
addInput(input);
}
ActionsDAG::Node & ActionsDAG::addNode(Node node, bool can_replace)
{
auto it = index.find(node.result_name);
if (it != index.end() && !can_replace)
throw Exception("Column '" + node.result_name + "' already exists", ErrorCodes::DUPLICATE_COLUMN);
auto & res = nodes.emplace_back(std::move(node));
if (it != index.end())
it->second->renaming_parent = &res;
index[res.result_name] = &res;
return res;
}
ActionsDAG::Node & ActionsDAG::getNode(const std::string & name)
{
auto it = index.find(name);
if (it == index.end())
throw Exception("Unknown identifier: '" + name + "'", ErrorCodes::UNKNOWN_IDENTIFIER);
return *it->second;
}
const ActionsDAG::Node & ActionsDAG::addInput(std::string name, DataTypePtr type)
{
Node node;
node.type = Type::INPUT;
node.result_type = std::move(type);
node.result_name = std::move(name);
return addNode(std::move(node));
}
const ActionsDAG::Node & ActionsDAG::addInput(ColumnWithTypeAndName column)
{
Node node;
node.type = Type::INPUT;
node.result_type = std::move(column.type);
node.result_name = std::move(column.name);
node.column = std::move(column.column);
return addNode(std::move(node));
}
const ActionsDAG::Node & ActionsDAG::addColumn(ColumnWithTypeAndName column)
{
if (!column.column)
throw Exception("Cannot add column " + column.name + " because it is nullptr", ErrorCodes::LOGICAL_ERROR);
Node node;
node.type = Type::COLUMN;
node.result_type = std::move(column.type);
node.result_name = std::move(column.name);
node.column = std::move(column.column);
return addNode(std::move(node));
}
const ActionsDAG::Node & ActionsDAG::addAlias(const std::string & name, std::string alias, bool can_replace)
{
auto & child = getNode(name);
Node node;
node.type = Type::ALIAS;
node.result_type = child.result_type;
node.result_name = std::move(alias);
node.column = child.column;
node.allow_constant_folding = child.allow_constant_folding;
node.children.emplace_back(&child);
return addNode(std::move(node), can_replace);
}
const ActionsDAG::Node & ActionsDAG::addArrayJoin(
const std::string & source_name, std::string result_name, std::string unique_column_name)
{
auto & child = getNode(source_name);
const DataTypeArray * array_type = typeid_cast<const DataTypeArray *>(child.result_type.get());
if (!array_type)
throw Exception("ARRAY JOIN requires array argument", ErrorCodes::TYPE_MISMATCH);
Node node;
node.type = Type::ARRAY_JOIN;
node.result_type = array_type->getNestedType();
node.result_name = std::move(result_name);
node.unique_column_name_for_array_join = std::move(unique_column_name);
node.children.emplace_back(&child);
return addNode(std::move(node));
}
const ActionsDAG::Node & ActionsDAG::addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name,
bool compile_expressions [[maybe_unused]])
{
size_t num_arguments = argument_names.size();
Node node;
node.type = Type::FUNCTION;
node.function_builder = function;
node.children.reserve(num_arguments);
bool all_const = true;
ColumnsWithTypeAndName arguments(num_arguments);
ColumnNumbers argument_numbers(num_arguments);
for (size_t i = 0; i < num_arguments; ++i)
{
auto & child = getNode(argument_names[i]);
node.children.emplace_back(&child);
node.allow_constant_folding = node.allow_constant_folding && child.allow_constant_folding;
ColumnWithTypeAndName argument;
argument.column = child.column;
argument.type = child.result_type;
if (!argument.column || !isColumnConst(*argument.column))
all_const = false;
arguments[i] = std::move(argument);
argument_numbers[i] = i;
}
node.function_base = function->build(arguments);
node.result_type = node.function_base->getReturnType();
Block sample_block(std::move(arguments));
sample_block.insert({nullptr, node.result_type, node.result_name});
node.function = node.function_base->prepare(sample_block, argument_numbers, num_arguments);
bool do_compile_expressions = false;
#if USE_EMBEDDED_COMPILER
do_compile_expressions = compile_expressions;
#endif
/// If all arguments are constants, and function is suitable to be executed in 'prepare' stage - execute function.
/// But if we compile expressions compiled version of this function maybe placed in cache,
/// so we don't want to unfold non deterministic functions
if (all_const && node.function_base->isSuitableForConstantFolding() && (!do_compile_expressions || node.function_base->isDeterministic()))
{
node.function->execute(sample_block, argument_numbers, num_arguments, sample_block.rows(), true);
/// If the result is not a constant, just in case, we will consider the result as unknown.
ColumnWithTypeAndName & col = sample_block.safeGetByPosition(num_arguments);
if (isColumnConst(*col.column))
{
/// All constant (literal) columns in block are added with size 1.
/// But if there was no columns in block before executing a function, the result has size 0.
/// Change the size to 1.
if (col.column->empty())
col.column = col.column->cloneResized(1);
node.column = std::move(col.column);
}
}
/// Some functions like ignore() or getTypeName() always return constant result even if arguments are not constant.
/// We can't do constant folding, but can specify in sample block that function result is constant to avoid
/// unnecessary materialization.
if (!node.column && node.function_base->isSuitableForConstantFolding())
{
if (auto col = node.function_base->getResultIfAlwaysReturnsConstantAndHasArguments(sample_block, argument_numbers))
{
node.column = std::move(col);
node.allow_constant_folding = false;
}
}
if (result_name.empty())
{
result_name = function->getName() + "(";
for (size_t i = 0; i < argument_names.size(); ++i)
{
if (i)
result_name += ", ";
result_name += argument_names[i];
}
result_name += ")";
}
node.result_name = std::move(result_name);
return addNode(std::move(node));
}
ColumnsWithTypeAndName ActionsDAG::getResultColumns() const
{
ColumnsWithTypeAndName result;
result.reserve(index.size());
for (const auto & node : nodes)
if (!node.renaming_parent)
result.emplace_back(node.column, node.result_type, node.result_name);
return result;
}
NamesAndTypesList ActionsDAG::getNamesAndTypesList() const
{
NamesAndTypesList result;
for (const auto & node : nodes)
if (!node.renaming_parent)
result.emplace_back(node.result_name, node.result_type);
return result;
}
Names ActionsDAG::getNames() const
{
Names names;
names.reserve(index.size());
for (const auto & node : nodes)
if (!node.renaming_parent)
names.emplace_back(node.result_name);
return names;
}
std::string ActionsDAG::dumpNames() const
{
WriteBufferFromOwnString out;
for (auto it = nodes.begin(); it != nodes.end(); ++it)
{
if (it != nodes.begin())
out << ", ";
out << it->result_name;
}
return out.str();
}
ExpressionActionsPtr ActionsDAG::buildExpressions(const Context & context)
{
struct Data
{
Node * node = nullptr;
size_t num_created_children = 0;
size_t num_expected_children = 0;
std::vector<Node *> parents;
Node * renamed_child = nullptr;
};
std::vector<Data> data(nodes.size());
std::unordered_map<Node *, size_t> reverse_index;
for (auto & node : nodes)
{
size_t id = reverse_index.size();
data[id].node = &node;
reverse_index[&node] = id;
}
std::queue<Node *> ready_nodes;
std::queue<Node *> ready_array_joins;
for (auto & node : nodes)
{
data[reverse_index[&node]].num_expected_children += node.children.size();
for (const auto & child : node.children)
data[reverse_index[child]].parents.emplace_back(&node);
if (node.renaming_parent)
{
auto & cur = data[reverse_index[node.renaming_parent]];
cur.renamed_child = &node;
cur.num_expected_children += 1;
}
}
for (auto & node : nodes)
{
if (node.children.empty() && data[reverse_index[&node]].renamed_child == nullptr)
ready_nodes.emplace(&node);
}
auto update_parent = [&](Node * parent)
{
auto & cur = data[reverse_index[parent]];
++cur.num_created_children;
if (cur.num_created_children == cur.num_expected_children)
{
auto & push_stack = parent->type == Type::ARRAY_JOIN ? ready_array_joins : ready_nodes;
push_stack.push(parent);
}
};
auto expressions = std::make_shared<ExpressionActions>(NamesAndTypesList(), context);
while (!ready_nodes.empty() || !ready_array_joins.empty())
{
auto & stack = ready_nodes.empty() ? ready_array_joins : ready_nodes;
Node * node = stack.front();
stack.pop();
Names argument_names;
for (const auto & child : node->children)
argument_names.emplace_back(child->result_name);
auto & cur = data[reverse_index[node]];
switch (node->type)
{
case Type::INPUT:
expressions->addInput({node->column, node->result_type, node->result_name});
break;
case Type::COLUMN:
expressions->add(ExpressionAction::addColumn({node->column, node->result_type, node->result_name}));
break;
case Type::ALIAS:
expressions->add(ExpressionAction::copyColumn(argument_names.at(0), node->result_name, cur.renamed_child != nullptr));
break;
case Type::ARRAY_JOIN:
/// Here we copy argument because arrayJoin removes source column.
/// It makes possible to remove source column before arrayJoin if it won't be needed anymore.
/// It could have been possible to implement arrayJoin which keeps source column,
/// but in this case it will always be replicated (as many arrays), which is expensive.
expressions->add(ExpressionAction::copyColumn(argument_names.at(0), node->unique_column_name_for_array_join));
expressions->add(ExpressionAction::arrayJoin(node->unique_column_name_for_array_join, node->result_name));
break;
case Type::FUNCTION:
{
ExpressionAction action;
action.type = ExpressionAction::APPLY_FUNCTION;
action.result_name = node->result_name;
action.result_type = node->result_type;
action.function_builder = node->function_builder;
action.function_base = node->function_base;
action.function = node->function;
action.argument_names = std::move(argument_names);
action.added_column = node->column;
expressions->add(action);
break;
}
}
for (const auto & parent : cur.parents)
update_parent(parent);
if (node->renaming_parent)
update_parent(node->renaming_parent);
}
return expressions;
}
}
......@@ -140,6 +140,89 @@ private:
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
class ActionsDAG
{
public:
enum class Type
{
/// Column which must be in input.
INPUT,
/// Constant column with known value.
COLUMN,
/// Another one name for column.
ALIAS,
/// Function arrayJoin. Specially separated because it changes the number of rows.
ARRAY_JOIN,
FUNCTION,
};
struct Node
{
std::vector<Node *> children;
/// This field is filled if current node is replaced by existing node with the same name.
Node * renaming_parent = nullptr;
Type type;
std::string result_name;
DataTypePtr result_type;
std::string unique_column_name_for_array_join;
FunctionOverloadResolverPtr function_builder;
/// Can be used after action was added to ExpressionActions if we want to get function signature or properties like monotonicity.
FunctionBasePtr function_base;
/// Prepared function which is used in function execution.
ExecutableFunctionPtr function;
/// For COLUMN node and propagated constants.
ColumnPtr column;
/// Some functions like `ignore()` always return constant but can't be replaced by constant it.
/// We calculate such constants in order to avoid unnecessary materialization, but prohibit it's folding.
bool allow_constant_folding = true;
};
using Index = std::unordered_map<std::string_view, Node *>;
private:
std::list<Node> nodes;
Index index;
public:
ActionsDAG() = default;
ActionsDAG(const ActionsDAG &) = delete;
ActionsDAG & operator=(const ActionsDAG &) = delete;
ActionsDAG(const NamesAndTypesList & inputs);
ActionsDAG(const ColumnsWithTypeAndName & inputs);
const Index & getIndex() const { return index; }
ColumnsWithTypeAndName getResultColumns() const;
NamesAndTypesList getNamesAndTypesList() const;
Names getNames() const;
std::string dumpNames() const;
const Node & addInput(std::string name, DataTypePtr type);
const Node & addInput(ColumnWithTypeAndName column);
const Node & addColumn(ColumnWithTypeAndName column);
const Node & addAlias(const std::string & name, std::string alias, bool can_replace = false);
const Node & addArrayJoin(const std::string & source_name, std::string result_name, std::string unique_column_name);
const Node & addFunction(
const FunctionOverloadResolverPtr & function,
const Names & argument_names,
std::string result_name,
bool compile_expressions);
ExpressionActionsPtr buildExpressions(const Context & context);
private:
Node & addNode(Node node, bool can_replace = false);
Node & getNode(const std::string & name);
};
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
/** Contains a sequence of actions on the block.
*/
class ExpressionActions
......@@ -287,17 +370,19 @@ struct ExpressionActionsChain
virtual std::string dump() const = 0;
/// Only for ExpressionActionsStep
ExpressionActionsPtr & actions();
const ExpressionActionsPtr & actions() const;
ActionsDAGPtr & actions();
const ActionsDAGPtr & actions() const;
ExpressionActionsPtr getExpression() const;
};
struct ExpressionActionsStep : public Step
{
ActionsDAGPtr actions_dag;
ExpressionActionsPtr actions;
explicit ExpressionActionsStep(ExpressionActionsPtr actions_, Names required_output_ = Names())
explicit ExpressionActionsStep(ActionsDAGPtr actions_, Names required_output_ = Names())
: Step(std::move(required_output_))
, actions(std::move(actions_))
, actions_dag(std::move(actions_))
{
}
......@@ -382,7 +467,9 @@ struct ExpressionActionsChain
throw Exception("Empty ExpressionActionsChain", ErrorCodes::LOGICAL_ERROR);
}
return steps.back()->actions();
auto * step = typeid_cast<ExpressionActionsStep *>(steps.back().get());
step->actions = step->actions_dag->buildExpressions(context);
return step->actions;
}
Step & getLastStep()
......@@ -396,7 +483,7 @@ struct ExpressionActionsChain
Step & lastStep(const NamesAndTypesList & columns)
{
if (steps.empty())
steps.emplace_back(std::make_unique<ExpressionActionsStep>(std::make_shared<ExpressionActions>(columns, context)));
steps.emplace_back(std::make_unique<ExpressionActionsStep>(std::make_shared<ActionsDAG>(columns)));
return *steps.back();
}
......
......@@ -153,38 +153,51 @@ void ExpressionAnalyzer::analyzeAggregation()
auto * select_query = query->as<ASTSelectQuery>();
ExpressionActionsPtr temp_actions = std::make_shared<ExpressionActions>(sourceColumns(), context);
auto temp_actions = std::make_shared<ActionsDAG>(sourceColumns());
if (select_query)
{
NamesAndTypesList array_join_columns;
columns_after_array_join = sourceColumns();
bool is_array_join_left;
if (ASTPtr array_join_expression_list = select_query->arrayJoinExpressionList(is_array_join_left))
{
getRootActionsNoMakeSet(array_join_expression_list, true, temp_actions, false);
if (auto array_join = addMultipleArrayJoinAction(temp_actions, is_array_join_left))
auto array_join = addMultipleArrayJoinAction(temp_actions, is_array_join_left);
auto sample_columns = temp_actions->getResultColumns();
array_join->prepare(sample_columns);
temp_actions = std::make_shared<ActionsDAG>(sample_columns);
NamesAndTypesList new_columns_after_array_join;
NameSet added_columns;
for (auto & column : temp_actions->getResultColumns())
{
auto sample_block = temp_actions->getSampleBlock();
array_join->prepare(sample_block);
temp_actions = std::make_shared<ExpressionActions>(sample_block.getColumnsWithTypeAndName(), context);
if (syntax->array_join_result_to_source.count(column.name))
{
new_columns_after_array_join.emplace_back(column.name, column.type);
added_columns.emplace(column.name);
}
}
for (auto & column : temp_actions->getSampleBlock().getNamesAndTypesList())
if (syntax->array_join_result_to_source.count(column.name))
array_join_columns.emplace_back(column);
for (auto & column : columns_after_array_join)
if (added_columns.count(column.name) == 0)
new_columns_after_array_join.emplace_back(column.name, column.type);
columns_after_array_join.swap(new_columns_after_array_join);
}
columns_after_array_join = sourceColumns();
columns_after_array_join.insert(columns_after_array_join.end(), array_join_columns.begin(), array_join_columns.end());
const ASTTablesInSelectQueryElement * join = select_query->join();
if (join)
{
getRootActionsNoMakeSet(analyzedJoin().leftKeysList(), true, temp_actions, false);
auto sample_columns = temp_actions->getSampleBlock().getColumnsWithTypeAndName();
auto sample_columns = temp_actions->getResultColumns();
analyzedJoin().addJoinedColumnsAndCorrectNullability(sample_columns);
temp_actions = std::make_shared<ExpressionActions>(sample_columns, context);
temp_actions = std::make_shared<ActionsDAG>(sample_columns);
}
columns_after_join = columns_after_array_join;
......@@ -212,15 +225,16 @@ void ExpressionAnalyzer::analyzeAggregation()
getRootActionsNoMakeSet(group_asts[i], true, temp_actions, false);
const auto & column_name = group_asts[i]->getColumnName();
const auto & block = temp_actions->getSampleBlock();
const auto & index = temp_actions->getIndex();
if (!block.has(column_name))
auto it = index.find(column_name);
if (it == index.end())
throw Exception("Unknown identifier (in GROUP BY): " + column_name, ErrorCodes::UNKNOWN_IDENTIFIER);
const auto & col = block.getByName(column_name);
const auto & node = it->second;
/// Constant expressions have non-null column pointer at this stage.
if (col.column && isColumnConst(*col.column))
if (node->column && isColumnConst(*node->column))
{
/// But don't remove last key column if no aggregate functions, otherwise aggregation will not work.
if (!aggregate_descriptions.empty() || size > 1)
......@@ -235,7 +249,7 @@ void ExpressionAnalyzer::analyzeAggregation()
}
}
NameAndTypePair key{column_name, col.type};
NameAndTypePair key{column_name, node->result_type};
/// Aggregation keys are uniqued.
if (!unique_keys.count(key.name))
......@@ -256,14 +270,14 @@ void ExpressionAnalyzer::analyzeAggregation()
}
}
else
aggregated_columns = temp_actions->getSampleBlock().getNamesAndTypesList();
aggregated_columns = temp_actions->getNamesAndTypesList();
for (const auto & desc : aggregate_descriptions)
aggregated_columns.emplace_back(desc.column_name, desc.function->getReturnType());
}
else
{
aggregated_columns = temp_actions->getSampleBlock().getNamesAndTypesList();
aggregated_columns = temp_actions->getNamesAndTypesList();
}
}
......@@ -362,12 +376,11 @@ void SelectQueryExpressionAnalyzer::makeSetsForIndex(const ASTPtr & node)
}
else
{
ExpressionActionsPtr temp_actions = std::make_shared<ExpressionActions>(columns_after_join, context);
auto temp_actions = std::make_shared<ActionsDAG>(columns_after_join);
getRootActions(left_in_operand, true, temp_actions);
Block sample_block_with_calculated_columns = temp_actions->getSampleBlock();
if (sample_block_with_calculated_columns.has(left_in_operand->getColumnName()))
makeExplicitSet(func, sample_block_with_calculated_columns, true, context,
if (temp_actions->getIndex().count(left_in_operand->getColumnName()) != 0)
makeExplicitSet(func, *temp_actions, true, context,
settings.size_limits_for_set, prepared_sets);
}
}
......@@ -375,29 +388,29 @@ void SelectQueryExpressionAnalyzer::makeSetsForIndex(const ASTPtr & node)
}
void ExpressionAnalyzer::getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts)
void ExpressionAnalyzer::getRootActions(const ASTPtr & ast, bool no_subqueries, ActionsDAGPtr & actions, bool only_consts)
{
LogAST log;
ActionsVisitor::Data visitor_data(context, settings.size_limits_for_set, subquery_depth,
sourceColumns(), actions, prepared_sets, subqueries_for_sets,
sourceColumns(), std::move(actions), prepared_sets, subqueries_for_sets,
no_subqueries, false, only_consts, !isRemoteStorage());
ActionsVisitor(visitor_data, log.stream()).visit(ast);
visitor_data.updateActions(actions);
actions = visitor_data.getActions();
}
void ExpressionAnalyzer::getRootActionsNoMakeSet(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts)
void ExpressionAnalyzer::getRootActionsNoMakeSet(const ASTPtr & ast, bool no_subqueries, ActionsDAGPtr & actions, bool only_consts)
{
LogAST log;
ActionsVisitor::Data visitor_data(context, settings.size_limits_for_set, subquery_depth,
sourceColumns(), actions, prepared_sets, subqueries_for_sets,
sourceColumns(), std::move(actions), prepared_sets, subqueries_for_sets,
no_subqueries, true, only_consts, !isRemoteStorage());
ActionsVisitor(visitor_data, log.stream()).visit(ast);
visitor_data.updateActions(actions);
actions = visitor_data.getActions();
}
bool ExpressionAnalyzer::makeAggregateDescriptions(ExpressionActionsPtr & actions)
bool ExpressionAnalyzer::makeAggregateDescriptions(ActionsDAGPtr & actions)
{
for (const ASTFunction * node : aggregates())
{
......@@ -412,7 +425,7 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ExpressionActionsPtr & action
{
getRootActionsNoMakeSet(arguments[i], true, actions);
const std::string & name = arguments[i]->getColumnName();
types[i] = actions->getSampleBlock().getByName(name).type;
types[i] = actions->getIndex().find(name)->second->result_type;
aggregate.argument_names[i] = name;
}
......@@ -443,14 +456,14 @@ const ASTSelectQuery * SelectQueryExpressionAnalyzer::getAggregatingQuery() cons
}
/// "Big" ARRAY JOIN.
ArrayJoinActionPtr ExpressionAnalyzer::addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool array_join_is_left) const
ArrayJoinActionPtr ExpressionAnalyzer::addMultipleArrayJoinAction(ActionsDAGPtr & actions, bool array_join_is_left) const
{
NameSet result_columns;
for (const auto & result_source : syntax->array_join_result_to_source)
{
/// Assign new names to columns, if needed.
if (result_source.first != result_source.second)
actions->add(ExpressionAction::copyColumn(result_source.second, result_source.first));
actions->addAlias(result_source.second, result_source.first);
/// Make ARRAY JOIN (replace arrays with their insides) for the columns in these new names.
result_columns.insert(result_source.first);
......@@ -472,8 +485,8 @@ ArrayJoinActionPtr SelectQueryExpressionAnalyzer::appendArrayJoin(ExpressionActi
getRootActions(array_join_expression_list, only_types, step.actions());
before_array_join = chain.getLastActions();
auto array_join = addMultipleArrayJoinAction(step.actions(), is_array_join_left);
before_array_join = chain.getLastActions();
chain.steps.push_back(std::make_unique<ExpressionActionsChain::ArrayJoinStep>(
array_join, step.getResultColumns()));
......@@ -615,13 +628,14 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(const ASTTablesInSelectQuer
return subquery_for_join.join;
}
bool SelectQueryExpressionAnalyzer::appendPrewhere(
ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendPrewhere(
ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns)
{
const auto * select_query = getSelectQuery();
ExpressionActionsPtr prewhere_actions;
if (!select_query->prewhere())
return false;
return prewhere_actions;
auto & step = chain.lastStep(sourceColumns());
getRootActions(select_query->prewhere(), only_types, step.actions());
......@@ -629,15 +643,16 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere(
step.required_output.push_back(prewhere_column_name);
step.can_remove_required_output.push_back(true);
auto filter_type = step.actions()->getSampleBlock().getByName(prewhere_column_name).type;
auto filter_type = step.actions()->getIndex().find(prewhere_column_name)->second->result_type;
if (!filter_type->canBeUsedInBooleanContext())
throw Exception("Invalid type for filter in PREWHERE: " + filter_type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER);
{
/// Remove unused source_columns from prewhere actions.
auto tmp_actions = std::make_shared<ExpressionActions>(sourceColumns(), context);
getRootActions(select_query->prewhere(), only_types, tmp_actions);
auto tmp_actions_dag = std::make_shared<ActionsDAG>(sourceColumns());
getRootActions(select_query->prewhere(), only_types, tmp_actions_dag);
auto tmp_actions = tmp_actions_dag->buildExpressions(context);
tmp_actions->finalize({prewhere_column_name});
auto required_columns = tmp_actions->getRequiredColumns();
NameSet required_source_columns(required_columns.begin(), required_columns.end());
......@@ -653,7 +668,7 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere(
}
}
auto names = step.actions()->getSampleBlock().getNames();
auto names = step.actions()->getNames();
NameSet name_set(names.begin(), names.end());
for (const auto & column : sourceColumns())
......@@ -661,7 +676,8 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere(
name_set.erase(column.name);
Names required_output(name_set.begin(), name_set.end());
step.actions()->finalize(required_output);
prewhere_actions = chain.getLastActions();
prewhere_actions->finalize(required_output);
}
{
......@@ -672,8 +688,8 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere(
/// 2. Store side columns which were calculated during prewhere actions execution if they are used.
/// Example: select F(A) prewhere F(A) > 0. F(A) can be saved from prewhere step.
/// 3. Check if we can remove filter column at prewhere step. If we can, action will store single REMOVE_COLUMN.
ColumnsWithTypeAndName columns = step.actions()->getSampleBlock().getColumnsWithTypeAndName();
auto required_columns = step.actions()->getRequiredColumns();
ColumnsWithTypeAndName columns = prewhere_actions->getSampleBlock().getColumnsWithTypeAndName();
auto required_columns = prewhere_actions->getRequiredColumns();
NameSet prewhere_input_names(required_columns.begin(), required_columns.end());
NameSet unused_source_columns;
......@@ -687,11 +703,13 @@ bool SelectQueryExpressionAnalyzer::appendPrewhere(
}
chain.steps.emplace_back(std::make_unique<ExpressionActionsChain::ExpressionActionsStep>(
std::make_shared<ExpressionActions>(std::move(columns), context)));
std::make_shared<ActionsDAG>(std::move(columns))));
chain.steps.back()->additional_input = std::move(unused_source_columns);
chain.getLastActions();
chain.addStep();
}
return true;
return prewhere_actions;
}
void SelectQueryExpressionAnalyzer::appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name)
......@@ -699,7 +717,8 @@ void SelectQueryExpressionAnalyzer::appendPreliminaryFilter(ExpressionActionsCha
ExpressionActionsChain::Step & step = chain.lastStep(sourceColumns());
// FIXME: assert(filter_info);
step.actions() = std::move(actions);
auto * expression_step = typeid_cast<ExpressionActionsChain::ExpressionActionsStep *>(&step);
expression_step->actions = std::move(actions);
step.required_output.push_back(std::move(column_name));
step.can_remove_required_output = {true};
......@@ -721,7 +740,7 @@ bool SelectQueryExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain,
getRootActions(select_query->where(), only_types, step.actions());
auto filter_type = step.actions()->getSampleBlock().getByName(where_column_name).type;
auto filter_type = step.actions()->getIndex().find(where_column_name)->second->result_type;
if (!filter_type->canBeUsedInBooleanContext())
throw Exception("Invalid type for filter in WHERE: " + filter_type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_COLUMN_FOR_FILTER);
......@@ -750,8 +769,9 @@ bool SelectQueryExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain
{
for (auto & child : asts)
{
group_by_elements_actions.emplace_back(std::make_shared<ExpressionActions>(columns_after_join, context));
getRootActions(child, only_types, group_by_elements_actions.back());
auto actions_dag = std::make_shared<ActionsDAG>(columns_after_join);
getRootActions(child, only_types, actions_dag);
group_by_elements_actions.emplace_back(actions_dag->buildExpressions(context));
}
}
......@@ -838,8 +858,9 @@ bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain
{
for (auto & child : select_query->orderBy()->children)
{
order_by_elements_actions.emplace_back(std::make_shared<ExpressionActions>(columns_after_join, context));
getRootActions(child, only_types, order_by_elements_actions.back());
auto actions_dag = std::make_shared<ActionsDAG>(columns_after_join);
getRootActions(child, only_types, actions_dag);
order_by_elements_actions.emplace_back(actions_dag->buildExpressions(context));
}
}
return true;
......@@ -873,7 +894,7 @@ bool SelectQueryExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain
return true;
}
void SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActionsChain & chain) const
ExpressionActionsPtr SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActionsChain & chain) const
{
const auto * select_query = getSelectQuery();
......@@ -919,7 +940,9 @@ void SelectQueryExpressionAnalyzer::appendProjectResult(ExpressionActionsChain &
}
}
step.actions()->add(ExpressionAction::project(result_columns));
auto actions = chain.getLastActions();
actions->add(ExpressionAction::project(result_columns));
return actions;
}
......@@ -933,7 +956,7 @@ void ExpressionAnalyzer::appendExpression(ExpressionActionsChain & chain, const
ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool project_result)
{
ExpressionActionsPtr actions = std::make_shared<ExpressionActions>(aggregated_columns, context);
auto actions_dag = std::make_shared<ActionsDAG>(aggregated_columns);
NamesWithAliases result_columns;
Names result_names;
......@@ -954,9 +977,11 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje
alias = name;
result_columns.emplace_back(name, alias);
result_names.push_back(alias);
getRootActions(ast, false, actions);
getRootActions(ast, false, actions_dag);
}
auto actions = actions_dag->buildExpressions(context);
if (add_aliases)
{
if (project_result)
......@@ -980,10 +1005,10 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje
ExpressionActionsPtr ExpressionAnalyzer::getConstActions()
{
ExpressionActionsPtr actions = std::make_shared<ExpressionActions>(NamesAndTypesList(), context);
auto actions = std::make_shared<ActionsDAG>(NamesAndTypesList());
getRootActions(query, true, actions, true);
return actions;
return actions->buildExpressions(context);
}
ExpressionActionsPtr SelectQueryExpressionAnalyzer::simpleSelectActions()
......@@ -1064,10 +1089,9 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
query_analyzer.appendPreliminaryFilter(chain, filter_info->actions, filter_info->column_name);
}
if (query_analyzer.appendPrewhere(chain, !first_stage, additional_required_columns_after_prewhere))
if (auto actions = query_analyzer.appendPrewhere(chain, !first_stage, additional_required_columns_after_prewhere))
{
prewhere_info = std::make_shared<PrewhereInfo>(
chain.steps.front()->actions(), query.prewhere()->getColumnName());
prewhere_info = std::make_shared<PrewhereInfo>(actions, query.prewhere()->getColumnName());
if (allowEarlyConstantFolding(*prewhere_info->prewhere_actions, settings))
{
......@@ -1081,7 +1105,6 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
prewhere_constant_filter_description = ConstantFilterDescription(*column_elem.column);
}
}
chain.addStep();
}
array_join = query_analyzer.appendArrayJoin(chain, before_array_join, only_types || !first_stage);
......@@ -1167,8 +1190,7 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
chain.addStep();
}
query_analyzer.appendProjectResult(chain);
final_projection = chain.getLastActions();
final_projection = query_analyzer.appendProjectResult(chain);
finalize_chain(chain);
}
......
......@@ -37,6 +37,9 @@ using StorageMetadataPtr = std::shared_ptr<const StorageInMemoryMetadata>;
class ArrayJoinAction;
using ArrayJoinActionPtr = std::shared_ptr<ArrayJoinAction>;
class ActionsDAG;
using ActionsDAGPtr = std::shared_ptr<ActionsDAG>;
/// Create columns in block or return false if not possible
bool sanitizeBlock(Block & block, bool throw_if_cannot_create_column = false);
......@@ -137,15 +140,15 @@ protected:
/// Find global subqueries in the GLOBAL IN/JOIN sections. Fills in external_tables.
void initGlobalSubqueriesAndExternalTables(bool do_global);
ArrayJoinActionPtr addMultipleArrayJoinAction(ExpressionActionsPtr & actions, bool is_left) const;
ArrayJoinActionPtr addMultipleArrayJoinAction(ActionsDAGPtr & actions, bool is_left) const;
void getRootActions(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts = false);
void getRootActions(const ASTPtr & ast, bool no_subqueries, ActionsDAGPtr & actions, bool only_consts = false);
/** Similar to getRootActions but do not make sets when analyzing IN functions. It's used in
* analyzeAggregation which happens earlier than analyzing PREWHERE and WHERE. If we did, the
* prepared sets would not be applicable for MergeTree index optimization.
*/
void getRootActionsNoMakeSet(const ASTPtr & ast, bool no_subqueries, ExpressionActionsPtr & actions, bool only_consts = false);
void getRootActionsNoMakeSet(const ASTPtr & ast, bool no_subqueries, ActionsDAGPtr & actions, bool only_consts = false);
/** Add aggregation keys to aggregation_keys, aggregate functions to aggregate_descriptions,
* Create a set of columns aggregated_columns resulting after the aggregation, if any,
......@@ -153,7 +156,7 @@ protected:
* Set has_aggregation = true if there is GROUP BY or at least one aggregate function.
*/
void analyzeAggregation();
bool makeAggregateDescriptions(ExpressionActionsPtr & actions);
bool makeAggregateDescriptions(ActionsDAGPtr & actions);
const ASTSelectQuery * getSelectQuery() const;
......@@ -267,7 +270,7 @@ public:
/// These appends are public only for tests
void appendSelect(ExpressionActionsChain & chain, bool only_types);
/// Deletes all columns except mentioned by SELECT, arranges the remaining columns and renames them to aliases.
void appendProjectResult(ExpressionActionsChain & chain) const;
ExpressionActionsPtr appendProjectResult(ExpressionActionsChain & chain) const;
private:
StorageMetadataPtr metadata_snapshot;
......@@ -317,7 +320,7 @@ private:
void appendPreliminaryFilter(ExpressionActionsChain & chain, ExpressionActionsPtr actions, String column_name);
/// remove_filter is set in ExpressionActionsChain::finalize();
/// Columns in `additional_required_columns` will not be removed (they can be used for e.g. sampling or FINAL modifier).
bool appendPrewhere(ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns);
ExpressionActionsPtr appendPrewhere(ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns);
bool appendWhere(ExpressionActionsChain & chain, bool only_types);
bool appendGroupBy(ExpressionActionsChain & chain, bool only_types, bool optimize_aggregation_in_order, ManyExpressionActions &);
void appendAggregateFunctionsArguments(ExpressionActionsChain & chain, bool only_types);
......
......@@ -619,19 +619,20 @@ ASTPtr MutationsInterpreter::prepareInterpreterSelectQuery(std::vector<Stage> &
for (const auto & kv : stage.column_to_updated)
{
actions_chain.getLastActions()->add(ExpressionAction::copyColumn(
kv.second->getColumnName(), kv.first, /* can_replace = */ true));
actions_chain.getLastStep().actions()->addAlias(
kv.second->getColumnName(), kv.first, /* can_replace = */ true);
}
}
/// Remove all intermediate columns.
actions_chain.addStep();
actions_chain.getLastStep().required_output.assign(stage.output_columns.begin(), stage.output_columns.end());
actions_chain.getLastActions();
actions_chain.finalize();
/// Propagate information about columns needed as input.
for (const auto & column : actions_chain.steps.front()->actions()->getRequiredColumnsWithTypes())
for (const auto & column : actions_chain.steps.front()->getRequiredColumns())
prepared_stages[i - 1].output_columns.insert(column.name);
}
......@@ -675,12 +676,12 @@ QueryPipelinePtr MutationsInterpreter::addStreamsForLaterStages(const std::vecto
if (i < stage.filter_column_names.size())
{
/// Execute DELETEs.
plan.addStep(std::make_unique<FilterStep>(plan.getCurrentDataStream(), step->actions(), stage.filter_column_names[i], false));
plan.addStep(std::make_unique<FilterStep>(plan.getCurrentDataStream(), step->getExpression(), stage.filter_column_names[i], false));
}
else
{
/// Execute UPDATE or final projection.
plan.addStep(std::make_unique<ExpressionStep>(plan.getCurrentDataStream(), step->actions()));
plan.addStep(std::make_unique<ExpressionStep>(plan.getCurrentDataStream(), step->getExpression()));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册